In [96]:
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import torch

In [97]:
xy_train = FashionMNIST('', train = True, transform = transforms.ToTensor())
xy_test = FashionMNIST('', train = False, transform = transforms.ToTensor())

In [98]:
loadertr=torch.utils.data.DataLoader(xy_train,batch_size=60,shuffle=True)
loadertest=torch.utils.data.DataLoader(xy_test,batch_size=10,shuffle=False) 

In [99]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

In [100]:
def train(model, device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        
        data, target = torch.reshape(data,(data.shape[0],784)).to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

In [121]:
def test(model, device, test_loader, oldparams, oldloss):
    model.eval()
    test_loss = 0
    correct = 0
    correct_list = [0 for x in range(10)]
    with torch.no_grad():
        for data, target in test_loader:
            data, target = torch.reshape(data,(data.shape[0],784)).to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            for t, p in zip(target.view(-1), pred.view(-1)):
                if t.data==p.data:
                    correct_list[t] += 1 
    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    print (correct_list)
    if test_loss < oldloss:
        return model.state_dict(), test_loss
    else:
        return oldparams, oldloss

In [122]:
epochs = 12
learningrate = 0.01
device = torch.device('cpu')

model = Net()
optimizer=torch.optim.SGD(model.parameters(),lr=learningrate, momentum=0.0, weight_decay=0)
oldparams = model.state_dict()
oldloss = float('inf')

In [123]:
for x in range(epochs):
    train(model, device, loadertr, optimizer)
    oldparams, oldloss = test(model, device, loadertest, oldparams, oldloss)


Test set: Average loss: 0.8157, Accuracy: 6921/10000 (69%)

[779, 915, 578, 812, 653, 322, 133, 878, 922, 929]

Test set: Average loss: 0.6607, Accuracy: 7558/10000 (76%)

[832, 934, 761, 824, 550, 798, 133, 842, 935, 949]

Test set: Average loss: 0.5732, Accuracy: 7967/10000 (80%)

[825, 949, 694, 815, 814, 866, 280, 857, 925, 942]

Test set: Average loss: 0.5348, Accuracy: 8105/10000 (81%)

[798, 944, 736, 874, 784, 891, 321, 882, 944, 931]

Test set: Average loss: 0.5161, Accuracy: 8169/10000 (82%)

[858, 947, 668, 865, 833, 887, 351, 872, 936, 952]

Test set: Average loss: 0.4971, Accuracy: 8193/10000 (82%)

[858, 945, 633, 854, 666, 906, 563, 909, 934, 925]

Test set: Average loss: 0.4967, Accuracy: 8162/10000 (82%)

[670, 945, 670, 876, 636, 921, 691, 862, 938, 953]

Test set: Average loss: 0.4674, Accuracy: 8354/10000 (84%)

[767, 960, 763, 853, 755, 915, 547, 915, 943, 936]

Test set: Average loss: 0.4640, Accuracy: 8375/10000 (84%)

[855, 949, 793, 879, 726, 947, 455, 909, 93

In [104]:
oldloss

0.4335938627243042

In [105]:
torch.save(model.state_dict(), 'model_state.json')

In [106]:
model = Net()
model.load_state_dict(torch.load('model_state.txt'))
model.eval()
oldparams, oldloss = test(model, device, loadertest, oldparams, oldloss)


Test set: Average loss: 0.4622, Accuracy: 8365/10000 (84%)

