In [99]:
import os
import random
import numpy as np
from tqdm import tqdm
from torchinfo import summary

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as tr
import torchvision.models as models

from IPython.display import clear_output

def seed_everything(seed = 21):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything()

In [None]:
# Log in W&B account
import wandb
wandb.login()

In [101]:
# Loading Data - MNIST dataset
def make_loader(batch_size, train=True, shuffle=True):
    full_dataset = torchvision.datasets.MNIST(root='./data/MNIST',
                                              train=train,
                                              download=True,
                                              transform=tr.ToTensor())
    
    loader = DataLoader(dataset=full_dataset,
                        batch_size=batch_size,
                        shuffle=shuffle,
                        pin_memory=True)

    return loader

In [102]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(32 * 5 * 5, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 32 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [103]:
class MLPNet(nn.Module):
    def __init__(self):
        super(MLPNet, self).__init__()
        self.fc1 = nn.Linear(784, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 16)
        self.fc4 = nn.Linear(16, 10)

    def forward(self, x):
        x = x.float()
        x = F.relu(self.fc1(x.view(-1, 784)))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        return x

In [104]:
# Total params: 30,762
summary(ConvNet())

Layer (type:depth-idx)                   Param #
ConvNet                                  --
├─Conv2d: 1-1                            160
├─Conv2d: 1-2                            4,640
├─Linear: 1-3                            25,632
├─Linear: 1-4                            330
Total params: 30,762
Trainable params: 30,762
Non-trainable params: 0

In [105]:
# Total params: 53,018
summary(MLPNet())

Layer (type:depth-idx)                   Param #
MLPNet                                   --
├─Linear: 1-1                            50,240
├─Linear: 1-2                            2,080
├─Linear: 1-3                            528
├─Linear: 1-4                            170
Total params: 53,018
Trainable params: 53,018
Non-trainable params: 0

In [106]:
def model_train(model, 
                data_loader, 
                criterion, 
                optimizer, 
                device, 
                scheduler=None, 
                tqdm_disable=False):
    """
    Model train (for multi-class classification)

    Args:
        model (torch model)
        data_loader (torch dataLoader)
        criterion (torch loss)
        optimizer (torch optimizer)
        device (str): 'cpu' / 'cuda' / 'mps'
        scheduler (torch scheduler, optional): lr scheduler. Defaults to None.
        tqdm_disable (bool, optional): if True, tqdm progress bars will be removed. Defaults to False.

    Returns:
        loss, accuracy: Avg loss, acc for 1 epoch
    """
    model.train()

    running_loss = 0
    correct = 0

    for X, y in tqdm(data_loader, disable=tqdm_disable):
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()

        output = model(X)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

        # multi-class classification
        _, pred = output.max(dim=1)
        correct += pred.eq(y).sum().item()
        running_loss += loss.item() * X.size(0)

    if scheduler:
        scheduler.step()

    accuracy = correct / len(data_loader.dataset) # Avg acc
    loss = running_loss / len(data_loader.dataset) # Avg loss

    return loss, accuracy


def model_evaluate(model, 
                   data_loader, 
                   criterion, 
                   device):
    """
    Model validate (for multi-class classification)

    Args:
        model (torch model)
        data_loader (torch dataLoader)
        criterion (torch loss)
        device (str): 'cpu' / 'cuda' / 'mps'

    Returns:
        loss, accuracy: Avg loss, acc for 1 epoch
    """
    model.eval()

    with torch.no_grad():
        running_loss = 0
        correct = 0

        sample_batch = []
        sample_label = []
        sample_prediction = []

        for i, (X, y) in enumerate(data_loader):
            X, y = X.to(device), y.to(device)

            output = model(X)

            # multi-class classification
            _, pred = output.max(dim=1)
            correct += torch.sum(pred.eq(y)).item()
            running_loss += criterion(output, y).item() * X.size(0)

            if i == 0:
                sample_batch.append(X)
                sample_label.append(y)
                sample_prediction.append(pred)

        accuracy = correct / len(data_loader.dataset) # Avg acc
        loss = running_loss / len(data_loader.dataset) # Avg loss

        return loss, accuracy, sample_batch[0][:16], sample_label[0][:16], sample_prediction[0][:16]

In [107]:
def map_dict_to_str(config):
    config_str = ', '.join(f"{key}: {value}" for key, value in config.items() if key not in ['dataset', 'epochs', 'batch_size'])
    return config_str

In [108]:
def run(config):
    wandb.init(project='PROJECT NAME', config=config)
    wandb.run.name = map_dict_to_str(config)

    print('------')
    print(map_dict_to_str(config))
    print('------\n')

    config = wandb.config
    device = "cuda" if torch.cuda.is_available() else "cpu"

    train_loader = make_loader(batch_size=config.batch_size, train=True)
    test_loader = make_loader(batch_size=config.batch_size, train=False)
    
    if config.model == 'CNN':
        model = ConvNet().to(device)
    if config.model == 'MLP':
        model = MLPNet().to(device)

    criterion = nn.CrossEntropyLoss()

    if config.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    if config.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    if config.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

    wandb.watch(model, criterion, log="all")

    max_loss = np.inf

    for epoch in range(0, config.epochs):
        train_loss, train_acc = model_train(model, train_loader, criterion, optimizer, device, None)
        val_loss, val_acc, sample_batch, sample_label, sample_prediction = model_evaluate(model, test_loader, criterion, device)

        wandb.log({"Train Loss": train_loss}, step=epoch+1)
        wandb.log({"Train Accuracy": train_acc}, step=epoch+1)
        wandb.log({"Validation Loss": val_loss}, step=epoch+1)
        wandb.log({"Validation Accuracy": val_acc}, step=epoch+1)

        wandb.log({"examples": [wandb.Image(image, caption=f"Pred: {pred}, Label: {label}") for image, pred, label in zip(sample_batch, sample_prediction, sample_label)]}, step=epoch+1)

        if val_loss < max_loss:
            print(f'[INFO] val_loss has been improved from {max_loss:.5f} to {val_loss:.5f}. Save model.')
            max_loss = val_loss
            torch.save(model.state_dict(), 'Best_Model.pth')

        print(f'epoch {epoch+1:02d}, loss: {train_loss:.5f}, acc: {train_acc:.5f}, val_loss: {val_loss:.5f}, val_accuracy: {val_acc:.5f} \n')
      
    if config.model == 'CNN':
        model = ConvNet().to(device)
        wandb.log({'Total Params': 30762})
    if config.model == 'MLP':
        model = MLPNet().to(device)      
        wandb.log({'Total Params': 53018})
  
    model.load_state_dict(torch.load('Best_Model.pth', map_location=device))
    model.eval()
    val_loss, val_acc, _, _, _ = model_evaluate(model, test_loader, criterion, device)

    print('Test Loss: %s'%val_loss)
    print('Test Accuracy: %s'%val_acc)
    print()
    
    wandb.log({"Best Test Loss": val_loss})
    wandb.log({"Best Test Accuracy": val_acc})

    return 'Done'

In [None]:
model_list = ['CNN', 'MLP']
optimizer_list = ['sgd', 'adam', 'adamw']
learning_rate_list = [1e-2, 1e-3, 1e-4]
weight_decay_list = [0, 1e-2]

for model in model_list:
    for optimizer in optimizer_list:
        for learning_rate in learning_rate_list:
            for weight_decay in weight_decay_list:
                config = {'dataset': 'MNIST',
                          'model': model,
                          'epochs': 10,
                          'batch_size': 64,
                          'optimizer': optimizer,
                          'learning_rate': learning_rate,
                          'weight_decay': weight_decay}

                run(config)
                clear_output(wait=True)