In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

**Dataset**

In [2]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)


**Custom Modules**

In [3]:
class c_Linear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()

        std_dev = torch.sqrt(1 / torch.tensor(in_features))
        self.weights = nn.Parameter(torch.randn((in_features, out_features)) * std_dev)
        self.bias = nn.Parameter(torch.randn((1, out_features)) * std_dev)

    def forward(self, x):
        x = x @ self.weights
        x = x + self.bias
        return x

In [4]:
class c_ReLU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = torch.maximum(x, torch.tensor(0, device=x.device))
        return x

In [5]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            c_Linear(28*28, 512),
            c_ReLU(),
            c_Linear(512, 512),
            c_ReLU(),
            c_Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()
model

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): c_Linear()
    (1): c_ReLU()
    (2): c_Linear()
    (3): c_ReLU()
    (4): c_Linear()
  )
)

**Hyperparameters**

In [6]:
learning_rate = 1e-3
batch_size = 64
epochs = 5

**Loss Function**

In [7]:
loss_fn = nn.CrossEntropyLoss()

**Optimizer**

In [8]:
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

**Train Model**

In [9]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [10]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.283262  [   64/60000]
loss: 2.191972  [ 6464/60000]
loss: 2.038403  [12864/60000]
loss: 2.013173  [19264/60000]
loss: 1.895012  [25664/60000]
loss: 1.787168  [32064/60000]
loss: 1.755145  [38464/60000]
loss: 1.604513  [44864/60000]
loss: 1.579576  [51264/60000]
loss: 1.456357  [57664/60000]
Test Error: 
 Accuracy: 64.5%, Avg loss: 1.459977 

Epoch 2
-------------------------------
loss: 1.498345  [   64/60000]
loss: 1.475300  [ 6464/60000]
loss: 1.259487  [12864/60000]
loss: 1.362056  [19264/60000]
loss: 1.207773  [25664/60000]
loss: 1.188057  [32064/60000]
loss: 1.194746  [38464/60000]
loss: 1.106400  [44864/60000]
loss: 1.125696  [51264/60000]
loss: 1.032214  [57664/60000]
Test Error: 
 Accuracy: 67.8%, Avg loss: 1.057028 

Epoch 3
-------------------------------
loss: 1.090392  [   64/60000]
loss: 1.125175  [ 6464/60000]
loss: 0.897557  [12864/60000]
loss: 1.074420  [19264/60000]
loss: 0.939251  [25664/60000]
loss: 0.942777  [32064/600

In [11]:
model.state_dict()

OrderedDict([('linear_relu_stack.0.weights',
              tensor([[-0.0818, -0.0226,  0.0455,  ..., -0.0118, -0.0375, -0.0901],
                      [ 0.0249, -0.0368,  0.0305,  ..., -0.0026,  0.0490, -0.0085],
                      [ 0.0013,  0.0122,  0.0409,  ...,  0.0526, -0.0342,  0.0186],
                      ...,
                      [-0.0565, -0.0290, -0.0123,  ..., -0.0072, -0.0196,  0.0396],
                      [ 0.0153, -0.0136,  0.0315,  ...,  0.0650,  0.0906, -0.0489],
                      [ 0.0537, -0.0432, -0.0133,  ..., -0.0268, -0.0254,  0.0125]])),
             ('linear_relu_stack.0.bias',
              tensor([[ 7.9588e-04, -1.1547e-02, -4.9352e-02,  4.5235e-03,  6.0220e-02,
                        7.7197e-03, -3.2454e-02, -2.7994e-02, -3.9177e-02, -1.0494e-02,
                        1.0827e-02,  2.3062e-02,  2.3528e-02,  1.2449e-02, -1.4877e-02,
                        4.0454e-03,  6.6616e-02,  2.4341e-02,  1.4857e-02, -1.7112e-02,
                        4.4

**Save model**

In [12]:
torch.save(model.state_dict(), 'basic_nn_fashion_mnist.pth')

**Load Mode**

In [13]:
model.load_state_dict(torch.load('basic_nn_fashion_mnist.pth', weights_only=True))
model.eval()

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): c_Linear()
    (1): c_ReLU()
    (2): c_Linear()
    (3): c_ReLU()
    (4): c_Linear()
  )
)

**Save Architecture + Weights**

In [14]:
torch.save(model, 'basic_nn_fashion_mnist_architecture.pth')

**Load Architecture + Weights**

In [15]:
model = torch.load('basic_nn_fashion_mnist_architecture.pth', weights_only=False)
model

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): c_Linear()
    (1): c_ReLU()
    (2): c_Linear()
    (3): c_ReLU()
    (4): c_Linear()
  )
)