## Data_Setup

In [2]:
## data-setup
%%writefile module/data_setup.py

import os
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def create_dataloader(
        train_dir : str,
        test_dir : str,
        transform : transforms.Compose,
        batch_size : int,
        num_workers : int = os.cpu_count()):

    train_data = datasets.ImageFolder(train_dir, transform = transform)
    test_data = datasets.ImageFolder(test_dir, transfomr = transform)

    class_names = train_data.classes

    train_dataloader = DataLoader(
        train_data,
        batch_size = batch_size,
        shuffle = True,
        num_workers = num_workers,
        pin_memory = True
    )

    test_dataloader = DataLoader(
        test_data,
        batch_size = batch_size,
        shuffle = False,
        num_workers = num_workers,
        pin_memory = True
    )

    return train_dataloader, test_dataloader, class_names


Writing module/data_setup.py


## Engine setup

In [5]:
%%writefile module/engine_setup.py


import torch
from tqdm.auto import tqdm
from typing import Dict, List, Tuple

# train_step()
def train_step(model : torch.nn.Module,
               dataloader : torch.utils.data.DataLoader,
               loss_fn : torch.nn.Module,
               optimizer : torch.optim.Optimizer,
               device : torch.device) -> Tuple[float, float]:


    model.train() # Train mode
    train_loss, train_acc = 0, 0

    for batch, (x_train,y_train) in enumerate(dataloader):

        x_train, y_train = x_train.to(device), y_train.to(device)

        # 1. Forward pass
        y_pred = model(x_train) # logits

        # 2. loss
        loss = loss_fn(y_pred, y_train)
        train_loss += loss.item()

        # 3. Optimizer
        optimizer.zero_grad()

        # 4. backporpagation
        loss.backward()

        # 5. step optimizer
        optimizer.step()

        # Calculate accuracy
        y_pred_class = torch.argmax(torch.softmax(y_pred,dim = 1), dim = 1)
        train_acc += (y_pred_class == y_pred).sum().item()/len(y_pred)

    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc


# Test_step()
def test_step(model : torch.nn.Module,
              dataloader : torch.utils.data.DataLoader,
              loss_fn : torch.nn.Module,
              device : torch.device) -> Tuple[float, float]:


    model.eval() # Eval mode
    with torch.inference_mode():
        test_loss, test_acc = 0, 0

        for batch, (x_test, y_tess) in enumerate(dataloader):

            x_test, y_test = x_test.to(device), y_test.to(device)

            # 1. forward pass
            test_pred = model(x_test)

            # 2. loss
            loss = loss_fn(test_pred,y_test)
            test_loss += loss.item()

            # calculate accuracy
            test_pred_labels = test_pred.argmax(dim =  1)
            test_acc += (test_pred_labels == y_test).sum().item()/len(test_pred_labels)

    test_loss += test_loss / len(dataloader)
    test_acc += test_acc / len(dataloader)
    return test_loss, test_acc


# Train
def train(model : torch.nn.Module,
          train_dataloader : torch.utils.data.DataLoader,
          test_dataloader : torch.utils.data.DataLoader,
          optimizer : torch.optim.Optimizer,
          loss_fn : torch.nn.Module,
          epochs : int,
          device : torch.device) -> Dict[str, list]:

    results = {"train_loss" : [],
                "train_acc" : [],
                "test_loss" : [],
                "test_acc" : []}

    for epoch in tqdm(range(epochs)):

        train_loss, train_acc = train_step(model = model,
                                            dataloader = train_dataloader,
                                            loss_fn = loss_fn,
                                            optimizer = optimizer,
                                            device = device)

        test_loss, test_acc = test_step(model = model,
                                        dataloader = test_dataloader,
                                        loss_fn = loss_fn,
                                        device = device)


        print(
            f"Epoch: {epoch+1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"test_loss: {test_loss:.4f} | "
            f"test_acc: {test_acc:.4f}"
        )

        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

        return results

Overwriting module/engine_setup.py
