Author: Anurag Vaidya  
Date: 2/4/2022  
Lab: Polina Lab @ CSAIL  
Purpose: Create a basic classifier for the afhq dataset

## Notebook Structure
- Imports
- Args
- Dataset
- Model
- Training/ Val loop
- main()

---

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn 
import torch.nn.functional as Fun
import torchvision.transforms.functional as F
from torch.utils.data import Dataset
from torchvision import models
import torch.optim as optim
from torch.optim import lr_scheduler

import os
import sys
import random
from argparse import Namespace
import time, copy

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

# import wandb
# wandb.init(project="cat-dog-styleSpace", entity="ajv012")


sys.path.append("./")
sys.path.append("../")

---

#### Args

In [2]:
args = Namespace(device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu"),
                 train_dir = "../data/afhq/train",
                 val_dir = "../data/afhq/val",
                 save_path = "./checkpoints",
                 seed = 7,
                 labels = ["cat", "dog"],
                 batch_size = 64,
                 epochs = 50,
                 num_workers = 0,
                 class_names = {0:"cat", 1:"dog"} ,
                 lr = 0.0001,
                 momentum = 0.9,
                 criterion = nn.CrossEntropyLoss(),
                 optimizer = "SGD",
                 scheduler = "STEP",
                 scheduler_step_size = 7,
                 scheduler_gamma = 0.1,
    )


---

#### Dataset class

In [3]:
class afhq_dataset(Dataset):
    r"""
    Take a root dir and return the transformed img and associated label with it
    """
    def __init__(self, root_dir, seed, labels, img_transform=None):

        self.seed = seed
        np.random.seed(self.seed)

        # this dir has two sub dirs cat and dog. Need to combine them
        self.root_dir = root_dir
        self.cat_names = os.listdir(os.path.join(self.root_dir, "cat"))
        self.dog_names = os.listdir(os.path.join(self.root_dir, "dog"))
        self.all_names = np.asarray(self.cat_names + self.dog_names)
        np.random.shuffle(self.all_names)
        self.img_transform = img_transform
        self.labels = {}
        for i in range(len(labels)):
            self.labels[labels[i]] = i
        

    def __len__(self):
        return len(self.all_names)

    def __getitem__(self, idx):
        curr_path = os.path.join(self.root_dir, self.all_names[idx].strip().split("_")[1], self.all_names[idx])
        curr_img = Image.open(curr_path)
        curr_label = self.labels[self.all_names[idx].strip().split("_")[1]]
        
        if self.img_transform:
            curr_img_transformed = self.img_transform(curr_img)
        
        return {"inputs" : curr_img_transformed, "labels" : curr_label} 
    
    def viz_img(self, imgs):
        r"""
        Take a tensor or list of tensors and visualize it
        """
        if not isinstance(imgs, list):
            imgs = [imgs]
        fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
        for i, img in enumerate(imgs):
            img = img.detach()
            img = F.to_pil_image(img)
            axs[0, i].imshow(np.asarray(img))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    
    def what_labels_mean(self):
        return [label + ": " + str(self.labels[label]) for label in self.labels]
        

---

#### Model class

In [4]:
class clf(torch.nn.Module):
    r"""
    A simple encoder and fully connected layer for classification
    """

    def __init__(self, num_classes):
        super(clf, self).__init__()
        self.model_ft = models.resnet18(pretrained=True)
        self.num_ftrs = self.model_ft.fc.in_features
        self.model_ft.fc = nn.Linear(self.num_ftrs, num_classes)

    def forward(self, x):
        x = self.model_ft(x)
        return x


---

#### Training loop

In [5]:
def train_and_val_model(model, datasets, dataloaders, device, criterion, optimizer, scheduler, PATH, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in tqdm(range(num_epochs)):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        running_loss_train = 0.0
        running_corrects_train = 0

        running_loss_val = 0.0
        running_corrects_val = 0

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            # Iterate over data.
            for batch in dataloaders[phase]:
                inputs, labels = batch["inputs"], batch["labels"]
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                if phase == "train":
                    running_loss_train += loss.item() * inputs.size(0)
                    running_corrects_train += torch.sum(preds == labels.data)
                else:
                    running_loss_val += loss.item() * inputs.size(0)
                    running_corrects_val += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            if phase == "train":
                epoch_loss = running_loss_train / len(datasets[phase])
                epoch_acc = running_corrects_train.double() / len(datasets[phase])
                wandb.log({"train_epoch_loss": epoch_loss, "train_epoch_acc": epoch_acc})
            else:
                epoch_loss = running_loss_val / len(datasets[phase])
                epoch_acc = running_corrects_val.double() / len(datasets[phase])
                wandb.log({"val_epoch_loss": epoch_loss, "val_epoch_acc": epoch_acc})

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                # save current best model
                PATH = "{}/checkpoint_{}.pt".format(args.save_path, epoch)
                torch.save({
                            'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'loss': epoch_loss,
                            'acc' : epoch_acc,
                            }, PATH)
                wandb.log({"best_acc": best_acc})

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [6]:
def visualize_model(model, dataloaders, device, class_names, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for batch in enumerate(dataloaders['val']):
            inputs, labels = batch["inputs"], batch["labels"]
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                plt.imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

---

#### Utils

In [7]:
def def_transforms():
    train_transforms = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(512),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    val_transforms = transforms.Compose([
        transforms.Resize(512),
        transforms.CenterCrop(512),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return train_transforms, val_transforms

def def_datasets(args, train_transforms, val_transforms):
    dataset_train = afhq_dataset(args.train_dir, args.seed, args.labels, train_transforms)
    dataset_val = afhq_dataset(args.val_dir, args.seed, args.labels, val_transforms)
    datasets = {"train": dataset_train, "val": dataset_val}
    dataset_sizes = {x: datasets[x] for x in ['train', 'val']}

    return datasets, dataset_sizes

def def_dataloaders(args, dataset_train, dataset_val):
    dataloader_train = DataLoader(dataset_train, batch_size = args.batch_size, num_workers=args.num_workers)
    dataloader_val = DataLoader(dataset_val, batch_size = args.batch_size, num_workers=args.num_workers)
    dataloaders = {"train": dataloader_train, "val":dataloader_val}

    return dataloaders

---

#### main()

In [8]:
def main():

    # define transforms
    train_transforms, val_transforms = def_transforms()

    # define datasets and sizes
    datasets, dataset_sizes = def_datasets(args, train_transforms, val_transforms)

    # define dataloaders
    dataloaders = def_dataloaders(args, datasets["train"], datasets["val"])
    
    # define model
    model = clf(len(args.labels))
    model = model.to(args.device)
    
    # define criterion
    criterion = args.criterion
    
    # define optim
    if args.optimizer == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    
    # define lr scheduler
    if args.scheduler == "STEP":
        scheduler = lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step_size, gamma=args.scheduler_gamma)

    # logging
    wandb.config = {
                    "learning_rate": args.lr,
                    "epochs": args.epochs,
                    "batch_size": args.batch_size
    }
    
    # train and val model
    model_final = train_and_val_model(model, datasets, dataloaders, args.device, 
                                     criterion, optimizer, scheduler, args.save_path, args.epochs)

    visualize_model(model_final, dataloaders, args.device, args.class_names)

In [9]:
main()

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 0/49
----------
train Loss: 0.3838 Acc: 0.9092
val Loss: 0.1920 Acc: 0.9940


  2%|▏         | 1/50 [02:37<2:08:15, 157.04s/it]


Epoch 1/49
----------
train Loss: 0.1434 Acc: 0.9929
val Loss: 0.0911 Acc: 0.9980


  4%|▍         | 2/50 [05:10<2:04:06, 155.13s/it]


Epoch 2/49
----------
train Loss: 0.0813 Acc: 0.9969
val Loss: 0.0562 Acc: 1.0000


  6%|▌         | 3/50 [07:51<2:03:25, 157.57s/it]


Epoch 3/49
----------
train Loss: 0.0554 Acc: 0.9975


  8%|▊         | 4/50 [10:21<1:58:34, 154.66s/it]

val Loss: 0.0395 Acc: 1.0000

Epoch 4/49
----------
train Loss: 0.0417 Acc: 0.9981


 10%|█         | 5/50 [12:51<1:54:36, 152.81s/it]

val Loss: 0.0301 Acc: 1.0000

Epoch 5/49
----------
train Loss: 0.0333 Acc: 0.9985


 12%|█▏        | 6/50 [15:23<1:51:59, 152.72s/it]

val Loss: 0.0241 Acc: 1.0000

Epoch 6/49
----------
train Loss: 0.0276 Acc: 0.9986


 14%|█▍        | 7/50 [17:55<1:49:11, 152.35s/it]

val Loss: 0.0200 Acc: 1.0000

Epoch 7/49
----------
train Loss: 0.0251 Acc: 0.9986


 16%|█▌        | 8/50 [20:25<1:46:14, 151.77s/it]

val Loss: 0.0197 Acc: 1.0000

Epoch 8/49
----------
train Loss: 0.0247 Acc: 0.9987


 18%|█▊        | 9/50 [22:57<1:43:48, 151.92s/it]

val Loss: 0.0194 Acc: 1.0000

Epoch 9/49
----------
train Loss: 0.0243 Acc: 0.9987


 20%|██        | 10/50 [25:29<1:41:07, 151.68s/it]

val Loss: 0.0191 Acc: 1.0000

Epoch 10/49
----------
train Loss: 0.0240 Acc: 0.9987


 22%|██▏       | 11/50 [27:59<1:38:20, 151.31s/it]

val Loss: 0.0188 Acc: 1.0000

Epoch 11/49
----------
train Loss: 0.0236 Acc: 0.9987


 24%|██▍       | 12/50 [30:30<1:35:43, 151.14s/it]

val Loss: 0.0185 Acc: 1.0000

Epoch 12/49
----------
train Loss: 0.0233 Acc: 0.9987


 26%|██▌       | 13/50 [32:59<1:32:49, 150.53s/it]

val Loss: 0.0182 Acc: 1.0000

Epoch 13/49
----------
train Loss: 0.0230 Acc: 0.9987


 28%|██▊       | 14/50 [35:28<1:30:05, 150.16s/it]

val Loss: 0.0179 Acc: 1.0000

Epoch 14/49
----------
train Loss: 0.0228 Acc: 0.9987


 30%|███       | 15/50 [37:59<1:27:37, 150.22s/it]

val Loss: 0.0179 Acc: 1.0000

Epoch 15/49
----------
train Loss: 0.0228 Acc: 0.9987


 32%|███▏      | 16/50 [40:30<1:25:18, 150.54s/it]

val Loss: 0.0178 Acc: 1.0000

Epoch 16/49
----------
train Loss: 0.0227 Acc: 0.9987


 34%|███▍      | 17/50 [43:03<1:23:09, 151.19s/it]

val Loss: 0.0178 Acc: 1.0000

Epoch 17/49
----------
train Loss: 0.0227 Acc: 0.9987


 36%|███▌      | 18/50 [45:31<1:20:13, 150.43s/it]

val Loss: 0.0178 Acc: 1.0000

Epoch 18/49
----------
train Loss: 0.0227 Acc: 0.9987


 38%|███▊      | 19/50 [48:01<1:17:39, 150.32s/it]

val Loss: 0.0178 Acc: 1.0000

Epoch 19/49
----------
train Loss: 0.0226 Acc: 0.9987


 40%|████      | 20/50 [50:32<1:15:08, 150.30s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 20/49
----------
train Loss: 0.0226 Acc: 0.9987


 42%|████▏     | 21/50 [53:01<1:12:32, 150.09s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 21/49
----------
train Loss: 0.0226 Acc: 0.9987


 44%|████▍     | 22/50 [55:32<1:10:06, 150.22s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 22/49
----------
train Loss: 0.0226 Acc: 0.9987


 46%|████▌     | 23/50 [58:01<1:07:25, 149.85s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 23/49
----------
train Loss: 0.0226 Acc: 0.9987


 48%|████▊     | 24/50 [1:00:32<1:05:11, 150.44s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 24/49
----------
train Loss: 0.0226 Acc: 0.9987


 50%|█████     | 25/50 [1:03:02<1:02:34, 150.16s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 25/49
----------
train Loss: 0.0226 Acc: 0.9987


 52%|█████▏    | 26/50 [1:05:31<59:53, 149.75s/it]  

val Loss: 0.0177 Acc: 1.0000

Epoch 26/49
----------
train Loss: 0.0226 Acc: 0.9987


 54%|█████▍    | 27/50 [1:08:01<57:25, 149.82s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 27/49
----------
train Loss: 0.0226 Acc: 0.9987


 56%|█████▌    | 28/50 [1:10:31<55:01, 150.07s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 28/49
----------
train Loss: 0.0226 Acc: 0.9987


 58%|█████▊    | 29/50 [1:13:01<52:28, 149.94s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 29/49
----------
train Loss: 0.0226 Acc: 0.9987


 60%|██████    | 30/50 [1:15:33<50:10, 150.54s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 30/49
----------
train Loss: 0.0226 Acc: 0.9987


 62%|██████▏   | 31/50 [1:18:03<47:36, 150.37s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 31/49
----------
train Loss: 0.0226 Acc: 0.9987


 64%|██████▍   | 32/50 [1:20:34<45:08, 150.49s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 32/49
----------
train Loss: 0.0226 Acc: 0.9987


 66%|██████▌   | 33/50 [1:23:04<42:38, 150.50s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 33/49
----------
train Loss: 0.0226 Acc: 0.9987


 68%|██████▊   | 34/50 [1:25:36<40:15, 150.95s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 34/49
----------
train Loss: 0.0226 Acc: 0.9987


 70%|███████   | 35/50 [1:28:07<37:42, 150.84s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 35/49
----------
train Loss: 0.0226 Acc: 0.9987


 72%|███████▏  | 36/50 [1:30:38<35:15, 151.09s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 36/49
----------
train Loss: 0.0226 Acc: 0.9987


 74%|███████▍  | 37/50 [1:33:07<32:36, 150.46s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 37/49
----------
train Loss: 0.0226 Acc: 0.9987


 76%|███████▌  | 38/50 [1:35:37<30:00, 150.05s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 38/49
----------
train Loss: 0.0226 Acc: 0.9987


 78%|███████▊  | 39/50 [1:38:07<27:33, 150.31s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 39/49
----------
train Loss: 0.0226 Acc: 0.9987


 80%|████████  | 40/50 [1:40:38<25:03, 150.34s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 40/49
----------
train Loss: 0.0226 Acc: 0.9987


 82%|████████▏ | 41/50 [1:43:08<22:31, 150.17s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 41/49
----------
train Loss: 0.0226 Acc: 0.9987


 84%|████████▍ | 42/50 [1:45:37<20:00, 150.05s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 42/49
----------
train Loss: 0.0226 Acc: 0.9987


 86%|████████▌ | 43/50 [1:48:09<17:33, 150.47s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 43/49
----------
train Loss: 0.0226 Acc: 0.9987


 88%|████████▊ | 44/50 [1:50:41<15:05, 150.93s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 44/49
----------
train Loss: 0.0226 Acc: 0.9987


 90%|█████████ | 45/50 [1:53:11<12:33, 150.71s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 45/49
----------
train Loss: 0.0226 Acc: 0.9987


 92%|█████████▏| 46/50 [1:55:47<10:08, 152.24s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 46/49
----------
train Loss: 0.0226 Acc: 0.9987


 94%|█████████▍| 47/50 [1:58:17<07:35, 151.74s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 47/49
----------
train Loss: 0.0226 Acc: 0.9987


 96%|█████████▌| 48/50 [2:00:49<05:03, 151.77s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 48/49
----------
train Loss: 0.0226 Acc: 0.9987


 98%|█████████▊| 49/50 [2:03:19<02:31, 151.28s/it]

val Loss: 0.0177 Acc: 1.0000

Epoch 49/49
----------
train Loss: 0.0226 Acc: 0.9987


100%|██████████| 50/50 [2:05:50<00:00, 151.00s/it]

val Loss: 0.0177 Acc: 1.0000

Training complete in 125m 50s
Best val Acc: 1.000000





TypeError: tuple indices must be integers or slices, not str

<Figure size 432x288 with 0 Axes>