In [None]:
# Imports
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import pickle
import matplotlib.pyplot as plt
%matplotlib inline

# Hyperparams
# training config
NUM_EPOCHS = 20
LR = 0.001
# dataset config
batch_size = 64
criterion = nn.CrossEntropyLoss()
generator=torch.Generator().manual_seed(42) # Can be included for reproducability

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

In [None]:
# Load Datasets

_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

NUM_CLASSES = 0

def getTrainingSet(dataset_name):
  if dataset_name == 'CIFAR-10':
    NUM_CLASSES=10
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                          download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                          download=True, transform=transform_test)
    trainset, validset = torch.utils.data.random_split(trainset, 
                                                      [int(len(trainset)*0.8),len(trainset)- 
                                                      int(len(trainset)*0.8)], generator=generator)
  elif dataset_name == 'STL10':
    NUM_CLASSES=10
    trainset = torchvision.datasets.STL10(root='./data', split='train',
                                          download=True, transform=transform_train)
    testset = torchvision.datasets.STL10(root='./data', split='test',
                                          download=True, transform=transform_train)
    trainset, validset = torch.utils.data.random_split(trainset, 
                                                      [int(len(trainset)*0.8),len(trainset)- 
                                                      int(len(trainset)*0.8)], generator=generator)
  elif dataset_name == 'Caltech101':
    NUM_CLASSES=101
#     !gdown https://drive.google.com/uc?id=1DX_XeKHn3yXtZ18DD7qc1wf-Jy5lnhD5
#     !unzip -qq '101_ObjectCategories.zip' 
    PATH = '101_ObjectCategories/'
    transform = transforms.Compose(
      [transforms.CenterCrop(256),
      transforms.Resize((64,64)),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    totalset = torchvision.datasets.ImageFolder(PATH, transform=transform_train)
    X, y = zip(*totalset)
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size = 0.3, 
                                                      stratify=y)
    X_val, X_test, y_val, y_test = train_test_split(X_val, y_val, 
                                                    test_size = 0.5, 
                                                    stratify=y_val)
    trainset, validset, testset = list(zip(X_train, y_train)), list(zip(X_val, y_val)), list(zip(X_test, y_test))

  trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=2)
  validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size,
                                            shuffle=False,num_workers=2)
  testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)
  return trainloader, testloader, validloader

In [None]:
class VGGStyleNet(nn.Module):
    def __init__(self, num_classes: int = 10, init_weights: bool = True):
        super(VGGStyleNet, self).__init__()
        self.maxpool = nn.MaxPool2d(1)
        self.maxpool2 = nn.MaxPool2d(2)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.classifier = nn.Sequential(
            nn.Linear(512*8*8, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()
    
    def _initialize_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        out = self.maxpool2(F.relu(self.conv1(x)))
        out = self.maxpool2(F.relu(self.conv2(out)))
        out = F.relu(self.conv3_1(out))
        out = self.maxpool(F.relu(self.conv3_2(out)))
        out = F.relu(self.conv4_1(out))
        out = self.maxpool(F.relu(self.conv4_2(out)))
        out = F.relu(self.conv5_1(out))
        out = self.maxpool(F.relu(self.conv5_2(out)))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out

In [None]:
# Model
net = VGGStyleNet()
net.to(device)
trainloader, testloader, validloader = getTrainingSet("Caltech101")

In [None]:
optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.999, weight_decay=5e-4)

In [None]:
def test_validation():
    val_loss = 0
    total_images = 0
    correct_images = 0
    net.eval()
    with torch.no_grad():
      for batch_index, (images, labels) in enumerate(validloader):
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        loss = criterion(outputs, labels)
        val_loss += loss.item()
        _, predicted = outputs.max(1)
        total_images += labels.size(0)
        correct_images += predicted.eq(labels).sum().item()
    val_accuracy = 100.*correct_images/total_images
    return val_loss/(batch_index+1), val_accuracy

In [None]:
def train(epoch):
  net.train()
  correct_images = 0
  total_images = 0
  training_loss = 0
  for batch_index, (images, labels) in enumerate(tqdm(trainloader)):
    optimizer.zero_grad()
    images, labels = images.to(device), labels.to(device)
    outputs = net(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    training_loss += loss.item()
    _, predicted = outputs.max(1)
    total_images += labels.size(0)
    correct_images += predicted.eq(labels).sum().item()
  print('Epoch: %d, Loss: %.3f, '
              'Accuracy: %.3f%% (%d/%d)' % (epoch, training_loss/(batch_index+1),
                                       100.*correct_images/total_images, correct_images, total_images))

In [None]:
def test():
    test_loss = 0
    total_images = 0
    correct_images = 0
    net.eval()
    with torch.no_grad():
      for batch_index, (images, labels) in enumerate(tqdm(testloader)):
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = outputs.max(1)
        total_images += labels.size(0)
        correct_images += predicted.eq(labels).sum().item()
    test_accuracy = 100.*correct_images/total_images
    print("Loss on Test Set is", test_loss/(batch_index+1))
    print("Accuracy on Test Set is",test_accuracy)

In [None]:
# train different models across momentum values
momentum_values = [0.999, 0.995, 0.99, 0.9]
trainloader, testloader, validloader = getTrainingSet("Caltech101")
for v in momentum_values:
    net = VGGStyleNet(num_classes=101) # make sure to change for caltech101
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=LR, momentum=v, weight_decay=5e-4, nesterov=True)
    model_name = 'vgg_nag'+str(v)
    history = []
    for epoch in range(NUM_EPOCHS):
        train(epoch)
        history.append(test_validation())
    torch.save(net, '../models/caltech101/' + model_name + '.pt')
    outfile = open('../models/caltech101/' + model_name + '_hist.pt','wb')
    pickle.dump(history, outfile)
    outfile.close()

In [None]:
# train single model
model_name = 'vgg_999'
history = []
for epoch in range(NUM_EPOCHS):
  train(epoch)
  history.append(test_validation())
torch.save(net, '../models/' + model_name + '.pt')
outfile = open('../models/' + model_name + '_hist.pt','wb')
pickle.dump(history, outfile)
outfile.close()

In [None]:
# load trained model
net = torch.load('../models/caltech101/vgg_nag0.99.pt')
test()

In [None]:
# Plot Loss
baseline = pickle.load(open("../models/cifar-10/vgg_0_hist.pt", "rb" ))
vgg_999 = pickle.load(open("../models/cifar-10/vgg_nag0.999_hist.pt", "rb" ))
vgg_995 = pickle.load(open("../models/cifar-10/vgg_nag0.995_hist.pt", "rb" ))
vgg_99 = pickle.load(open("../models/cifar-10/vgg_nag0.99_hist.pt", "rb" ))
vgg_9 = pickle.load(open("../models/cifar-10/vgg_nag0.9_hist.pt", "rb" ))
plt.plot([v[0] for v in baseline], '-x', label='baseline')
plt.plot([v[0] for v in vgg_999], '-x', label='0.999')
plt.plot([v[0] for v in vgg_995], '-x', label='0.995')
plt.plot([v[0] for v in vgg_99], '-x', label='0.99')
plt.plot([v[0] for v in vgg_9], '-x', label='0.9')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')
plt.title('Loss vs. No. of epochs - NAG');

In [None]:
# Plot Accuracy
baseline = pickle.load(open("../models/cifar-10/vgg_0_hist.pt", "rb" ))
vgg_999 = pickle.load(open("../models/cifar-10/vgg_nag0.999_hist.pt", "rb" ))
vgg_995 = pickle.load(open("../models/cifar-10/vgg_nag0.995_hist.pt", "rb" ))
vgg_99 = pickle.load(open("../models/cifar-10/vgg_nag0.99_hist.pt", "rb" ))
vgg_9 = pickle.load(open("../models/cifar-10/vgg_nag0.9_hist.pt", "rb" ))
plt.plot([v[1] for v in baseline], '-x', label='baseline')
plt.plot([v[1] for v in vgg_999], '-x', label='0.999')
plt.plot([v[1] for v in vgg_995], '-x', label='0.995')
plt.plot([v[1] for v in vgg_99], '-x', label='0.99')
plt.plot([v[1] for v in vgg_9], '-x', label='0.9')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy (Val) vs. No. of epochs - NAG');