<a href="https://colab.research.google.com/github/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Use this notebook to fine-tune the model with linear evaluation.

# Setup

## Basic Imports

In [None]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision

import gdown
import shutil

import warnings

warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)
print(sys.path)

## Download Datasets

In [None]:
checkpoints = {
    # MorphCLRSingle checkpoints
    "canny_single_pretrained": "stl10_canny_single_pretrained_resnet18_0050",
    "canny_single_random": "stl10_canny_single_random_resnet18_0050",
    "dexined_single_pretrained": "stl10_dexined_single_pretrained_resnet18_0050",
    "dexined_single_random": "stl10_dexined_single_random_resnet18_0050",
    "baseline_pretrained": "stl10_pretrained_resnet18_0050",
    "baseline_random": "stl10_random_resnet18_0050",
    # MorphCLRDual checkpoints
    "canny_dual_pretrained": "stl10_canny_dual_pretrained_resnet18_0050",
    "canny_dual_random": "stl10_canny_dual_random_resnet18_0050",
    "dexined_dual_pretrained": "stl10_dexined_dual_pretrained_resnet18_0050",
    "dexined_dual_random": "stl10_dexined_dual_random_resnet18_0050",
}

# Train the Classification Layer and Save the Checkpoint

In [None]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
from Edge_images.generate_datasets import (
    STL10,
    CannyDataset,
    DexiNedTrainDataset,
    DexiNedTestDataset,
    DualDataset,
)
from models.morphclr import MorphCLRSingleEval, MorphCLRDualEval
from utils import save_checkpoint

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device: ", device)

In [None]:
data_root = "../datasets"
checkpoint_root = "../runs"
print("Data root: ", data_root)
print("Checkpoint root: ", checkpoint_root)

In [None]:
def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
    train_dataset = datasets.STL10(
        "../datasets", split="train", download=download, transform=transforms.ToTensor()
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=shuffle,
    )

    test_dataset = datasets.STL10(
        "../datasets", split="test", download=download, transform=transforms.ToTensor()
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=2 * batch_size,
        num_workers=10,
        drop_last=False,
        shuffle=shuffle,
    )
    return train_loader, test_loader

def get_stl10_canny_dual_data_loaders(download, shuffle=False, batch_size=256):
    train_dataset = DualDataset(
        CannyDataset(root=data_root, split="train", transform=transforms.ToTensor()),
        STL10(root=data_root, split="train", transform=transforms.ToTensor()),
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=shuffle,
    )

    test_dataset = DualDataset(
        CannyDataset(root=data_root, split="test", transform=transforms.ToTensor()),
        STL10(root=data_root, split="test", transform=transforms.ToTensor()),
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=2 * batch_size,
        num_workers=10,
        drop_last=False,
        shuffle=shuffle,
    )

    return train_loader, test_loader


def get_stl10_dexined_dual_data_loaders(download, shuffle=False, batch_size=256):
    train_dataset = DualDataset(
        DexiNedTrainDataset(
            csv_file="../Edge_images/Dexi/train/labels.csv",
            root_dir="../Edge_images/Dexi/train",
            transform=transforms.ToTensor(),
        ),
        STL10(root=data_root, split="train", transform=transforms.ToTensor()),
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=shuffle,
    )

    test_dataset = DualDataset(
        DexiNedTestDataset(
            csv_file="../Edge_images/Dexi/test/labels.csv",
            root_dir="../Edge_images/Dexi/test",
            transform=transforms.ToTensor(),
        ),
        STL10(root=data_root, split="test", transform=transforms.ToTensor()),
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=2 * batch_size,
        num_workers=10,
        drop_last=False,
        shuffle=shuffle,
    )

    return train_loader, test_loader

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()

        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k)
        return res

In [None]:
arch = "resnet18"
print("Backbone arch: ", arch)

exps = [
    "baseline_pretrained",
    "baseline_random",
    "canny_single_pretrained",
    "canny_single_random",
    "dexined_single_pretrained",
    "dexined_single_random",
    "canny_dual_pretrained",
    "canny_dual_random",
    "dexined_dual_pretrained",
    "dexined_dual_random",
]

with open("morphclr_finetune_result.csv", "a") as f:
    f.write("exp,top_1_train_acc,top_1_test_acc,top_5_test_acc\n")

for exp in exps:
    print("Experiment: ", exp)

    if exp.startswith("baseline"):
        if arch == "resnet18":
            eval_model = torchvision.models.resnet18(pretrained=False).to(device)
        elif arch == "resnet50":
            eval_model = torchvision.models.resnet50(pretrained=False).to(device)
        
        checkpoint_file_path = os.path.join(
            checkpoint_root, checkpoints[exp], "checkpoint_" + checkpoints[exp] + ".pth.tar"
        )
        checkpoint = torch.load(
            checkpoint_file_path, map_location=device
        )
        
        state_dict = checkpoint["state_dict"]
        for k in list(state_dict.keys()):
            if k.startswith("backbone."):
                if k.startswith("backbone") and not k.startswith("backbone.fc"):
                    # remove prefix
                    state_dict[k[len("backbone.") :]] = state_dict[k]
            del state_dict[k]
        
        log = eval_model.load_state_dict(state_dict, strict=False)
        assert log.missing_keys == ["fc.weight", "fc.bias"]
        
    elif "single" in exp:
        init_type = exp.split("_")[-1]
        edge_checkpoint_key, non_edge_checkpoint_key = exp, "baseline_{}".format(init_type)
        edge_checkpoint_file_path = os.path.join(
            checkpoint_root, checkpoints[exp], "checkpoint_" + checkpoints[exp] + ".pth.tar"
        )
        non_edge_checkpoint_file_path = os.path.join(
            checkpoint_root,
            checkpoints[non_edge_checkpoint_key],
            "checkpoint_" + checkpoints[non_edge_checkpoint_key] + ".pth.tar",
        )
        eval_model = MorphCLRSingleEval(
            arch, edge_checkpoint_file_path, non_edge_checkpoint_file_path, device
        )
    else:
        checkpoint_file_path = os.path.join(
            checkpoint_root,
            checkpoints[exp],
            "checkpoint_" + checkpoints[exp] + ".pth.tar",
        )
        eval_model = MorphCLRDualEval(arch, checkpoint_file_path, device)

    if exp.startswith("baseline"):
        get_data_loader_fn = get_stl10_data_loaders
    elif "canny" in exp:
        get_data_loader_fn = get_stl10_canny_dual_data_loaders
    else:
        get_data_loader_fn = get_stl10_dexined_dual_data_loaders

    train_loader, test_loader = get_data_loader_fn(download=True)

    # freeze all layers but the last fc
    for name, param in eval_model.named_parameters():
        if exp.startswith("baseline"):
            trainable_param_names = ["fc.weight", "fc.bias"]
        else:
            trainable_param_names = ["linear.weight", "linear.bias"]
        if name not in trainable_param_names:
            param.requires_grad = False

    parameters = list(filter(lambda p: p.requires_grad, eval_model.parameters()))
    assert len(parameters) == 2

    optimizer = torch.optim.Adam(eval_model.parameters(), lr=3e-4, weight_decay=0.0008)
    criterion = torch.nn.CrossEntropyLoss().to(device)

    epochs = 100
    for epoch in range(epochs):
        eval_model.train()
        top1_train_accuracy = 0
        for counter, batch_data in enumerate(train_loader):
            if exp.startswith("baseline"):
                x_batch, y_batch = batch_data
                x_batch = x_batch.to(device)
            else:
                x_edge_batch, x_non_edge_batch, y_batch = batch_data
                x_edge_batch = x_edge_batch.to(device)
                if len(x_edge_batch.shape) == 3:
                    x_edge_batch = x_edge_batch.unsqueeze(1)
                # If the image is of grayscale, repeat the dimension to create 3 channels
                if x_edge_batch.shape[1] == 1:
                    x_edge_batch = x_edge_batch.repeat(1, 3, 1, 1)
                x_non_edge_batch = x_non_edge_batch.to(device)
                x_batch = torch.stack([x_edge_batch, x_non_edge_batch], dim=0)
            
            y_batch = y_batch.flatten().to(device)

            logits = eval_model(x_batch)
            loss = criterion(logits, y_batch)
            top1 = accuracy(logits, y_batch, topk=(1,))
            top1_train_accuracy += top1[0]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        top1_train_accuracy /= len(train_loader.dataset)

        eval_model.eval()
        top1_accuracy = 0
        top5_accuracy = 0
        for counter, batch_data in enumerate(test_loader):
            if exp.startswith("baseline"):
                x_batch, y_batch = batch_data
                x_batch = x_batch.to(device)
            else:
                x_edge_batch, x_non_edge_batch, y_batch = batch_data
                x_edge_batch = x_edge_batch.to(device)
                if len(x_edge_batch.shape) == 3:
                    x_edge_batch = x_edge_batch.unsqueeze(1)
                # If the image is of grayscale, repeat the dimension to create 3 channels
                if x_edge_batch.shape[1] == 1:
                    x_edge_batch = x_edge_batch.repeat(1, 3, 1, 1)
                x_non_edge_batch = x_non_edge_batch.to(device)
                x_batch = torch.stack([x_edge_batch, x_non_edge_batch], dim=0)
            
            y_batch = y_batch.flatten().to(device)

            logits = eval_model(x_batch)

            top1, top5 = accuracy(logits, y_batch, topk=(1, 5))
            top1_accuracy += top1[0]
            top5_accuracy += top5[0]

        top1_accuracy /= len(test_loader.dataset)
        top5_accuracy /= len(test_loader.dataset)
        print(
            "Epoch {}\tTop1 Train accuracy {}\tTop1 Test accuracy: {}\tTop5 test acc: {}".format(
                epoch, top1_train_accuracy.item(), top1_accuracy.item(), top5_accuracy.item()
            )
        )

    with open("morphclr_finetune_result.csv", "a") as f:
        f.write("{},{},{},{}\n".format(exp, top1_train_accuracy.item(), top1_accuracy.item(), top5_accuracy.item()))
    
    lr_checkpoint_name = "morphclr_{}_{}_50-epochs_stl10_{}-epochs.pt".format(exp, arch, epochs)

    save_checkpoint(
        {
            "epoch": epochs,
            "arch": "resnet18",
            "state_dict": eval_model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        is_best=False,
        filename=os.path.join("../checkpoints/finetune/", lr_checkpoint_name),
    )