In [1]:
from torch.nn import Sequential, Module, ReLU, Conv2d, Linear, MaxPool2d, LogSoftmax, NLLLoss, Dropout, BatchNorm2d, LeakyReLU, GELU, SELU, Mish, CrossEntropyLoss
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from torch import flatten, float, no_grad
from torch.optim import Adam
import torch
import wandb
import math

In [2]:
def get_data(param, type):
    if(type.lower() == 'train'):
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomHorizontalFlip(p=0.3),
            transforms.RandomRotation(degrees=12),
            transforms.ColorJitter(),
            transforms.CenterCrop(224),
            transforms.ToTensor(), 
            transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])  
        ])
        tdataset = datasets.ImageFolder(root=param['train_data_dir'], transform=transform)
        total = len(tdataset)
        train_sample = math.ceil(total*(0.8))
        val_sample = total-train_sample
        # print(total, train_sample, val_sample)
        train_dataset, validation_dataset = torch.utils.data.random_split(tdataset, [train_sample, val_sample])
        train_dataloader = DataLoader(train_dataset, batch_size=param['batch_size'], shuffle=True)
        validation_dataloader = DataLoader(validation_dataset, batch_size=param['batch_size'], shuffle=False)
        return train_dataloader, validation_dataloader
    
    else:
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(size=224),
            transforms.ToTensor(), 
            transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])  
        ])
        test_dataset = datasets.ImageFolder(root=param['test_data_dir'], transform=transform)
        test_dataloader = DataLoader(test_dataset, batch_size=param['batch_size'])
        return test_dataloader

## Transfer Learning using GoogLeNet model

In [3]:
def train():

    wandb.init()
    param = wandb.config
    wandb.run.name = f'GoogLeNet_strategy_{param.strategy}_batchSz_{param.batch_size}_epochs_{param.epochs}_layersToFreeze_{param.layers_to_freeze}'

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    pmodel = models.googlenet(pretrained=True)

    if(param['strategy'] == 'all_freeze'):
        num_features = pmodel.fc.in_features
        pmodel.fc = Linear(num_features, 10)
        for name, par in pmodel.named_parameters():
            if name not in ['fc.weight', 'fc.bias']:
                par.requires_grad = False

    elif(param['strategy'] == 'k_freeze'):
        layers_to_freeze = list(pmodel.children())[:param['layers_to_freeze']]
        for x in layers_to_freeze:
            for y in x.parameters():
                y.requires_grad = False
        num_features = pmodel.fc.in_features
        pmodel.fc = Linear(num_features, 10)
    
    else:
        num_features = pmodel.fc.in_features
        pmodel.fc = Linear(num_features, 10)
    
    
    pmodel = pmodel.to(device)
    optimizer = Adam(pmodel.parameters())
    loss_function = CrossEntropyLoss()
    train_data_loader, validation_data_loader = get_data(param, 'train')

    for epo in range(param['epochs']):
        totalTrainLoss = 0
        totalValLoss = 0
        trainCorrect = 0
        valCorrect = 0
        train_counter=0
        validation_counter=0
        pmodel.train()
        for (image, label) in train_data_loader:
            (image, label) = (image.to(device), label.to(device))
            prediction = pmodel(image)
            loss = loss_function(prediction, label)
            loss.backward()
            optimizer.step()
            totalTrainLoss += loss
            trainCorrect += (prediction.argmax(1) == label).type(float).sum().item()
            train_counter+=1

        ## Validation
        pmodel.eval()
        with no_grad():
            for (image, label) in validation_data_loader:
                (image, label) = (image.to(device), label.to(device))
                pred = pmodel(image)
                loss = loss_function(pred, label)
                totalValLoss += loss
                valCorrect += (pred.argmax(1) == label).type(float).sum().item()
                validation_counter += 1


        tr_ls = (totalTrainLoss/train_counter).cpu().detach().numpy()
        tr_acc = trainCorrect/len(train_data_loader.dataset)
        val_ls = (totalValLoss/validation_counter).cpu().detach().numpy()
        val_acc = valCorrect/len(validation_data_loader.dataset)


        print(f"Epoch --> {epo}")
        print(f"Train Loss --> {tr_ls}")
        print(f"Train Accuracy --> {tr_acc}")
        print(f"Validation Loss --> {val_ls}")
        print(f"Validation Accuracy --> {val_acc}")
        print("-----------------------------------------------------------")
        
        lg={
            'epoch': epo+1,
            'tr_accuracy': tr_acc,
            'val_accuracy': val_acc,
            'tr_loss': tr_ls,
            'val_loss': val_ls
        }
        wandb.log(lg)


    

In [61]:
# train()

In [4]:
sweep_config = {
  "method": "grid",
  "name": "PartB GoogLeNet Sweep",
  "metric": {"goal": "maximize", "name": "val_accuracy"},
  "parameters": {
    "batch_size":{"values": [32]},
    "epochs":{"values": [10]},
    "strategy":{"values": ['k_freeze']},  ## K freeze, No Freeze, all_freeze(except last)
    "layers_to_freeze": {"values": [15]},
    "train_data_dir":{"values": ["./data/train"]},
    "test_data_dir":{"values": ["./data/val"]}
  }
}

In [5]:
# wandb.init()
sweep_id = wandb.sweep(sweep_config, project="cs6910_assignment2")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: 8aebzumu
Sweep URL: https://wandb.ai/cs23m070/cs6910_assignment2/sweeps/8aebzumu


In [6]:
wandb.agent(sweep_id, function=train, count=1)
wandb.finish()

[34m[1mwandb[0m: Agent Starting Run: fwubj5xd with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	layers_to_freeze: 15
[34m[1mwandb[0m: 	strategy: all_freeze
[34m[1mwandb[0m: 	test_data_dir: ./data/val
[34m[1mwandb[0m: 	train_data_dir: ./data/train
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcs23m070[0m. Use [1m`wandb login --relogin`[0m to force relogin




5,610,154 total parameters.
10,250 training parameters.
Epoch --> 0
Train Loss --> 9.581089973449707
Train Accuracy --> 0.30225
Validation Loss --> 19.375024795532227
Validation Accuracy --> 0.2896448224112056
-----------------------------------------------------------
Epoch --> 1
Train Loss --> 22.02076530456543
Train Accuracy --> 0.341125
Validation Loss --> 27.991174697875977
Validation Accuracy --> 0.4012006003001501
-----------------------------------------------------------
Epoch --> 2
Train Loss --> 25.492965698242188
Train Accuracy --> 0.4245
Validation Loss --> 18.1383113861084
Validation Accuracy --> 0.5052526263131566
-----------------------------------------------------------
Epoch --> 3
Train Loss --> 22.986249923706055
Train Accuracy --> 0.487
Validation Loss --> 18.42642593383789
Validation Accuracy --> 0.535767883941971
-----------------------------------------------------------
Epoch --> 4
Train Loss --> 26.38643455505371
Train Accuracy --> 0.494125
Validation Loss -->

wandb: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.


Epoch --> 5
Train Loss --> 21.692285537719727
Train Accuracy --> 0.55
Validation Loss --> 23.570619583129883
Validation Accuracy --> 0.5672836418209104
-----------------------------------------------------------
Epoch --> 6
Train Loss --> 25.332639694213867
Train Accuracy --> 0.552625
Validation Loss --> 24.127452850341797
Validation Accuracy --> 0.5692846423211606
-----------------------------------------------------------
Epoch --> 7
Train Loss --> 21.726051330566406
Train Accuracy --> 0.597375
Validation Loss --> 24.587656021118164
Validation Accuracy --> 0.6173086543271635
-----------------------------------------------------------
Epoch --> 8
Train Loss --> 27.96365737915039
Train Accuracy --> 0.593625
Validation Loss --> 27.495100021362305
Validation Accuracy --> 0.6193096548274137
-----------------------------------------------------------
Epoch --> 9
Train Loss --> 24.510570526123047
Train Accuracy --> 0.625
Validation Loss --> 21.219383239746094
Validation Accuracy --> 0.64182

0,1
epoch,▁▂▃▃▄▅▆▆▇█
tr_accuracy,▁▂▄▅▅▆▆▇▇█
tr_loss,▁▆▇▆▇▆▇▆█▇
val_accuracy,▁▃▅▆▆▇▇███
val_loss,▂█▁▁▅▅▅▆█▃

0,1
epoch,10.0
tr_accuracy,0.625
tr_loss,24.51057
val_accuracy,0.64182
val_loss,21.21938
