# ECE 271B Final Project

## GPU status

In [None]:
!nvidia-smi

## Import packages

In [None]:
import numpy as np
import pandas as pd
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18, ResNet18_Weights, vgg16, VGG16_Weights

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

from APINet.models import API_Net 

## Model training

In [None]:
def train(model, dataset, config, csv_path, checkpoint_dir=None):
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
#         if torch.cuda.device_count() > 1:
#             model = nn.DataParallel(model)
    model.to(device)

    criterion = nn.CrossEntropyLoss()
#     optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9)
    optimizer = optim.Adam(model.parameters(), lr=config["lr"], weight_decay=0.00002)

    # The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint
    # should be restored.
    if checkpoint_dir:
        checkpoint = os.path.join(checkpoint_dir, "checkpoint")
        model_state, optimizer_state = torch.load(checkpoint)
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    dataloaders = {'train': DataLoader(dataset['train'], batch_size=config["batch_size_train"],
                    shuffle=True, drop_last = True, num_workers=8),
                   'val': DataLoader(dataset['val'], batch_size=config["batch_size_val"],
                    shuffle=True, drop_last = True, num_workers=8),}
    
    acc_train, acc_val, loss_train, loss_val = [], [], [], []
    best_model = None
    best_acc = 0
    
    for epoch in range(config["epoch"]):
        print(f"Epoch: {epoch+1}")
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            total, epoch = 0, 0
            running_loss, running_corrects = 0.0, 0.0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                total += labels.size(0)

                # zero the parameter gradients
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item()
                running_corrects += torch.sum(preds == labels.data).item()

            if phase == 'train':
                train_loss = running_loss / total
                train_acc = running_corrects / total
                acc_train.append(train_acc)
                loss_train.append(train_loss)
            elif phase == 'val':
                val_loss = running_loss / total
                val_acc = running_corrects / total
                acc_val.append(val_acc)
                loss_val.append(val_loss)
                if val_acc > best_acc:
                    best_acc = val_acc
                    best_model = model
                
        print(f"train_acc: {train_acc:.4f}", f"train_loss: {train_loss:.4f}",
              f"val_acc: {val_acc:.4f}", f"{val_loss:.4f}")

        if checkpoint_dir and (epoch + 1) // 5 == 0:
            path = os.path.join(checkpoint_dir, f"checkpoint_{epoch+1}")
            torch.save(
                (model.state_dict(), optimizer.state_dict()), path)
            
    record = pd.DataFrame(
    {'train_loss': loss_train,
     'train_acc': acc_train,
     'val_loss': loss_val,
     'val_acc': acc_val
    })
    record.to_csv(csv_path)
    
    return acc_train, acc_val, loss_train, loss_val, best_model

        

## Model evaluation

In [None]:
def eval(model, dataset):
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
#         if torch.cuda.device_count() > 1:
#             model = nn.DataParallel(model)
    model.to(device)

    criterion = nn.CrossEntropyLoss()

    dataloaders = DataLoader(dataset, batch_size=1,
                    shuffle=False, drop_last = True, num_workers=8)
    
    acc_train, acc_val, loss_train, loss_val = [], [], [], []
    best_model = None
    best_acc = 0
    all_preds = []
    all_labels = []
    
    model.eval()

    total, epoch = 0, 0
    running_loss, running_corrects = 0.0, 0.0

    # Iterate over data.
    for inputs, labels in dataloaders:
        inputs = inputs.to(device)
        labels = labels.to(device)
        total += labels.size(0)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)

        # statistics
        running_loss += loss.item()
        
        all_preds.append(preds.item())
        all_labels.append(labels.item())
        
    acc = accuracy_score(all_preds, all_labels)
    prec = f1_score(all_preds, all_labels)
    f1 = precision_score(all_preds, all_labels)
    rec = recall_score(all_preds, all_labels)

    print(f"val_acc: {acc:.4f}", f"val_loss: {running_loss/len(all_preds):.4f}",
          f'precision: {prec:.4f}', f'f1 score: {f1:.4f}', f'recall: {rec:.4f}')
    
    return

        

## Model with non-augmented images

In [None]:
resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
VGG = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)

resnet.fc = nn.Linear(resnet.fc.in_features, 2)
VGG.classifier[6] = nn.Linear(VGG.classifier[6].in_features, 2)

In [None]:
config = {
    "img_size": (128, 128), #768,
    "epoch": 20,
    "lr": 0.0001,
    "batch_size_train": 32,
    "batch_size_val": 32,
}

In [None]:
transform=transforms.Compose([
    transforms.Resize(config['img_size']),
    transforms.ToTensor(),
    ])

In [None]:
dataset_train = ImageFolder('./data_mhist/dataset/train', transform=transform)
dataset_test = ImageFolder('./data_mhist/dataset/val', transform=transform)

dataset = {"train": dataset_train, "val": dataset_test}

In [None]:
acc_train_res, acc_val_res, loss_train_res, loss_val_res, resnet = train(resnet, dataset, config, f'./log_resnet_noAug.csv')
torch.save(resnet.state_dict(), f'./resnet_noAug.pt')

In [None]:
acc_train_vgg, acc_val_vgg, loss_train_vgg, loss_val_vgg, VGG = train(VGG, dataset, config, f'./log_VGG_noAug.csv')
torch.save(VGG.state_dict(), f'./vgg_noAug.pt')

In [None]:
resnet.load_state_dict(torch.load(f'./resnet_noAug.pt'))
VGG.load_state_dict(torch.load(f'./vgg_noAug.pt'))

In [None]:
eval(resnet, dataset['val'])

In [None]:
eval(VGG, dataset['val'])

In [None]:
pytorch_total_params = sum(p.numel() for p in resnet.parameters())
pytorch_total_params

## Baseline model

In [None]:
resnet_base = resnet18()
VGG_base = vgg16()

resnet_base.fc = nn.Linear(resnet_base.fc.in_features, 2)
VGG_base.classifier[6] = nn.Linear(VGG_base.classifier[6].in_features, 2)

In [None]:
config = {
    "img_size": (128, 128), #768,
    "epoch": 300,
    "lr": 0.0001,
    "batch_size_train": 32,
    "batch_size_val": 32,
}

In [None]:
transform=transforms.Compose([
    transforms.Resize(config['img_size']),
    transforms.ToTensor(),
    ])

In [None]:
dataset_train = ImageFolder('./data_mhist/dataset/train', transform=transform)
dataset_test = ImageFolder('./data_mhist/dataset/val', transform=transform)

dataset = {"train": dataset_train, "val": dataset_test}

In [None]:
acc_train_res_base, acc_val_res_base, loss_train_res_base, loss_val_res_base, resnet_base = train(resnet_base, dataset, config, f'./log_resnet_noAug_base.csv')
torch.save(resnet_base.state_dict(), f'./resnet_noAug_base.pt')

In [None]:
acc_train_vgg_base, acc_val_vgg_base, loss_train_vgg_base, loss_val_vgg_base, VGG = train(VGG_base, dataset, config, f'./log_VGG_noAug_base.csv')
torch.save(VGG_base.state_dict(), f'./vgg_noAug_base.pt')

In [None]:
resnet_base.load_state_dict(torch.load(f'./resnet_noAug_base.pt'))
VGG_base.load_state_dict(torch.load(f'./vgg_noAug_base.pt'))

In [None]:
eval(resnet_base, dataset['val'])

In [None]:
eval(VGG_base, dataset['val'])

## Baseline model with augmentation

In [None]:
resnet_base_aug = resnet18()
VGG_base_aug = vgg16()

resnet_base_aug.fc = nn.Linear(resnet_base_aug.fc.in_features, 2)
VGG_base_aug.classifier[6] = nn.Linear(VGG_base_aug.classifier[6].in_features, 2)

In [None]:
config = {
    "img_size": (128, 128), #768,
    "epoch": 5,
    "lr": 0.00001,
    "batch_size_train": 32,
    "batch_size_val": 32,
}

In [None]:
transform_aug=transforms.Compose([
    transforms.Resize(config['img_size']),
    transforms.RandomRotation(45),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    ])

In [None]:
aug_dataset_train = ImageFolder('./data_mhist/dataset/train', transform=transform_aug)
aug_dataset_test = ImageFolder('./data_mhist/dataset/val', transform=transform)

aug_dataset = {"train": aug_dataset_train, "val": aug_dataset_test}

In [None]:
acc_train_res_base_aug, acc_val_res_base_aug, loss_train_res_base_aug, loss_val_res_base_aug, resnet_base_aug = train(resnet_base_aug, aug_dataset, config, f'./log_resnet_base_aug.csv')
torch.save(resnet_base_aug.state_dict(), f'./resnet_base_aug.pt')

In [None]:
acc_train_vgg_base_aug, acc_val_vgg_base_aug, loss_train_vgg_base_aug, loss_val_vgg_base_aug, VGG_base_aug = train(VGG_base_aug, aug_dataset, config, f'./log_VGG_base_aug.csv')
torch.save(VGG_base_aug.state_dict(), f'./vgg_base_aug.pt')

In [None]:
resnet_base_aug.load_state_dict(torch.load(f'./resnet_base_aug.pt'))
VGG_base_aug.load_state_dict(torch.load(f'./vgg_base_aug.pt'))

In [None]:
eval(resnet_base_aug, dataset['val'])

In [None]:
eval(VGG_base_aug, dataset['val'])

## Model with augmented images

In [None]:
resnet_aug = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
VGG_aug = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)

resnet_aug.fc = nn.Linear(resnet_aug.fc.in_features, 2)
VGG_aug.classifier[6] = nn.Linear(VGG_aug.classifier[6].in_features, 2)

In [None]:
config = {
    "img_size": (128, 128), #768,
    "epoch": 20,
    "lr": 0.0001,
    "batch_size_train": 32,
    "batch_size_val": 32,
}

In [None]:
transform_aug=transforms.Compose([
    transforms.Resize(config['img_size']),
    transforms.RandomRotation(45),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    ])

In [None]:
aug_dataset_train = ImageFolder('./data_mhist/dataset/train', transform=transform_aug)
aug_dataset_test = ImageFolder('./data_mhist/dataset/val', transform=transform)

aug_dataset = {"train": aug_dataset_train, "val": aug_dataset_test}

In [None]:
acc_train_res_aug, acc_val_res_aug, loss_train_res_aug, loss_val_res_aug, resnet_aug = train(resnet_aug, aug_dataset, config, f'./log_resnet_aug.csv')
torch.save(resnet_aug.state_dict(), f'./resnet_aug.pt')

In [None]:
acc_train_vgg_aug, acc_val_vgg_aug, loss_train_vgg_aug, loss_val_vgg_aug, VGG_aug = train(VGG_aug, aug_dataset, config, f'./log_VGG_aug.csv')
torch.save(VGG_aug.state_dict(), f'./vgg_aug.pt')

In [None]:
resnet_base_aug.load_state_dict(torch.load(f'./resnet_aug.pt'))
VGG_base_aug.load_state_dict(torch.load(f'./vgg_aug.pt'))

In [None]:
eval(resnet_base_aug, dataset['val'])

In [None]:
eval(VGG_base_aug, dataset['val'])

## Fine-grained model with non-augmented images

In [None]:
def correct_fg(scores, targets, k):
    """
    Computes top-k accuracy, from predicted and true labels.

    :param scores: scores from the model
    :param targets: true labels
    :param k: k in top-k accuracy
    :return: top-k accuracy
    """

    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item()

In [None]:
def train_fine_grained(model, dataset, config, csv_path, checkpoint_dir=None):
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
#         if torch.cuda.device_count() > 1:
#             model = nn.DataParallel(model)
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    rank_criterion = nn.MarginRankingLoss(margin=0.05)
    softmax_layer = nn.Softmax(dim=1).to(device)
#     optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9)
    optimizer = optim.Adam(model.parameters(), lr=config["lr"], weight_decay=0.00002)

    # The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint
    # should be restored.
    if checkpoint_dir:
        checkpoint = os.path.join(checkpoint_dir, "checkpoint")
        model_state, optimizer_state = torch.load(checkpoint)
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    dataloaders = {'train': DataLoader(dataset['train'], batch_size=config["batch_size_train"],
                    shuffle=True, drop_last = True, num_workers=8),
                   'val': DataLoader(dataset['val'], batch_size=config["batch_size_val"],
                    shuffle=True, drop_last = True, num_workers=8),}
    
    acc_train, acc_val, loss_train, loss_val = [], [], [], []
    best_model = None
    best_acc = 0
    
    for epoch in range(config["epoch"]):
        print(f"Epoch: {epoch+1}")
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            total, epoch = 0, 0
            running_loss, running_corrects = 0.0, 0.0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    logit1_self, logit1_other, logit2_self, logit2_other, labels1, labels2 = model(inputs, labels, flag='train')
                    batch_size = logit1_self.shape[0]
                    
                    self_logits = torch.zeros(2*batch_size, 2).to(device) # 2 classes
                    other_logits= torch.zeros(2*batch_size, 2).to(device)
                    self_logits[:batch_size] = logit1_self
                    self_logits[batch_size:] = logit2_self
                    other_logits[:batch_size] = logit1_other
                    other_logits[batch_size:] = logit2_other

                    # compute loss
                    logits = torch.cat([self_logits, other_logits], dim=0)
                    targets = torch.cat([labels1, labels2, labels1, labels2], dim=0)
                    softmax_loss = criterion(logits, targets)

                    self_scores = softmax_layer(self_logits)[torch.arange(2*batch_size).to(device).long(),
                                                                     torch.cat([labels1, labels2], dim=0)]
                    other_scores = softmax_layer(other_logits)[torch.arange(2*batch_size).to(device).long(),
                                                                     torch.cat([labels1, labels2], dim=0)]
                    flag = torch.ones([2*batch_size, ]).to(device)
                    rank_loss = rank_criterion(self_scores, other_scores, flag)

                    loss = softmax_loss + rank_loss

#                     _, preds = torch.max(logits, 1)
                    total += targets.size(0)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item()
                running_corrects += correct_fg(logits, targets, 1)

            if phase == 'train':
                train_loss = running_loss / total
                train_acc = running_corrects / total
                acc_train.append(train_acc)
                loss_train.append(train_loss)
            elif phase == 'val':
                val_loss = running_loss / total
                val_acc = running_corrects / total
                acc_val.append(val_acc)
                loss_val.append(val_loss)
                if val_acc > best_acc:
                    best_acc = val_acc
                    best_model = model
                
        print(f"train_acc: {train_acc:.4f}", f"train_loss: {train_loss:.4f}",
              f"val_acc: {val_acc:.4f}", f"{val_loss:.4f}")

        if checkpoint_dir and (epoch + 1) // 5 == 0:
            path = os.path.join(checkpoint_dir, f"checkpoint_{epoch+1}")
            torch.save(
                (model.state_dict(), optimizer.state_dict()), path)
            
    record = pd.DataFrame(
    {'train_loss': loss_train,
     'train_acc': acc_train,
     'val_loss': loss_val,
     'val_acc': acc_val
    })
    record.to_csv(csv_path)
    print(f'Best Acc: {best_acc}')
    
    return acc_train, acc_val, loss_train, loss_val, best_model

In [None]:
def eval_fine_grained(model, dataset):
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
#         if torch.cuda.device_count() > 1:
#             model = nn.DataParallel(model)
    model.to(device)

    criterion = nn.CrossEntropyLoss()

    dataloaders = DataLoader(dataset, batch_size=1,
                    shuffle=False, drop_last = True, num_workers=8)
    
    acc_train, acc_val, loss_train, loss_val = [], [], [], []
    best_model = None
    best_acc = 0
    all_preds = []
    all_labels = []
    
    model.eval()

    total, epoch = 0, 0
    running_loss, running_corrects = 0.0, 0.0

    # Iterate over data.
    for inputs, labels in dataloaders:
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            logits = model(inputs, targets=None, flag='val')
            logits = logits.unsqueeze(0)
            loss = criterion(logits, labels)

            total += logits.shape[0]
            _, preds = torch.max(logits, 1)

        all_preds.extend(preds.cpu())
        all_labels.extend(labels.cpu())

        # statistics
        running_loss += loss.item()
        running_corrects += torch.sum(logits == labels.data).item()
                
    acc = accuracy_score(all_preds, all_labels)
    prec = f1_score(all_preds, all_labels)
    f1 = precision_score(all_preds, all_labels)
    rec = recall_score(all_preds, all_labels)

    print(f"val_acc: {acc:.4f}", f"val_loss: {running_loss/len(all_preds):.4f}",
          f'precision: {prec:.4f}', f'f1 score: {f1:.4f}', f'recall: {rec:.4f}')
    
    return

In [None]:
config = {
    "img_size": (512, 512), #768,
    "epoch": 5,
    "lr": 0.000001,
    "batch_size_train": 8,
    "batch_size_val": 8,
}

In [None]:
transform=transforms.Compose([
    transforms.Resize(config['img_size']),
    transforms.ToTensor(),
    ])

In [None]:
dataset_train = ImageFolder('./data_mhist/dataset/train', transform=transform)
dataset_test = ImageFolder('./data_mhist/dataset/val', transform=transform)

dataset = {"train": dataset_train, "val": dataset_test}

In [None]:
API_net = API_Net()
API_net.fc = nn.Linear(API_net.fc.in_features, 2)

In [None]:
acc_train_api, acc_val_api, loss_train_api, loss_val_api, API_net = train_fine_grained(API_net, dataset, config, f'./log_api_noaug.csv')
torch.save(API_net.state_dict(), f'./api_noaug.pt')

In [None]:
transform_aug=transforms.Compose([
    transforms.Resize(config['img_size']),
    transforms.RandomRotation(45),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    ])

In [None]:
aug_dataset_train = ImageFolder('./data_mhist/dataset/train', transform=transform_aug)
aug_dataset_test = ImageFolder('./data_mhist/dataset/val', transform=transform)

aug_dataset = {"train": aug_dataset_train, "val": aug_dataset_test}

In [None]:
API_net_aug = API_Net()
API_net_aug.fc = nn.Linear(API_net_aug.fc.in_features, 2)

In [None]:
acc_train_api_aug, acc_val_api_aug, loss_train_api_aug, loss_val_api_aug, API_net_aug = train_fine_grained(API_net_aug, aug_dataset, config, f'./log_api_aug.csv')
torch.save(API_net_aug.state_dict(), f'./api_aug.pt')

In [None]:
API_net_aug.load_state_dict(torch.load(f'./api_aug.pt'))
API_net.load_state_dict(torch.load(f'./api_noaug'))

In [None]:
eval_fine_grained(API_net, dataset['val'])

In [None]:
eval_fine_grained(API_net_aug, dataset['val'])