In [None]:
# MAIN TODO:
#     - put validation set
#     - put optuna
#     - show correct and incorrect predictions
#     - show loss functions decrease through epochs (train, test and val)


# Imports

In [16]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm

from sklearn import metrics

# Helper functions

In [10]:
def get_scores(targets, predictions):
    return {
        "accuracy"         : metrics.accuracy_score(targets, predictions),
        "balanced_accuracy": metrics.balanced_accuracy_score(targets, predictions),
        "f1_score"         : metrics.f1_score(targets, predictions, average="weighted"),
        "precision"        : metrics.precision_score(targets, predictions, average="weighted"),
        "recall"           : metrics.recall_score(targets, predictions, average="weighted"),
        "f1_score"         : metrics.f1_score(targets, predictions, average="weighted")
    }

# Setup

In [11]:
transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
                    ])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Def MLP

In [12]:
class MLP(nn.Module):
    def __init__(self, input_size, num_classes, activation_function):
        super(MLP, self).__init__()
        self.activation_function = activation_function
        self.fc_input   = nn.Linear(in_features=input_size, out_features=64)
        self.fc_hidden1 = nn.Linear(in_features=64, out_features=128)
        self.fc_hidden2 = nn.Linear(in_features=128, out_features=64)
        self.fc_output  = nn.Linear(in_features=64, out_features=num_classes)

    def forward(self, x):
        x = self.activation_function(self.fc_input(x))
        x = self.activation_function(self.fc_hidden1(x))
        x = self.activation_function(self.fc_hidden2(x))
        x = self.fc_output(x)
        return x 

# Def hyperparameters and Loaders

In [13]:
input_size          = 32*32*3
num_classes         = 10

# TODO: let this params manipulated
learning_rate       = 0.001
num_epochs          = 100
batch_size          = 16
activation_function = nn.ReLU()

loss_function       = nn.CrossEntropyLoss()

In [14]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) # to avoid bias
test_loader  = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)

# Main training Loop

In [17]:
mlp = MLP(input_size=input_size, num_classes=num_classes, activation_function=activation_function)
# mlp.cpu()
mlp.cuda()

optimizer = torch.optim.SGD(mlp.parameters(), lr=learning_rate)

best_loss = float('inf')
patience = 5
patience_counter = 0

# main loop
for epoch in tqdm(range(num_epochs)):
    epoch_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images = images.view(-1, 32*32*3).cuda()
        labels = labels.cuda()
        optimizer.zero_grad()

        # forward pass
        
        outputs = mlp.forward(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if (i+1) % 1000 == 0:
            tqdm.write(f' Epoch {epoch + 1}/{num_epochs}, Step {i+1}/{len(train_dataset) // batch_size}, Loss: {loss}')

    epoch_loss /= len(train_loader)
    tqdm.write(f'Epoch {epoch+1} average loss: {epoch_loss:.4f}')

    # early stopping 
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            tqdm.write("Early stopping based on loss reduction through epochs")
            break



  0%|          | 0/100 [00:00<?, ?it/s]

  1%|          | 1/100 [00:01<02:06,  1.28s/it]

Epoch 1 average loss: 2.2991


  2%|▏         | 2/100 [00:02<01:58,  1.21s/it]

Epoch 2 average loss: 2.2928


  3%|▎         | 3/100 [00:03<01:56,  1.20s/it]

Epoch 3 average loss: 2.2856


  4%|▍         | 4/100 [00:04<01:55,  1.20s/it]

Epoch 4 average loss: 2.2760


  5%|▌         | 5/100 [00:06<01:56,  1.22s/it]

Epoch 5 average loss: 2.2622


  6%|▌         | 6/100 [00:07<01:54,  1.21s/it]

Epoch 6 average loss: 2.2421


  7%|▋         | 7/100 [00:08<01:52,  1.21s/it]

Epoch 7 average loss: 2.2138


  8%|▊         | 8/100 [00:09<01:51,  1.21s/it]

Epoch 8 average loss: 2.1796


  9%|▉         | 9/100 [00:10<01:50,  1.21s/it]

Epoch 9 average loss: 2.1450


 10%|█         | 10/100 [00:12<01:48,  1.20s/it]

Epoch 10 average loss: 2.1114


 11%|█         | 11/100 [00:13<01:45,  1.19s/it]

Epoch 11 average loss: 2.0785


 12%|█▏        | 12/100 [00:14<01:46,  1.21s/it]

Epoch 12 average loss: 2.0474


 13%|█▎        | 13/100 [00:15<01:45,  1.21s/it]

Epoch 13 average loss: 2.0191


 14%|█▍        | 14/100 [00:16<01:44,  1.21s/it]

Epoch 14 average loss: 1.9938


 15%|█▌        | 15/100 [00:18<01:44,  1.23s/it]

Epoch 15 average loss: 1.9711


 16%|█▌        | 16/100 [00:19<01:42,  1.22s/it]

Epoch 16 average loss: 1.9509


 17%|█▋        | 17/100 [00:20<01:42,  1.23s/it]

Epoch 17 average loss: 1.9325


 18%|█▊        | 18/100 [00:21<01:40,  1.23s/it]

Epoch 18 average loss: 1.9154


 19%|█▉        | 19/100 [00:23<01:40,  1.24s/it]

Epoch 19 average loss: 1.8994


 20%|██        | 20/100 [00:24<01:38,  1.23s/it]

Epoch 20 average loss: 1.8841


 21%|██        | 21/100 [00:25<01:37,  1.23s/it]

Epoch 21 average loss: 1.8690


 22%|██▏       | 22/100 [00:26<01:35,  1.22s/it]

Epoch 22 average loss: 1.8542


 23%|██▎       | 23/100 [00:28<01:35,  1.24s/it]

Epoch 23 average loss: 1.8397


 24%|██▍       | 24/100 [00:29<01:35,  1.26s/it]

Epoch 24 average loss: 1.8252


 25%|██▌       | 25/100 [00:30<01:34,  1.26s/it]

Epoch 25 average loss: 1.8109


 26%|██▌       | 26/100 [00:31<01:33,  1.26s/it]

Epoch 26 average loss: 1.7974


 27%|██▋       | 27/100 [00:33<01:32,  1.26s/it]

Epoch 27 average loss: 1.7836


 28%|██▊       | 28/100 [00:34<01:29,  1.24s/it]

Epoch 28 average loss: 1.7707


 29%|██▉       | 29/100 [00:35<01:27,  1.23s/it]

Epoch 29 average loss: 1.7579


 30%|███       | 30/100 [00:36<01:22,  1.18s/it]

Epoch 30 average loss: 1.7452


 31%|███       | 31/100 [00:37<01:19,  1.15s/it]

Epoch 31 average loss: 1.7332


 32%|███▏      | 32/100 [00:38<01:16,  1.12s/it]

Epoch 32 average loss: 1.7215


 33%|███▎      | 33/100 [00:39<01:13,  1.10s/it]

Epoch 33 average loss: 1.7101


 34%|███▍      | 34/100 [00:40<01:11,  1.08s/it]

Epoch 34 average loss: 1.6987


 35%|███▌      | 35/100 [00:41<01:10,  1.08s/it]

Epoch 35 average loss: 1.6875


 36%|███▌      | 36/100 [00:42<01:08,  1.07s/it]

Epoch 36 average loss: 1.6767


 37%|███▋      | 37/100 [00:44<01:07,  1.07s/it]

Epoch 37 average loss: 1.6658


 38%|███▊      | 38/100 [00:45<01:05,  1.06s/it]

Epoch 38 average loss: 1.6548


 39%|███▉      | 39/100 [00:46<01:04,  1.06s/it]

Epoch 39 average loss: 1.6440


 40%|████      | 40/100 [00:47<01:03,  1.07s/it]

Epoch 40 average loss: 1.6326


 41%|████      | 41/100 [00:48<01:02,  1.06s/it]

Epoch 41 average loss: 1.6219


 42%|████▏     | 42/100 [00:49<01:01,  1.06s/it]

Epoch 42 average loss: 1.6110


 43%|████▎     | 43/100 [00:50<01:02,  1.09s/it]

Epoch 43 average loss: 1.6000


 44%|████▍     | 44/100 [00:51<01:03,  1.13s/it]

Epoch 44 average loss: 1.5892


 45%|████▌     | 45/100 [00:52<01:03,  1.15s/it]

Epoch 45 average loss: 1.5792


 46%|████▌     | 46/100 [00:54<01:02,  1.17s/it]

Epoch 46 average loss: 1.5687


 47%|████▋     | 47/100 [00:55<01:01,  1.17s/it]

Epoch 47 average loss: 1.5585


 48%|████▊     | 48/100 [00:56<01:00,  1.17s/it]

Epoch 48 average loss: 1.5488


 49%|████▉     | 49/100 [00:57<00:59,  1.17s/it]

Epoch 49 average loss: 1.5387


 50%|█████     | 50/100 [00:58<00:59,  1.18s/it]

Epoch 50 average loss: 1.5302


 51%|█████     | 51/100 [01:00<00:59,  1.21s/it]

Epoch 51 average loss: 1.5194


 52%|█████▏    | 52/100 [01:01<00:58,  1.22s/it]

Epoch 52 average loss: 1.5107


 53%|█████▎    | 53/100 [01:02<00:58,  1.24s/it]

Epoch 53 average loss: 1.5006


 54%|█████▍    | 54/100 [01:03<00:56,  1.22s/it]

Epoch 54 average loss: 1.4926


 55%|█████▌    | 55/100 [01:04<00:52,  1.17s/it]

Epoch 55 average loss: 1.4833


 56%|█████▌    | 56/100 [01:05<00:50,  1.15s/it]

Epoch 56 average loss: 1.4735


 57%|█████▋    | 57/100 [01:07<00:49,  1.14s/it]

Epoch 57 average loss: 1.4640


 58%|█████▊    | 58/100 [01:08<00:49,  1.17s/it]

Epoch 58 average loss: 1.4554


 59%|█████▉    | 59/100 [01:09<00:47,  1.15s/it]

Epoch 59 average loss: 1.4452


 60%|██████    | 60/100 [01:10<00:44,  1.12s/it]

Epoch 60 average loss: 1.4364


 61%|██████    | 61/100 [01:11<00:44,  1.15s/it]

Epoch 61 average loss: 1.4279


 62%|██████▏   | 62/100 [01:12<00:44,  1.18s/it]

Epoch 62 average loss: 1.4183


 63%|██████▎   | 63/100 [01:14<00:43,  1.18s/it]

Epoch 63 average loss: 1.4084


 64%|██████▍   | 64/100 [01:15<00:42,  1.19s/it]

Epoch 64 average loss: 1.3993


 65%|██████▌   | 65/100 [01:16<00:42,  1.20s/it]

Epoch 65 average loss: 1.3895


 66%|██████▌   | 66/100 [01:17<00:40,  1.20s/it]

Epoch 66 average loss: 1.3800


 67%|██████▋   | 67/100 [01:19<00:39,  1.20s/it]

Epoch 67 average loss: 1.3710


 68%|██████▊   | 68/100 [01:20<00:37,  1.17s/it]

Epoch 68 average loss: 1.3606


 69%|██████▉   | 69/100 [01:21<00:36,  1.18s/it]

Epoch 69 average loss: 1.3517


 70%|███████   | 70/100 [01:22<00:35,  1.19s/it]

Epoch 70 average loss: 1.3414


 71%|███████   | 71/100 [01:23<00:34,  1.19s/it]

Epoch 71 average loss: 1.3322


 72%|███████▏  | 72/100 [01:25<00:34,  1.23s/it]

Epoch 72 average loss: 1.3234


 73%|███████▎  | 73/100 [01:26<00:33,  1.23s/it]

Epoch 73 average loss: 1.3131


 74%|███████▍  | 74/100 [01:27<00:30,  1.18s/it]

Epoch 74 average loss: 1.3048


 75%|███████▌  | 75/100 [01:28<00:28,  1.14s/it]

Epoch 75 average loss: 1.2940


 76%|███████▌  | 76/100 [01:29<00:26,  1.11s/it]

Epoch 76 average loss: 1.2854


 77%|███████▋  | 77/100 [01:30<00:25,  1.13s/it]

Epoch 77 average loss: 1.2756


 78%|███████▊  | 78/100 [01:31<00:25,  1.15s/it]

Epoch 78 average loss: 1.2669


 79%|███████▉  | 79/100 [01:32<00:23,  1.12s/it]

Epoch 79 average loss: 1.2568


 80%|████████  | 80/100 [01:33<00:22,  1.10s/it]

Epoch 80 average loss: 1.2464


 81%|████████  | 81/100 [01:35<00:21,  1.13s/it]

Epoch 81 average loss: 1.2377


 82%|████████▏ | 82/100 [01:36<00:20,  1.15s/it]

Epoch 82 average loss: 1.2286


 83%|████████▎ | 83/100 [01:37<00:19,  1.17s/it]

Epoch 83 average loss: 1.2199


 84%|████████▍ | 84/100 [01:38<00:19,  1.19s/it]

Epoch 84 average loss: 1.2093


 85%|████████▌ | 85/100 [01:39<00:17,  1.19s/it]

Epoch 85 average loss: 1.2004


 86%|████████▌ | 86/100 [01:41<00:16,  1.18s/it]

Epoch 86 average loss: 1.1909


 87%|████████▋ | 87/100 [01:42<00:14,  1.14s/it]

Epoch 87 average loss: 1.1799


 88%|████████▊ | 88/100 [01:43<00:13,  1.15s/it]

Epoch 88 average loss: 1.1705


 89%|████████▉ | 89/100 [01:44<00:12,  1.16s/it]

Epoch 89 average loss: 1.1610


 90%|█████████ | 90/100 [01:45<00:11,  1.19s/it]

Epoch 90 average loss: 1.1511


 91%|█████████ | 91/100 [01:46<00:10,  1.17s/it]

Epoch 91 average loss: 1.1411


 92%|█████████▏| 92/100 [01:47<00:09,  1.13s/it]

Epoch 92 average loss: 1.1322


 93%|█████████▎| 93/100 [01:48<00:07,  1.11s/it]

Epoch 93 average loss: 1.1210


 94%|█████████▍| 94/100 [01:50<00:06,  1.09s/it]

Epoch 94 average loss: 1.1126


 95%|█████████▌| 95/100 [01:51<00:05,  1.10s/it]

Epoch 95 average loss: 1.1026


 96%|█████████▌| 96/100 [01:52<00:04,  1.08s/it]

Epoch 96 average loss: 1.0909


 97%|█████████▋| 97/100 [01:53<00:03,  1.07s/it]

Epoch 97 average loss: 1.0825


 98%|█████████▊| 98/100 [01:54<00:02,  1.06s/it]

Epoch 98 average loss: 1.0709


 99%|█████████▉| 99/100 [01:55<00:01,  1.07s/it]

Epoch 99 average loss: 1.0610


100%|██████████| 100/100 [01:56<00:00,  1.16s/it]

Epoch 100 average loss: 1.0512





In [19]:
mlp.eval()
predictions = []
labels = []

# running in test set
for images, label in test_loader:
    images = images.view(-1, 32*32*3).cuda()
    label = label.cuda()

    output = mlp.forward(images)
    _, predicted = torch.max(output, 1)

    predictions.extend(predicted.cpu().numpy())
    labels.extend(label.cpu().numpy())

scores = get_scores(labels, predictions)
print(f"Model scores: \n{scores}")



Model scores: 
{'accuracy': 0.6439, 'balanced_accuracy': np.float64(0.6438999999999999), 'f1_score': 0.6410596187680893, 'precision': 0.6434856216467476, 'recall': 0.6439}
