<a href="https://colab.research.google.com/github/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Use this notebook to fine-tune the model with linear evaluation.

# Setup

## Basic Imports

In [1]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision

import gdown
import shutil

import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)
print(sys.path)

['/home/lliu/MorphCLR/feature_eval', '/home/lliu/miniconda3/envs/simclr/lib/python37.zip', '/home/lliu/miniconda3/envs/simclr/lib/python3.7', '/home/lliu/miniconda3/envs/simclr/lib/python3.7/lib-dynload', '', '/home/lliu/miniconda3/envs/simclr/lib/python3.7/site-packages', '/home/lliu/miniconda3/envs/simclr/lib/python3.7/site-packages/IPython/extensions', '/home/lliu/.ipython', '/home/lliu/MorphCLR']


## Download Datasets

In [3]:
def get_file_id_by_model(folder_name):
    file_id = {
        "resnet18_100-epochs_stl10": "14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF",
        "resnet18_100-epochs_cifar10": "1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C",
        "resnet50_50-epochs_stl10": "1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu",
    }
    return file_id.get(folder_name, "Model not found.")

In [4]:
folder_name = "resnet50_50-epochs_stl10"
file_id = get_file_id_by_model(folder_name)
gdrive_url = "https://drive.google.com/uc?id={}".format(file_id)
print("GDrive URL: ", gdrive_url)
print("Folder Name: ", folder_name)
print("File ID: ", file_id)

GDrive URL:  https://drive.google.com/uc?id=1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu
Folder Name:  resnet50_50-epochs_stl10
File ID:  1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu


In [5]:
gdrive_url = "https://drive.google.com/uc?id={}".format(file_id)
zip_name = folder_name + ".zip"
gdown.download(gdrive_url, zip_name, quiet=False)
shutil.unpack_archive(zip_name, folder_name)
os.remove(zip_name)

Downloading...
From: https://drive.google.com/uc?id=1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu
To: /home/lliu/MorphCLR/feature_eval/resnet50_50-epochs_stl10.zip
100%|██████████| 277M/277M [00:06<00:00, 45.0MB/s] 


# Train the Classification Layer

In [6]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [7]:
from transformers import ViTFeatureExtractor

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [9]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

In [10]:
feature_extractor

ViTFeatureExtractor {
  "do_normalize": true,
  "do_resize": true,
  "feature_extractor_type": "ViTFeatureExtractor",
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "size": 224
}

In [11]:
def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
    train_dataset = datasets.STL10(
        "../datasets", split="train", download=download, transform=transforms.Compose([transforms.ToTensor(),
                                                                                       feature_extractor])
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=shuffle,
    )

    test_dataset = datasets.STL10(
        "../datasets", split="test", download=download,transform=transforms.Compose([transforms.ToTensor(),
                                                                                       feature_extractor])
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=2 * batch_size,
        num_workers=10,
        drop_last=False,
        shuffle=shuffle,
    )
    return train_loader, test_loader


In [12]:
with open(os.path.join(folder_name, "./config.yml")) as file:
    config = yaml.load(file, Loader=yaml.UnsafeLoader)

In [13]:

train_loader, test_loader = get_stl10_data_loaders(download=True)


Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ../datasets/stl10_binary.tar.gz


2640404480it [01:50, 23985317.86it/s]                                


Files already downloaded and verified


In [14]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=10,
)

model = model.to(device)

Downloading: 100%|██████████| 502/502 [00:00<00:00, 698kB/s]
Downloading: 100%|██████████| 346M/346M [00:08<00:00, 42.8MB/s] 
Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probab

In [15]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ["classifier.weight", "classifier.bias"]:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [17]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()

        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k)
        return res

In [18]:
epochs = 5
for epoch in range(epochs):
    model.train()
    top1_train_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(train_loader):
        x_batch = x_batch['pixel_values'][0]
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch).logits
        loss = criterion(logits, y_batch)
        top1 = accuracy(logits, y_batch, topk=(1,))
        top1_train_accuracy += top1[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    top1_train_accuracy /= len(train_loader.dataset)

    model.eval()
    top1_accuracy = 0
    top5_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(test_loader):
        x_batch = x_batch['pixel_values'][0]
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch).logits

        top1, top5 = accuracy(logits, y_batch, topk=(1, 5))
        top1_accuracy += top1[0]
        top5_accuracy += top5[0]

    top1_accuracy /= len(test_loader.dataset)
    top5_accuracy /= len(test_loader.dataset)
    print(
        f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}"
    )

Epoch 0	Top1 Train accuracy 0.4747999906539917	Top1 Test accuracy: 0.8175000548362732	Top5 test acc: 0.971375048160553


# Save Model

In [None]:
torch.save(model, "VIT_5_epochs.pt")

In [None]:
torch.save(model.state_dict(), "VIT_5_epochs_weights.pt")