In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter
import copy

In [42]:
train_set = torchvision.datasets.FashionMNIST(root='./data/FashionMNIST', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test_set = torchvision.datasets.FashionMNIST(root='./data/FashionMNIST', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))

Dataset FashionMNIST
    Number of datapoints: 60000
    Root location: ./data/FashionMNIST
    Split: Train

In [45]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=10)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1)

data_loader = {
    'train': train_loader,
    'val': test_loader
}

data_size = {
    'train': len(train_set),
    'val': len(test_set)
}

In [46]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 =nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
    def forward(self, t):
        t = F.max_pool2d(F.relu(self.conv1(t)), kernel_size=2)
        t = F.max_pool2d(F.relu(self.conv2(t)), kernel_size=2)
        t = t.reshape(-1, 12*4*4)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        #t = F.softmax(t, dim=1) # Loss function we are going to use performs the softmax function
        
        return t

In [51]:
network = Network()
lr = 0.01
epochs = 10
optimizer = optim.Adam(network.parameters(), lr=lr)
best_acc = 0.0
best_model_weights = 0.0

In [52]:
tb = SummaryWriter(comment="test_network")

In [None]:
for epoch in range(epochs):
    for phase in ['train', 'val']:
        if phase == 'train':
            network.train()
        else:
            network.eval()
        
        running_loss = 0
        num_correct = 0
        
        for batch in data_loader[phase]:
            images, labels = batch
            
            with torch.set_grad_enabled(phase == 'train'):
                predictions = network(images)
                loss = F.cross_entropy(predictions, labels)
                optimizer.zero_grad()
                
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                
            running_loss += loss.item()
            num_correct += predictions.argmax(dim=1).eq(labels).sum().item()
        
        accuracy = (num_correct / data_size[phase]) * 100
        tb.add_scalar(phase + " loss", running_loss, epoch)
        tb.add_scalar(phase + " acc", accuracy, epoch)
        
        if phase == 'val' and accuracy > best_acc:
            best_acc = accuracy
            best_model_weights = copy.deepcopy(network.state_dict())
        
        print("Epoch:", epoch, phase, "loss:", running_loss, "Number correct:", num_correct, "Accuracy:", accuracy, "%")

trained_model = network.load_state_dict(best_model_weights)

Epoch: 0 train loss: 3909.393696174724 Number correct: 45180 Accuracy: 75.3 %
Epoch: 0 val loss: 5424.159754067659 Number correct: 8050 Accuracy: 80.5 %
Epoch: 1 train loss: 3200.9323345459998 Number correct: 48580 Accuracy: 80.96666666666667 %
Epoch: 1 val loss: 5160.5871496498585 Number correct: 8198 Accuracy: 81.98 %
Epoch: 2 train loss: 3149.088037339039 Number correct: 48935 Accuracy: 81.55833333333334 %
Epoch: 2 val loss: 5289.056150257587 Number correct: 8083 Accuracy: 80.83 %
Epoch: 3 train loss: 3121.4850941256154 Number correct: 49011 Accuracy: 81.685 %
Epoch: 3 val loss: 5959.346417343244 Number correct: 8063 Accuracy: 80.63 %
Epoch: 4 train loss: 3116.6658673509955 Number correct: 48914 Accuracy: 81.52333333333334 %
Epoch: 4 val loss: 5410.351660350338 Number correct: 8241 Accuracy: 82.41000000000001 %
