In [1]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda')
batch_size = 10

In [3]:
train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('data-cifar', train=True, 
                                                          download=True, 
                                                          transform=transforms.ToTensor()), 
                                                       batch_size=batch_size, shuffle=True)

Files already downloaded and verified


In [4]:
test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('data-cifar', train=False, 
                                                          download=True, 
                                                          transform=transforms.ToTensor()), 
                                                       batch_size=batch_size, shuffle=True)

Files already downloaded and verified


In [5]:
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.initConv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.conv01 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.conv02 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.conv11 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.conv12 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.transition0 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=1)
        )
        self.transition1 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=1)
        )
        self.transition = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2),
            nn.Dropout(p=0.2)
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * 6 * 6, 256),
            nn.Linear(256, 10),
            nn.LogSoftmax(dim=-1)
        )
        
    def forward(self, x):
        xc = self.initConv(x)
        
        conv01 = self.conv01(xc)
        conv02 = self.conv02(conv01)
        out0 = self.transition0(torch.cat([conv01, conv02], 1))
        
        conv11 = self.conv11(xc)
        conv12 = self.conv12(conv11)
        out1 = self.transition1(torch.cat([conv11, conv12], 1))
        
        out = self.transition(torch.cat([out0, out1], 1))
        
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out
        

In [7]:
def train(model, train_loader, optimizer):
    loss_f = torch.nn.CrossEntropyLoss()
    model.train()
    tot_loss = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_f(output, target)
        loss.backward()
        optimizer.step()
        tot_loss += loss
    print('loss', batch_size * tot_loss.item() / len(train_loader.dataset))

In [8]:
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += data.size(0)
            correct += (predicted == target).sum().item()
    print('accuracy', 100 * correct / total )

In [9]:
model = Model().to(device)

learning_rate = 0.001
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

In [10]:
for epoch in range(30):
    train(model, train_loader, optimizer)
    test(model, test_loader)

loss 1.48113076171875
accuracy 62.17
loss 1.008612109375
accuracy 71.2
loss 0.82045458984375
accuracy 74.88
loss 0.703790478515625
accuracy 77.18
loss 0.620142529296875
accuracy 77.73
loss 0.551052001953125
accuracy 78.33
loss 0.498493212890625
accuracy 80.52
loss 0.4496220703125
accuracy 82.19
loss 0.407627001953125
accuracy 82.81
loss 0.3702484130859375
accuracy 83.35
loss 0.339906591796875
accuracy 78.92
loss 0.3087789306640625
accuracy 82.92
loss 0.2881062744140625
accuracy 83.27
loss 0.26432919921875
accuracy 82.23
loss 0.246428759765625
accuracy 83.94
loss 0.229862158203125
accuracy 83.62
loss 0.211873876953125
accuracy 83.43
loss 0.20017645263671874
accuracy 83.94
loss 0.1875206298828125
accuracy 83.27
loss 0.1827477294921875
accuracy 84.31
loss 0.16637484130859376
accuracy 82.75
loss 0.1593916259765625
accuracy 83.89
loss 0.15486883544921876
accuracy 84.32
loss 0.14872449951171876
accuracy 83.88
loss 0.1409282470703125
accuracy 84.86


KeyboardInterrupt: 