# Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%pwd

In [None]:
!nvidia-smi

# Import libraries

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import shutil

import torch
from torch import optim, nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torchvision.utils import make_grid
from torchvision import transforms as T
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
import torchvision.models as models
from torch import linalg as LA

# Define device (GPU or CPU)

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

# Download Tiny ImageNet data

In [None]:
if 'tiny-imagenet-200' in os.listdir():
  shutil.rmtree('tiny-imagenet-200')
else:
  print('tiny-imagenet-200 not existed')
  
if 'tiny-imagenet-200.zip' in os.listdir():
  os.remove('tiny-imagenet-200.zip')
else:
  print('tiny-imagenet-200.zip not existed')

In [None]:
# Download
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
  
# Unzip
!unzip -qq 'tiny-imagenet-200.zip'

# Define directory

In [None]:
# Define main data directory
DATA_DIR = 'tiny-imagenet-200' # Original images come in shapes of [3,64,64]

# Define training and validation data paths
TRAIN_DIR = os.path.join(DATA_DIR, 'train') 
VALID_DIR = os.path.join(DATA_DIR, 'val')

# Create validation labels from val_annotations.txt

In [None]:
val_img_dir = os.path.join(VALID_DIR, 'images')

# Open and read val annotations text file
fp = open(os.path.join(VALID_DIR, 'val_annotations.txt'), 'r')
data = fp.readlines()

val_img_dict = {}
for line in data:
    words = line.split('\t')
    val_img_dict[words[0]] = words[1]
fp.close()

# Show val_img_dict (first 5)
for i, (k, v) in enumerate(val_img_dict.items()):
  print(k, ":", v)
  if i == 4:
    break

# Create subfolders (if not present)

In [None]:
for img, folder in val_img_dict.items():
    newpath = (os.path.join(val_img_dir, folder))
    if not os.path.exists(newpath):
        os.makedirs(newpath)
    if os.path.exists(os.path.join(val_img_dir, img)):
        os.rename(os.path.join(val_img_dir, img), os.path.join(newpath, img))

# Define transformation sequence

In [None]:
# If not using pre-trained model, normalize with 0.5, 0.5, 0.5 (mean and SD)
# If using pre-trained ImageNet, normalize with mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess_transform = T.Compose([
                                  T.Resize(256), 
                                  T.CenterCrop(224), 
                                  T.RandomHorizontalFlip(),
                                  T.ToTensor(),  
                                  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

preprocess_transform_validation = T.Compose([
                                             T.Resize(256), 
                                             T.CenterCrop(224), 
                                             T.ToTensor(),  
                                             T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define function to create dataloaders

In [None]:
def generate_dataset(data, transform):
    if transform is None:
        dataset = datasets.ImageFolder(data, transform = T.ToTensor())
    else:
        dataset = datasets.ImageFolder(data, transform = transform)
    
    return dataset

def generate_dataloader(dataset, name):
    if use_cuda:
        kwargs = {"pin_memory": True, "num_workers": 1}
    else:
        kwargs = {}

    dataloader = DataLoader(dataset, batch_size = batch_size, shuffle=(name=="train"), **kwargs)   

    return dataloader

# Create dataloader

In [None]:
# Define batch size for DataLoaders
batch_size = 128

# Create DataLoader for training data
train_dataset = generate_dataset(TRAIN_DIR, preprocess_transform)
train_loader = generate_dataloader(train_dataset, "train")

# Create DataLoader for validation data
val_dataset = generate_dataset(val_img_dir, preprocess_transform_validation)
val_loader = generate_dataloader(val_dataset, "validation")

# Define model architecture

In [None]:
model = models.resnet18(pretrained = True)

#Finetune Final few layers to adjust for tiny imagenet input
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 200)

# Move model to designated device (Use GPU when on Colab)
model = model.to(device)

# ResNet reconstraction

In [None]:
# Create model class
class SplittedResNet18(nn.Module):
    def __init__(self, resnet18):
        super().__init__()
        self.cnn = nn.Sequential(*list(resnet18.children())[:-1])
        self.flatten = nn.Flatten()
        self.fc = resnet18.fc

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        representation = self.flatten(self.cnn(x))
        output = self.fc(representation)
        return representation, output

# Create model instance
splitted_model = SplittedResNet18(model)
del model

# Define loss function

In [None]:
class CrossEntropyOSLoss(nn.Module):
    def __init__(self, regularization_param):
        super().__init__()
        self.Cross_Entropy_Loss = nn.CrossEntropyLoss()
        self.alpha = regularization_param

    def forward(self, output, representation, target):

        # Calculate Cross Entropy Loss value
        CEL_value = self.Cross_Entropy_Loss(output, target)

        # Calculate constraint for orthogonal features
        normalized_representation = F.normalize(representation, p = 2, dim = 1)
        OS_value = torch.add(
            torch.matmul(torch.t(normalized_representation), normalized_representation), 
            torch.eye(normalized_representation.size()[1], device = device), 
            alpha = -1
        )
        OS_value = self.alpha*LA.norm(OS_value, ord = "fro")

        return CEL_value + OS_value

# Define hyperparameters and settings

In [None]:
lr = 0.001  # Learning rate
momentum = 0.9 # Momentum
num_epochs = 30  # Number of epochs
log_interval = 500  # Number of iterations before logging
step_size = 7 # for lr_scheduler
gamma = 0.1 # for lr_scheduler
alpha = 0.01 # for CrossEntropyOSLoss

# Set loss function
loss_func = CrossEntropyOSLoss(alpha)

# Set optimizer
optimizer = optim.SGD(splitted_model.parameters(), lr = lr, momentum = momentum)
scheduler = lr_scheduler.StepLR(optimizer, step_size = step_size, gamma = gamma)

# Trainning

In [None]:
%%time

loss_list = []
val_accuracy_list = []
best_accuracy = 0
model_name = 'ResNet18_OS_' + str(alpha) + '.pth'

for epoch in range(1, num_epochs+1):
    
    # Training step
    splitted_model.train()
    for i, (input_images, labels) in enumerate(train_loader):

        input_images = input_images.to(device)
        labels = labels.to(device)

        representation, output = splitted_model(input_images)
        loss = loss_func(output, representation, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_list.append(loss.detach())
      
    scheduler.step()

    # Validation step
    splitted_model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for input_images, labels in val_loader:

            input_images = input_images.to(device)
            labels = labels.to(device)

            _, predicted_values = splitted_model(input_images)
            _, predicted_labels = torch.max(predicted_values.data, 1)

            total += labels.shape[0]
            correct += (predicted_labels == labels).sum().item()

    current_accuracy = (correct/total)*100
    val_accuracy_list.append(current_accuracy)
    print("Accuracy of epoch", epoch, "is", f"{current_accuracy:.3f}%")

    # Save model
    if current_accuracy > best_accuracy:
        torch.save(splitted_model.state_dict(), model_name)
        best_accuracy = current_accuracy
        print('Model of epoch', epoch, 'was saved.')
    else:
        print('Model was not saved.')

print(" --------------- Train complete ---------------")

# Plot loss values

In [None]:
fig, ax = plt.subplots(1, 2, figsize = (12, 4))

ax[0].plot(loss_list)
ax[0].set_xlabel("Iteration")
ax[0].set_ylabel('Training loss value')

ax[1].plot(val_accuracy_list)
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel('Validation accuracy')

plt.show()

# Save model to Google Drive 

In [None]:
shutil.move(model_name, '/content/drive/MyDrive/work/orthgonal_constraint/' + model_name)