In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import CyclicLR
from sklearn.metrics import confusion_matrix
import seaborn as sns


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x
            
  
# Data augmentation transformations
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),

    transforms.RandomRotation(10),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Class names for images
class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm',]

# Updated data loaders with data augmentation
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=transform_train),
    batch_size=320, shuffle=True, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=transform_test),
    batch_size=320, shuffle=False, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')

In [None]:
# Show some of the images
plt.rcParams['figure.dpi'] = 100
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    img = test_loader.dataset[i+100][0].numpy().transpose(1, 2, 0)
    plt.imshow(img*0.5+0.5)
    plt.xlabel(class_names[test_loader.dataset[i+100][1]])
plt.show()

In [None]:
# ResNet CNN model.
# Model has been inspired from (https://github.com/akamaster/pytorch_resnet_cifar10) (see top for references)

# Basic building block for ResNet model.
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
# #         Added dropout - does not improve
#         self.dropout = nn.Dropout(p=0.2)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
              nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
              nn.BatchNorm2d(self.expansion*planes)
          )
            
    def forward(self, x):
        out = F.elu(self.bn1(self.conv1(x)))
        out = F.elu(self.bn2(self.conv2(out)))
        out += self.shortcut(x)
        out = F.elu(out)
        return out

# Defines the ResNet model.
class Classifier(nn.Module):
    def __init__(self, block, num_blocks, num_classes=100):
        super(Classifier, self).__init__()
        self.in_planes = 8

        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(8)
        self.layer1 = self._make_layer(block, 12, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 19, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 32, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 64, num_blocks[3], stride=2)
        self.linear = nn.Linear(64*block.expansion, num_classes)
        
        # Adjusted convolution layer for input x
        self.conv1_adjusted = nn.Conv2d(3, 32, kernel_size=5, stride=4, padding=1, bias=False)
        self.bn1_adjusted = nn.BatchNorm2d(32)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.elu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        
#       Add input (x) to out
        x_adjusted = F.elu(self.bn1_adjusted(self.conv1_adjusted(x)))
        out += x_adjusted
        
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

# Contains 2 layer 1 BasicBlocks, 2 layer 2 BasicBlocks and 1 layer 3 + 4 BasicBlock.
N = Classifier(BasicBlock, [2,2,1,1]).to(device)
# Wraps model with DataParallel to enable parallel processing across multiple GPUs.
N = torch.nn.DataParallel(N)
# Enables benchmark mode
# cudnn.benchmark = True


# print the number of parameters
print(f'> Number of parameters {len(torch.nn.utils.parameters_to_vector(N.parameters()))}')

if len(torch.nn.utils.parameters_to_vector(N.parameters())) > 100000:
    print("> Warning: you have gone over your parameter budget and will have a grade penalty!")

# initialise the optimiser
optimiser = torch.optim.Adam(N.parameters(), lr=0.01)

plot_data = []
steps = 0


In [None]:
grad_clip = 0.1

# Scheduler to decrease learning rate every 1000 steps
scheduler = StepLR(optimiser, step_size=1000, gamma=0.95)

lr_arr = np.zeros(0)
steps_arr = np.zeros(0)

# keep within our optimisation step budget
while (steps < 10000):

    # arrays for metrics
    train_loss_arr = np.zeros(0)
    train_acc_arr = np.zeros(0)
    test_acc_arr = np.zeros(0)
    

    # iterate through some of the train dateset
    for i in range(1000):
        x,t = next(train_iterator)
        x,t = x.to(device), t.to(device)
#       Zero the optimizer gradients
        optimiser.zero_grad()
#       Make predictions on input data
        p = N(x)
        pred = p.argmax(dim=1, keepdim=True)
#       Compare predictions to target labels
        loss = torch.nn.functional.cross_entropy(p, t)
        
#       Computes the gradfient of the loss with respect to model parameters using backpropagation
        loss.backward()
        
#       Clip gradient to prevent them from becoming too large
        if grad_clip:
            nn.utils.clip_grad_value_(N.parameters(), grad_clip)
        
#       Update model parameters
        optimiser.step()
        steps += 1
        
        # Record the current learning rate and step
        current_lr = optimiser.param_groups[0]['lr']
        lr_arr = np.append(lr_arr, current_lr)
        steps_arr = np.append(steps_arr, steps)
        
#       Adjust the learning rate based on schedule
        scheduler.step()

        train_loss_arr = np.append(train_loss_arr, loss.cpu().data)
        train_acc_arr = np.append(train_acc_arr, pred.data.eq(t.view_as(pred)).float().mean().item())

    # iterate over the entire test dataset
    for m,l in test_loader:
        m,l = m.to(device), l.to(device)
        p_test = N(m)
        loss = torch.nn.functional.cross_entropy(p_test, l)
        pred = p_test.argmax(dim=1, keepdim=True)
        test_acc_arr = np.append(test_acc_arr, pred.data.eq(l.view_as(pred)).float().mean().item())

    # print your loss and accuracy data
    print('steps: {:.2f}, train loss: {:.3f}, train acc: {:.3f}±{:.3f}, test acc: {:.3f}±{:.3f}'.format(
        steps, train_loss_arr.mean(),train_acc_arr.mean(),train_acc_arr.std(),test_acc_arr.mean(),test_acc_arr.std()))

    # plot your accuracy graph
    plot_data.append([steps, np.array(train_acc_arr).mean(), np.array(train_acc_arr).std(), np.array(test_acc_arr).mean(), np.array(test_acc_arr).std()])
    reward_list = []
    plt.plot([x[0] for x in plot_data], [x[1] for x in plot_data], '-', color='tab:grey', label="Train accuracy")
    plt.fill_between([x[0] for x in plot_data], [x[1]-x[2] for x in plot_data], [x[1]+x[2] for x in plot_data], alpha=0.2, color='tab:grey')
    plt.plot([x[0] for x in plot_data], [x[3] for x in plot_data], '-', color='tab:purple', label="Test accuracy")
    plt.fill_between([x[0] for x in plot_data], [x[3]-x[4] for x in plot_data], [x[3]+x[4] for x in plot_data], alpha=0.2, color='tab:purple')
    plt.xlabel('Steps')
    plt.ylabel('Accuracy')
    plt.legend(loc="upper left")
    plt.show()
    disp.clear_output(wait=True)

In [None]:
# Save model.
# model_path = 'Models/disciminative.pth'
# torch.save(N, model_path)

In [None]:
# # Load model.
# model_path = 'Models/disciminative.pth'
# N = torch.load(model_path)


In [None]:
# Plot learning rate over time
plt.plot(steps_arr, lr_arr, label='Learning Rate')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.legend()
plt.show()

In [None]:
# Plot confusion matrix for different superclasses

different_class_names = ['bear', 'sea', 'man', 'motorcycle', 'apple', 'rose', 'oak_tree', 'table', 'bottle']


# Get the indices corresponding to animal classes
different_indices = [class_names.index(different) for different in different_class_names]


# iterate over the entire test dataset
all_preds = []
all_labels = []

for m, l in test_loader:
    m, l = m.to(device), l.to(device)
    p_test = N(m)
    pred = p_test.argmax(dim=1, keepdim=True)
    all_preds.extend(pred.cpu().numpy())
    all_labels.extend(l.cpu().numpy())

# Get the indices corresponding to animal superclasses
# animal_indices = [4, 6, 8, 9, 11, 13, 16, 17]

# Get the class labels for animal superclasses
different_labels = [class_names[idx] for idx in different_indices]

conf_matrix = confusion_matrix(all_labels, all_preds)

# Filter the confusion matrix to include only animal indices
different_conf_matrix = conf_matrix[different_indices][:, different_indices]

# Plot the filtered confusion matrix with animal labels
plt.figure(figsize=(10, 5))
sns.heatmap(different_conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=different_labels, yticklabels=different_labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title("Different Superclass's Confusion Matrix")
plt.show()



In [None]:
# Plot confusion matrix for animal superclass

animal_class_names = ['bear', 'beaver', 'bee', 'butterfly', 'cattle', 'chimpanzee',
                      'cockroach', 'crocodile', 'dinosaur', 'dolphin', 'hamster',
                      'kangaroo', 'leopard', 'otter', 'porcupine', 'possum',
                      'rabbit', 'raccoon', 'ray', 'seal', 'shark', 'shrew', 'skunk', 'squirrel',
                      'tiger', 'turtle', 'whale', 'wolf']

# Get the indices corresponding to animal classes
animal_indices = [class_names.index(animal) for animal in animal_class_names]


# iterate over the entire test dataset
all_preds = []
all_labels = []

for m, l in test_loader:
    m, l = m.to(device), l.to(device)
    p_test = N(m)
    pred = p_test.argmax(dim=1, keepdim=True)
    all_preds.extend(pred.cpu().numpy())
    all_labels.extend(l.cpu().numpy())

# Get the class labels for animal superclasses
animal_labels = [class_names[idx] for idx in animal_indices]

conf_matrix = confusion_matrix(all_labels, all_preds)

# Filter the confusion matrix
animal_conf_matrix = conf_matrix[animal_indices][:, animal_indices]

# Plot the filtered confusion matrix
plt.figure(figsize=(15, 10))
sns.heatmap(animal_conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=animal_labels, yticklabels=animal_labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Animal Superclass Confusion Matrix')

# Save the plot as a JPEG image
plt.savefig('animal_confusion_matrix.jpg', format='jpeg')
plt.show()


In [None]:
# Plot confusion matrix for scenery superclass

outdoor_class_names = ['plain', 'sea', 'cloud', 'forest', 'mountain']

# Get the indices corresponding to scenery classes
outdoor_indices = [class_names.index(outdoor) for outdoor in outdoor_class_names]


# iterate over the entire test dataset
all_preds = []
all_labels = []

for m, l in test_loader:
    m, l = m.to(device), l.to(device)
    p_test = N(m)
    pred = p_test.argmax(dim=1, keepdim=True)
    all_preds.extend(pred.cpu().numpy())
    all_labels.extend(l.cpu().numpy())


# Get the class labels
outdoor_labels = [class_names[idx] for idx in outdoor_indices]

conf_matrix = confusion_matrix(all_labels, all_preds)

# Filter the confusion matrix
outdoor_conf_matrix = conf_matrix[outdoor_indices][:, outdoor_indices]

# Plot the filtered confusion matrix
plt.figure(figsize=(10, 5))
sns.heatmap(outdoor_conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=outdoor_labels, yticklabels=outdoor_labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Outdoor Scenes Superclass Confusion Matrix')
plt.show()


In [None]:
# Plot confusion matrix for vehicle superclass

vehicle_class_names = ['motorcycle', 'lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor', 'bus', 'bicycle', 'train' ]

# Get the indices corresponding to vehicle classes
vehicle_indices = [class_names.index(vehicle) for vehicle in vehicle_class_names]


# iterate over the entire test dataset
all_preds = []
all_labels = []

for m, l in test_loader:
    m, l = m.to(device), l.to(device)
    p_test = N(m)
    pred = p_test.argmax(dim=1, keepdim=True)
    all_preds.extend(pred.cpu().numpy())
    all_labels.extend(l.cpu().numpy())

# Get the class labels
vehicle_labels = [class_names[idx] for idx in vehicle_indices]

conf_matrix = confusion_matrix(all_labels, all_preds)

# Filter the confusion matrix
vehicle_conf_matrix = conf_matrix[vehicle_indices][:, vehicle_indices]

# Plot the filtered confusion matrix
plt.figure(figsize=(10, 5))
sns.heatmap(vehicle_conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=vehicle_labels, yticklabels=vehicle_labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Vehicle Superclass Confusion Matrix')
plt.show()


In [None]:
# Plot confusion matrix for people superclass

people_class_names = ['man', 'woman', 'boy', 'girl', 'baby']

# Get the indices corresponding to people classes
people_indices = [class_names.index(people) for people in people_class_names]


# iterate over the entire test dataset
all_preds = []
all_labels = []

for m, l in test_loader:
    m, l = m.to(device), l.to(device)
    p_test = N(m)
    pred = p_test.argmax(dim=1, keepdim=True)
    all_preds.extend(pred.cpu().numpy())
    all_labels.extend(l.cpu().numpy())


# Get the class labels
people_labels = [class_names[idx] for idx in people_indices]

conf_matrix = confusion_matrix(all_labels, all_preds)

# Filter the confusion matrix
people_conf_matrix = conf_matrix[people_indices][:, people_indices]

# Plot the filtered confusion matrix
plt.figure(figsize=(10, 5))
sns.heatmap(people_conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=people_labels, yticklabels=people_labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('People Superclass Confusion Matrix')
plt.show()


In [None]:
def plot_image(i, predictions_array, true_label, img):
    predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    img = np.transpose(img, (1, 2, 0))
    plt.imshow(img*0.5+0.5)

    predicted_label = np.argmax(predictions_array)
    color = '#335599' if predicted_label == true_label else '#ee4433'

    plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                  100*np.max(predictions_array),
                                  class_names[true_label]),
                                  color=color)

def plot_value_array(i, predictions_array, true_label):
    predictions_array, true_label = predictions_array[i], true_label[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    thisplot = plt.bar(range(100), predictions_array, color="#777777")
    plt.ylim([0, 1])
    predicted_label = np.argmax(predictions_array)

    thisplot[predicted_label].set_color('#ee4433')
    thisplot[true_label].set_color('#335599')

test_images, test_labels = next(test_iterator)
test_images, test_labels = test_images.to(device), test_labels.to(device)
test_preds = torch.softmax(N(test_images), dim=1).data.cpu().numpy()
num_rows = 8
num_cols = 4
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
    plt.subplot(num_rows, 2*num_cols, 2*i+1)
    plot_image(i, test_preds, test_labels.cpu().numpy(), test_images.cpu().numpy()) # Used .numpy() here
    plt.subplot(num_rows, 2*num_cols, 2*i+2)
    plot_value_array(i, test_preds, test_labels.cpu().numpy())