In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

from torchvision import datasets
from torchvision import transforms

In [2]:
data_path = './'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4915, 0.4823, 0.4468), (0.2470, 0.2435, 0.2616))
])

cifar10 = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:10<00:00, 15957912.62it/s]


Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


In [3]:
len(cifar10)

50000

In [4]:
isinstance(cifar10, torch.utils.data.Dataset)

True

In [5]:
img, label = cifar10[9]

In [6]:
img.shape

torch.Size([3, 32, 32])

In [7]:
train_loader = torch.utils.data.DataLoader(cifar10, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(cifar10_val, batch_size=64, shuffle=True)

In [8]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(8*8*8, 32)

        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)
        out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)

        out = out.view(-1, 8*8*8)
        out = torch.tanh(self.fc1(out))
        out = self.fc2(out)

        return out

In [9]:
def training_loop(n_epechs, optimizer, model, loss_fn, train_loader, val_loader):
    for epoch in range(1, n_epechs + 1):
        losses = 0.0

        for imgs, labels in train_loader:

            out = model(imgs)
            loss = loss_fn(out, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses += loss.item()

        correct = 0
        with torch.no_grad():
            for data in val_loader:
                imgs, labels = data
                out = model(imgs)
                _, pred = torch.max(out, 1)
                c = (pred == labels).squeeze()
                correct += c.sum()

        print(epoch, losses/len(train_loader), correct/len(cifar10_val))

In [10]:
model = CNN()
optimizer = optim.SGD(model.parameters(), lr=3e-2)
loss_fn = nn.CrossEntropyLoss()

training_loop(
    n_epechs=30,
    optimizer=optimizer,
    model=model,
    loss_fn=loss_fn,
    train_loader=train_loader,
    val_loader=val_loader
)

1 1.7789772721507666 tensor(0.4373)
2 1.4481457519104413 tensor(0.5237)
3 1.2913902085607925 tensor(0.4778)
4 1.1990187265683927 tensor(0.5033)
5 1.1391978399527958 tensor(0.5573)
6 1.0939035244152675 tensor(0.5710)
7 1.0573693354568823 tensor(0.6021)
8 1.0257844892151826 tensor(0.5965)
9 1.000690243661861 tensor(0.5112)
10 0.9765799566912834 tensor(0.6036)
11 0.9570879386666485 tensor(0.6224)
12 0.9364162214729183 tensor(0.6337)
13 0.9250510257223378 tensor(0.5574)
14 0.9091725993491805 tensor(0.6296)
15 0.8987151977351254 tensor(0.6074)
16 0.8869173546581317 tensor(0.6406)
17 0.873153803065 tensor(0.6293)
18 0.8636263725550278 tensor(0.6374)
19 0.8554771773498077 tensor(0.6450)
20 0.8482040746894943 tensor(0.6471)
21 0.8426987602735114 tensor(0.5827)
22 0.8315050373296908 tensor(0.6232)
23 0.8239210305540153 tensor(0.6348)
24 0.8193803027538997 tensor(0.6330)
25 0.812749119022923 tensor(0.6349)
26 0.8096763923040131 tensor(0.6265)
27 0.8005234063662532 tensor(0.6411)
28 0.79405810659

In [20]:
f64 = ()
with torch.no_grad():
  for img, label in val_loader:
    out = model(img)
    f64 = (out, label)
    break

In [22]:
pred, label = f64

In [28]:
y_pred = torch.argmax(pred, 1).numpy()
y = label.numpy()

In [29]:
from sklearn.metrics import classification_report

print(classification_report(y, y_pred))

              precision    recall  f1-score   support

           0       0.38      0.50      0.43         6
           1       0.78      0.88      0.82         8
           2       0.29      0.33      0.31         6
           3       0.80      0.36      0.50        11
           4       0.50      0.67      0.57         3
           5       0.60      0.75      0.67         8
           6       0.67      0.67      0.67         3
           7       0.86      0.86      0.86         7
           8       1.00      0.50      0.67         6
           9       0.50      0.67      0.57         6

    accuracy                           0.61        64
   macro avg       0.64      0.62      0.61        64
weighted avg       0.66      0.61      0.61        64

