<a href="https://colab.research.google.com/github/ArthurCTLin/Workbook/blob/main/SimCLR/SimCLR_Eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Import Libraries


In [None]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torchvision

# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
from torchvision.datasets import DatasetFolder, STL10
from torchvision import datasets, transforms, models

# This is for the progress bar.
from tqdm import tqdm

# set a random seed for reproducibility
myseed = 42069  
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)
    
    
NUM_WORKERS = os.cpu_count()
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device: ", device)
print("Number of workers: ", NUM_WORKERS)

### Model
* SimCLR: pretrained model
  * Encoder: ResNet-50
  * 

In [None]:
class SimCLR(nn.Module):

    def __init__(self, encoder):
        super(SimCLR, self).__init__()

        self.encoder = encoder

        dim_mlp = self.encoder.fc.in_features
        self.encoder.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                          nn.ReLU(),
                          self.encoder.fc
                          )
                    
    def forward(self, x):
        return self.encoder(x)

In [None]:
class DownstreamModel(nn.Module):
    def __init__(self, premodel, num_classes):
        super(DownstreamModel, self).__init__()
        
        self.premodel = premodel
        self.num_classes = num_classes
        self.lastlayer = nn.Linear(128, self.num_classes)
        
    def forward(self, x):
        x = self.premodel(x)
        x = self.lastlayer(x)
        return x

In [None]:
model_path = "../input/stl10-50epoch/model.ckpt"

projection_dim = 128
encoder = models.resnet50(pretrained=False, num_classes=projection_dim)

premodel = SimCLR(encoder)
premodel.load_state_dict(torch.load(model_path))

ds_model = DownstreamModel(premodel, 10).to(device)

### Downstream Dataset
* Splitted STL10 training dataset with the ratio of 8:2 as training, validation dataset
* Since I want to check the performance of the SimCLR pretrained model, the data augmentation is not adopted.

In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [None]:
batch_size = 256

dataset = STL10(root='STL10', split ='train', download=True, transform=transform)
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
train_set, valid_set = torch.utils.data.random_split(dataset, [train_size, valid_size])

test_set = STL10(root='STL10', split ='test', download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=NUM_WORKERS)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)

### Optimizer & criterion

In [None]:
SimCLR_DS_save_path = "SimCLR_DS_model.ckpt"
n_epochs = 100

optimizer = torch.optim.Adam(ds_model.parameters(), lr=5e-4, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()

### Training

In [None]:
best_acc = 0.0
valid_acc = 0.0
valid_loss = 0.0

train_loss_record = []
valid_loss_record = []
train_acc_record = []
valid_acc_record = []

for epoch in range(n_epochs):
    ds_model.train()

    # These are used to record information in training.
    train_loss = []
    train_accs = []

    # Iterate the training set by batches.
    for batch in tqdm(train_loader):

        imgs, labels = batch

        logits = ds_model(imgs.to(device))
        loss = criterion(logits, labels.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        train_loss.append(loss.item())
        train_accs.append(acc)

    # The average loss and accuracy of the training set is the average of the recorded values.
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    # Print the information.
    print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")

    # ---------- Validation ----------
    ds_model.eval()

    # These are used to record information in validation.
    valid_loss = []
    valid_accs = []

    # Iterate the validation set by batches.
    for batch in tqdm(valid_loader):
        
        imgs, labels = batch

        with torch.no_grad():
            logits = ds_model(imgs.to(device))

        loss = criterion(logits, labels.to(device))

        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        valid_loss.append(loss.item())
        valid_accs.append(acc)

    # The average loss and accuracy for entire validation set is the average of the recorded values.
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)

    # Print the information.
    print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")
    
    # ------Record every time information------
    if valid_acc > best_acc:
        best_acc = valid_acc
        torch.save(ds_model.state_dict(), SimCLR_DS_save_path)
    train_loss_record.append(train_loss)
    valid_loss_record.append(valid_loss)
    train_acc_record.append(train_acc)
    valid_acc_record.append(valid_acc)

### Evaluation

In [None]:
ds_model.load_state_dict(torch.load(SimCLR_DS_save_path))

ds_model.eval()

test_loss = []
test_accs = []
test_acc = []

for batch in tqdm(test_loader):
    imgs, labels = batch
  
    with torch.no_grad():
        logits = ds_model(imgs.to(device))
  
    loss = criterion(logits, labels.to(device))
    acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

    test_loss.append(loss)
    test_accs.append(acc)

test_loss = sum(test_loss)/len(test_loss)
test_acc = sum(test_accs)/len(test_accs)
print(f"loss = {test_loss:.5f}, acc = {test_acc:.5f}")

### Baseline of ResNet50
The unpretained ResNet50 model is taken as the baseline to check the performance of SimCLR model.

In [None]:
# ResNet50 Model
model_resnet50 = models.resnet50(pretrained=False)
for param in model_resnet50.parameters():
    param.requires_grad = True
model_resnet50.fc = torch.nn.Linear(model_resnet50.fc.in_features, 10)
model_resnet50 = model_resnet50.to(device)

In [None]:
ResNet50_save_path = "ResNet50_model.ckpt"
optimizer = torch.optim.Adam(model_resnet50.parameters(), lr=5e-4, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
best_acc = 0.0
valid_acc = 0.0
valid_loss = 0.0

train_loss_record = []
valid_loss_record = []
train_acc_record = []
valid_acc_record = []

for epoch in range(n_epochs):
    model_resnet50.train()

    # These are used to record information in training.
    train_loss = []
    train_accs = []

    # Iterate the training set by batches.
    for batch in tqdm(train_loader):

        imgs, labels = batch

        logits = model_resnet50(imgs.to(device))
        loss = criterion(logits, labels.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        train_loss.append(loss.item())
        train_accs.append(acc)

    # The average loss and accuracy of the training set is the average of the recorded values.
    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)

    # Print the information.
    print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")

    # ---------- Validation ----------
    model_resnet50.eval()

    # These are used to record information in validation.
    valid_loss = []
    valid_accs = []

    # Iterate the validation set by batches.
    for batch in tqdm(valid_loader):
        
        imgs, labels = batch

        with torch.no_grad():
            logits = model_resnet50(imgs.to(device))

        loss = criterion(logits, labels.to(device))

        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # Record the loss and accuracy.
        valid_loss.append(loss.item())
        valid_accs.append(acc)

    # The average loss and accuracy for entire validation set is the average of the recorded values.
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)

    # Print the information.
    print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")
    
    # ------Record every time information------
    if valid_acc > best_acc:
        best_acc = valid_acc
        torch.save(model_resnet50.state_dict(), ResNet50_save_path)
    train_loss_record.append(train_loss)
    valid_loss_record.append(valid_loss)
    train_acc_record.append(train_acc)
    valid_acc_record.append(valid_acc)

In [None]:
model_resnet50.load_state_dict(torch.load(ResNet50_save_path))

model_resnet50.eval()

test_loss = []
test_accs = []
test_acc = []

for batch in tqdm(test_loader):
    imgs, labels = batch
  
    with torch.no_grad():
        logits = model_resnet50(imgs.to(device))
  
    loss = criterion(logits, labels.to(device))
    acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

    test_loss.append(loss)
    test_accs.append(acc)

test_loss = sum(test_loss)/len(test_loss)
test_acc = sum(test_accs)/len(test_accs)
print(f"loss = {test_loss:.5f}, acc = {test_acc:.5f}")