In [1]:
import torch as th

from torchvision import transforms

import matplotlib.pyplot as plt

In [2]:
from src.dataloader import customDataloader
from tqdm import trange

In [3]:
train_loader = customDataloader(train=True,  transform=None, batch_size=64, shuffle=True)
test_loader  = customDataloader(train=False, transform=None, batch_size=64, shuffle=True)

In [4]:
class Baseline(th.nn.Module):
    def __init__(self):
        super().__init__()

        self.net = th.nn.Sequential(
            th.nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3),
            th.nn.ReLU(),
            th.nn.MaxPool2d(kernel_size=2, stride=2),
            th.nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3),
            th.nn.ReLU(),
            th.nn.MaxPool2d(kernel_size=2, stride=2),
            th.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
            th.nn.ReLU(),
            th.nn.Flatten(start_dim=1),
            th.nn.Linear(in_features=32*12*12, out_features=15, bias=True),
        )

    def forward(self, x):
        return self.net(x)

In [8]:
def train_loop(model, train_loader, test_loader, criterion, optimizer, num_epochs=10):
    train_losses     = []
    test_losses      = []
    train_accuracies = []
    test_accuracies  = []
    
    # model.to(device)
    bar = trange(num_epochs)
    
    for epoch in bar:
        model.train()
        running_loss        = 0.0
        correct_predictions = 0.0
        total_samples = 0.0
        
        for X, y in train_loader:
            X, y = X.float(), y.float()
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            y_hat = model(X)
            loss = criterion(y_hat, y)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            y_hat = y_hat.argmax(1)
            y_true = y.argmax(1)
            correct_predictions += (y_hat == y_true).sum().item()
            total_samples += y_true.size(0)
            
            running_loss += loss.item()

        # Calculate and print epoch loss
        train_loss = running_loss / len(train_loader)
        train_acc  = correct_predictions / total_samples

        running_loss        = 0.0
        correct_predictions = 0.0
        total_samples = 0.0
        
        with th.no_grad():
            for X, y in test_loader:
                X, y = X.float(), y.float()
                
                # Zero the parameter gradients
                optimizer.zero_grad()
                
                # Forward pass
                y_hat = model(X)
                loss = criterion(y_hat, y)
                
                y_hat = y_hat.argmax(1)
                y_true = y.argmax(1)
                correct_predictions += (y_hat == y_true).sum().item()
                total_samples += y_true.size(0)
                
                running_loss += loss.item()
                
        test_loss = running_loss / len(test_loader)
        test_acc  = correct_predictions / total_samples

        train_losses.append(train_loss)
        test_losses.append(test_loss)
        train_accuracies.append(train_acc)
        test_accuracies.append(test_acc)

        bar.set_postfix({"Train Loss": train_loss, 
                        "Test Loss": test_loss,
                        "Train Acc": train_acc, 
                        "Test Acc": test_acc})

    return train_losses, test_losses, train_accuracies, test_accuracies

In [6]:
model = Baseline()
criterion = th.nn.CrossEntropyLoss()
optimizer = th.optim.Adam(model.parameters(), lr=0.001)

In [9]:
train_loop(model, train_loader, test_loader, criterion, optimizer, num_epochs=10)

100%|█| 10/10 [00:20<00:00,  2.04s/it, Train Loss=0.0738, Test Loss=4.11, Train Acc=0.989, T


([2.1095303942759833,
  1.5138242443402607,
  1.0294386098782222,
  0.7595674420396487,
  0.5013731122016907,
  0.34685775140921277,
  0.22356335942943892,
  0.17118993618836006,
  0.10774062775696318,
  0.07377193480109175],
 [2.4137590743125754,
  2.375284747874483,
  2.6018372748760465,
  2.7619812995829482,
  3.1502200999158494,
  3.3758714858521808,
  3.81685009915778,
  3.786969281257467,
  4.156646063987245,
  4.11318536007658],
 [0.33466666666666667,
  0.528,
  0.68,
  0.7693333333333333,
  0.86,
  0.9073333333333333,
  0.9506666666666667,
  0.964,
  0.982,
  0.9886666666666667],
 [0.2455611390284757,
  0.26532663316582916,
  0.27872696817420434,
  0.29279731993299835,
  0.28509212730318256,
  0.29346733668341707,
  0.28710217755443884,
  0.3082077051926298,
  0.2984924623115578,
  0.30954773869346736])