In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AtariFrameEncoder(nn.Module):
    """
    CNN encoder for 64x64x3 Atari frames.
    Outputs a single 256-dim embedding suitable as a transformer token.
    """

    def __init__(self, embed_dim=256):
        super().__init__()

        # Conv1: (64x64x3) -> (16x16x32)
        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=32,
            kernel_size=8,
            stride=4
        )

        # Conv2: (16x16x32) -> (8x8x64)
        self.conv2 = nn.Conv2d(
            in_channels=32,
            out_channels=64,
            kernel_size=4,
            stride=2
        )

        # Conv3: (8x8x64) -> (8x8x128)
        self.conv3 = nn.Conv2d(
            in_channels=64,
            out_channels=128,
            kernel_size=3,
            stride=1,
            padding=1  # keep 8x8
        )

        # Final projection: (8192) -> (embed_dim=256)
        self.fc = nn.Linear(8 * 8 * 128, embed_dim)

        # Optional: initialize reasonably
        nn.init.kaiming_normal_(self.fc.weight, nonlinearity='relu')
        nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        """
        x: (batch, 64, 64, 3) or (batch, 3, 64, 64)
        returns: (batch, 256)
        """

        # convert NHWC -> NCHW if needed
        if x.shape[-1] == 3:
            x = x.permute(0, 3, 1, 2)

        x = F.relu(self.conv1(x))  # (B, 32, 16, 16)
        x = F.relu(self.conv2(x))  # (B, 64, 8, 8)
        x = F.relu(self.conv3(x))  # (B, 128, 8, 8)

        x = torch.flatten(x, start_dim=1)  # (B, 8192)
        x = self.fc(x)  # (B, 256)

        return x
