## Device

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Model architecture

### CNN encoder

In [None]:
import torch
from torch import nn


class Encoder(nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int):
        """
        Encodes the 28*28 image vid a CNN network into hidden_channels channels of 7*7 representation.
        :param in_channels: how many color channels the image has.
        :param hidden_channels: how many channels the image should be encoded into.
        """
        super().__init__()

        self.lrelu = nn.ReLU()

        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=hidden_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            self.lrelu,
            nn.Conv2d(in_channels=hidden_channels,
                      out_channels=hidden_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            self.lrelu,
            nn.MaxPool2d(kernel_size=2,
                         stride=2)
        )
        self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels=hidden_channels,
                      out_channels=hidden_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            self.lrelu,
            nn.Conv2d(in_channels=hidden_channels,
                      out_channels=hidden_channels,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            self.lrelu,
            nn.MaxPool2d(kernel_size=2,
                         stride=2)
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block_2(x)
        return x


### Linear decoder

In [None]:
import torch
from torch import nn


class Decoder(nn.Module):
    def __init__(self, output_classes: int, hidden_channels: int):
        """
        Decodes (classifies) the encoded image into one of output_classes classes.
        :param output_classes: how many classes the image could have.
        :param hidden_channels: how many hidden channels the image should be decoded from.
        """
        super().__init__()

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=hidden_channels * 7 * 7,
                      out_features=output_classes),
        )

    def forward(self, x):
        x = self.classifier(x)
        return x

### Combined model

In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self, in_channels: int, output_classes: int, hidden_channels: int):
        super().__init__()
        self.encoder = Encoder(in_channels=in_channels, hidden_channels=hidden_channels)
        self.decoder = Decoder(hidden_channels=hidden_channels, output_classes=output_classes).to(device=device)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


model = Model(in_channels=1,
              output_classes=10,
              hidden_channels=10).to(device=device)

## Training data

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.MNIST(root='mnist_data',
                            train=True,
                            download=True,
                            transform=ToTensor(),
                            target_transform=None)

test_data = datasets.MNIST(root='mnist_data',
                           train=False,
                           download=True,
                           transform=ToTensor(),
                           target_transform=None)

BATCH_SIZE = 32

train_dataloader = DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

test_dataloader = DataLoader(
    dataset=test_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
)


## Training Loop

In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt
from timeit import default_timer as timer
import numpy as np

EPOCHS = 2
LEARNING_RATE = 0.01
UPDATE_PARAMETERS_EVERY = 1

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loss_history = []
test_loss_history = []
train_acc_history = []
test_acc_history = []
current_batch = 0

for epoch in range(EPOCHS):
    model.train()


    for train_batch_index, (x_train, y_train) in enumerate(train_dataloader):
        optimizer.zero_grad()
        x_train, y_train = x_train.to(device=device), y_train.to(device=device)
        train_outputs = model(x_train)
        loss = loss_fn(train_outputs, y_train)
        train_loss_history.append(loss.item())
        current_batch += 1
        print(f"E {epoch + 1:,}/{EPOCHS:,} | Batch {train_batch_index + 1:,}/{len(train_dataloader):,} | Loss: {loss.item():,}")
        loss.backward()
        optimizer.step()

    with torch.inference_mode():
        model.eval()
        for test_batch_index, (x_test, y_test) in enumerate(test_dataloader):
            x_test, y_test = x_test.to(device=device), y_test.to(device=device)
            test_outputs = model(x_test)
            test_loss = loss_fn(test_outputs, y_test)
            test_loss_history.append(test_loss.item())
            print(f"E {epoch + 1:,}/{EPOCHS:,} | Test batch {test_batch_index + 1:,}/{len(test_dataloader):,} | Loss: {test_loss.item():,}")

train_x = np.arange(len(train_loss_history))  # Train loss x-axis
test_x = np.linspace(0, len(train_loss_history)-1, len(test_loss_history))

plt.figure(figsize=(8, 5))
plt.plot(train_x, train_loss_history, label='Train Loss', marker='o')
plt.plot(test_x, test_loss_history, label='Test Loss', marker='s')

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Train vs Test Loss")
plt.legend()
plt.grid(True)
plt.show()