In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

NUM_CLASSES = 10


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, pool=True):
        super().__init__()

        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=0,  # PaddingConfig2d::Valid
            bias=True,
        )
        self.norm = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU(inplace=False)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2) if pool else None

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.activation(x)

        if self.pool is not None:
            x = self.pool(x)

        return x


class Model(nn.Module):
    def __init__(self):
        super().__init__()

        # Conv blocks
        self.conv1 = ConvBlock(1, 64, kernel_size=3, pool=True)
        # Output after conv1 + pool: [B, 64, 13, 13]

        self.conv2 = ConvBlock(64, 64, kernel_size=3, pool=True)
        # Output after conv2 + pool: [B, 64, 5, 5]

        # Fully connected layers
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, NUM_CLASSES)

        self.activation = nn.GELU()
        self.dropout = nn.Dropout(p=0.25)

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        """
        x: Tensor of shape [B, H, W]
        """
        batch_size, height, width = x.shape

        # [B, H, W] â†’ [B, 1, H, W]
        x = x.view(batch_size, 1, height, width)

        # Match Burn's `.detach()`
        x = x.detach()

        x = self.conv1(x)
        x = self.conv2(x)

        # Flatten
        x = x.view(batch_size, -1)

        # FC block 1
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)

        # FC block 2
        x = self.fc2(x)
        x = self.activation(x)
        x = self.dropout(x)

        # Output logits
        x = self.fc3(x)
        return x

    def forward_classification(self, images, targets):
        """
        Equivalent to forward_classification in Burn
        """
        logits = self.forward(images)
        loss = self.loss_fn(logits, targets)

        return {
            "loss": loss,
            "output": logits,
            "targets": targets,
        }
