In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [6]:


class Encoder2D(nn.Module):
    def __init__(self, repr_dim, input_size=65):
        super().__init__()
        self.repr_dim = repr_dim
        self.output_side = int(math.sqrt(repr_dim))  # Calculate the side of the 2D embedding

        # Determine the number of convolutional blocks required
        self.num_conv_blocks = int(math.log2(input_size / self.output_side))
        if 2 ** self.num_conv_blocks * self.output_side != 2 ** int(math.log2(input_size)):
            raise ValueError("Cannot evenly reduce input_size to output_side using stride-2 convolutions.")

        layers = []
        in_channels = 2  # Input has 2 channels (agent and wall)
        out_channels = 32  # Start with 32 output channels
        for i in range(self.num_conv_blocks):
            layers.append(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=0,
                ) if i == 0 else
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                )
            )  # Halve the spatial dimensions
            layers.append(nn.ReLU())
            in_channels = out_channels
            out_channels = min(out_channels * 2, 256)  # Cap channels at 256

        # Final convolution to reduce to single-channel output
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=1))  # Single-channel embedding

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        # Input: (B, 2, 65, 65)
        x = self.conv(x)  # Dynamically reduce to (B, 1, output_side, output_side)
        return x  # Output shape: (B, 1, output_side, output_side)

# Instantiate the Encoder2D with input size 65x65 and repr_dim 256
encoder = Encoder2D(repr_dim=256, input_size=65)
print(encoder)

# Test with a dummy input
input_tensor = torch.randn(1, 2, 65, 65)  # Batch size of 1, 2 channels, 65x65 input
output = encoder(input_tensor)

output.shape

Encoder2D(
  (conv): Sequential(
    (0): Conv2d(2, 32, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)


torch.Size([1, 1, 16, 16])