# Train File
This file's purpose is to train the model(s) to then prepare them for evaluation and prediction

In [1]:
import torch 
from torch import nn

torch.__version__

'2.2.2'

In [2]:
import torchvision

torchvision.__version__

'0.17.2'

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [4]:
from importnb import Notebook
with Notebook():
    import dataloader, model

../data/intel-image-classification.zip already exists. Skipping download.
../data already exists and is not empty. Skipping extraction.
There are 3 directories and 0 images in '../data/intel-image-classification'.
There are 1 directories and 0 images in '../data/intel-image-classification/seg_test'.
There are 6 directories and 0 images in '../data/intel-image-classification/seg_test/seg_test'.
There are 0 directories and 474 images in '../data/intel-image-classification/seg_test/seg_test/forest'.
There are 0 directories and 437 images in '../data/intel-image-classification/seg_test/seg_test/buildings'.
There are 0 directories and 553 images in '../data/intel-image-classification/seg_test/seg_test/glacier'.
There are 0 directories and 501 images in '../data/intel-image-classification/seg_test/seg_test/street'.
There are 0 directories and 525 images in '../data/intel-image-classification/seg_test/seg_test/mountain'.
There are 0 directories and 510 images in '../data/intel-image-classific

In [5]:
from helper_functions import accuracy_fn
from tqdm.auto import tqdm
from timeit import default_timer as timer 

In [6]:
# Importing all relevant variables
BATCH_SIZE = model.BATCH_SIZE
ResidualBlock = model.ResidualBlock
RestNet = model.RestNet
train_dataloader = dataloader.train_dataloader
test_dataloader = dataloader.test_dataloader
model_0 = model.model_0

In [7]:
def print_train_time(start: float, end: float, device: torch.device = None):
    total_time = end - start
    print(f"Train time on {device}: {total_time:.3f} seconds")
    return total_time

In [8]:
def train_step (model: torch.nn.Module,
                data_loader: torch.utils.data.DataLoader,
                loss_fn: torch.nn.Module,
                optimizer: torch.optim.Optimizer,
                accuracy_fn,
                device: torch.device = device):
    
    train_loss, train_accuracy = 0, 0

    model.train()

    for batch, (X, y) in enumerate(data_loader):

        X, y = X.to(device), y.to(device)
        y_pred = model(X) # Forward pass, outputs raw logits
        loss = loss_fn(y_pred, y)
        train_loss += loss
        train_accuracy += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1)) # transforms from logits to labels
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step()
        
    # We adjust to get these metrics per batch, and not the total
    train_loss /= len(data_loader)
    train_accuracy /= len(data_loader)
    print(f"Train loss: {train_loss:.5f} | Train acc: {train_accuracy:.2f}%]\n")


In [9]:
def test_step(model: torch.nn.Module,
               data_loader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               accuracy_fn,
               device: torch.device = device):
    
    test_loss, test_accuracy = 0, 0

    model.eval()  # Turns off different settings in the model not needed for evaluation/testing (dropout/batch norm layers)

    with torch.inference_mode(): # Turns off gradient tracking and a couple more things behind the scenes
        for batch, (X, y) in enumerate(dataloader):

            X, y = X.to(device), y.to(device)
            test_pred_logits = model(X)
            loss = loss_fn(test_pred_logits, y)
            test_loss += loss.item()
            # test_pred_labels = test_pred_logits.argmax(dim=1)
            # test_accuracy += ((test_pred_labels == y).sum().item()/len(test_pred_labels))
            test_accuracy += accuracy_fn(y_true=y, y_pred=test_pred_logits.argmax(dim=1)) # transforms from logits to labels

    test_loss = test_loss / len(dataloader)
    test_accuracy = test_accuracy / len(dataloader)
    return test_loss, test_accuracy


In [10]:
def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
          epochs: int = 5,
          device=device):

  results = {"train_loss": [],
             "train_acc": [],
             "test_loss": [],
             "test_acc": []}

  for epoch in tqdm(range(epochs)):
    train_loss, train_acc = train_step(model=model,
                                       data_loader=train_dataloader,
                                       loss_fn=loss_fn,
                                       optimizer=optimizer,
                                       accuracy_fn=accuracy_fn,
                                       device=device)
    test_loss, test_acc = test_step(model=model,
                                    data_loader=test_dataloader,
                                    loss_fn=loss_fn,
                                    accuracy_fn=accuracy_fn,
                                    device=device)

    print(f"Epoch: {epoch} | Train loss: {train_loss:.4f} | Train acc: {train_acc:.4f} | Test loss: {test_loss:.4f} | Test acc: {test_acc:.4f}")

    results["train_loss"].append(train_loss)
    results["train_acc"].append(train_acc)
    results["test_loss"].append(test_loss)
    results["test_acc"].append(test_acc)

  return results

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

NUM_EPOCHS = 5

loss_fn = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model_0.parameters(), lr=0.001)

start_time = timer()

model_0_results = train(model_0,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        optimizer=optimizer,
                        loss_fn=loss_fn,
                        epochs=NUM_EPOCHS,
                        device=device)

end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")

  0%|          | 0/5 [00:00<?, ?it/s]