### params to tune
- weight decay
- dropout rate
- learning rate
- batch size
- loss function

In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm
import numpy as np
import sklearn.metrics as metrics
import matplotlib.pyplot as plt

# MNIST

In [31]:
train = datasets.MNIST("./347data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test = datasets.MNIST("./347data", train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))
validation_set_size = int(len(train) * 0.1)
training_set_size = len(train) - validation_set_size
train_set, validation_set = torch.utils.data.random_split(train, [training_set_size, validation_set_size])
train_set = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)
validation_set = torch.utils.data.DataLoader(validation_set, batch_size=10, shuffle=True)
test_set = torch.utils.data.DataLoader(test, batch_size=10, shuffle=True)

In [16]:
class MNIST(nn.Module):
    def __init__(self):
        super().__init__()
        self.dropout = nn.Dropout(0.2)
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        x = torch.randn(29, 29).view(-1, 1, 29, 29)
        self.to_linear = None
        self.convs(x)
        self.fc1 = nn.Linear(self.to_linear, 512)
        self.fc2 = nn.Linear(512, 10)

    def convs(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))
        if self.to_linear is None:
            self.to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
        return x

    def forward(self, x):
        x = F.pad(x, (0, 0, 0, 0))
        x = self.convs(x)
        x = x.view(-1, self.to_linear)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [17]:
MNIST_net = MNIST().cuda()
optimizer = optim.Adam(MNIST_net.parameters(), lr=0.001)
EPOCHS = 3
for batch in range(EPOCHS):
    for data in tqdm(train_set):
        X, y = data
        MNIST_net.zero_grad()
        output = MNIST_net(X.cuda())
        loss = F.nll_loss(output, y.cuda())
        loss.backward()
        optimizer.step()
    print(loss)

100%|██████████| 6000/6000 [00:25<00:00, 239.56it/s]
  0%|          | 23/6000 [00:00<00:26, 224.18it/s]

tensor(0.0334, device='cuda:0', grad_fn=<NllLossBackward0>)


100%|██████████| 6000/6000 [00:25<00:00, 234.80it/s]
  0%|          | 24/6000 [00:00<00:25, 235.92it/s]

tensor(0.0045, device='cuda:0', grad_fn=<NllLossBackward0>)


100%|██████████| 6000/6000 [00:24<00:00, 240.25it/s]

tensor(0.0085, device='cuda:0', grad_fn=<NllLossBackward0>)





In [39]:
output = []
true = []
MNIST_net.eval()
with torch.no_grad():
    for data in validation_set:
        X, y = data
        for i in MNIST_net(X.cuda()):
            output.append(torch.argmax(i).cpu())
        for i in y:
            true.append(i)
MNIST_net.train()
print("Validation Accuracy:", metrics.accuracy_score(true, output))
print("Validation F1 Score:", metrics.f1_score(true, output, average="macro"))
# make one hot encoding
true = np.eye(10)[true]
output = np.eye(10)[output]
print("Validation AUC Score:", metrics.roc_auc_score(true, output, multi_class="ovo", average="macro"))
            

Validation Accuracy: 0.988
Validation F1 Score: 0.987824552068956
Validation AUC Score: 0.9932404031201022


# CIFAR-10

In [40]:
train = datasets.CIFAR10("./347data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test = datasets.CIFAR10("./347data", train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))
train_set = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)
test_set = torch.utils.data.DataLoader(test, batch_size=10, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [76]:
class CIFAR10(nn.Module):
    def __init__(self):
        super().__init__()
        self.dropout = nn.Dropout(0.5)
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        x = torch.randn(3,32, 32).view(-1, 3, 32, 32)
        self.to_linear = None
        self.convs(x)
        self.fc1 = nn.Linear(self.to_linear, 512)
        self.fc2 = nn.Linear(512, 10)

    def convs(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))
        if self.to_linear is None:
            self.to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
        return x
    
    def forward(self, x):
        x = F.pad(x, (0, 0, 0, 0))
        x = self.convs(x)
        x = x.view(-1, self.to_linear)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [75]:
CIFAR10_net = CIFAR10().cuda()
optimizer = optim.Adam(CIFAR10_net.parameters(), lr=0.001)
EPOCHS = 3
for batch in range(EPOCHS):
    for data in tqdm(train_set):
        X, y = data
        CIFAR10_net.zero_grad()
        output = CIFAR10_net(X.cuda())
        loss = F.nll_loss(output, y.cuda())
        loss.backward()
        optimizer.step()
    print(loss)

  0%|          | 18/5000 [00:00<00:28, 176.61it/s]

convs torch.Size([1, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([1, 128, 2, 2])
to_linear 512
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([

  1%|          | 53/5000 [00:00<00:28, 171.87it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

  2%|▏         | 84/5000 [00:00<00:30, 159.85it/s]

post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end

  2%|▏         | 115/5000 [00:00<00:31, 153.32it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

  3%|▎         | 145/5000 [00:00<00:32, 149.15it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

  4%|▎         | 175/5000 [00:01<00:33, 144.64it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

  4%|▍         | 205/5000 [00:01<00:33, 144.46it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

  5%|▍         | 237/5000 [00:01<00:32, 148.54it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

  5%|▌         | 271/5000 [00:01<00:30, 154.57it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

  6%|▌         | 303/5000 [00:01<00:30, 156.20it/s]

pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32,

  7%|▋         | 335/5000 [00:02<00:30, 155.04it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

  7%|▋         | 367/5000 [00:02<00:29, 156.26it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

  8%|▊         | 400/5000 [00:02<00:29, 158.25it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

  9%|▊         | 433/5000 [00:02<00:28, 160.17it/s]

post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end

  9%|▉         | 467/5000 [00:03<00:28, 161.39it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

 10%|▉         | 484/5000 [00:03<00:28, 161.04it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

 10%|█         | 518/5000 [00:03<00:27, 161.82it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

 11%|█         | 552/5000 [00:03<00:27, 159.66it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

 12%|█▏        | 590/5000 [00:03<00:25, 171.59it/s]

pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32,

 13%|█▎        | 635/5000 [00:03<00:22, 193.17it/s]

post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end

 14%|█▎        | 681/5000 [00:04<00:20, 208.82it/s]

pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32,

 14%|█▍        | 725/5000 [00:04<00:20, 207.85it/s]

torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

 15%|█▌        | 766/5000 [00:04<00:23, 178.16it/s]

post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end

 15%|█▌        | 774/5000 [00:04<00:25, 165.45it/s]


torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pre pad torch.Size([10, 3, 32, 32])
post pad torch.Size([10, 3, 32, 32])
convs torch.Size([10, 3, 32, 32])
post 1
post 2
post 3
convs end torch.Size([10, 128, 2, 2])
torch.Size([10, 3, 32, 32])
pr

KeyboardInterrupt: 

In [None]:
output = []
true = []
CIFAR10_net.eval()
with torch.no_grad():
    for data in validation_set:
        X, y = data
        for i in CIFAR10_net(X.cuda()):
            output.append(torch.argmax(i).cpu())
        for i in y:
            true.append(i)
CIFAR10_net.train()
print("Validation Accuracy:", metrics.accuracy_score(true, output))
print("Validation F1 Score:", metrics.f1_score(true, output, average="macro"))
# make one hot encoding
true = np.eye(10)[true]
output = np.eye(10)[output]
print("Validation AUC Score:", metrics.roc_auc_score(true, output, multi_class="ovo", average="macro"))
            