In [1]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda')
batch_size = 100

In [3]:
train_loader = torch.utils.data.DataLoader(datasets.FashionMNIST('data-fashion', train=True, 
                                                          download=True, 
                                                          transform=transforms.ToTensor()), 
                                                       batch_size=batch_size, shuffle=True)

In [4]:
test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST('data-fashion', train=False, 
                                                          download=True, 
                                                          transform=transforms.ToTensor()), 
                                                       batch_size=batch_size, shuffle=True)

In [5]:
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # Hout = 1 + (Hin+2×padding[0]−dilation[0]×(kernel_size[0]−1)−1)/stride[0]
        self.conv = nn.Sequential(
            # shape = (batch_size, 1, 28, 28)
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            # shape = (batch_size, 16, 24, 24)
            nn.AvgPool2d(kernel_size=2),
            # shape = (batch_size, 16, 12, 12)
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            # shape = (batch_size, 32, 8, 8),
            nn.AvgPool2d(kernel_size=2),
            # shape = (batch_size, 32, 4, 4)
            nn.Dropout(p=0.2)
        )
        self.fc = nn.Sequential(
            # (32, 4, 4) -> (512) -> (10)
            nn.Linear(32* 4 * 4, 256),
            nn.Linear(256, 10),
            nn.LogSoftmax(dim=-1)
        )
        
    def forward(self, x):

        out = self.conv(x)
        
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out
        

In [7]:
def train(model, train_loader, optimizer):
    loss_f = torch.nn.CrossEntropyLoss()
    model.train()
    tot_loss = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_f(output, target)
        loss.backward()
        optimizer.step()
        tot_loss += loss
    print('loss', batch_size * tot_loss.item() / len(train_loader.dataset))

In [8]:
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += data.size(0)
            correct += (predicted == target).sum().item()
    print('accuracy', 100 * correct / total )

In [9]:
model = Model().to(device)

learning_rate = 0.001
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

In [10]:
for epoch in range(30):
    train(model, train_loader, optimizer)
    test(model, test_loader)

loss 0.48859232584635415
accuracy 85.51
loss 0.36134404500325523
accuracy 86.52
loss 0.3276580556233724
accuracy 88.91
loss 0.3037409210205078
accuracy 85.26
loss 0.2882270304361979
accuracy 89.12
loss 0.2768098704020182
accuracy 89.41
loss 0.2681060791015625
accuracy 89.64
loss 0.2606243133544922
accuracy 89.39
loss 0.25417762756347656
accuracy 88.47
loss 0.24833569844563802
accuracy 90.3
loss 0.24294230143229167
accuracy 89.77
loss 0.23860209147135417
accuracy 88.33
loss 0.23364082336425782
accuracy 90.08
loss 0.22904820760091146
accuracy 90.49
loss 0.22564737955729167
accuracy 90.41
loss 0.22026178995768228
accuracy 90.17
loss 0.21675916035970053
accuracy 90.84
loss 0.21449905395507812
accuracy 89.03
loss 0.21254071553548176
accuracy 90.08
loss 0.20710435231526692
accuracy 90.1
loss 0.20669217427571615
accuracy 90.53
loss 0.2010430908203125
accuracy 90.84
loss 0.20242870330810547
accuracy 91.01
loss 0.2007204818725586
accuracy 90.91
loss 0.1983448537190755
accuracy 89.97
loss 0.1946