In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [6]:


class Encoder2D(nn.Module):
    def __init__(self, repr_dim, input_size=65):
        super().__init__()
        self.repr_dim = repr_dim
        self.output_side = int(math.sqrt(repr_dim))  # Calculate the side of the 2D embedding

        # Determine the number of convolutional blocks required
        self.num_conv_blocks = int(math.log2(input_size / self.output_side))
        if 2 ** self.num_conv_blocks * self.output_side != 2 ** int(math.log2(input_size)):
            raise ValueError("Cannot evenly reduce input_size to output_side using stride-2 convolutions.")

        layers = []
        in_channels = 2  # Input has 2 channels (agent and wall)
        out_channels = 32  # Start with 32 output channels
        for i in range(self.num_conv_blocks):
            layers.append(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=0,
                ) if i == 0 else
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                )
            )  # Halve the spatial dimensions
            layers.append(nn.ReLU())
            in_channels = out_channels
            out_channels = min(out_channels * 2, 256)  # Cap channels at 256

        # Final convolution to reduce to single-channel output
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=1))  # Single-channel embedding

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        # Input: (B, 2, 65, 65)
        x = self.conv(x)  # Dynamically reduce to (B, 1, output_side, output_side)
        return x  # Output shape: (B, 1, output_side, output_side)

# Instantiate the Encoder2D with input size 65x65 and repr_dim 256
encoder = Encoder2D(repr_dim=256, input_size=65)
print(encoder)

# Test with a dummy input
input_tensor = torch.randn(1, 2, 65, 65)  # Batch size of 1, 2 channels, 65x65 input
output = encoder(input_tensor)

output.shape

Encoder2D(
  (conv): Sequential(
    (0): Conv2d(2, 32, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)


torch.Size([1, 1, 16, 16])

In [None]:
import timm
import torch
import torch.nn as nn
import math

class FlexibleEncoder2D(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.repr_dim = config.embed_dim

        # Ensure output size is consistent
        self.output_side = int(math.sqrt(self.repr_dim))  # Should be 16 for config.embed_dim=256
        if self.output_side != 16:
            raise ValueError("Output side must be 16 for config.embed_dim=256.")

        # Dynamically select backbone based on config.encoder_backbone
        self.backbone = timm.create_model(
            config.encoder_backbone,  # Example: 'resnet18.a1_in1k'
            pretrained=False,  # No pretraining allowed
            num_classes=0,  # No classifier head
            in_chans=2,  # Input has 2 channels
            features_only=True,  # Extract spatial features
        )

        # Inspect available feature maps
        self.feature_channels = [info['num_chs'] for info in self.backbone.feature_info]
        self.feature_shapes = [info['reduction'] for info in self.backbone.feature_info]  # Spatial size reductions

        # Select the layer closest to 16x16
        self.closest_layer_index = self._find_closest_layer()

        # Final adjustment to 16x16
        self.adjust_to_target = nn.Sequential(
            nn.Conv2d(self.feature_channels[self.closest_layer_index], 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
        )

    def _find_closest_layer(self):
        # Find the layer whose spatial size is closest to output_side
        input_size = 65  # Assumes input spatial dimensions (H, W) = 65x65
        reductions = [input_size // red for red in self.feature_shapes]
        closest_index = min(range(len(reductions)), key=lambda i: abs(reductions[i] - self.output_side))
        return closest_index

    def forward(self, x):
        # Pass input through the backbone and select the appropriate layer
        features = self.backbone(x)
        x = features[self.closest_layer_index]  # Closest layer to 16x16

        # Adjust to target shape
        x = self.adjust_to_target(x)
        return x


# Define the configuration class
class Config:
    embed_dim = 256  # Output embedding size
    encoder_backbone = 'resnet18.a1_in1k'  # Use ResNet-18 as the backbone

# Instantiate and test the model
config = Config()
model = FlexibleEncoder2D(config)

# Generate a random input tensor
input_tensor = torch.randn(4, 2, 65, 65)  # Example input (B, 2, 65, 65)

# Run the model and check the output size
output = model(input_tensor)
output.shape


torch.Size([4, 1, 17, 17])

In [28]:
import timm
import torch
import torch.nn as nn
import math

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.convnext = timm.create_model('resnet18.a1_in1k', pretrained=False, num_classes=0, in_chans=2, features_only=True)
    
    def forward(self, x):
        # Reshape input to merge batch and trajectory dimensions
        original_shape = x.shape
        x = x.view(-1, *original_shape[-3:])  # Reshape to [batch*trajectory, channels, height, width]
        features = self.convnext(x)[1]
        
        # Reshape features back to original trajectory structure
        features = features.view(original_shape[0], original_shape[1], *features.shape[-3:])
        return features


# Define the configuration class
class Config:
    embed_dim = 256  # Output embedding size
    encoder_backbone = 'resnet18.a1_in1k'  # Use ResNet-18 as the backbone

# Instantiate and test the model
config = Config()
model = Encoder()

# Generate a random input tensor
input_tensor = torch.randn(4, 4, 2, 65, 65)  # Example input (B, 2, 65, 65)

# Run the model and check the output size
output = model(input_tensor)
output.shape


torch.Size([4, 4, 64, 17, 17])

In [37]:
class Predictor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.predictor = nn.Sequential(
            nn.Conv2d(input_dim, input_dim-2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(input_dim-2, output_dim, kernel_size=3, padding=1)
        )
    
    def forward(self, encoded_o_t, action):
        # Reshape inputs
        batch_size, trajectory_length = encoded_o_t.shape[:2]
        
        # Reshape action to match encoded_o_t dimensions
        action = action.view(batch_size, trajectory_length-1, 2, 1, 1)
        action = action.repeat(1, 1, 1, encoded_o_t.size(3), encoded_o_t.size(4))
        
        # Prepare inputs for prediction
        predictions = []
        for t in range(trajectory_length - 1):
            # Concatenate current encoded state with action
            x = torch.cat([encoded_o_t[:, t], action[:, t]], dim=1)
            pred = self.predictor(x)
            predictions.append(pred)
        
        return torch.stack(predictions, dim=1)
    
# Generate a random action tensor
action_tensor = torch.randn(4, 3, 2)  # Example action (B, T-1, 2)

# Pass the encoded output and action through the predictor
predictor = Predictor(input_dim=66, output_dim=64)  # Assuming input_dim=3 (encoded_o_t channels + action channels) and output_dim=1
predicted_output = predictor(output, action_tensor)

# Check the output size
predicted_output.shape

torch.Size([4, 3, 64, 17, 17])