<a href="https://colab.research.google.com/github/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

## Basic Imports

In [1]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


## Download Datasets

In [None]:
! pip install gdown

In [2]:
def get_file_id_by_model(folder_name):
    file_id = {
        "resnet18_100-epochs_stl10": "14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF",
        "resnet18_100-epochs_cifar10": "1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C",
        "resnet50_50-epochs_stl10": "1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu",
    }
    return file_id.get(folder_name, "Model not found.")

In [3]:
folder_name = "resnet50_50-epochs_stl10"
file_id = get_file_id_by_model(folder_name)
print("Folder Name: ", folder_name)
print("File ID: ", file_id)

Folder Name:  resnet50_50-epochs_stl10
File ID:  1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu


In [4]:
# download and extract model files
pretrained_weights_dir = "./simclr_pretrained_weights"
gdrive_url = "https://drive.google.com/uc?id={}".format(file_id)
folder_full_name = os.path.join(pretrained_weights_dir, folder_name)

if not os.path.exists(pretrained_weights_dir):
    os.makedirs(pretrained_weights_dir)

os.system("cd {} && gdown {}".format(pretrained_weights_dir, gdrive_url))
os.system("unzip {} -d {}".format(folder_full_name, folder_full_name))
os.system("ls {}".format(folder_full_name))

Downloading...
From: https://drive.google.com/uc?id=1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu
To: /home/achen353/MorphCLR/feature_eval/simclr_pretrained_weights/resnet50_50-epochs_stl10.zip
100%|██████████| 277M/277M [00:00<00:00, 413MB/s] 


Archive:  ./simclr_pretrained_weights/resnet50_50-epochs_stl10.zip
checkpoint_0040.pth.tar
config.yml
events.out.tfevents.1610927742.4cb2c837708d.2694093.0
lr_checkpoint_0100.pth.tar
training.log


replace ./simclr_pretrained_weights/resnet50_50-epochs_stl10/checkpoint_0040.pth.tar? [y]es, [n]o, [A]ll, [N]one, [r]ename:  NULL
(EOF or read error, treating as "[N]one" ...)


0

## Train the Classification Layer

In [5]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [7]:
def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
    train_dataset = datasets.STL10(
        "./data", split="train", download=download, transform=transforms.ToTensor()
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=shuffle,
    )

    test_dataset = datasets.STL10(
        "./data", split="test", download=download, transform=transforms.ToTensor()
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=2 * batch_size,
        num_workers=10,
        drop_last=False,
        shuffle=shuffle,
    )
    return train_loader, test_loader


def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
    train_dataset = datasets.CIFAR10(
        "./data", train=True, download=download, transform=transforms.ToTensor()
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=shuffle,
    )

    test_dataset = datasets.CIFAR10(
        "./data", train=False, download=download, transform=transforms.ToTensor()
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=2 * batch_size,
        num_workers=10,
        drop_last=False,
        shuffle=shuffle,
    )
    return train_loader, test_loader

In [8]:
with open(os.path.join(folder_full_name, "./config.yml")) as file:
    config = yaml.load(file, Loader=yaml.UnsafeLoader)

In [9]:
if config.arch == "resnet18":
    model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
elif config.arch == "resnet50":
    model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "


In [10]:
checkpoint = torch.load(
    os.path.join(folder_full_name, "checkpoint_0040.pth.tar"), map_location=device
)
state_dict = checkpoint["state_dict"]

for k in list(state_dict.keys()):

    if k.startswith("backbone."):
        if k.startswith("backbone") and not k.startswith("backbone.fc"):
            # remove prefix
            state_dict[k[len("backbone.") :]] = state_dict[k]
    del state_dict[k]

In [11]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ["fc.weight", "fc.bias"]

In [12]:
if config.dataset_name == "cifar10":
    train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif config.dataset_name == "stl10":
    train_loader, test_loader = get_stl10_data_loaders(download=True)
print("Dataset:", config.dataset_name)

Files already downloaded and verified
Files already downloaded and verified
Dataset: stl10


In [13]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ["fc.weight", "fc.bias"]:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [15]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()

        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k)
        return res

In [16]:
epochs = 100
for epoch in range(epochs):
    model.train()
    top1_train_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(train_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch)
        loss = criterion(logits, y_batch)
        top1 = accuracy(logits, y_batch, topk=(1,))
        top1_train_accuracy += top1[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    top1_train_accuracy /= len(train_loader.dataset)

    model.eval()
    top1_accuracy = 0
    top5_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(test_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch)

        top1, top5 = accuracy(logits, y_batch, topk=(1, 5))
        top1_accuracy += top1[0]
        top5_accuracy += top5[0]

    top1_accuracy /= len(test_loader.dataset)
    top5_accuracy /= len(test_loader.dataset)
    print(
        f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}"
    )

Epoch 0	Top1 Train accuracy 0.34939998388290405	Top1 Test accuracy: 0.5040000081062317	Top5 test acc: 0.9450000524520874
Epoch 1	Top1 Train accuracy 0.5375999808311462	Top1 Test accuracy: 0.5534999966621399	Top5 test acc: 0.9563750624656677
Epoch 2	Top1 Train accuracy 0.5669999718666077	Top1 Test accuracy: 0.5712500214576721	Top5 test acc: 0.9618750214576721
Epoch 3	Top1 Train accuracy 0.5821999907493591	Top1 Test accuracy: 0.580625057220459	Top5 test acc: 0.9637500643730164
Epoch 4	Top1 Train accuracy 0.593999981880188	Top1 Test accuracy: 0.5842500329017639	Top5 test acc: 0.9652500748634338
Epoch 5	Top1 Train accuracy 0.5981999635696411	Top1 Test accuracy: 0.5868750214576721	Top5 test acc: 0.9655000567436218
Epoch 6	Top1 Train accuracy 0.6007999777793884	Top1 Test accuracy: 0.5916250348091125	Top5 test acc: 0.9670000672340393
Epoch 7	Top1 Train accuracy 0.6039999723434448	Top1 Test accuracy: 0.5957500338554382	Top5 test acc: 0.968000054359436
Epoch 8	Top1 Train accuracy 0.609399974346

# Save Model

In [None]:
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from utils import save_checkpoint

lr_checkpoint_name = "lr_checkpoint_{:04d}.pth.tar".format(epochs)

save_checkpoint(
    {
        "epoch": epochs,
        "arch": config.arch,
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    },
    is_best=False,
    filename=os.path.join(folder_full_name, lr_checkpoint_name),
)