In [1]:
import os
import torch
import numpy as np
import math
from torchvision import datasets, transforms
from torch import nn
from torch.utils.data import DataLoader

In [2]:
# get data
getTestDataset = datasets.FashionMNIST(root="./", train=False, transform=transforms.ToTensor())
getTrainDataset = datasets.FashionMNIST(root="./", train=True, transform=transforms.ToTensor())

# load data
loadTestDataset = DataLoader(dataset=getTestDataset, batch_size=32, shuffle=True)
loadTrainDataset = DataLoader(dataset=getTrainDataset, batch_size=32, shuffle=True)

In [6]:
# define neural network
class NNthing(nn.Module):
    def __init__(self):
        super().__init__()
        # turn image into 1 dimension
        self.flatten = nn.Flatten()
        # layers
        self.linear_relu_stack = nn.Sequential(
            # connect flattened data to 512 nodes
            nn.Linear(28*28, 512),
            # activation function
            nn.ReLU(),
            # nodes to nodes
            nn.Linear(512, 512),
            nn.ReLU(),
            # nodes into 10 categories
            nn.Linear(512, 10),
        )
    
    def forward(self, x):
        x = self.flatten(x)
        return self.linear_relu_stack(x)

In [9]:
# load model to cpu
nnthing = NNthing().to("cpu")
# define loss
crossE = nn.CrossEntropyLoss()
# define optimizer
opti = torch.optim.Adam(nnthing.parameters(), lr=0.01)

In [10]:
# training
for epoch in range(3):
    for i, (images, labels) in enumerate(loadTrainDataset):
        images = images.to("cpu")
        labels = labels.to("cpu")
        opti.zero_grad()
        pred = nnthing(images)
        loss = crossE(pred, labels)
        loss.backward()
        opti.step()

In [11]:
# test
count, success = 0, 0
with torch.no_grad():
    for images, labels in loadTestDataset:
        pred = nnthing(images)
        _, predicted = torch.max(pred.data, 1)
        count += labels.size(0)
        success += (predicted == labels).sum().item()

print(f"Accuracy: {100*success/count:.1f}%")


Accuracy: 83.3%
