In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
import datetime
from matplotlib import pyplot as plt

%matplotlib inline
%run ./data_loading.ipynb

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=6, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)

        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(in_features=10*10*12, out_features=300)
        self.fc2 = nn.Linear(in_features=300, out_features=60)
        self.fc3 = nn.Linear(in_features=60, out_features=9)
        
    def forward(self, X):
        X = self.relu(self.conv1(X))
        X = self.pool(self.relu(self.conv2(X)))
        X = self.relu(self.conv3(X))
        X = self.pool(self.relu(self.conv4(X)))
        X = X.view(-1, 10*10*12)
        X = self.relu(self.fc1(X))
        X = self.relu(self.fc2(X))
        X = self.fc3(X)
        return X

In [3]:
def calculate_accuracy(dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    return acc

In [4]:
def train(dataset, epochs, batch_size=5, to_cuda=True, calculate_acc=False):
    train_acc = []
    val_acc = []
    if to_cuda:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    else:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for epoch in tqdm(range(epochs)):
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data
            if to_cuda:
                inputs = inputs.cuda()
                labels = labels.cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = loss_f(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
        print('epoch:', epoch, ', loss:', running_loss * batch_size / dataloader.__len__())
        if calculate_acc:
            t = calculate_accuracy(dataloader)
            v = calculate_accuracy(val_loader)
            print("train acc:", t, "val acc:", v)
            train_acc.append(t)
            val_acc.append(v)
        if epoch % 1 == 0: # every x epochs reshuffle data by creating a new loader
            if to_cuda:
                dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
            else:
                dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return train_acc[-1], val_acc[-1]

In [5]:
cr_test = ClassroomDataset(range_10=[7,10])
# Dividing data into 50% train, 40% val, 30% test
hyper_scores = {}
for wd in [0.0001, 0.001, 0.01, 0.1]:
    scores = []
    for i in range(1):
        net = Net().float()
        if torch.cuda.is_available():
            net = net.cuda()
        loss_f = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=wd)
        if i == 0:
            cr_train = ClassroomDataset(range_10=[0,5])
            cr_val = ClassroomDataset(range_10=[5,7])
        else:
            cr_train = ClassroomDataset(range_10=[2,7])
            cr_val = ClassroomDataset(range_10=[0,2])
        val_loader = DataLoader(cr_val, batch_size=1, shuffle=True)
        train_accuracy, val_accuracy = train(cr_train, to_cuda=False, epochs=20, calculate_acc=True)
        print(train_accuracy, val_accuracy)
        scores.append((train_accuracy, val_accuracy))
    hyper_scores[wd] = scores

  0%|                                                                                   | 0/20 [00:00<?, ?it/s]

epoch: 0 , loss: 10.05963523616083
train acc: 62.03125 val acc: 59.84375


  5%|███▊                                                                       | 1/20 [00:19<06:19, 20.00s/it]

epoch: 1 , loss: 4.0547609763052606
train acc: 69.125 val acc: 66.25


 10%|███████▌                                                                   | 2/20 [00:37<05:46, 19.27s/it]

epoch: 2 , loss: 2.5736156189159374
train acc: 87.84375 val acc: 84.21875


 15%|███████████▎                                                               | 3/20 [00:56<05:25, 19.12s/it]

epoch: 3 , loss: 1.9215270133236544
train acc: 86.96875 val acc: 82.34375


 20%|███████████████                                                            | 4/20 [01:14<04:59, 18.74s/it]

epoch: 4 , loss: 1.3029365969648552
train acc: 94.8125 val acc: 90.078125


 25%|██████████████████▊                                                        | 5/20 [01:31<04:35, 18.39s/it]

epoch: 5 , loss: 1.0459089065748035
train acc: 95.5 val acc: 89.53125


 30%|██████████████████████▌                                                    | 6/20 [01:48<04:11, 17.94s/it]

epoch: 6 , loss: 0.7551473406243989
train acc: 97.71875 val acc: 91.5625


 35%|██████████████████████████▎                                                | 7/20 [02:05<03:48, 17.61s/it]

epoch: 7 , loss: 0.5927228126885176
train acc: 97.4375 val acc: 90.9375


 40%|██████████████████████████████                                             | 8/20 [02:23<03:32, 17.70s/it]

epoch: 8 , loss: 0.4723205506438061
train acc: 97.9375 val acc: 91.25


 45%|█████████████████████████████████▊                                         | 9/20 [02:45<03:30, 19.16s/it]

epoch: 9 , loss: 0.4219254073790695
train acc: 98.5 val acc: 91.875


 50%|█████████████████████████████████████                                     | 10/20 [03:07<03:18, 19.86s/it]

epoch: 10 , loss: 0.29928346799359473
train acc: 98.65625 val acc: 91.953125


 55%|████████████████████████████████████████▋                                 | 11/20 [03:29<03:04, 20.45s/it]

epoch: 11 , loss: 0.35236670496003286
train acc: 98.96875 val acc: 91.484375


 60%|████████████████████████████████████████████▍                             | 12/20 [03:51<02:46, 20.84s/it]

epoch: 12 , loss: 0.1871524242728363
train acc: 99.4375 val acc: 91.875


 65%|████████████████████████████████████████████████                          | 13/20 [04:11<02:25, 20.86s/it]

epoch: 13 , loss: 0.17987331065891454
train acc: 99.75 val acc: 92.578125


 70%|███████████████████████████████████████████████████▊                      | 14/20 [04:29<01:59, 19.96s/it]

epoch: 14 , loss: 0.06080484003781583
train acc: 99.9375 val acc: 93.203125


 75%|███████████████████████████████████████████████████████▌                  | 15/20 [04:48<01:38, 19.68s/it]

epoch: 15 , loss: 0.02020101397642249
train acc: 99.96875 val acc: 92.890625


 80%|███████████████████████████████████████████████████████████▏              | 16/20 [05:09<01:20, 20.11s/it]

epoch: 16 , loss: 0.010572648131224494
train acc: 99.96875 val acc: 93.359375


 85%|██████████████████████████████████████████████████████████████▉           | 17/20 [05:30<01:00, 20.20s/it]

epoch: 17 , loss: 0.01142047476633512
train acc: 99.96875 val acc: 93.28125


 90%|██████████████████████████████████████████████████████████████████▌       | 18/20 [05:48<00:38, 19.48s/it]

epoch: 18 , loss: 0.005160542957626468
train acc: 99.96875 val acc: 93.125


 95%|██████████████████████████████████████████████████████████████████████▎   | 19/20 [06:05<00:18, 18.95s/it]

epoch: 19 , loss: 0.0056029237278956145
train acc: 100.0 val acc: 93.125


100%|██████████████████████████████████████████████████████████████████████████| 20/20 [06:23<00:00, 18.44s/it]


100.0 93.125


  0%|                                                                                   | 0/20 [00:00<?, ?it/s]

epoch: 0 , loss: 10.202701816568151
train acc: 57.34375 val acc: 56.171875


  5%|███▊                                                                       | 1/20 [00:18<05:54, 18.64s/it]

epoch: 1 , loss: 4.409250341632287
train acc: 78.84375 val acc: 76.484375


 10%|███████▌                                                                   | 2/20 [00:37<05:37, 18.76s/it]

epoch: 2 , loss: 2.8797251306732505
train acc: 86.96875 val acc: 84.765625


 15%|███████████▎                                                               | 3/20 [00:54<05:07, 18.07s/it]

epoch: 3 , loss: 2.167702009680397
train acc: 89.34375 val acc: 86.640625


 20%|███████████████                                                            | 4/20 [01:10<04:41, 17.61s/it]

epoch: 4 , loss: 1.5477017913822806
train acc: 90.65625 val acc: 86.640625


 25%|██████████████████▊                                                        | 5/20 [01:27<04:22, 17.52s/it]

epoch: 5 , loss: 1.2577872406702681
train acc: 94.46875 val acc: 89.6875


 30%|██████████████████████▌                                                    | 6/20 [01:45<04:03, 17.37s/it]

epoch: 6 , loss: 1.0332342510754842
train acc: 96.03125 val acc: 89.296875


 35%|██████████████████████████▎                                                | 7/20 [02:01<03:42, 17.08s/it]

epoch: 7 , loss: 0.7838175966161955
train acc: 96.78125 val acc: 90.46875


 40%|██████████████████████████████                                             | 8/20 [02:17<03:22, 16.84s/it]

epoch: 8 , loss: 0.5358308075646059
train acc: 96.5 val acc: 89.296875


 45%|█████████████████████████████████▊                                         | 9/20 [02:33<03:01, 16.54s/it]

epoch: 9 , loss: 0.5411396319553017
train acc: 97.75 val acc: 91.5625


 50%|█████████████████████████████████████                                     | 10/20 [02:49<02:43, 16.32s/it]

epoch: 10 , loss: 0.5482095829653573
train acc: 98.8125 val acc: 91.328125


 55%|████████████████████████████████████████▋                                 | 11/20 [03:05<02:25, 16.16s/it]

epoch: 11 , loss: 0.4432037364687112
train acc: 98.1875 val acc: 90.46875


 60%|████████████████████████████████████████████▍                             | 12/20 [03:24<02:17, 17.22s/it]

epoch: 12 , loss: 0.2709654934482977
train acc: 97.90625 val acc: 90.078125


 65%|████████████████████████████████████████████████                          | 13/20 [03:43<02:04, 17.76s/it]

epoch: 13 , loss: 0.36391195767905193
train acc: 98.3125 val acc: 91.09375


 70%|███████████████████████████████████████████████████▊                      | 14/20 [04:01<01:46, 17.82s/it]

epoch: 14 , loss: 0.21319290415318548
train acc: 99.375 val acc: 91.09375


 75%|███████████████████████████████████████████████████████▌                  | 15/20 [04:19<01:28, 17.73s/it]

epoch: 15 , loss: 0.13470771792209568
train acc: 99.6875 val acc: 92.5


 80%|███████████████████████████████████████████████████████████▏              | 16/20 [04:36<01:10, 17.59s/it]

epoch: 16 , loss: 0.16427539331235308
train acc: 99.6875 val acc: 92.34375


 85%|██████████████████████████████████████████████████████████████▉           | 17/20 [04:53<00:52, 17.50s/it]

epoch: 17 , loss: 0.37945052264898305
train acc: 99.71875 val acc: 92.578125


 90%|██████████████████████████████████████████████████████████████████▌       | 18/20 [05:12<00:35, 17.70s/it]

epoch: 18 , loss: 0.06167690349170307
train acc: 99.96875 val acc: 92.109375


 95%|██████████████████████████████████████████████████████████████████████▎   | 19/20 [05:33<00:18, 18.79s/it]

epoch: 19 , loss: 0.019312533588864755
train acc: 99.9375 val acc: 92.5


100%|██████████████████████████████████████████████████████████████████████████| 20/20 [05:55<00:00, 19.79s/it]


99.9375 92.5


  0%|                                                                                   | 0/20 [00:00<?, ?it/s]

epoch: 0 , loss: 10.8325803745538
train acc: 16.3125 val acc: 15.3125


  5%|███▊                                                                       | 1/20 [00:22<07:04, 22.33s/it]

epoch: 1 , loss: 10.657263843342662
train acc: 12.6875 val acc: 12.109375


 10%|███████▌                                                                   | 2/20 [00:45<06:44, 22.47s/it]

epoch: 2 , loss: 10.574763344600797
train acc: 12.5625 val acc: 12.34375


 15%|███████████▎                                                               | 3/20 [01:07<06:20, 22.40s/it]

epoch: 3 , loss: 10.510924216359854
train acc: 12.5 val acc: 12.34375


 20%|███████████████                                                            | 4/20 [01:27<05:49, 21.85s/it]

epoch: 4 , loss: 10.458101315423846
train acc: 12.6875 val acc: 12.109375


 25%|██████████████████▊                                                        | 5/20 [01:48<05:22, 21.52s/it]

epoch: 5 , loss: 10.431637533940375
train acc: 23.75 val acc: 23.125


 30%|██████████████████████▌                                                    | 6/20 [02:09<04:58, 21.31s/it]

epoch: 6 , loss: 8.162157019134611
train acc: 69.0625 val acc: 68.125


 35%|██████████████████████████▎                                                | 7/20 [02:30<04:37, 21.31s/it]

epoch: 7 , loss: 3.9789322159049334
train acc: 78.21875 val acc: 76.640625


 40%|██████████████████████████████                                             | 8/20 [02:51<04:14, 21.21s/it]

epoch: 8 , loss: 2.8961363037815318
train acc: 83.625 val acc: 82.578125


 45%|█████████████████████████████████▊                                         | 9/20 [03:12<03:52, 21.11s/it]

epoch: 9 , loss: 2.397207286063349
train acc: 88.125 val acc: 85.625


 50%|█████████████████████████████████████                                     | 10/20 [03:33<03:30, 21.04s/it]

epoch: 10 , loss: 1.961322449025829
train acc: 91.34375 val acc: 87.1875


 55%|████████████████████████████████████████▋                                 | 11/20 [03:54<03:08, 20.92s/it]

epoch: 11 , loss: 1.5234145479555536
train acc: 92.59375 val acc: 88.046875


 60%|████████████████████████████████████████████▍                             | 12/20 [04:15<02:47, 20.96s/it]

epoch: 12 , loss: 1.2733469896461855
train acc: 84.59375 val acc: 79.53125


 65%|████████████████████████████████████████████████                          | 13/20 [04:36<02:26, 20.97s/it]

epoch: 13 , loss: 1.0795340426112574
train acc: 92.125 val acc: 85.46875


 70%|███████████████████████████████████████████████████▊                      | 14/20 [04:56<02:05, 20.86s/it]

epoch: 14 , loss: 0.9357955938955413
train acc: 95.71875 val acc: 89.53125


 75%|███████████████████████████████████████████████████████▌                  | 15/20 [05:17<01:44, 20.82s/it]

epoch: 15 , loss: 0.759788554298817
train acc: 96.3125 val acc: 89.84375


 80%|███████████████████████████████████████████████████████████▏              | 16/20 [05:38<01:23, 20.75s/it]

epoch: 16 , loss: 0.748578826294116
train acc: 96.625 val acc: 89.765625


 85%|██████████████████████████████████████████████████████████████▉           | 17/20 [05:58<01:02, 20.74s/it]

epoch: 17 , loss: 0.6706141641020449
train acc: 97.34375 val acc: 90.15625


 90%|██████████████████████████████████████████████████████████████████▌       | 18/20 [06:18<00:40, 20.28s/it]

epoch: 18 , loss: 0.6131773453387277
train acc: 96.9375 val acc: 90.3125


 95%|██████████████████████████████████████████████████████████████████████▎   | 19/20 [06:37<00:20, 20.13s/it]

epoch: 19 , loss: 0.5562750710962465
train acc: 98.375 val acc: 90.703125


100%|██████████████████████████████████████████████████████████████████████████| 20/20 [06:59<00:00, 20.48s/it]


98.375 90.703125


  0%|                                                                                   | 0/20 [00:00<?, ?it/s]

epoch: 0 , loss: 10.816189860925078
train acc: 12.5625 val acc: 12.34375


  5%|███▊                                                                       | 1/20 [00:20<06:36, 20.86s/it]

epoch: 1 , loss: 10.73833079636097
train acc: 12.5625 val acc: 12.34375


 10%|███████▌                                                                   | 2/20 [00:41<06:16, 20.90s/it]

epoch: 2 , loss: 10.718037093058228
train acc: 12.6875 val acc: 12.109375


 15%|███████████▎                                                               | 3/20 [01:05<06:08, 21.70s/it]

epoch: 3 , loss: 10.711636437103152
train acc: 12.5625 val acc: 12.734375


 20%|███████████████                                                            | 4/20 [01:26<05:46, 21.63s/it]

epoch: 4 , loss: 10.7090948484838
train acc: 12.6875 val acc: 12.109375


 25%|██████████████████▊                                                        | 5/20 [01:48<05:22, 21.48s/it]

epoch: 5 , loss: 10.709710283204913
train acc: 12.6875 val acc: 12.109375


 30%|██████████████████████▌                                                    | 6/20 [02:09<04:58, 21.33s/it]

epoch: 6 , loss: 10.708941826596856
train acc: 12.6875 val acc: 12.109375


 35%|██████████████████████████▎                                                | 7/20 [02:29<04:32, 20.93s/it]

epoch: 7 , loss: 10.708839537575841
train acc: 12.6875 val acc: 12.109375


 40%|██████████████████████████████                                             | 8/20 [02:46<03:57, 19.83s/it]

epoch: 8 , loss: 10.708970563486218
train acc: 12.5 val acc: 12.34375


 45%|█████████████████████████████████▊                                         | 9/20 [03:07<03:41, 20.16s/it]

epoch: 9 , loss: 10.709294075146317
train acc: 12.46875 val acc: 12.34375


 50%|█████████████████████████████████████                                     | 10/20 [03:28<03:24, 20.42s/it]

epoch: 10 , loss: 10.708764111623168
train acc: 12.5625 val acc: 12.734375


 55%|████████████████████████████████████████▋                                 | 11/20 [03:49<03:05, 20.65s/it]

epoch: 11 , loss: 10.709291152656078
train acc: 12.6875 val acc: 12.109375


 60%|████████████████████████████████████████████▍                             | 12/20 [04:10<02:45, 20.70s/it]

epoch: 12 , loss: 10.70929560996592
train acc: 12.5625 val acc: 12.65625


 65%|████████████████████████████████████████████████                          | 13/20 [04:28<02:18, 19.85s/it]

epoch: 13 , loss: 10.709473349153996
train acc: 12.6875 val acc: 12.109375


 70%|███████████████████████████████████████████████████▊                      | 14/20 [04:49<02:01, 20.33s/it]

epoch: 14 , loss: 10.709057139232755
train acc: 12.5625 val acc: 12.65625


 75%|███████████████████████████████████████████████████████▌                  | 15/20 [05:11<01:44, 20.83s/it]

epoch: 15 , loss: 10.709235338494182
train acc: 12.5625 val acc: 12.34375


 80%|███████████████████████████████████████████████████████████▏              | 16/20 [05:33<01:24, 21.18s/it]

epoch: 16 , loss: 10.709364784881473
train acc: 12.5625 val acc: 12.734375


 85%|██████████████████████████████████████████████████████████████▉           | 17/20 [05:55<01:04, 21.35s/it]

epoch: 17 , loss: 10.708666322752833
train acc: 12.5625 val acc: 12.34375


 90%|██████████████████████████████████████████████████████████████████▌       | 18/20 [06:16<00:42, 21.44s/it]

epoch: 18 , loss: 10.709505533799529
train acc: 12.5625 val acc: 12.34375


 95%|██████████████████████████████████████████████████████████████████████▎   | 19/20 [06:38<00:21, 21.34s/it]

epoch: 19 , loss: 10.709252683445811
train acc: 12.5625 val acc: 12.34375


100%|██████████████████████████████████████████████████████████████████████████| 20/20 [06:59<00:00, 21.34s/it]


12.5625 12.34375


In [6]:
hyper_scores

{0.0001: [(100.0, 93.125)],
 0.001: [(99.9375, 92.5)],
 0.01: [(98.375, 90.703125)],
 0.1: [(12.5625, 12.34375)]}

In [7]:
def class_wise_accuracy(dataloader):
    correct = {}
    total = {}
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            for l in range(len(labels)):
                label = int(labels[l].cpu().numpy())
                pred = predicted[l]
                if label in correct:
                    correct[label] += int(pred == label)
                    total[label] += 1
                else:
                    correct[label] = int(pred == label)
                    total[label] = 1
    acc = {}
    for label in correct:
        acc[label] = 100 * correct[label] / total[label]
    return acc

In [8]:
class_wise_accuracy(val_loader)

{8: 0.0, 2: 0.0, 6: 0.0, 5: 0.0, 4: 0.0, 7: 0.0, 3: 0.0, 1: 100.0}