In [None]:
import torchvision
import torch
from torchvision import transforms
import timm
import torch.nn.functional as F
from torch import nn
from tqdm.notebook import tqdm
import wandb
from torchvision.models.resnet import resnet34, resnet50
import numpy as np 
import random 
import os
from sklearn.model_selection import KFold

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 48
# Model archs supported in this notebook are 34 and 50 for ResNets.
model_arch = 34

In [None]:
def move_to(obj, device):
    if torch.is_tensor(obj):
        return obj.to(device)
    elif isinstance(obj, dict):
        res = {}
        for k, v in obj.items():
            res[k] = move_to(v, device)
        return res
    elif isinstance(obj, list):
        res = []
        for v in obj:
            res.append(move_to(v, device))
        return res
    else:
        raise TypeError("Invalid type for move_to")

In [None]:
transforms_cifar = transforms.Compose([
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])
transforms_imagenet = transforms.Compose([
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

In [None]:
class Model(nn.Module):
    def __init__(self, feature_dim=128, pretrained=False, model=model_arch):
        super(Model, self).__init__()
        
        assert model in [34,50], "Invalid model architecture"
        
        self.f = []
        if model == 34:
            for name, module in resnet34(pretrained=pretrained).named_children():
                if name == 'conv1':
                    module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
                if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                    self.f.append(module)
            # encoder
            self.f = nn.Sequential(*self.f)
            # projection head
            self.g = nn.Sequential(nn.Linear(512, 512, bias=False), nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))
        
        elif model == 50:
            for name, module in resnet50(pretrained=pretrained).named_children():
                if name == 'conv1':
                    module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
                if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                    self.f.append(module)
            # encoder
            self.f = nn.Sequential(*self.f)
            # projection head
            self.g = nn.Sequential(nn.Linear(2048, 512, bias=False), nn.BatchNorm1d(512),
                                   nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))

            
    def forward(self, x):
        x = self.f(x)
        h = torch.flatten(x, start_dim=1)
        q = self.g[0](h)
        z = self.g(h)
        
        return 0, F.normalize(q, dim=-1), F.normalize(z, dim=-1)
    
class Classifier(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.pred = nn.Linear(512, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        _, x, _ = self.model(x)
        x = self.relu(x)
        x = self.pred(x)
        
        return x

In [None]:
def get_model(transfer=False, pretrained=False):
    if transfer:
        model = Model(pretrained=True, model=model_arch)
    else:
        model = Model(model=model_arch)
            
    model.to(DEVICE)
    
    if pretrained:
        model.load_state_dict(torch.load(f"best_model_resnet{model_arch}.pt"))

    classifier = Classifier(model)
    classifier.to(DEVICE)
    
    # Freeze the network
    for name, param in classifier.named_parameters():
        if "model" in name:
            param.requires_grad = False
            
    return classifier

In [None]:
def train_fn(model, dl, optimizer, criterion, transforms):
    model.train()
    n_samples = 0

    correct = 0
    total_loss = 0

    for batch in tqdm(dl):
        X, y = batch
        X = move_to(X, DEVICE)
        y = move_to(y, DEVICE)
        
        X = transforms(X)

        predictions = model(X)

        loss = criterion(predictions, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        n_samples += X.shape[0]
        total_loss += loss.item()

        predictions = torch.argmax(predictions, dim=1)
        correct += torch.sum(predictions == y).item()

    average_loss = total_loss / n_samples
    average_accuracy = correct / n_samples
    return average_loss, average_accuracy

In [None]:
def test_fn(model, dl, criterion, transforms):
    model.eval()
    n_samples = 0

    correct = 0
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(dl):
            X, y = batch
            X = move_to(X, DEVICE)
            y = move_to(y, DEVICE)
            
            X = transforms(X)

            predictions = model(X)

            loss = criterion(predictions, y)

            n_samples += X.shape[0]
            total_loss += loss.item()

            predictions = torch.argmax(predictions, dim=1)
            correct += torch.sum(predictions == y).item()

        average_loss = total_loss / n_samples
        average_accuracy = correct / n_samples
    return average_loss, average_accuracy

In [None]:
import copy
def finetune_folds(min_epochs=10, max_epochs=50, min_acc=0.95, fracs=[0.01, 0.1, 1], folds=5, test=False,
                   unfreeze_epoch=1, batch_size=90):
    classifier_names = ['Supervised', 'Transfer']
    
    if test:
        classifiers = [get_model(), get_model(transfer=True)]
    else:
        classifiers = [get_model(), get_model(transfer=True), get_model(pretrained=True)]
        classifier_names.append('Semi-Supervised')
    
    criterion = nn.CrossEntropyLoss()
    
    metrics = {classifier_name:{f"frac_{frac}":{f"fold_{fold}":{'acc_train':0, 'acc_test':0} for fold in range(1, folds+1)} for frac in fracs} for classifier_name in classifier_names}
    
    SEED = 42

    ds_train = torchvision.datasets.CIFAR10('data', download=True, transform = transforms.ToTensor())
    ds_test = torchvision.datasets.CIFAR10('data', train=False, download=True, transform = transforms.ToTensor())

    ds_combined = torch.utils.data.ConcatDataset((ds_train,ds_test))
    
    if test:
        ds_combined =  torch.utils.data.Subset(ds_combined, range(folds*2))
        batch_size = 2
    
    classifier_metrics = []
    for classifier_name, classifier in zip(classifier_names, classifiers):
        # save initial model parameters to reuse on every fold
        torch.save(classifier.state_dict(), "classifier_state_dict.pt")
        
        
        if classifier_name == 'Transfer':
            transformations = transforms_cifar
        else:
            transformations = transforms_imagenet
                
        
        print("---------------------------------------")
        print(f"Training classifier: {classifier_name}")
        for frac in fracs:
            print(f"  Training on {frac*100}% of training data")
            kf = KFold(n_splits=folds, shuffle=True)
            
            for fold_number, indices in enumerate(kf.split(ds_combined), start=1):
                print(f"  Fold {fold_number}")
                print("---------------------------------------")
                train_idx, test_idx = indices
                train_idx = np.random.choice(train_idx, int(np.ceil(len(train_idx)*frac)), replace=False) 
                
                ds_train = torch.utils.data.Subset(ds_combined, train_idx)
                ds_test = torch.utils.data.Subset(ds_combined, test_idx)
                
                dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
                dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
                
                classifier.load_state_dict(torch.load("classifier_state_dict.pt"))

                optimizer = torch.optim.AdamW(classifier.parameters())

                best_acc_train = 0
                best_acc_test = 0
                epoch = 0
                while (best_acc_train < min_acc or epoch < min_epochs) and epoch < max_epochs:
                    if epoch == 1:
                        for name, param in classifier.named_parameters():
                            param.requires_grad = True

                    loss_train, acc_train = train_fn(classifier, dl_train, optimizer, criterion, transformations)
                    loss_test, acc_test = test_fn(classifier, dl_test, criterion, transformations)
                    
                    if acc_train > best_acc_train:
                        best_acc_train = acc_train
                    if acc_test > best_acc_test:
                        best_acc_test = acc_test
                    
                    epoch += 1

                    print(f"  Epoch: {epoch}")
                    print("  Training metrics")
                    print(f"    loss_train: {loss_train}")
                    print(f"    acc_train: {acc_train}")
                    print(f"    loss_test: {loss_test}")
                    print(f"    acc_test: {acc_test}")
                    print()
                
                metrics[classifier_name][f"frac_{frac}"][f"fold_{fold_number}"]['acc_train'] = best_acc_train
                metrics[classifier_name][f"frac_{frac}"][f"fold_{fold_number}"]['acc_test'] = best_acc_test

    return metrics            

In [None]:
metrics = finetune_folds(folds=10)
with open(f"metrics_{model_arch}.pkl", "wb") as f:
    pickle.dump(metrics, f)

In [None]:
with open(f"metrics_{model_arch}.pkl", "rb") as g:
    metrics = pickle.load(g)
metrics