# Import Libraries

In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

import warnings
warnings.filterwarnings("ignore")

# Implement custom PyTorch Dataset Class

In [None]:
from torchvision.io import read_image

class CarsDataset(Dataset):
  def __init__(self, csv_file, root_dir, transform=None):
    self.annotations = pd.read_csv(csv_file)
    self.root_dir = root_dir
    self.transform = transform

  def __len__(self):
    return len(self.annotations)

  def __getitem__(self,index):
    img_path = os.path.join(self.root_dir, self.annotations.iloc[index,0])
    image = read_image(img_path)
    label = int(self.annotations.iloc[index,-1])
    
    if self.transform:
      image = self.transform(image)
    
    return image, label

# PyTorch data transforms

In [None]:
from torchvision.transforms.transforms import ToPILImage
from torch.utils import data

data_transforms = {
    'train': transforms.Compose([      
        transforms.ToPILImage(),        
        #transforms.RandomResizedCrop(224),
        transforms.Resize(256),               # NOT IN ORIGINAL
        transforms.CenterCrop(224),           # NOT ORIGINAL
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([    
        transforms.ToPILImage(),                           
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Create Dataset

In [None]:
train_dataset = CarsDataset('train_w_converted.txt','./train_real', data_transforms['train'])
val_dataset = CarsDataset('test_w_converted.txt','./test_real', data_transforms['val'])
val_dataset, test_dataset = torch.utils.data.random_split(val_dataset, [7470, 7469])
dataset_sizes = {'train':len(train_dataset),'val':len(val_dataset)}

print(f'Length of train_dataset is: {len(train_dataset)}')
print(f'Length of val_dataset is: {len(val_dataset)}')
print(f'Length of val_dataset is: {len(test_dataset)}')

# Create Dataloaders

In [None]:
batch_size = 32

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=16)

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
                                             shuffle=False, num_workers=16)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                             shuffle=False, num_workers=16)

dataloaders = {'train':train_loader,'val':val_loader}

# Import Model

In [None]:
from torchvision import models

resnet = models.resnet152(pretrained= True, progress = True) # import a pretrained PyTorch implementation of Resnet-152

if torch.cuda.is_available(): # switch to GPU if available
  device = 'cuda'
else:
  device = 'cpu'

# Freeze weights and place new FC layer

In [None]:
for param in resnet.parameters(): # freeze model weights
    param.requires_grad = False

num_ftrs = resnet.fc.in_features # number of input features to final fully connected layer
resnet.fc = nn.Linear(num_ftrs,431)

# Define cost function and optimizer

In [None]:
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(resnet.parameters(), lr = 0.001, momentum = 0.9) # default as SGD

# Def train function

In [None]:
import copy
from tqdm import tqdm

def train_model(model, criterion, optimizer, dataloaders, device, num_epochs=50, lrscheduler=None):

    model = model.to(device) # Send model to GPU if available

    iter_num = {'train':0,'val':0} # Track total number of iterations

    train_loss = []
    train_acc = []

    val_loss = []
    val_acc = []

    epoch_list = []
    current_epoch = 0

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        epoch_list.append(current_epoch)
        current_epoch += 1

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Get the input images and labels, and send to GPU if available
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the weight gradients
                optimizer.zero_grad()

                # Forward pass to get outputs and calculate loss
                # Track gradient only for training data
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backpropagation to get the gradients with respect to each weight
                    # Only if in train
                    if phase == 'train':
                        loss.backward()
                        # Update the weights
                        optimizer.step()

                        if type(lrscheduler) == optim.lr_scheduler.OneCycleLR:
                            lrscheduler.step()

                # Convert loss into a scalar and add it to running_loss
                running_loss += loss.item() * inputs.size(0)
                # Track number of correct predictions
                running_corrects += torch.sum(preds == labels.data)

                # Iterate count of iterations
                iter_num[phase] += 1

            # Calculate and display average loss and accuracy for the epoch
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            # print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            print(f'{phase} Loss: {np.round(epoch_loss,2)} Acc: {np.round(epoch_acc.item()*100,2)}%')

            if phase == 'val' and epoch_acc.item() > best_acc:
                best_acc = epoch_acc.item()
                best_model_wts = copy.deepcopy(model.state_dict())
                best_epoch = epoch

            if phase == 'train':
                train_loss.append(epoch_loss)
                train_acc.append(epoch_acc)
            elif phase == 'val':
                val_loss.append(epoch_loss)
                val_acc.append(epoch_acc)

    return epoch_list, train_loss, train_acc, val_loss, val_acc, best_acc, best_model_wts, best_epoch

# Train and evaluate (run only one of the following. You must redownload and strip FC layer before training a different version of model.

### SGD, no LR scheduler:

In [None]:
epoch_list, train_loss, train_acc, val_loss, val_acc, best_acc, best_model_wts, best_epoch = train_model(resnet, criterion, optimizer, dataloaders, device, num_epochs=50, lrscheduler=None)

### Train and evaluate SGD w/ OneCycleLR scheduler

In [None]:
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.30, base_momentum = 0.6, steps_per_epoch=501, epochs=30)

epoch_list, train_loss, train_acc, val_loss, val_acc, best_acc, best_model_wts, best_epoch = train_model(resnet, criterion, optimizer, dataloaders, device, num_epochs=50, lrscheduler=lr_scheduler)

### Train and evaluate AdamW policy (better than Adam), no LR scheduler

In [None]:
optimizer = optim.AdamW(resnet.parameters(), lr = 0.001)

epoch_list, train_loss, train_acc, val_loss, val_acc, best_acc, best_model_wts, best_epoch = train_model(resnet, criterion, optimizer, dataloaders, device, num_epochs=50, lrscheduler=None)

# Visualize results

### Visualize loss:

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

fig = plt.figure(figsize=(15, 9))
plt.plot(epoch_list,train_loss, val_loss)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('Loss')
plt.legend(['train','val'])

### Visualize accuracy:

In [None]:
train_acc_final = []
val_acc_final = []

for i in range(len(train_acc)):
  train_acc_final.append(train_acc[i].item())
  val_acc_final.append(val_acc[i].item())


fig = plt.figure(figsize=(15, 9))
plt.plot(epoch_list,train_acc_final, val_acc_final)
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.title('Accuracy')
plt.legend(['train','val'])