In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

from torchvision import datasets
from torchvision import transforms

In [3]:
data_path = './'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4915, 0.4823, 0.4468), (0.2470, 0.2435, 0.2616))
])

cifar10 = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
len(cifar10)

50000

In [5]:
isinstance(cifar10, torch.utils.data.Dataset)

True

In [7]:
img, label = cifar10[9]

In [9]:
img.shape

torch.Size([3, 32, 32])

In [6]:
train_loader = torch.utils.data.DataLoader(cifar10, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(cifar10_val, batch_size=64, shuffle=True)

In [10]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(8*8*8, 32)

        self.fc2 = nn.Linear(32, 10)
    
    def forward(self, x):
        out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)
        out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)

        out = out.view(-1, 8*8*8)
        out = torch.tanh(self.fc1(out))
        out = self.fc2(out)

        return out

In [13]:
def training_loop(n_epechs, optimizer, model, loss_fn, train_loader, val_loader):
    for epoch in range(1, n_epechs + 1):
        losses = 0.0

        for imgs, labels in train_loader:

            out = model(imgs)
            loss = loss_fn(out, labels)
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            losses += loss.item()
        
        correct = 0
        with torch.no_grad():
            for data in val_loader:
                imgs, labels = data
                out = model(imgs)
                _, pred = torch.max(out, 1)
                c = (pred == labels).squeeze()
                correct += c.sum()
        
        print(epoch, losses/len(train_loader), correct/len(val_loader))

In [14]:
model = CNN()
optimizer = optim.SGD(model.parameters(), lr=3e-2)
loss_fn = nn.CrossEntropyLoss()

training_loop(
    n_epechs=30,
    optimizer=optimizer,
    model=model,
    loss_fn=loss_fn,
    train_loader=train_loader,
    val_loader=val_loader
)

1 1.8066132333882325 tensor(27.8217)
2 1.4436177260735457 tensor(31.8153)
3 1.2920267290014136 tensor(33.1720)
4 1.2032420367688474 tensor(36.0955)
5 1.1479189411910904 tensor(35.6815)
6 1.106378240048733 tensor(35.9936)
7 1.0710688121331013 tensor(38.3822)
8 1.0410017571638308 tensor(36.9236)
9 1.0139948713505054 tensor(38.7134)
10 0.9936721373701949 tensor(37.0637)
11 0.9729034135408718 tensor(38.0701)
12 0.9565913354039497 tensor(40.2675)
13 0.9403248587837609 tensor(40.5032)


KeyboardInterrupt: 