In [156]:
from copy import deepcopy

import torch
import torch.nn.functional as F

from torch import nn
from tensorflow.keras.datasets import fashion_mnist

from torch.utils.data import DataLoader, TensorDataset

import numpy as np

In [50]:
class FMnistNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = x.reshape(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x), dim=1)
        return x

In [177]:
def acc(net_output, labels):
    predicted = net_output.argmax(dim=1)
    correct = (predicted == labels).sum()
    examples = len(labels)
    return (correct / examples).item()

In [134]:
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train_tensor = torch.tensor(deepcopy(x_train), dtype=torch.float32)
x_train_tensor /= 255.
x_test_tensor = torch.tensor(deepcopy(x_test), dtype=torch.float32)
x_test_tensor /= 255.
y_train_tensor = torch.tensor(deepcopy(y_train), dtype=torch.long)
y_test_tensor = torch.tensor(deepcopy(y_test), dtype=torch.long)

In [180]:
train_set = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), shuffle=False, batch_size=32)
test_set = DataLoader(TensorDataset(x_test_tensor, y_test_tensor), batch_size=len(x_test_tensor))

In [185]:
next(iter(test_set))[0].shape, next(iter(test_set))[1].shape, 

(torch.Size([10000, 28, 28]), torch.Size([10000]))

In [188]:
net = FMnistNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
epochs = 100

for epoch in range(epochs):
    accum_loss = []
    accum_acc = []
    
    for data, labels in train_set:
        optimizer.zero_grad()
        net_output = net(data)
        loss = loss_fn(net_output, labels)
        loss.backward()
        optimizer.step()
        
        accum_loss.append(loss.item())
        accum_acc.append(acc(net_output, labels))
        
    train_avg_loss = np.mean(accum_loss)
    train_avg_acc = np.mean(accum_acc)
    
    with torch.no_grad():
        test_data, test_labels = next(iter(test_set))
        test_output = net(test_data)
        test_loss = loss_fn(test_output, test_labels)
        test_acc = acc(test_output, test_labels)
        
        
    print(f"EPOCH {epoch+1}/{epochs} -- train_loss: {train_avg_loss:.4f}, test_loss: {test_loss:.4f}, train_acc: {train_avg_acc:.4f},  test_acc: {test_acc:.4f}")

EPOCH 0/100 -- train_loss: 2.2838, test_loss: 2.2196, train_acc: 0.2726,  test_acc: 0.3457
EPOCH 1/100 -- train_loss: 2.0394, test_loss: 1.8939, train_acc: 0.4864,  test_acc: 0.6451
EPOCH 2/100 -- train_loss: 1.8346, test_loss: 1.8099, train_acc: 0.6693,  test_acc: 0.6743
EPOCH 3/100 -- train_loss: 1.7853, test_loss: 1.7729, train_acc: 0.6921,  test_acc: 0.7129
EPOCH 4/100 -- train_loss: 1.7471, test_loss: 1.7360, train_acc: 0.7457,  test_acc: 0.7603
EPOCH 5/100 -- train_loss: 1.7160, test_loss: 1.7120, train_acc: 0.7724,  test_acc: 0.7719
EPOCH 6/100 -- train_loss: 1.6979, test_loss: 1.6982, train_acc: 0.7815,  test_acc: 0.7791
EPOCH 7/100 -- train_loss: 1.6871, test_loss: 1.6896, train_acc: 0.7876,  test_acc: 0.7842
EPOCH 8/100 -- train_loss: 1.6800, test_loss: 1.6836, train_acc: 0.7922,  test_acc: 0.7878
EPOCH 9/100 -- train_loss: 1.6748, test_loss: 1.6792, train_acc: 0.7961,  test_acc: 0.7906
EPOCH 10/100 -- train_loss: 1.6708, test_loss: 1.6759, train_acc: 0.7988,  test_acc: 0.793

EPOCH 90/100 -- train_loss: 1.6222, test_loss: 1.6422, train_acc: 0.8427,  test_acc: 0.8191
EPOCH 91/100 -- train_loss: 1.6220, test_loss: 1.6422, train_acc: 0.8428,  test_acc: 0.8191
EPOCH 92/100 -- train_loss: 1.6218, test_loss: 1.6421, train_acc: 0.8430,  test_acc: 0.8195
EPOCH 93/100 -- train_loss: 1.6215, test_loss: 1.6420, train_acc: 0.8432,  test_acc: 0.8194
EPOCH 94/100 -- train_loss: 1.6213, test_loss: 1.6418, train_acc: 0.8435,  test_acc: 0.8196
EPOCH 95/100 -- train_loss: 1.6211, test_loss: 1.6417, train_acc: 0.8437,  test_acc: 0.8196
EPOCH 96/100 -- train_loss: 1.6209, test_loss: 1.6417, train_acc: 0.8439,  test_acc: 0.8197
EPOCH 97/100 -- train_loss: 1.6206, test_loss: 1.6416, train_acc: 0.8442,  test_acc: 0.8199
EPOCH 98/100 -- train_loss: 1.6204, test_loss: 1.6416, train_acc: 0.8444,  test_acc: 0.8195
EPOCH 99/100 -- train_loss: 1.6202, test_loss: 1.6416, train_acc: 0.8446,  test_acc: 0.8192
