In [1]:
%matplotlib inline
import logging

import os
import random
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set(style='darkgrid', font_scale=1.4)
from glob import glob
import wandb

# Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader, random_split, Dataset
from torchvision import datasets, transforms, models
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from torchvision.transforms import ToTensor, RandomCrop

from tqdm import tqdm
from sklearn.manifold import TSNE
from PIL import Image
import warnings
warnings.filterwarnings('ignore')
import statistics


import torchvision
import umap
from cycler import cycler

import pytorch_metric_learning
import pytorch_metric_learning.utils.logging_presets as logging_presets
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.utils.inference import InferenceModel, MatchFinder
from pytorch_metric_learning.utils import common_functions as c_f

logging.getLogger().setLevel(logging.INFO)
logging.info("VERSION %s" % pytorch_metric_learning.__version__)

import fuzzymatcher
from fuzzymatcher import link_table, fuzzy_left_join

INFO:root:VERSION 2.1.2


In [63]:
path = '/var/scratch/mxiao/data/'
os.chdir(path)

In [78]:
# Set the image transforms
normalize = transforms.Normalize(mean=[0.6195012,0.6195012,0.6195012], std=[0.3307451,0.3307451,0.3307451])
# normalize = transforms.Normalize(mean=[0.53997546,0.53997546,0.53997546], std=[0.36844322,0.36844322,0.36844322])

train_transform = transforms.Compose([
        transforms.RandomRotation(10),      # rotate +/- 10 degrees
        transforms.RandomHorizontalFlip(),  # reverse 50% of images
        transforms.RandomVerticalFlip(p=0.5),
#         transforms.Resize(224),             # resize shortest side to 224 pixels
#         transforms.CenterCrop(224),         # crop longest side to 224 pixels at center
        transforms.RandomCrop(size=(224,224),pad_if_needed=True), 
        transforms.ToTensor(),
        normalize
    ])

test_transform = transforms.Compose([
        transforms.RandomCrop((224,224),pad_if_needed=True),
        transforms.ToTensor(),
        normalize
    ])

In [88]:
batch_size = 32

dataset1 = datasets.ImageFolder(root=("model/train"),transform=train_transform)
dataset2 = datasets.ImageFolder(root=("model/test"),transform=test_transform)

class_dict = {i: class_name for i, class_name in enumerate(dataset1.classes)}

In [89]:
train_loader = DataLoader(dataset1, batch_size=32, shuffle=True,num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False,num_workers=4)
test_loader = DataLoader(dataset2, batch_size=32, shuffle=False,num_workers=4)

In [90]:
# class MLP(nn.Module):
#     # layer_sizes[0] is the dimension of the input
#     # layer_sizes[-1] is the dimension of the output
#     def __init__(self, layer_sizes, final_relu=False):
#         super().__init__()
#         layer_list = []
#         layer_sizes = [int(x) for x in layer_sizes]
#         num_layers = len(layer_sizes) - 1
#         final_relu_layer = num_layers if final_relu else num_layers - 1
#         for i in range(len(layer_sizes) - 1):
#             input_size = layer_sizes[i]
#             curr_size = layer_sizes[i + 1]
#             if i < final_relu_layer:
#                 layer_list.append(nn.ReLU(inplace=False))
#             layer_list.append(nn.Linear(input_size, curr_size))
#         self.net = nn.Sequential(*layer_list)
#         self.last_linear = self.net[-1]

#     def forward(self, x):
#         return self.net(x)
class MLP(nn.Module):
    # layer_sizes[0] is the dimension of the input
    # layer_sizes[-1] is the dimension of the output
    def __init__(self, layer_sizes, final_relu=False, dropout_rate=0.5): # you can adjust dropout_rate as per your needs
        super().__init__()
        layer_list = []
        layer_sizes = [int(x) for x in layer_sizes]
        num_layers = len(layer_sizes) - 1
        final_relu_layer = num_layers if final_relu else num_layers - 1
        for i in range(len(layer_sizes) - 1):
            input_size = layer_sizes[i]
            curr_size = layer_sizes[i + 1]
            if i < final_relu_layer:
                layer_list.append(nn.ReLU(inplace=False))
                layer_list.append(nn.Dropout(dropout_rate))  # add dropout layer after ReLU
            layer_list.append(nn.Linear(input_size, curr_size))
        self.net = nn.Sequential(*layer_list)
        self.last_linear = self.net[-1]

    def forward(self, x):
        return self.net(x)


In [91]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [92]:
# Resnet
# Set trunk model and replace the softmax layer with an identity function
trunk = torchvision.models.resnet50(pretrained=True).to(device)
trunk_output_size = trunk.fc.in_features
trunk.fc = nn.Identity()  #

In [93]:
import glob
best_trunk_weights = glob.glob('PML_v13/trunk_best*.pth'.format('Resnet50'))[0]
trunk.load_state_dict(torch.load(best_trunk_weights))

<All keys matched successfully>

In [136]:
# embedder = MLP([trunk_output_size, 748]).to(device)
embedder = MLP([trunk_output_size, 748], dropout_rate=0.3).to(device)  # example with custom dropout rate

In [137]:
best_embedder_weights = glob.glob('PML_v13/embedder_best*.pth'.format('Resnet50'))[0]
embedder.load_state_dict(torch.load(best_embedder_weights))

<All keys matched successfully>

In [138]:
# classifier = nn.DataParallel(MLP([748, len(class_dict)])).to(device)
classifier = nn.DataParallel(MLP([748, len(class_dict)], dropout_rate=0)).to(device)  # example with custom dropout rate

In [139]:
CFG = dict(
        epochs = 40
#         model = 'resnet101'
)

In [140]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(classifier.parameters(), lr=0.001, weight_decay=0.1)

# Learning rate scheduler
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG['epochs'])

In [141]:
def train_one_epoch(train_loader, model, criterion, optimizer, scheduler):
    # Train mode
    model.train()
    
    # Track metrics
    loss_epoch = 0
    accuracy_epoch = 0
    
    # Loop over minibatches
    for inputs, labels in tqdm(train_loader):
        # Send to device
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        with torch.no_grad():
            backbone_out = trunk(inputs)
            embedding_out = embedder(backbone_out)
            
        outputs = model(embedding_out)
        loss = criterion(outputs, labels)
        
#         loss.requires_grad = True
        # Backprop
        loss.backward()

        # Update parameters
        optimizer.step()

        # Zero gradients
        optimizer.zero_grad()

        # Track loss
        loss_epoch += loss.detach().item()
        
        # Accuracy
        _, preds = torch.max(outputs, 1)
        accuracy_epoch += torch.sum(preds == labels)/inputs.shape[0]
        
    # Update learning rate
    scheduler.step()
        
    return loss_epoch/len(train_loader), accuracy_epoch.item()/len(train_loader)

In [142]:
def evaluate_one_epoch(test_loader, model, criterion):
    # Eval mode
    model.eval()
    
    # Track metrics
    loss_epoch = 0
    accuracy_epoch = 0
    
    # Don't update weights
    with torch.no_grad():
        # Loop over minibatches
        for inputs, labels in tqdm(test_loader):
            # Send to device
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Forward pass
            backbone_out = trunk(inputs)
            embedding_out = embedder(backbone_out)
            
            outputs = model(embedding_out)
            loss = criterion(outputs, labels)
            
            # Track loss
            loss_epoch += loss.detach().item()
            
            # Accuracy
            _, preds = torch.max(outputs, 1)
            accuracy_epoch += torch.sum(preds == labels)/inputs.shape[0]
    
    return loss_epoch/len(test_loader), accuracy_epoch.item()/len(test_loader)

In [143]:
# Plot history
def plot_hist(train_loss_hist, test_loss_hist, train_acc_hist, test_acc_hist):    
    plt.figure(figsize=(15,4))
    plt.subplot(1,2,1)
    plt.plot(train_loss_hist, label='Train_Loss')
    plt.plot(test_loss_hist, label='Test_loss')
    plt.title('Cross Entropy Loss')
    plt.legend()
    
    plt.subplot(1,2,2)
    plt.plot(train_acc_hist, label='Train_Accuracy')
    plt.plot(test_acc_hist, label='Val_Accuracy')
    plt.title('Accuracy')
    plt.legend()
    plt.show()

In [144]:
def train_model(model, criterion, optimizer, scheduler, train_loader, test_loader, verbose=True):
    # Initialise outputs
    train_loss_hist = []
    test_loss_hist = []
    train_acc_hist = []
    test_acc_hist = []
    best_acc = 0.0
    model_path = './emb_res50_1024_model_v1.pth'
    
    # Loop over epochs
    for epoch in range(CFG['epochs']):
        # Train
        train_loss, train_accuracy = train_one_epoch(train_loader, model, criterion, optimizer, scheduler)
        
        # Evaluate
        test_loss, test_accuracy = evaluate_one_epoch(test_loader, model, criterion)
        
        # Track metrics
        train_loss_hist.append(train_loss)
        test_loss_hist.append(test_loss)
        train_acc_hist.append(train_accuracy)
        test_acc_hist.append(test_accuracy)
        
        # Log metrics
        wandb.log({
        'epoch': epoch,
        'train_loss': train_loss,
        'test_loss': test_loss,
        'train_accuracy': train_accuracy,
        'test_accuracy': test_accuracy
        })
        
        if test_accuracy > best_acc:
            best_acc = test_accuracy
            torch.save(model.state_dict(), model_path) 
            print('saving model with acc {:.3f}'.format(best_acc))
            
        # Print loss
        if verbose:
            if (epoch+1)%1==0:
                print(f'Epoch {epoch+1}/{CFG["epochs"]}, loss {train_loss:.5f}, test_loss {test_loss:.5f}, accuracy {train_accuracy:.5f}, test_accuracy {test_accuracy:.5f}')
    
    return train_loss_hist, test_loss_hist, train_acc_hist, test_acc_hist

In [145]:
# Initialise run
run = wandb.init(
                 project = 'comic-classification',
                 config = CFG,
                 save_code = True)

0,1
epoch,▁▂▃▄▅▅▆▇█
test_accuracy,▃▃▁█▅██▆▄
test_loss,▁▄▅▆▇▇▇██
train_accuracy,▁▅▆▇▇▇███
train_loss,█▃▂▂▂▁▁▁▁

0,1
epoch,8.0
test_accuracy,0.02213
test_loss,8.55639
train_accuracy,0.41289
train_loss,2.40332


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666931832830111, max=1.0)…

In [146]:
# Train model
train_loss_hist, test_loss_hist, train_acc_hist, test_acc_hist = train_model(classifier, criterion, optimizer, scheduler, train_loader, test_loader, verbose=True)

100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:03<00:00,  9.03it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.81it/s]


saving model with acc 0.018
Epoch 1/40, loss 3.77444, test_loss 7.00733, accuracy 0.23849, test_accuracy 0.01757


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.91it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.91it/s]


saving model with acc 0.023
Epoch 2/40, loss 2.80619, test_loss 7.62297, accuracy 0.34503, test_accuracy 0.02280


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.94it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.96it/s]


saving model with acc 0.023
Epoch 3/40, loss 2.62595, test_loss 8.01164, accuracy 0.37155, test_accuracy 0.02291


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.93it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.92it/s]


Epoch 4/40, loss 2.54976, test_loss 8.19067, accuracy 0.38811, test_accuracy 0.02230


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:05<00:00,  8.90it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.83it/s]


saving model with acc 0.023
Epoch 5/40, loss 2.50819, test_loss 8.33836, accuracy 0.39418, test_accuracy 0.02335


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.94it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.96it/s]


Epoch 6/40, loss 2.48648, test_loss 8.38862, accuracy 0.39559, test_accuracy 0.02285


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.94it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.95it/s]


Epoch 7/40, loss 2.46241, test_loss 8.42371, accuracy 0.40348, test_accuracy 0.02219


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:05<00:00,  8.91it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:05<00:00,  8.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.93it/s]


saving model with acc 0.025
Epoch 9/40, loss 2.42950, test_loss 8.50039, accuracy 0.41092, test_accuracy 0.02491


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:05<00:00,  8.91it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.91it/s]


Epoch 10/40, loss 2.39755, test_loss 8.53018, accuracy 0.41390, test_accuracy 0.02285


 27%|█████████████████████▊                                                          | 304/1114 [00:34<01:29,  9.05it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.92it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.90it/s]


Epoch 13/40, loss 2.35364, test_loss 8.43852, accuracy 0.42166, test_accuracy 0.02374


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:05<00:00,  8.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.91it/s]


Epoch 14/40, loss 2.30708, test_loss 8.48941, accuracy 0.43252, test_accuracy 0.02430


 34%|███████████████████████████▏                                                    | 379/1114 [00:42<01:23,  8.85it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.82it/s]


saving model with acc 0.025
Epoch 17/40, loss 2.27792, test_loss 8.35888, accuracy 0.44058, test_accuracy 0.02524


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:06<00:00,  8.83it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.93it/s]


Epoch 18/40, loss 2.25146, test_loss 8.36309, accuracy 0.44671, test_accuracy 0.02458


100%|██████████████████████████████████████████████████████████████████████████████▊| 1112/1114 [02:04<00:00,  9.06it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:05<00:00,  8.88it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.95it/s]


saving model with acc 0.026
Epoch 22/40, loss 2.17794, test_loss 8.29599, accuracy 0.46440, test_accuracy 0.02563


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.92it/s]
 59%|███████████████████████████████████████████████▌                                 | 330/562 [00:37<00:25,  8.99it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.94it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.94it/s]


Epoch 26/40, loss 2.11210, test_loss 8.17989, accuracy 0.48440, test_accuracy 0.02563


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.93it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.92it/s]


saving model with acc 0.027
Epoch 27/40, loss 2.10346, test_loss 8.15655, accuracy 0.48945, test_accuracy 0.02691


 49%|██████████████████████████████████████▊                                         | 541/1114 [01:00<01:03,  8.95it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.91it/s]


Epoch 30/40, loss 2.05462, test_loss 8.11727, accuracy 0.50348, test_accuracy 0.02658


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.91it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.89it/s]


Epoch 31/40, loss 2.05977, test_loss 8.10156, accuracy 0.50310, test_accuracy 0.02769


 78%|██████████████████████████████████████████████████████████████▌                 | 871/1114 [01:37<00:27,  8.98it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:06<00:00,  8.83it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.79it/s]


Epoch 35/40, loss 2.01132, test_loss 8.06296, accuracy 0.51840, test_accuracy 0.02608


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:05<00:00,  8.87it/s]
 90%|████████████████████████████████████████████████████████████████████████▉        | 506/562 [00:56<00:06,  8.93it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.91it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:02<00:00,  8.93it/s]


Epoch 39/40, loss 1.98838, test_loss 8.04946, accuracy 0.53277, test_accuracy 0.02725


100%|███████████████████████████████████████████████████████████████████████████████| 1114/1114 [02:04<00:00,  8.95it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 562/562 [01:03<00:00,  8.81it/s]

Epoch 40/40, loss 1.98873, test_loss 8.03331, accuracy 0.53008, test_accuracy 0.02869



