In [1]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda')
batch_size = 10

In [3]:
train_loader = torch.utils.data.DataLoader(datasets.FashionMNIST('data-fashion', train=True, 
                                                          download=True, 
                                                          transform=transforms.ToTensor()), 
                                                       batch_size=batch_size, shuffle=True)

In [4]:
test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST('data-fashion', train=False, 
                                                          download=True, 
                                                          transform=transforms.ToTensor()), 
                                                       batch_size=batch_size, shuffle=True)

In [5]:
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.initConv = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.conv01 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.conv02 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.conv11 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.conv12 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.transition0 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=1)
        )
        self.transition1 = nn.Sequential(
            # shape = (batch_size, 1, 28, 28)
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=1)
        )
        self.transition = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2),
            nn.Dropout(p=0.2)
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * 4 * 4, 256),
            nn.Linear(256, 10),
            nn.LogSoftmax(dim=-1)
        )
        
    def forward(self, x):
        xc = self.initConv(x)
        
        conv01 = self.conv01(xc)
        conv02 = self.conv02(conv01)
        out0 = self.transition0(torch.cat([conv01, conv02], 1))
        
        conv11 = self.conv11(xc)
        conv12 = self.conv12(conv11)
        out1 = self.transition1(torch.cat([conv11, conv12], 1))
        
        out = self.transition(torch.cat([out0, out1], 1))
        
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out
        

In [7]:
def train(model, train_loader, optimizer):
    loss_f = torch.nn.CrossEntropyLoss()
    model.train()
    tot_loss = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_f(output, target)
        loss.backward()
        optimizer.step()
        tot_loss += loss
    print('loss', batch_size * tot_loss.item() / len(train_loader.dataset))

In [8]:
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += data.size(0)
            correct += (predicted == target).sum().item()
    print('accuracy', 100 * correct / total )

In [9]:
model = Model().to(device)

learning_rate = 0.001
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

In [10]:
for epoch in range(30):
    train(model, train_loader, optimizer)
    test(model, test_loader)

loss 0.478175048828125
accuracy 87.7
loss 0.30797979736328124
accuracy 89.87
loss 0.26075779215494793
accuracy 90.85
loss 0.23486002604166667
accuracy 90.75
loss 0.21347941080729166
accuracy 91.43
loss 0.19481313069661457
accuracy 91.38
loss 0.17891276041666668
accuracy 92.11
loss 0.16549793497721355
accuracy 92.3
loss 0.15044680786132814
accuracy 92.65
loss 0.13979664103190104
accuracy 92.6
loss 0.12978335571289062
accuracy 92.63
loss 0.11841788736979167
accuracy 92.41
loss 0.10996072387695313
accuracy 92.88
loss 0.10048176066080729
accuracy 92.87
loss 0.09331232706705729
accuracy 93.0


KeyboardInterrupt: 