Author: Anurag Vaidya  
Date: 2/4/2022  
Lab: Polina Lab @ CSAIL  
Purpose: Create a basic classifier for the afhq dataset

## Notebook Structure
- Imports
- Args
- Dataset
- Model
- Training/ Val loop
- main()

---

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn 
import torch.nn.functional as Fun
import torchvision.transforms.functional as F
from torch.utils.data import Dataset
from torchvision import models
import torch.optim as optim
from torch.optim import lr_scheduler

import os
import sys
import random
from argparse import Namespace
import time, copy

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

# import wandb
# wandb.init(project="cat-dog-styleSpace", entity="ajv012")


sys.path.append("./")
sys.path.append("../")

---

#### Args

In [15]:
args = Namespace(device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu"),
                 train_dir = "../../data/afhq/train",
                 val_dir = "../../data/afhq/val",
                 save_path = "./checkpoints",
                 seed = 7,
                 labels = ["cat", "dog"],
                 batch_size = 64,
                 epochs = 50,
                 num_workers = 0,
                 class_names = {0:"cat", 1:"dog"} ,
                 lr = 0.0001,
                 momentum = 0.9,
                 criterion = nn.CrossEntropyLoss(),
                 optimizer = "SGD",
                 scheduler = "STEP",
                 scheduler_step_size = 7,
                 scheduler_gamma = 0.1,
    )


---

#### Dataset class

In [3]:
class afhq_dataset(Dataset):
    r"""
    Take a root dir and return the transformed img and associated label with it
    """
    def __init__(self, root_dir, seed, labels, img_transform=None):

        self.seed = seed
        np.random.seed(self.seed)

        # this dir has two sub dirs cat and dog. Need to combine them
        self.root_dir = root_dir
        self.cat_names = os.listdir(os.path.join(self.root_dir, "cat"))
        self.dog_names = os.listdir(os.path.join(self.root_dir, "dog"))
        self.all_names = np.asarray(self.cat_names + self.dog_names)
        np.random.shuffle(self.all_names)
        self.img_transform = img_transform
        self.labels = {}
        for i in range(len(labels)):
            self.labels[labels[i]] = i
        

    def __len__(self):
        return len(self.all_names)

    def __getitem__(self, idx):
        curr_path = os.path.join(self.root_dir, self.all_names[idx].strip().split("_")[1], self.all_names[idx])
        curr_img = Image.open(curr_path)
        curr_label = self.labels[self.all_names[idx].strip().split("_")[1]]
        
        if self.img_transform:
            curr_img_transformed = self.img_transform(curr_img)
        
        return {"inputs" : curr_img_transformed, "labels" : curr_label} 
    
    def viz_img(self, imgs):
        r"""
        Take a tensor or list of tensors and visualize it
        """
        if not isinstance(imgs, list):
            imgs = [imgs]
        fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
        for i, img in enumerate(imgs):
            img = img.detach()
            img = F.to_pil_image(img)
            axs[0, i].imshow(np.asarray(img))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    
    def what_labels_mean(self):
        return [label + ": " + str(self.labels[label]) for label in self.labels]
        

---

#### Model class

In [4]:
class clf(torch.nn.Module):
    r"""
    A simple encoder and fully connected layer for classification
    """

    def __init__(self, num_classes):
        super(clf, self).__init__()
        self.model_ft = models.resnet18(pretrained=True)
        self.num_ftrs = self.model_ft.fc.in_features
        self.model_ft.fc = nn.Linear(self.num_ftrs, num_classes)

    def forward(self, x):
        x = self.model_ft(x)
        return x


---

#### Training loop

In [5]:
def train_and_val_model(model, datasets, dataloaders, device, criterion, optimizer, scheduler, PATH, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in tqdm(range(num_epochs)):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        running_loss_train = 0.0
        running_corrects_train = 0

        running_loss_val = 0.0
        running_corrects_val = 0

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            # Iterate over data.
            for batch in dataloaders[phase]:
                inputs, labels = batch["inputs"], batch["labels"]
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                if phase == "train":
                    running_loss_train += loss.item() * inputs.size(0)
                    running_corrects_train += torch.sum(preds == labels.data)
                else:
                    running_loss_val += loss.item() * inputs.size(0)
                    running_corrects_val += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            if phase == "train":
                epoch_loss = running_loss_train / len(datasets[phase])
                epoch_acc = running_corrects_train.double() / len(datasets[phase])
                wandb.log({"train_epoch_loss": epoch_loss, "train_epoch_acc": epoch_acc})
            else:
                epoch_loss = running_loss_val / len(datasets[phase])
                epoch_acc = running_corrects_val.double() / len(datasets[phase])
                wandb.log({"val_epoch_loss": epoch_loss, "val_epoch_acc": epoch_acc})

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                # save current best model
                PATH = "{}/checkpoint_{}.pt".format(args.save_path, epoch)
                torch.save({
                            'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'loss': epoch_loss,
                            'acc' : epoch_acc,
                            }, PATH)
                wandb.log({"best_acc": best_acc})

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [6]:
def visualize_model(model, dataloaders, device, class_names, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for batch in enumerate(dataloaders['val']):
            inputs, labels = batch["inputs"], batch["labels"]
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                plt.imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

---

#### Utils

In [7]:
def def_transforms():
    train_transforms = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(512),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    val_transforms = transforms.Compose([
        transforms.Resize(512),
        transforms.CenterCrop(512),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return train_transforms, val_transforms

def def_datasets(args, train_transforms, val_transforms):
    dataset_train = afhq_dataset(args.train_dir, args.seed, args.labels, train_transforms)
    dataset_val = afhq_dataset(args.val_dir, args.seed, args.labels, val_transforms)
    datasets = {"train": dataset_train, "val": dataset_val}
    dataset_sizes = {x: datasets[x] for x in ['train', 'val']}

    return datasets, dataset_sizes

def def_dataloaders(args, dataset_train, dataset_val):
    dataloader_train = DataLoader(dataset_train, batch_size = args.batch_size, num_workers=args.num_workers)
    dataloader_val = DataLoader(dataset_val, batch_size = args.batch_size, num_workers=args.num_workers)
    dataloaders = {"train": dataloader_train, "val":dataloader_val}

    return dataloaders

---

#### main()

In [18]:
def main():

    # define transforms
    train_transforms, val_transforms = def_transforms()

    # define datasets and sizes
    datasets, dataset_sizes = def_datasets(args, train_transforms, val_transforms)
    
    # define dataloaders
    dataloaders = def_dataloaders(args, datasets["train"], datasets["val"])
    
    # define model
    model = clf(len(args.labels))
    model = model.to(args.device)
    
    # define criterion
    criterion = args.criterion
    
    # define optim
    if args.optimizer == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    
    # define lr scheduler
    if args.scheduler == "STEP":
        scheduler = lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step_size, gamma=args.scheduler_gamma)

    # logging
    wandb.config = {
                    "learning_rate": args.lr,
                    "epochs": args.epochs,
                    "batch_size": args.batch_size
    }
    
    # train and val model
    model_final = train_and_val_model(model, datasets, dataloaders, args.device, 
                                     criterion, optimizer, scheduler, args.save_path, args.epochs)

    visualize_model(model_final, dataloaders, args.device, args.class_names)