In [None]:
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root='../data',
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root='../data',
    train=False,
    download=True,
    transform=ToTensor(),
)

In [None]:
# DataLoader wraps an iterable over our dataset, and supports automatic 
# batching, sampling, shuffling and multiprocess data loading
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    # N: Batch size, C: Channels, H: Height, W: Width
    print(f'Size of X[N,C,H,W]: {X.shape}, type: {X.dtype}')
    print(f'Size of y: {y.shape}, type: {y.dtype}')
    break

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10))
        
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(mps_device)
print(model)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X,y) in enumerate(dataloader):
        X, y = X.to(mps_device), y.to(mps_device)
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch % 100 == 0:
            loss, current = loss.item(), (batch+1) * len(X)
            print(f'Loss: {loss:>.8f}, [{(current / size)}]')

In [None]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches =  len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(mps_device), y.to(mps_device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f'Test error: \n Accuracy {correct*100:>.1f}% ,avg loss: {test_loss:>.8f}')

In [None]:
epochs = 10
for t in range(epochs):
    print(f'Epoch {t+1}\n')
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print('Done!')