# Классификация на FashionMNIST

In [1]:
import torch
import numpy as np
import torchvision as tv
import time
import torch.nn as nn

In [2]:
BATCH_SIZE=256

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
train_dataset = tv.datasets.FashionMNIST('.', train=True, transform=tv.transforms.ToTensor(), download=True)
test_dataset = tv.datasets.FashionMNIST('.', train=False, transform=tv.transforms.ToTensor(), download=True)
train = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE)
test = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)
train_dataset.classes

['T-shirt/top',
 'Trouser',
 'Pullover',
 'Dress',
 'Coat',
 'Sandal',
 'Shirt',
 'Sneaker',
 'Bag',
 'Ankle boot']

In [5]:
train_dataset[0][0].shape

torch.Size([1, 28, 28])

In [6]:
model = nn.Sequential(nn.BatchNorm2d(1),
                      nn.ReLU(),
                      nn.Flatten(),
                      nn.Linear(in_features=28*28, out_features=2304),
                      nn.BatchNorm1d(2304), nn.ReLU(),
                      nn.Linear(in_features=2304, out_features=120),
                      nn.BatchNorm1d(120),nn.Dropout(), nn.ReLU(),
                      nn.Linear(in_features=120, out_features=84),
                      nn.BatchNorm1d(84), nn.Dropout(.7), 
                      nn.Linear(in_features=84, out_features=10)
                      )
model = model.to(device)

In [7]:
loss_f = nn.CrossEntropyLoss()
trainer_1 = torch.optim.AdamW(model.parameters(), lr=.01)
trainer_2 = torch.optim.SGD(model.parameters(), lr=.001)
num_epochs = 15

In [8]:
def training(X, y, model, dict, trainer=False):
    X, y = X.to(device), y.to(device)
    if trainer:
        trainer.zero_grad()
    predictions = model(X)
    loss = loss_f(predictions, y)
    loss.backward()
    if trainer:
        trainer.step()
    dict = {'loss':(dict['loss']+loss.item()), 
            'tp':(dict['tp']+(predictions.argmax(dim=1) == y).sum().item()), 
            'iters':(dict['iters']+1), 
            'len':(dict['len']+len(X))}
    return dict    

In [9]:
print('|{: ^8}|{: ^9}|{: ^17}|{: ^16}|{: ^11}|{: ^16}|{: ^15}|{: ^10}|'\
      .format('Epochs','Time','Train Adam loss','Train SGD loss',
              'Test loss','Train Adam acc','Train SGD acc','Test acc'))
for epoch in range(num_epochs):
    start=time.time()
    model.train()
    first_train = {'loss':0, 'tp':0, 'iters':0, 'len':0}
    for X,y in train:
        first_train = training(X, y, model, first_train, trainer_1)
        
    second_train = {'loss':0, 'tp':0, 'iters':0, 'len':0}
    for X,y in train:
        second_train = training(X, y, model, second_train, trainer_2)
    
    model.eval()
    last_train = {'loss':0, 'tp':0, 'iters':0, 'len':0}
    for X,y in test:
        last_train = training(X, y, model, last_train, )
        
    print('|{: ^8}|{: ^9.4f}|{: ^17.4f}|{: ^16.4f}|{: ^11.4f}|{: ^16.4%}|{: ^15.4%}|{: ^10.4%}|'\
          .format(epoch+1, time.time()-start,  
                  first_train['loss']/first_train['iters'],
                  second_train['loss']/second_train['iters'], 
                  last_train['loss']/last_train['iters'],
                  first_train['tp']/first_train['len'],
                  second_train['tp']/second_train['len'],
                  last_train['tp']/last_train['len']))

| Epochs |  Time   | Train Adam loss | Train SGD loss | Test loss | Train Adam acc | Train SGD acc | Test acc |
|   1    | 17.0505 |     0.6584      |     0.4846     |  0.4220   |    76.8733%    |   83.7017%    | 84.9400% |
|   2    | 18.3635 |     0.4667      |     0.4209     |  0.3863   |    84.2917%    |   85.8450%    | 86.2900% |
|   3    | 18.5925 |     0.4126      |     0.3755     |  0.3587   |    86.0683%    |   86.8783%    | 87.1600% |
|   4    | 18.9740 |     0.3809      |     0.3461     |  0.3435   |    87.0617%    |   88.0617%    | 87.5500% |
|   5    | 18.2400 |     0.3581      |     0.3345     |  0.3416   |    87.7283%    |   88.4717%    | 87.5400% |
|   6    | 18.6515 |     0.3364      |     0.3107     |  0.3378   |    88.4717%    |   89.3617%    | 88.3100% |
|   7    | 18.6004 |     0.3202      |     0.3029     |  0.3292   |    89.0883%    |   89.4933%    | 88.3900% |
|   8    | 18.3830 |     0.3101      |     0.2969     |  0.3407   |    89.4150%    |   89.6967%    | 88.