In [1]:
%matplotlib inline
import logging

import os
import random
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set(style='darkgrid', font_scale=1.4)
from glob import glob
import wandb

# Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader, random_split, Dataset
from torchvision import datasets, transforms, models
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from torchvision.transforms import ToTensor, RandomCrop

from tqdm import tqdm
from sklearn.manifold import TSNE
from PIL import Image
import warnings
warnings.filterwarnings('ignore')
import statistics


import torchvision
import umap
from cycler import cycler

import pytorch_metric_learning
import pytorch_metric_learning.utils.logging_presets as logging_presets
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.utils.inference import InferenceModel, MatchFinder
from pytorch_metric_learning.utils import common_functions as c_f

logging.getLogger().setLevel(logging.INFO)
logging.info("VERSION %s" % pytorch_metric_learning.__version__)

import fuzzymatcher
from fuzzymatcher import link_table, fuzzy_left_join

INFO:root:VERSION 2.1.2


In [63]:
path = '/var/scratch/mxiao/data/'
os.chdir(path)

In [78]:
# Set the image transforms
normalize = transforms.Normalize(mean=[0.6195012,0.6195012,0.6195012], std=[0.3307451,0.3307451,0.3307451])
# normalize = transforms.Normalize(mean=[0.53997546,0.53997546,0.53997546], std=[0.36844322,0.36844322,0.36844322])

train_transform = transforms.Compose([
        transforms.RandomRotation(10),      # rotate +/- 10 degrees
        transforms.RandomHorizontalFlip(),  # reverse 50% of images
        transforms.RandomVerticalFlip(p=0.5),
#         transforms.Resize(224),             # resize shortest side to 224 pixels
#         transforms.CenterCrop(224),         # crop longest side to 224 pixels at center
        transforms.RandomCrop(size=(224,224),pad_if_needed=True), 
        transforms.ToTensor(),
        normalize
    ])

test_transform = transforms.Compose([
        transforms.RandomCrop((224,224),pad_if_needed=True),
        transforms.ToTensor(),
        normalize
    ])

In [88]:
batch_size = 32

dataset1 = datasets.ImageFolder(root=("model/train"),transform=train_transform)
dataset2 = datasets.ImageFolder(root=("model/test"),transform=test_transform)

class_dict = {i: class_name for i, class_name in enumerate(dataset1.classes)}

In [89]:
train_loader = DataLoader(dataset1, batch_size=32, shuffle=True,num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False,num_workers=4)
test_loader = DataLoader(dataset2, batch_size=32, shuffle=False,num_workers=4)

In [90]:
# class MLP(nn.Module):
#     # layer_sizes[0] is the dimension of the input
#     # layer_sizes[-1] is the dimension of the output
#     def __init__(self, layer_sizes, final_relu=False):
#         super().__init__()
#         layer_list = []
#         layer_sizes = [int(x) for x in layer_sizes]
#         num_layers = len(layer_sizes) - 1
#         final_relu_layer = num_layers if final_relu else num_layers - 1
#         for i in range(len(layer_sizes) - 1):
#             input_size = layer_sizes[i]
#             curr_size = layer_sizes[i + 1]
#             if i < final_relu_layer:
#                 layer_list.append(nn.ReLU(inplace=False))
#             layer_list.append(nn.Linear(input_size, curr_size))
#         self.net = nn.Sequential(*layer_list)
#         self.last_linear = self.net[-1]

#     def forward(self, x):
#         return self.net(x)
class MLP(nn.Module):
    # layer_sizes[0] is the dimension of the input
    # layer_sizes[-1] is the dimension of the output
    def __init__(self, layer_sizes, final_relu=False, dropout_rate=0.5): # you can adjust dropout_rate as per your needs
        super().__init__()
        layer_list = []
        layer_sizes = [int(x) for x in layer_sizes]
        num_layers = len(layer_sizes) - 1
        final_relu_layer = num_layers if final_relu else num_layers - 1
        for i in range(len(layer_sizes) - 1):
            input_size = layer_sizes[i]
            curr_size = layer_sizes[i + 1]
            if i < final_relu_layer:
                layer_list.append(nn.ReLU(inplace=False))
                layer_list.append(nn.Dropout(dropout_rate))  # add dropout layer after ReLU
            layer_list.append(nn.Linear(input_size, curr_size))
        self.net = nn.Sequential(*layer_list)
        self.last_linear = self.net[-1]

    def forward(self, x):
        return self.net(x)


In [91]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [92]:
# Resnet
# Set trunk model and replace the softmax layer with an identity function
trunk = torchvision.models.resnet50(pretrained=True).to(device)
trunk_output_size = trunk.fc.in_features
trunk.fc = nn.Identity()  #

In [93]:
import glob
best_trunk_weights = glob.glob('PML_v13/trunk_best*.pth'.format('Resnet50'))[0]
trunk.load_state_dict(torch.load(best_trunk_weights))

<All keys matched successfully>

In [109]:
# embedder = MLP([trunk_output_size, 748]).to(device)
embedder = MLP([trunk_output_size, 748], dropout_rate=0.5).to(device)  # example with custom dropout rate

In [110]:
best_embedder_weights = glob.glob('PML_v13/embedder_best*.pth'.format('Resnet50'))[0]
embedder.load_state_dict(torch.load(best_embedder_weights))

<All keys matched successfully>

In [111]:
# classifier = nn.DataParallel(MLP([748, len(class_dict)])).to(device)
classifier = nn.DataParallel(MLP([748, len(class_dict)], dropout_rate=0.3)).to(device)  # example with custom dropout rate

In [112]:
CFG = dict(
        epochs = 40
#         model = 'resnet101'
)

In [113]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(classifier.parameters(), lr=0.001, weight_decay=0.1)

# Learning rate scheduler
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG['epochs'])

In [114]:
def train_one_epoch(train_loader, model, criterion, optimizer, scheduler):
    # Train mode
    model.train()
    
    # Track metrics
    loss_epoch = 0
    accuracy_epoch = 0
    
    # Loop over minibatches
    for inputs, labels in tqdm(train_loader):
        # Send to device
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        with torch.no_grad():
            backbone_out = trunk(inputs)
            embedding_out = embedder(backbone_out)
            
        outputs = model(embedding_out)
        loss = criterion(outputs, labels)
        
#         loss.requires_grad = True
        # Backprop
        loss.backward()

        # Update parameters
        optimizer.step()

        # Zero gradients
        optimizer.zero_grad()

        # Track loss
        loss_epoch += loss.detach().item()
        
        # Accuracy
        _, preds = torch.max(outputs, 1)
        accuracy_epoch += torch.sum(preds == labels)/inputs.shape[0]
        
    # Update learning rate
    scheduler.step()
        
    return loss_epoch/len(train_loader), accuracy_epoch.item()/len(train_loader)

In [115]:
def evaluate_one_epoch(test_loader, model, criterion):
    # Eval mode
    model.eval()
    
    # Track metrics
    loss_epoch = 0
    accuracy_epoch = 0
    
    # Don't update weights
    with torch.no_grad():
        # Loop over minibatches
        for inputs, labels in tqdm(test_loader):
            # Send to device
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Forward pass
            backbone_out = trunk(inputs)
            embedding_out = embedder(backbone_out)
            
            outputs = model(embedding_out)
            loss = criterion(outputs, labels)
            
            # Track loss
            loss_epoch += loss.detach().item()
            
            # Accuracy
            _, preds = torch.max(outputs, 1)
            accuracy_epoch += torch.sum(preds == labels)/inputs.shape[0]
    
    return loss_epoch/len(test_loader), accuracy_epoch.item()/len(test_loader)

In [116]:
# Plot history
def plot_hist(train_loss_hist, test_loss_hist, train_acc_hist, test_acc_hist):    
    plt.figure(figsize=(15,4))
    plt.subplot(1,2,1)
    plt.plot(train_loss_hist, label='Train_Loss')
    plt.plot(test_loss_hist, label='Test_loss')
    plt.title('Cross Entropy Loss')
    plt.legend()
    
    plt.subplot(1,2,2)
    plt.plot(train_acc_hist, label='Train_Accuracy')
    plt.plot(test_acc_hist, label='Val_Accuracy')
    plt.title('Accuracy')
    plt.legend()
    plt.show()

In [117]:
def train_model(model, criterion, optimizer, scheduler, train_loader, test_loader, verbose=True):
    # Initialise outputs
    train_loss_hist = []
    test_loss_hist = []
    train_acc_hist = []
    test_acc_hist = []
    best_acc = 0.0
    model_path = './emb_res50_1024_model_v1.pth'
    
    # Loop over epochs
    for epoch in range(CFG['epochs']):
        # Train
        train_loss, train_accuracy = train_one_epoch(train_loader, model, criterion, optimizer, scheduler)
        
        # Evaluate
        test_loss, test_accuracy = evaluate_one_epoch(test_loader, model, criterion)
        
        # Track metrics
        train_loss_hist.append(train_loss)
        test_loss_hist.append(test_loss)
        train_acc_hist.append(train_accuracy)
        test_acc_hist.append(test_accuracy)
        
        # Log metrics
        wandb.log({
        'epoch': epoch,
        'train_loss': train_loss,
        'test_loss': test_loss,
        'train_accuracy': train_accuracy,
        'test_accuracy': test_accuracy
        })
        
        if test_accuracy > best_acc:
            best_acc = test_accuracy
            torch.save(model.state_dict(), model_path) 
            print('saving model with acc {:.3f}'.format(best_acc))
            
        # Print loss
        if verbose:
            if (epoch+1)%1==0:
                print(f'Epoch {epoch+1}/{CFG["epochs"]}, loss {train_loss:.5f}, test_loss {test_loss:.5f}, accuracy {train_accuracy:.5f}, test_accuracy {test_accuracy:.5f}')
    
    return train_loss_hist, test_loss_hist, train_acc_hist, test_acc_hist

In [118]:
# Initialise run
run = wandb.init(
                 project = 'comic-classification',
                 config = CFG,
                 save_code = True)

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇██
test_accuracy,▁▃▄▅▅▆▄▅▇▇▇█▇██▇▆▇█
test_loss,▁▁▁▂▃▃▄▄▅▅▆▆▇▇▇▇███
train_accuracy,▁▄▅▆▆▇▇▇▇▇▇████████
train_loss,█▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁

0,1
epoch,18.0
test_accuracy,0.02725
test_loss,7.25258
train_accuracy,0.5663
train_loss,1.92241


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669392585754395, max=1.0…

In [119]:
# Train model
train_loss_hist, test_loss_hist, train_acc_hist, test_acc_hist = train_model(classifier, criterion, optimizer, scheduler, train_loader, test_loader, verbose=True)

100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [01:57<00:00,  9.46it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.92it/s]


saving model with acc 0.017
Epoch 1/40, loss 5.26276, test_loss 6.04003, accuracy 0.12238, test_accuracy 0.01657


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:06<00:00,  8.84it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.82it/s]


saving model with acc 0.019
Epoch 2/40, loss 4.05254, test_loss 6.01599, accuracy 0.25511, test_accuracy 0.01902


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:05<00:00,  8.86it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.81it/s]


Epoch 3/40, loss 3.54829, test_loss 6.08539, accuracy 0.32170, test_accuracy 0.01802


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:05<00:00,  8.90it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.85it/s]


saving model with acc 0.023
Epoch 4/40, loss 3.25126, test_loss 6.16455, accuracy 0.35773, test_accuracy 0.02285


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.97it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.92it/s]


Epoch 5/40, loss 3.04045, test_loss 6.25099, accuracy 0.38544, test_accuracy 0.02274


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.91it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.93it/s]


saving model with acc 0.023
Epoch 6/40, loss 2.90461, test_loss 6.34745, accuracy 0.40239, test_accuracy 0.02313


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.93it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.95it/s]


saving model with acc 0.023
Epoch 7/40, loss 2.80178, test_loss 6.42081, accuracy 0.41533, test_accuracy 0.02330


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.93it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.94it/s]


saving model with acc 0.024
Epoch 8/40, loss 2.70563, test_loss 6.51785, accuracy 0.42555, test_accuracy 0.02363


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.95it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.93it/s]


saving model with acc 0.025
Epoch 9/40, loss 2.64207, test_loss 6.59020, accuracy 0.43578, test_accuracy 0.02452


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.93it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.94it/s]


saving model with acc 0.026
Epoch 10/40, loss 2.58473, test_loss 6.65070, accuracy 0.44066, test_accuracy 0.02558


 30%|███████████████████████▉                                                        | 333/1114 [00:37<01:28,  8.86it/s]


KeyboardInterrupt: 