Задание

Постройте модель на основе полносвязных слоёв для классификации Fashion MNIST из библиотеки torchvision (datasets).
Получите качество на тестовой выборке не ниже 88%


In [None]:
import torch
import torchvision as tv
import time

In [None]:
train_ds = tv.datasets.FashionMNIST('.',
                                    train=True,
                                    transform=tv.transforms.ToTensor(),
                                    download=True)

test_ds = tv.datasets.FashionMNIST('.',
                                  train=False,
                                  transform=tv.transforms.ToTensor(),
                                  download=True)

In [None]:
batch_size = 128

In [None]:
train = torch.utils.data.DataLoader(train_ds, batch_size=batch_size)
test = torch.utils.data.DataLoader(train_ds, batch_size=batch_size)

In [None]:
model = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(train_ds[0][0].shape[1] * train_ds[0][0].shape[2], 1024),
    torch.nn.BatchNorm1d(1024),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(1024, 512),
    torch.nn.BatchNorm1d(512),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(512, 256),
    torch.nn.BatchNorm1d(256),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(256, 10)
)

In [None]:
model

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=1024, bias=True)
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): ReLU()
  (4): Dropout(p=0.2, inplace=False)
  (5): Linear(in_features=1024, out_features=512, bias=True)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): ReLU()
  (8): Dropout(p=0.2, inplace=False)
  (9): Linear(in_features=512, out_features=256, bias=True)
  (10): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): ReLU()
  (12): Dropout(p=0.2, inplace=False)
  (13): Linear(in_features=256, out_features=10, bias=True)
)

In [None]:
loss = torch.nn.CrossEntropyLoss()
trainer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train_model(epochs, batch_size, model, trainer, loss, train_ds, test_ds):

  for ep in range(epochs):

    train_iter, train_pass = 0, 0
    train_loss, train_acc = 0., 0.

    start = time.time()

    model.train()

    for X, y in train_ds:
      trainer.zero_grad()
      y_pred = model(X)
      l = loss(y_pred, y)
      l.backward()
      trainer.step()
      train_loss += l.item()
      train_acc += (y_pred.argmax(dim=1) == y).sum().item()
      train_iter += 1
      train_pass += len(X)

    test_iter, test_pass  = 0, 0
    test_loss, test_acc = 0., 0.

    model.eval()

    for X, y in test_ds:
      y_pred = model(X)
      l = loss(y_pred, y)
      test_loss += l.item()
      test_acc += (y_pred.argmax(dim=1) == y).sum().item()
      test_iter += 1
      test_pass += len(X)

    print(f"epoch: {ep}",
          f"time: {time.time() - start}",
          f"train_loss: {train_loss / train_iter}",
          f"train_acc: {train_acc / train_pass}",
          f"test_loss: {test_loss / test_iter}",
          f"test_acc: {test_acc / test_pass}")

In [None]:
train_model(epochs=10, batch_size=batch_size, model=model, trainer=trainer, loss=loss, train_ds=train, test_ds=test)

epoch: 0 time: 28.02057194709778 train_loss: 0.4610768539755583 train_acc: 0.8341333333333333 test_loss: 0.3723439749306453 test_acc: 0.8583166666666666
epoch: 1 time: 28.212783813476562 train_loss: 0.3542959317088381 train_acc: 0.8698833333333333 test_loss: 0.30947802160213245 test_acc: 0.8839666666666667
epoch: 2 time: 28.95626449584961 train_loss: 0.31672504296434967 train_acc: 0.8824166666666666 test_loss: 0.2778182773829015 test_acc: 0.8976666666666666
epoch: 3 time: 27.987969875335693 train_loss: 0.29038264677087383 train_acc: 0.8913666666666666 test_loss: 0.2642913027676438 test_acc: 0.9013
epoch: 4 time: 27.688534021377563 train_loss: 0.27230854742308414 train_acc: 0.8984166666666666 test_loss: 0.23891326820037004 test_acc: 0.91205
epoch: 5 time: 28.251788854599 train_loss: 0.25592262172368546 train_acc: 0.9047333333333333 test_loss: 0.22428731891964035 test_acc: 0.9170166666666667
epoch: 6 time: 29.283618927001953 train_loss: 0.23999310502492543 train_acc: 0.9092333333333333 t