In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.utils.data import Dataset, DataLoader, Subset

import numpy as np
import os
import gc
import copy
import time as tm

from tqdm.autonotebook import tqdm
from itertools import chain
import matplotlib.pyplot as plt

dir = os.path.join(os.getcwd(),"LRDOCC_rot")
if not os.path.exists(dir):
    os.mkdir(dir)
working_dir = os.path.join(os.getcwd(), "LRDOCC_rot")
print(os.listdir(working_dir))
print(os.getcwd())

## Dataset Generation

In [None]:
class RepeatInterleave3D(object):
    """Convert single channel tensor to 3D"""

    def __call__(self, tensor_sample):
        tensor_sample = tensor_sample.repeat_interleave(3, dim=0)
        return tensor_sample

class SqueezeChannelDim(object):

    def __call__(self, tensor_sample):
        tensor_sample = tensor_sample.unsqueeze(0)
        return tensor_sample

class ResizeTensor224(object):

    def __call__(self, tensor_sample):
        tensor_sample = F.interpolate(tensor_sample, (224,224))
        return tensor_sample

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

class AugmentedFMNIST(Dataset):

  def __init__(self, one_hot=False):
    transform_augment = transforms.Compose(
                                  [#transforms.Resize(224),
                                   transforms.ToTensor(),
                                   SqueezeChannelDim(),
                                   RepeatInterleave3D(),
                                   ])

    dataset = torchvision.datasets.FashionMNIST(root='./FashionMNIST', train=True, download=True, transform=transform_augment)

    self.data = []
    for instance in dataset:
      x_0 = instance[0].squeeze() 
      y_0 = torch.tensor([1, 0, 0, 0]) if one_hot else 0

      x_90 = instance[0].squeeze().transpose(1, -2).flip(1)
      y_90 = torch.tensor([0, 1, 0, 0]) if one_hot else 1

      x_180 = instance[0].squeeze().flip(1).flip(0)
      y_180 = torch.tensor([0, 0, 1, 0]) if one_hot else 2

      x_270 = instance[0].squeeze().transpose(1, -2).flip(0)
      y_270 = torch.tensor([0, 0, 0, 1]) if one_hot else 3

      self.data.append((x_0  , y_0  ))
      self.data.append((x_90 , y_90 ))
      self.data.append((x_180, y_180))
      self.data.append((x_270, y_270))
          

  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, i):
    return (self.data[i][0], self.data[i][1])
  

In [None]:
dataset = AugmentedFMNIST()

idx = np.random.choice(list(range(len(dataset))),int(len(dataset)))
batch_size = 32 

train = Subset(dataset, idx[:200000])
test = Subset(dataset, idx[210000:])
val = Subset(dataset, idx[200000:210000])

train_loader = DataLoader(train, batch_size=batch_size, num_workers=4)
val_loader = DataLoader(val, batch_size=batch_size, num_workers=4)
test_loader  = DataLoader(test , batch_size=batch_size, num_workers=4)

dataset_sizes = {"train":len(train),
                 "val":len(val)}

## Model definition

In [None]:
ROT_CLASSES = 4
HEAD_DIMS = 512
DEVICE="cuda"

# Load base network architecture and modify to change outputs
resnet18_rot_pred_model = resnet18()
resnet18_rot_pred_model.avgpool = nn.AdaptiveAvgPool2d((1,1))
resnet18_rot_pred_model.fc = nn.Linear(HEAD_DIMS, HEAD_DIMS)

class LRDOCCRotNet18(nn.Module):
    def __init__(self, base_model):
        super(LRDOCCRotNet18, self).__init__()
        self.base_model = base_model
        self.projection_head = nn.Sequential(
                nn.Linear(HEAD_DIMS, HEAD_DIMS),
                nn.ReLU(),
                nn.Linear(HEAD_DIMS, HEAD_DIMS),
                nn.ReLU(),
                nn.Linear(HEAD_DIMS, HEAD_DIMS),
                nn.ReLU(),
                nn.Linear(HEAD_DIMS, ROT_CLASSES)
        )
    
    def forward(self, x):
        x = self.base_model(x)
        x = self.projection_head(x)
        
        return x

In [None]:
def train_model(model, criterion, trainloader, valloader, optimizer, scheduler, num_epochs=2048):  
    since = tm.time()
    val_loss = []
    train_loss = []
    val_acc = []
    train_acc = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
                dataloader = trainloader
            else:
                model.eval()   # Set model to evaluate mode
                dataloader = valloader

            running_loss = 0.0
            running_corrects = 0
            running_samples = 0

            # Iterate over data.
            for _, data in enumerate(tqdm(dataloader)):
                inputs = data[0].to(DEVICE)
                labels = data[1].to(DEVICE)
                inputs = F.interpolate(inputs, (224,224))
                inputs = normalize(inputs)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase=='train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                running_samples+=inputs.shape[0]
                
            # scheduler step
            if phase == "train":
                scheduler.step()
            
            # track statistics
            epoch_loss = running_loss / dataset_sizes[phase]
            if phase == "train":
                train_loss.append(epoch_loss)
            else:
                val_loss.append(epoch_loss)

            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            if phase == "train":
                train_acc.append(epoch_loss)
            else:
                val_acc.append(epoch_loss) 

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            gc.collect()
            
            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

            if epoch%10==0:
                torch.save(model.state_dict(), os.path.join(working_dir,f"rot_net_sched_lr009_100_epochs{50+epoch}.pth"))

        print()

    time_elapsed = tm.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)

    return model, train_acc, val_acc, train_loss, val_loss

In [None]:

lrdocc_rot_net18 = LRDOCCRotNet18(resnet18_rot_pred_model).to(DEVICE)

loss_function = nn.CrossEntropyLoss()
optimiser = optim.SGD(lrdocc_rot_net18.parameters(),lr=0.009, momentum=0.9, weight_decay=0.0003)
cos_decay_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimiser, 7, verbose=True)
lrdocc_trained_tuple = train_model(lrdocc_rot_net18, loss_function, train_loader, val_loader,optimiser,cos_decay_scheduler)

# Save model 
torch.save(lrdocc_trained[0].state_dict(), os.path.join(working_dir,"fmnist_rotnet.pth"))

In [None]:
"""Load Weights"""
weights_file = os.path.join(working_dir,'fmnist_rotnet.pth')
weights = torch.load(weights_file) 

overfit_net = resnet18()
overfit_net.avgpool = nn.AdaptiveAvgPool2d((1,1))
overfit_net.fc = nn.Linear(HEAD_DIMS, HEAD_DIMS)
model_warmed_up = LRDOCCRotNet18(overfit_net).to("cuda")
model_warmed_up.load_state_dict(weights)

# Attempt to train after initial 50 epochs
loss_function = nn.CrossEntropyLoss()
optimiser_v2 = optim.SGD(model_warmed_up.parameters(),lr=3.4549e-04, momentum=0.9, weight_decay=0.0003)
cos_decay_scheduler_v2 = optim.lr_scheduler.CosineAnnealingLR(optimiser_v2, 7, verbose=True)
lrdocc_trained_v2, train_acc, val_acc, train_loss, val_loss = train_model(model_warmed_up, loss_function, optimiser_v2,cos_decay_scheduler_v2)