In [6]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [8]:
# Prepare training and test data
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

	
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

training_data_loader= DataLoader(training_data, batch_size=64)
test_data_loader = DataLoader(test_data, batch_size=64)

# Define network
class FashionNetwork(nn.Module):
    def __init__(self):
        super(FashionNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )
    
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
    

def train_loop( data_loader, model, loss_function, optimizer):
    for batch, (data, labels) in enumerate(data_loader):
        # Feed data through networks and compute loss
        prediction = model(data)
        loss = loss_function(prediction, labels)

        # Zero gradients.
        optimizer.zero_grad()
 
        # Perform backpropagation and accumulate gradients.
        loss.backward()
 
        # Update network parameters.
        optimizer.step()
 

def test_loop(data_loader, model, loss_function):
    n_samples = len(data_loader.dataset)
    n_batches = len(data_loader)
    loss, n_correct = 0, 0
 
    with torch.no_grad():
        for data, labels in data_loader:
            # Feed data through network and accumulate loss.
            prediction = model(data)
            loss += loss_function(
                prediction, labels
            ).item()
            n_correct += (
                (prediction.argmax(1) == labels)
                .type(torch.float)
                .sum()
                .item()
            )
 
    print(
        f"Test Accuracy: {n_correct / n_samples:.2%}, "
        f"Test Loss: {loss / n_batches:.4}"
    )
 
 
# Initialize network, loss function, and optimizer.
model = FashionNetwork()
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.SGD(
    model.parameters(), lr=learning_rate
)
 
# Train the network.
n_epochs = 10
for t in range(n_epochs):
    print(f"Epoch {t + 1:02}", end=" ", flush=True)
    train_loop(
        training_data_loader, model, loss_fn, optimizer
    )
    test_loop(test_data_loader, model, loss_fn)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:23<00:00, 1115718.63it/s]


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 1231042.68it/s]

Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz





Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:04<00:00, 1037942.15it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 2832888.61it/s]


Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Epoch 01 Test Accuracy: 40.89%, Test Loss: 2.157
Epoch 02 Test Accuracy: 53.02%, Test Loss: 1.896
Epoch 03 Test Accuracy: 61.36%, Test Loss: 1.539
Epoch 04 Test Accuracy: 64.11%, Test Loss: 1.268
Epoch 05 Test Accuracy: 65.22%, Test Loss: 1.096
Epoch 06 Test Accuracy: 66.08%, Test Loss: 0.9843
Epoch 07 Test Accuracy: 67.21%, Test Loss: 0.9086
Epoch 08 Test Accuracy: 68.27%, Test Loss: 0.8549
Epoch 09 Test Accuracy: 69.54%, Test Loss: 0.8149
Epoch 10 Test Accuracy: 70.99%, Test Loss: 0.7835
