In [1]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda')
batch_size = 100

In [3]:
train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('data-cifar', train=True, 
                                                          download=True, 
                                                          transform=transforms.ToTensor()), 
                                                       batch_size=batch_size, shuffle=True)

Files already downloaded and verified


In [4]:
test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('data-cifar', train=False, 
                                                          download=True, 
                                                          transform=transforms.ToTensor()), 
                                                       batch_size=batch_size, shuffle=True)

Files already downloaded and verified


In [5]:
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # Hout = 1 + (Hin+2×padding[0]−dilation[0]×(kernel_size[0]−1)−1)/stride[0]
        self.conv = nn.Sequential(
            # shape = (batch_size, 1, 32, 32)
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            # shape = (batch_size, 16, 28, 28)
            nn.AvgPool2d(kernel_size=2),
            # shape = (batch_size, 16, 14, 14)
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            # shape = (batch_size, 32, 10, 10),
            nn.AvgPool2d(kernel_size=2),
            # shape = (batch_size, 32, 5, 5)
            nn.Dropout(p=0.2)
        )
        self.fc = nn.Sequential(
            # (32, 5, 5) -> (256) -> (10)
            nn.Linear(32 * 5 * 5, 256),
            nn.Linear(256, 10),
            nn.LogSoftmax(dim=-1)
        )
        
    def forward(self, x):

        out = self.conv(x)
        
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out
        

In [7]:
def train(model, train_loader, optimizer):
    loss_f = torch.nn.CrossEntropyLoss()
    model.train()
    tot_loss = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_f(output, target)
        loss.backward()
        optimizer.step()
        tot_loss += loss
    print('loss', batch_size * tot_loss.item() / len(train_loader.dataset))

In [8]:
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += data.size(0)
            correct += (predicted == target).sum().item()
    print('accuracy', 100 * correct / total )

In [9]:
model = Model().to(device)

learning_rate = 0.01
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)

In [10]:
for epoch in range(45):
    scheduler.step()
    train(model, train_loader, optimizer)
    test(model, test_loader)

loss 2.686644775390625
accuracy 38.56
loss 1.67875
accuracy 36.85
loss 1.597802490234375
accuracy 40.7
loss 1.520124267578125
accuracy 47.17
loss 1.47160595703125
accuracy 48.73
loss 1.41580615234375
accuracy 49.4
loss 1.3592952880859375
accuracy 50.84
loss 1.3336580810546874
accuracy 45.72
loss 1.306218505859375
accuracy 40.72
loss 1.28177734375
accuracy 54.32
loss 1.268618408203125
accuracy 45.21
loss 1.2558446044921876
accuracy 46.36
loss 1.2445733642578125
accuracy 57.49
loss 1.227628662109375
accuracy 58.14
loss 1.2217701416015625
accuracy 56.04
loss 1.0965325927734375
accuracy 62.93
loss 1.080784912109375
accuracy 62.89
loss 1.077850341796875
accuracy 63.24
loss 1.07891064453125
accuracy 63.27
loss 1.076147705078125
accuracy 63.19
loss 1.07089990234375
accuracy 63.05
loss 1.0667532958984376
accuracy 63.49
loss 1.0637393798828125
accuracy 63.73
loss 1.0683968505859376
accuracy 63.18
loss 1.0628843994140624
accuracy 63.57
loss 1.0614315185546874
accuracy 63.29
loss 1.06230163574218

KeyboardInterrupt: 