# Import packages

In [26]:

import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
import sys
sys.path.append("..")
from model.resnet_simclr import ResNetSimCLR
from lars import create_optimizer_lars
from tqdm.notebook import tqdm

# Configuration

In [27]:
device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

arch = 'resnet50'
dataset_name = 'cifar10'

# model = ResNetSimCLR(arch, 10)
# mlp_in_dim = model.mlp[0].in_features
# num_classes = 10
# model.mlp = torch.nn.Linear(mlp_in_dim, num_classes)

if arch == 'resnet18':
  model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
elif arch == 'resnet50':
  model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)

checkpoint_path = '../result/checkpoint/dacl/checkpoint_1000.pth.tar'

epochs = 200
batch_size = 256
dataset = 'cifar10'
lr = 1.0 * batch_size / 256
eta_min = 1e-7
w = 0

num_workers = 0

Using device: cuda:2


# Prepare DataLoaders

In [28]:
def get_stl10_data_loaders(download, shuffle=True, batch_size=batch_size):
  train_dataset = datasets.STL10('../dataset', split='train', download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=num_workers, drop_last=True, shuffle=shuffle)
  
  test_dataset = datasets.STL10('../dataset', split='test', download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=num_workers, drop_last=True, shuffle=shuffle)
  return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=True, batch_size=batch_size):
  train_dataset = datasets.CIFAR10('../dataset', train=True, download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=num_workers, drop_last=True, shuffle=shuffle)
  
  test_dataset = datasets.CIFAR10('../dataset', train=False, download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=num_workers, drop_last=True, shuffle=shuffle)
                            
  return train_loader, test_loader

# Modify SimCLR's checkpoint state_dict to fit ResNet Classification Task.

In [29]:
# checkpoint = torch.load(checkpoint_path, map_location=device)
# state_dict = checkpoint['state_dict']
# for key in list(state_dict.keys()):
#     if key.startswith("mlp"):
#         del state_dict[key]

In [30]:
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]

In [31]:
log = model.load_state_dict(state_dict, strict=False)
print(log)
assert log.missing_keys == ['fc.weight', 'fc.bias']

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])


In [32]:
if dataset_name == 'cifar10':
  train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif dataset_name == 'stl10':
  train_loader, test_loader = get_stl10_data_loaders(download=True)
print("Dataset:", dataset_name)

Files already downloaded and verified
Files already downloaded and verified
Dataset: cifar10


# Frozen ResNet parameter and train a classifier on its hidden representation to eval SimCLR's performance

In [33]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

# filter(function, iterable)
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # mlp.weight, mlp.bias

In [34]:
# optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=w)
optimizer = create_optimizer_lars(model, lr, weight_decay=w)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=eta_min, last_epoch=-1, verbose=True)

criterion = torch.nn.CrossEntropyLoss().to(device)

Adjusting learning rate of group 0 to 1.0000e+00.
Adjusting learning rate of group 1 to 1.0000e+00.


In [35]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def MyAccuracy(logits, labels):
    batch_size = logits.size()[0]
    outputs = torch.functional.F.log_softmax(logits, dim=1)
    predict = torch.max(outputs, dim=1)[1]
    acc_count = torch.sum(predict == labels)
    return acc_count / batch_size

In [36]:
model.to(device)
for epoch in range(1, epochs+1):
  model.train()

  top1_train_accuracy = 0
  
  for counter, (x_batch, y_batch) in enumerate(train_loader):

    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
    loss = criterion(logits, y_batch)
    top1 = accuracy(logits, y_batch, topk=(1,))
    top1_train_accuracy += top1[0]

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  
  scheduler.step()

  top1_train_accuracy /= (counter + 1)

  top1_accuracy = 0
  top5_accuracy = 0
  top1_my_acc = 0
  model.eval()
  with torch.no_grad():
    for counter, (x_batch, y_batch) in enumerate(test_loader):

      x_batch = x_batch.to(device)
      y_batch = y_batch.to(device)

      logits = model(x_batch)
    
      top1, top5 = accuracy(logits, y_batch, topk=(1,5))
      top1_accuracy += top1[0]
      top5_accuracy += top5[0]
      top1_my_acc += MyAccuracy(logits, y_batch)
    
  top1_accuracy /= (counter + 1)
  top5_accuracy /= (counter + 1)
  top1_my_acc /= (counter + 1)
  print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()} {top1_my_acc.item()*100} \tTop5 test acc: {top5_accuracy.item()}")

Adjusting learning rate of group 0 to 9.9994e-01.
Adjusting learning rate of group 1 to 9.9994e-01.
Epoch 1	Top1 Train accuracy 29.549280166625977	Top1 Test accuracy: 35.2590446472168 35.25904715061188 	Top5 test acc: 85.62911224365234
Adjusting learning rate of group 0 to 9.9975e-01.
Adjusting learning rate of group 1 to 9.9975e-01.
Epoch 2	Top1 Train accuracy 35.66706848144531	Top1 Test accuracy: 36.75986862182617 36.7598682641983 	Top5 test acc: 86.07113647460938
Adjusting learning rate of group 0 to 9.9944e-01.
Adjusting learning rate of group 1 to 9.9944e-01.
Epoch 3	Top1 Train accuracy 36.518428802490234	Top1 Test accuracy: 37.0065803527832 37.00657784938812 	Top5 test acc: 86.62623596191406
Adjusting learning rate of group 0 to 9.9901e-01.
Adjusting learning rate of group 1 to 9.9901e-01.
Epoch 4	Top1 Train accuracy 37.18950271606445	Top1 Test accuracy: 37.047698974609375 37.047699093818665 	Top5 test acc: 86.25617218017578
Adjusting learning rate of group 0 to 9.9846e-01.
Adjus