#Data Setup

The first thing to do is implement a dataset class to load rotated CIFAR10 images with matching labels. Since there is already a CIFAR10 dataset class implemented in `torchvision`, we will extend this class and modify the `__get_item__` method appropriately to load rotated images.

Each rotation label should be an integer in the set {0, 1, 2, 3} which correspond to rotations of 0, 90, 180, or 270 degrees respectively.

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
import os

def rotate_img(img, rot):
    if rot == 0: # 0 degrees rotation
        return img
    elif rot == 1:
        return transforms.functional.rotate(img, 90)
    elif rot == 2:
        return transforms.functional.rotate(img, 180)
    elif rot == 3:
        return transforms.functional.rotate(img, 270)
    else:
        raise ValueError('rotation should be 0, 90, 180, or 270 degrees')


class CIFAR10Rotation(torchvision.datasets.CIFAR10):

    def __init__(self, root, train, download, transform) -> None:
        super().__init__(root=root, train=train, download=download, transform=transform)
    
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, index: int):
        image, cls_label = super().__getitem__(index)

        # randomly select image rotation
        rotation_label = random.choice([0, 1, 2, 3])
        image_rotated = rotate_img(image, rotation_label)

        rotation_label = torch.tensor(rotation_label).long()
        return image, image_rotated, rotation_label, torch.tensor(cls_label).long()


In [None]:
from collections.abc import Callable
from typing import Any, Optional
from PIL import Image
# imagenette dataset, download here: https://github.com/fastai/imagenette 
lbl_dict = dict(
    n01440764='tench',
    n02102040='English springer',
    n02979186='cassette player',
    n03000684='chain saw',
    n03028079='church',
    n03394916='French horn',
    n03417042='garbage truck',
    n03425413='gas pump',
    n03445777='golf ball',
    n03888257='parachute'
) # from fastai document

# mean, var of each channel are computed by dataset_stat.py (single channel pics are pruned)
transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4601, 0.4564, 0.4305), (0.0818, 0.0795, 0.0930)),
])

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4601, 0.4564, 0.4305), (0.0818, 0.0795, 0.0930)),
])

batch_size = 64
def pil_loader(path: str) -> Image.Image:
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")
    
class rotate_Image_folder(torchvision.datasets.ImageFolder):
    def __init__(self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = pil_loader, is_valid_file: Optional[Callable[[str], bool]] = None):
        super().__init__(root, transform, target_transform, loader, is_valid_file)

    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, index: int):
        image, cls_label = super().__getitem__(index)
        

        # randomly select image rotation
        rotation_label = random.choice([0, 1, 2, 3])
        image_rotated = rotate_img(image, rotation_label)

        rotation_label = torch.tensor(rotation_label).long()
        return image, image_rotated, rotation_label, torch.tensor(cls_label).long()
        

trainset = rotate_Image_folder(root='data/imagenette2/train', transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=8)

testset = rotate_Image_folder(root='data/imagenette2/val', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=8)




In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

batch_size = 128

trainset = CIFAR10Rotation(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=8)

testset = CIFAR10Rotation(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=8)



Show some example images and rotated images with labels:

In [None]:
import matplotlib.pyplot as plt

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

rot_classes = ('0', '90', '180', '270')


def imshow(img):
    # unnormalize
    img = transforms.Normalize((0, 0, 0), (1/0.2023, 1/0.1994, 1/0.2010))(img)
    img = transforms.Normalize((-0.4914, -0.4822, -0.4465), (1, 1, 1))(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


dataiter = iter(trainloader)
images, rot_images, rot_labels, labels = next(dataiter)

# print images and rotated images
img_grid = imshow(torchvision.utils.make_grid(images[:4], padding=0))
print('Class labels: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
img_grid = imshow(torchvision.utils.make_grid(rot_images[:4], padding=0))
print('Rotation labels: ', ' '.join(f'{rot_classes[rot_labels[j]]:5s}' for j in range(4)))

#Evaluation code

In [None]:
import time

def run_test(net, testloader, criterion, task):
    correct = 0
    total = 0
    avg_test_loss = 0.0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for images, images_rotated, labels, cls_labels in testloader:
            if task == 'rotation':
              images, labels = images_rotated.to(device), labels.to(device)
            elif task == 'classification':
              images, labels = images.to(device), cls_labels.to(device)
            # TODO: Calculate outputs by running images through the network
            # The class with the highest energy is what we choose as prediction
            
            outputs = net(images)
            predicted = outputs.argmax(axis=1)
            # loss
            
            avg_test_loss += criterion(outputs, labels)  / len(testloader)
            total += outputs.shape[0]
            correct += (predicted == labels).sum().item()
            
    print('TESTING:')
    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f} %')
    print(f'Average loss on the 10000 test images: {avg_test_loss:.3f}')
  
    return 100 * correct / total

In [None]:
def adjust_learning_rate(optimizer, epoch, init_lr, decay_epochs=30):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = init_lr * (0.1 ** (epoch // decay_epochs))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

#Train a ResNet18 on the rotation task

In this section, we will train a ResNet18 model on the rotation task. The input is a rotated image and the model predicts the rotation label. See the Data Setup section for details.

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'mps' if torch.backends.mps.is_available() and torch.backends.mps.is_built() else 'cpu'
device

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18, resnet50, resnet101, resnext101_32x8d
import torch._dynamo

# torch.set_float32_matmul_precision('high')

# torch._dynamo.reset()
net = resnet101(num_classes=4)
net = net.to(device)
# net = torch.compile(net, mode="reduce-overhead")

In [None]:
import torch.optim as optim

lr = 1e-3
restart_epoch = 50
# TODO: Define criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(params=net.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=restart_epoch, T_mult=2, eta_min=1e-8, last_epoch=- 1, verbose=True)
scaler = torch.cuda.amp.GradScaler()

In [None]:
# Both the self-supervised rotation task and supervised CIFAR10 classification are
# trained with the CrossEntropyLoss, so we can use the training loop code.

def train(net, criterion, optimizer, scaler, num_epochs, decay_epochs, init_lr, task, scheduler=None, weight_best=None, weight_interval=None, restart_epoch=None):

    best_acc = -10
    best_epoch = -10

    for epoch in range(num_epochs):  # loop over the dataset multiple times

        running_loss = 0.0
        running_correct = 0.0
        running_total = 0.0
        start_time = time.time()

        net.train()

        for i, (imgs, imgs_rotated, rotation_label, cls_label) in enumerate(trainloader, 0):
            
            if scheduler is None:
                adjust_learning_rate(optimizer, epoch, init_lr, decay_epochs)

            if task == 'rotation':
                images, labels = imgs_rotated.to(device), rotation_label.to(device)
            elif task == 'classification':
                images, labels = imgs.to(device), cls_label.to(device)

            optimizer.zero_grad()

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                outputs = net(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            

            # print(outputs.shape)
            predicted = outputs.argmax(axis=1)

            # print statistics
            print_freq = 50
            running_loss += loss.item()

            # calc acc
            running_total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

            if i % print_freq == (print_freq - 1):    # print every 2000 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / print_freq:.3f} acc: {100*running_correct / running_total:.2f} time: {time.time() - start_time:.2f}')
                running_loss, running_correct, running_total = 0.0, 0.0, 0.0
                start_time = time.time()

        if scheduler is not None:
            scheduler.step()

        if restart_epoch is not None and weight_interval is not None:
            if (epoch+1) % restart_epoch == 0:
                torch.save(
                    net.state_dict(), f'{weight_interval}_epoch{epoch+1}.pt'
                )    
        
        net.eval()
        test_acc = run_test(net, testloader, criterion, task)
        if test_acc > best_acc:
            best_acc = test_acc
            best_epoch = epoch + 1
            if weight_best is not None:
                torch.save(
                    net.state_dict(), weight_best
                )

    print('Finished Training')
    print(f'Best Acc: {best_acc} at epoch: {best_epoch}')

In [None]:
pretrained_weight_interval = 'ssl_pretrained_resnet101_sfd_cosine_v5'
pretrained_weight_best = 'ssl_pretrained_resnet101_sfd_cosine_v5_best.pt'
train(net, criterion, optimizer, scaler, num_epochs=1000, decay_epochs=50, init_lr=0.001, task='rotation', scheduler=scheduler, weight_best=pretrained_weight_best, weight_interval=pretrained_weight_interval, restart_epoch=100)

# resnet18 78.74%
# resnet101 (SGD + CosineAnnealing) 90.76 %


In [None]:
from glob import glob
import time

correct = 0
total = 0
ppp = 0

with torch.no_grad():
    for images, images_rotated, labels, cls_labels in testloader:
        
        images, labels = images_rotated.to(device), labels.to(device)
        predictions = {}
        for i in range(images.shape[0]):
            predictions[i] = []

        for weight in ['ssl_pretrained_resnet101_sfd_cosine_v5_best.pt', 'ssl_pretrained_resnet101_sfd_cosine_v5_epoch700.pt', 'ssl_pretrained_resnet101_sfd_cosine_v4_best.pt']: # choose trained NN candidates you desire
            
            net.load_state_dict(torch.load(weight))
            outputs = net(images)
            pred = outputs.argmax(axis=1)
            
            for idx, i in enumerate(pred):
                predictions[idx].append(int(i))
        
        predicted = torch.ones(images.shape[0])
        for i in range(images.shape[0]):
            predicted[i] = max(set(predictions[i]), key=predictions[i].count)
            
        total += outputs.shape[0]
        correct += (predicted == labels.cpu()).sum().item()
            
    print('TESTING:')
    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f} %')
# resnet101: 90.69%




In [None]:
train(net, criterion, optimizer, scaler, num_epochs=190, decay_epochs=30, init_lr=0.01, task='rotation')

pretrained_weight = 'ssl_pretrained_imagenette.pt'
torch.save(
    net.state_dict(), pretrained_weight
)
# 65.09% imagenette 


##Fine-tuning on the pre-trained model

In this section, we will load the pre-trained ResNet18 model and fine-tune on the classification task. We will freeze all previous layers except for the 'layer4' block and 'fc' layer.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18, ResNet18_Weights
from gc import collect

del net
collect()
net = resnet101(num_classes=4).to(device)
net.load_state_dict(torch.load('ssl_pretrained_resnet101_sfd_cosine_v5_best.pt'))
net.fc = nn.Linear(2048, 10).to(device)
# net = torch.compile(net, mode="reduce-overhead")


In [None]:
for name, param in net.named_parameters():
    if not 'fc' in name and not 'layer4' in name: 
        param.requires_grad = False


In [None]:
# Print all the trainable parameters
params_to_update = net.parameters()
print("Params to learn:")
params_to_update = []
for name,param in net.named_parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
        print("\t",name)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=1e-3)
scaler = torch.cuda.amp.GradScaler()

In [None]:
train(net, criterion, optimizer, scaler, num_epochs=20, decay_epochs=10, init_lr=0.001, task='classification')
# 82.75% (if freeze layer4 and fc) (super weird...) -> for next section, fc and layer4 should have higher lr, otherwise lower
# resnet 18: 61.2

## Fine-tuning on the randomly initialized model
In this section, we will randomly initialize a ResNet18 model and fine-tune on the classification task. We will freeze all previous layers except for the 'layer4' block and 'fc' layer.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18

del net
collect()
net = resnet101(num_classes=10).to(device)
# net = torch.compile(net, mode="reduce-overhead")

In [None]:
for name, param in net.named_parameters():
    if not 'fc' in name and not 'layer4' in name: 
        param.requires_grad = False

In [None]:
# Print all the trainable parameters
params_to_update = net.parameters()
print("Params to learn:")
params_to_update = []
for name,param in net.named_parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
        print("\t",name)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=1e-3)
scaler = torch.cuda.amp.GradScaler()

In [None]:
train(net, criterion, optimizer, scaler, num_epochs=20, decay_epochs=10, init_lr=0.01, task='classification')
# 83.21% (if freeze layer4 and fc) (super weird...) 
# resnet 18: 46.23

##Supervised training on the pre-trained model
In this section, we will load the pre-trained ResNet18 model and re-train the whole model on the classification task.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18

# net = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
net = resnet101(num_classes=4).to(device)
net.load_state_dict(torch.load('ssl_pretrained_resnet101_sfd_cosine_v5_best.pt'))
net.fc = nn.Linear(2048, 10).to(device)
# net = torch.compile(net, mode="reduce-overhead")

In [None]:
# TODO: Define criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(params=net.parameters(), lr=1e-3, weight_decay=1e-3)
scaler = torch.cuda.amp.GradScaler()

In [None]:
train(net, criterion, optimizer, scaler, num_epochs=100, decay_epochs=40, init_lr=0.0005, task='classification')
# resnet18 84.2%
# resnet101 90.32%

##Supervised training on the randomly initialized model
In this section, we will randomly initialize a ResNet18 model and re-train the whole model on the classification task.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18, ResNet18_Weights

# TODO: Randomly initialize a ResNet18 model
del net
collect()
net = resnet101().to(device)
# net = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
# net.fc = nn.Linear(512, 10).to(device)
# net = torch.compile(net, mode="reduce-overhead")

In [None]:
# TODO: Define criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(params=net.parameters(), lr=1e-3, weight_decay=1e-3)
scaler = torch.cuda.amp.GradScaler()

In [None]:
train(net, criterion, optimizer, scaler, num_epochs=210, decay_epochs=40, init_lr=0.01, task='classification')
# resnet18 82.45%
# resnet101 87.55