In [None]:
import os
import time
import numpy as np
import pandas as pd
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from PIL import Image

from dataset_cifar10 import CIFAR10

import sys

sys.path.insert(0, "..") # to include ../helper_utils.py etc.

from helper_evaluate import compute_accuracy
from helper_data import get_dataloaders_cifar10
from helper_train import train_classifier_simple
from helper_utils import set_all_seeds, set_deterministic

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

### parameters

In [None]:
random_seed = 123
learning_rate = 0.0001
batch_size = 128
epochs = 10

num_classes = 10

DEVICE = "cuda"

### setting

#### Setting a random seed

#### Setting cuDNN and PyTorch algorithmic behavior to deterministic

In [None]:
set_all_seeds(random_seed)

set_deterministic()

### data

In [None]:
train_transforms = transforms.Compose([transforms.Resize((70, 70)),
                                       transforms.RandomCrop((64, 64)),
                                       transforms.ToTensor()])

test_transforms = transforms.Compose([transforms.Resize((70, 70)),
                                      transforms.CenterCrop((64, 64)),
                                      transforms.ToTensor()])

train_loader, valid_loader, test_loader = get_dataloaders_cifar10(CIFAR10, batch_size,
                                                                  train_transforms=train_transforms,
                                                                  test_transforms=test_transforms,
                                                                  num_workers=2, 
                                                                  validation_fraction=0.1)

# Checking the dataset
print('Training Set:\n')
for images, labels in train_loader:  
    print('Image batch dimensions:', images.size())
    print('Image label dimensions:', labels.size())
    print(labels[:10])
    break
    
# Checking the dataset
print('\nValidation Set:')
for images, labels in valid_loader:  
    print('Image batch dimensions:', images.size())
    print('Image label dimensions:', labels.size())
    print(labels[:10])
    break

# Checking the dataset
print('\nTesting Set:')
for images, labels in train_loader:  
    print('Image batch dimensions:', images.size())
    print('Image label dimensions:', labels.size())
    print(labels[:10])
    break

### model

In [None]:
class AlexNet(nn.Module):
    def __init__(self, num_classes):
        super(AlexNet, self).__init__()
        
        self.num_classes = num_classes
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6,6))
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(-1, 256*6*6)
        logits = self.classifier(x)
        probas = F.softmax(logits, dim=1)
        return logits
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, model="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = AlexNet(num_classes=num_classes)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

### train

In [None]:
log_dict = train_classifier_simple(epochs=epochs, model=model, 
                                      optimizer=optimizer, device=device, 
                                      train_loader=train_loader, valid_loader=valid_loader, 
                                      logging_interval=50)

In [None]:
plt.plot(log_dict['train_loss_per_batch'], label='Minibatch cost')
plt.plot(np.convolve(log_dict['valid_loss_per_epoch'], np.ones(200,)/200, mode='valid'), 
         label='Running average')

plt.ylabel('Cross Entropy')
plt.xlabel('Iteration')
plt.legend()
plt.show()

In [None]:
plt.plot(np.arange(1, epochs+1), log_dict['train_acc_per_batch'], label='Training')
plt.plot(np.arange(1, epochs+1), log_dict['valid_acc_per_batch'], label='Validation')

plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

### evaluate

In [None]:
with torch.set_grad_enabled(False):
    
    train_acc = compute_accuracy(model=model,
                                 data_loader=test_loader,
                                 device=device)
    
    test_acc = compute_accuracy(model=model,
                                data_loader=test_loader,
                                device=device)
    
    valid_acc = compute_accuracy(model=model,
                                 data_loader=valid_loader,
                                 device=device)
    

print(f'Train ACC: {valid_acc:.2f}%')
print(f'Validation ACC: {valid_acc:.2f}%')
print(f'Test ACC: {test_acc:.2f}%')