In [None]:
!pip install torchmetrics torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116

In [77]:
import torch
from timeit import default_timer as timer

def train_model(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, optimizer: torch.optim, acc_fn, loss_fn: torch.nn.Module, device="cpu"):
  train_loss_cum = 0
  train_acc_cum = 0
  train_start_time = timer()
  model.train()
  for train_batch_number, (train_features, train_labels) in enumerate(data_loader):
    train_features = train_features.to(device)
    train_labels = train_labels.to(device)

    train_pred_labels = model(train_features)

    train_loss = loss_fn(train_pred_labels, train_labels)

    train_loss_cum += train_loss.item()
    train_acc_cum += acc_fn(train_pred_labels.argmax(dim=1), train_labels)

    optimizer.zero_grad()

    train_loss.backward()

    optimizer.step()

    if train_batch_number % 500 == 0:
      print(f"Looked at {train_batch_number * len(train_features)}/{len(data_loader) * len(train_features)} samples")

  train_loss = train_loss_cum / len(data_loader)
  train_acc = train_acc_cum / len(data_loader)
  train_end_time = timer()
  train_total_time = train_end_time - train_start_time
  return {
      "train_loss": train_loss,
      "train_acc": train_acc,
      "train_total_time": train_total_time
  }

In [72]:
from timeit import default_timer as timer

def eval_model(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, acc_fn, loss_fn: torch.nn.Module, device="cpu"):
  eval_loss_cum = 0
  eval_acc_cum = 0
  eval_start_time = timer()
  model.eval()
  with torch.inference_mode():
    for eval_batch_number, (eval_features, eval_labels) in enumerate(data_loader):
      eval_features = eval_features.to(device)
      eval_labels = eval_labels.to(device)
      eval_pred_labels = model(eval_features)
      eval_loss_cum += loss_fn(eval_pred_labels, eval_labels)
      eval_acc_cum += acc_fn(eval_pred_labels.argmax(dim=1), eval_labels)

  eval_loss = eval_loss_cum / len(data_loader)
  eval_acc = eval_acc_cum / len(data_loader)
  eval_end_time = timer()
  eval_total_time = eval_end_time - eval_start_time
  return {
      "eval_total_time": eval_total_time,
      "eval_loss": eval_loss,
      "eval_acc": eval_acc
  }

In [80]:
## Creating a training loop and training model in batches of data rather than epoch
from tqdm.auto import tqdm
from timeit import default_timer as timer
import torch
from torchmetrics import Accuracy
from torch.utils.data import DataLoader

def train_eval_loop(model: torch.nn.Module, train_data: torch.utils.data.Dataset, test_data: torch.utils.data.Dataset, random_state: int=42, device: str="cpu", lr: float=0.01, batch_size: int=32, epochs: int=10):

  loss_fn = torch.nn.CrossEntropyLoss()

  optimizer = torch.optim.SGD(params=model.parameters(), lr=lr)

  acc_fn = Accuracy(task="multiclass", num_classes=len(train_data.classes)).to(device)

  torch.manual_seed(random_state)
  if device == "cuda":
    torch.cuda.manual_seed(random_state)

  train_dataloader = DataLoader(
      train_data,
      batch_size=batch_size,
      shuffle=True,
  )

  eval_dataloader = DataLoader(
      test_data,
      batch_size=batch_size,
      shuffle=False,
  )

  model.to(device)

  for epoch in tqdm(range(epochs)):
    print(f"Epoch: {epoch + 1} -----")
    
    train_data = train_model(
        model=model,
        data_loader=train_dataloader,
        acc_fn=acc_fn,
        loss_fn=loss_fn,
        device=device,
        optimizer=optimizer
    )
    print(f'''train time on {device}: {train_data["train_total_time"]:.4f} seconds''')

    eval_data = eval_model(
        model=model,
        data_loader=eval_dataloader,
        acc_fn=acc_fn,
        loss_fn=loss_fn,
        device=device
    )
    print(f'''eval time on {device}: {eval_data["eval_total_time"]:.4f} seconds''')

    print(f'''Train loss: {train_data["train_loss"]:.4f} | Train accuracy: {train_data["train_acc"]:.2f}% | Eval loss: {eval_data["eval_loss"]:.4f} | Eval accuracy: {eval_data["eval_acc"]:.2f}%''')
    print()

In [None]:
from torchvision import datasets
from torchvision.transforms import ToTensor
train_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor(), target_transform=None)
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor(), target_transform=None)

In [None]:
from torch import nn

# NN without any non-linear activation function
class FashionMNISTModelV0(nn.Module):
  def __init__(self, input_shape: int, hidden_units: int, output_shape: int):
    super().__init__()
    self.layers = nn.Sequential(
        nn.Flatten(),
        nn.Linear(
            in_features=input_shape,
            out_features=hidden_units
        ),
        nn.Linear(
            in_features=hidden_units,
            out_features=output_shape
        )
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.layers(x)

  
  for device in ["cpu", "cuda"]:
    train_eval_loop(
        FashionMNISTModelV0(
          input_shape=28 * 28,
          output_shape=len(train_data.classes),
          hidden_units=10
        ),
        train_data=train_data,
        test_data=test_data,
        epochs=5,
        device=device
    )

In [None]:
from torch import nn

# Model with non-linear activation functions
class FashionMNISTModelV1(nn.Module):
  def __init__(self, input_shape: int, hidden_units: int, output_shape: int):
    super().__init__()
    self.layers = nn.Sequential(
        nn.Flatten(),
        nn.Linear(
            in_features=input_shape,
            out_features=hidden_units
        ),
        nn.ReLU(),
        nn.Linear(
            in_features=hidden_units,
            out_features=output_shape
        )
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.layers(x)

  from torch import nn


for device in ["cpu", "cuda"]:
  train_eval_loop(
      FashionMNISTModelV1(
        input_shape=28 * 28,
        output_shape=len(train_data.classes),
        hidden_units=10
      ),
      train_data=train_data,
      test_data=test_data,
      epochs=5,
      device=device
  )