In [26]:
from torch.nn import Sequential, Module, ReLU, Conv2d, Linear, MaxPool2d, LogSoftmax, NLLLoss, Dropout, BatchNorm2d, LeakyReLU, GELU, SELU, Mish, CrossEntropyLoss
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from torch import flatten, float, no_grad
from torch.optim import Adam
import torch
import wandb
import math

In [27]:
def get_data(param, type):
    if(type.lower() == 'train'):
        transform = transforms.Compose([
            transforms.Resize(256),
            # transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(p=0.3),
            transforms.RandomRotation(degrees=12),
            transforms.ColorJitter(),
            transforms.CenterCrop(224),
            transforms.ToTensor(), 
            transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])  
        ])
        # print("hi")
        # print(param)
        tdataset = datasets.ImageFolder(root=param['train_data_dir'], transform=transform)
        total = len(tdataset)
        train_sample = math.ceil(total*(0.8))
        val_sample = total-train_sample
        # print(total, train_sample, val_sample)
        train_dataset, validation_dataset = torch.utils.data.random_split(tdataset, [train_sample, val_sample])
        train_dataloader = DataLoader(train_dataset, batch_size=param['batch_size'], shuffle=True)
        validation_dataloader = DataLoader(validation_dataset, batch_size=param['batch_size'], shuffle=False)
        return train_dataloader, validation_dataloader
    
    else:
        transform = transforms.Compose([
            transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
            transforms.CenterCrop(size=224),
            transforms.ToTensor(), 
            transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])  
        ])
        test_dataset = datasets.ImageFolder(root=param['test_data_dir'], transform=transform)
        test_dataloader = DataLoader(test_dataset, batch_size=param['batch_size'])
        return test_dataloader

## Transfer Learning using GoogLeNet model

In [62]:
def train():

    wandb.init()
    param = wandb.config
    wandb.run.name = f'GoogLeNet_strategy_{param.strategy}_batchSz_{param.batch_size}_epochs_{param.epochs}_layersToFreeze_{param.layers_to_freeze}'

    # param = {
    #     "batch_size": 32,
    #     "epochs": 5,
    #     "train_data_dir": "./data/train",
    #     "test_data_dir": "./data/val",
    #     "strategy": "no_freeze",
    #     "layers_to_freeze": 15
    # }


    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    pmodel = models.googlenet(pretrained=True)

    if(param['strategy'] == 'all_freeze'):
        num_features = pmodel.fc.in_features
        pmodel.fc = Linear(num_features, 10)

        for name, par in pmodel.named_parameters():
            if name not in ['fc.weight', 'fc.bias']:
                par.requires_grad = False

    elif(param['strategy'] == 'k_freeze'):
        layers_to_freeze = list(pmodel.children())[:param['layers_to_freeze']]
        for x in layers_to_freeze:
            for y in x.parameters():
                y.requires_grad = False

        num_features = pmodel.fc.in_features
        pmodel.fc = Linear(num_features, 10)
    
    else:
        num_features = pmodel.fc.in_features
        pmodel.fc = Linear(num_features, 10)
    

    total_params = sum(p.numel() for p in pmodel.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(
        p.numel() for p in pmodel.parameters() if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')
    
    pmodel = pmodel.to(device)
    optimizer = Adam(pmodel.parameters())
    loss_function = CrossEntropyLoss()
    train_data_loader, validation_data_loader = get_data(param, 'train')

    for epo in range(param['epochs']):
        totalTrainLoss = 0
        totalValLoss = 0
        trainCorrect = 0
        valCorrect = 0
        train_counter=0
        validation_counter=0
        pmodel.train()
        for (image, label) in train_data_loader:
            (image, label) = (image.to(device), label.to(device))
            prediction = pmodel(image)
            loss = loss_function(prediction, label)
            ## no optimize.zero_grad() ...
            loss.backward()
            optimizer.step()

            totalTrainLoss += loss
            trainCorrect += (prediction.argmax(1) == label).type(float).sum().item()
            train_counter+=1
            # print(train_counter)

        pmodel.eval()
        with no_grad():
            for (image, label) in validation_data_loader:
                (image, label) = (image.to(device), label.to(device))
                pred = pmodel(image)
                loss = loss_function(pred, label)
                totalValLoss += loss
                valCorrect += (pred.argmax(1) == label).type(float).sum().item()
                validation_counter += 1

        tr_ls = (totalTrainLoss/train_counter).cpu().detach().numpy()
        tr_acc = trainCorrect/len(train_data_loader.dataset)
        val_ls = (totalValLoss/validation_counter).cpu().detach().numpy()
        val_acc = valCorrect/len(validation_data_loader.dataset)
        print(f"Epoch --> {epo}")
        print(f"Train Loss --> {tr_ls}")
        print(f"Train Accuracy --> {tr_acc}")
        print(f"Validation Loss --> {val_ls}")
        print(f"Validation Accuracy --> {val_acc}")
        print("-----------------------------------------------------------")
        
        lg={
            'epoch': epo+1,
            'tr_accuracy': tr_acc,
            'val_accuracy': val_acc,
            'tr_loss': tr_ls,
            'val_loss': val_ls
        }
        wandb.log(lg)

    # torch.save(model, checkpoint_path)

    

In [61]:
# train()

In [69]:
sweep_config = {
  "method": "grid",  # Use grid search for this example
  "name": "PartB GoogLeNet Sweep",
  "metric": {"goal": "maximize", "name": "val_accuracy"},
  "parameters": {
    "batch_size":{"values": [32]},
    "epochs":{"values": [10]},
    "strategy":{"values": ['all_freeze']},  ## K freeze, No Freeze, all_freeze(except last)
    "layers_to_freeze": {"values": [15]},
    "train_data_dir":{"values": ["./data/train"]},
    "test_data_dir":{"values": ["./data/val"]}
  }
}

In [70]:
# wandb.init()
sweep_id = wandb.sweep(sweep_config, project="cs6910_assignment2")

Create sweep with ID: xg5b5wxw
Sweep URL: https://wandb.ai/cs23m070/cs6910_assignment2/sweeps/xg5b5wxw


In [68]:
wandb.agent(sweep_id, function=train, count=1)
wandb.finish()

[34m[1mwandb[0m: Agent Starting Run: azuzwgpj with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	layers_to_freeze: 15
[34m[1mwandb[0m: 	strategy: no_freeze
[34m[1mwandb[0m: 	test_data_dir: ./data/val
[34m[1mwandb[0m: 	train_data_dir: ./data/train
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




5,610,154 total parameters.
5,610,154 training parameters.
Epoch --> 0
Train Loss --> 2.8371622562408447
Train Accuracy --> 0.120125
Validation Loss --> 62.41395950317383
Validation Accuracy --> 0.11055527763881941
-----------------------------------------------------------
Epoch --> 1
Train Loss --> 2.6379847526550293
Train Accuracy --> 0.104625
Validation Loss --> 5.451559543609619
Validation Accuracy --> 0.10305152576288144
-----------------------------------------------------------
Epoch --> 2
Train Loss --> 2.3632090091705322
Train Accuracy --> 0.098125
Validation Loss --> 2.8544845581054688
Validation Accuracy --> 0.11105552776388194
-----------------------------------------------------------
Epoch --> 3
Train Loss --> 2.3623974323272705
Train Accuracy --> 0.10525
Validation Loss --> 6.597944259643555
Validation Accuracy --> 0.10855427713856929
-----------------------------------------------------------
Epoch --> 4
Train Loss --> 3.009793758392334
Train Accuracy --> 0.10125
Valid

0,1
epoch,▁▂▃▃▄▅▆▆▇█
tr_accuracy,▅▂▁▂▂▂▂▅█▅
tr_loss,▆▄▂▂█▄▃▂▁▁
val_accuracy,▄▃▄▄▃▁▃▆█▆
val_loss,█▁▁▂▁▁▁▁▁▁

0,1
epoch,10.0
tr_accuracy,0.1225
tr_loss,2.30267
val_accuracy,0.12556
val_loss,2.2941
