In [None]:
from typing import List
import numpy as np
from torch import nn
from torch.nn import functional as F
import torch


def build_mlp(layers_dims: List[int]):
    layers = []
    for i in range(len(layers_dims) - 2):
        layers.append(nn.Linear(layers_dims[i], layers_dims[i + 1]))
        layers.append(nn.BatchNorm1d(layers_dims[i + 1]))
        layers.append(nn.ReLU(True))
    layers.append(nn.Linear(layers_dims[-2], layers_dims[-1]))
    return nn.Sequential(*layers)


class MockModel(torch.nn.Module):
    """
    Does nothing. Just for testing.
    """

    def __init__(self, device="cuda", bs=64, n_steps=17, output_dim=256):
        super().__init__()
        self.device = device
        self.bs = bs
        self.n_steps = n_steps
        self.repr_dim = 256
        
        # Encoder with Adaptive Pooling for dynamic image size
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1, 1)),  # Outputs [B, C, 1, 1], regardless of input size
            nn.Flatten(),
            build_mlp([128, 512, self.repr_dim])  # 128 -> 512 -> output_dim
        )

        # Predictor (GRU) for sequential state prediction
        self.predictor = nn.GRU(input_size=self.repr_dim + 2, hidden_size=self.repr_dim, batch_first=True)


    def forward(self, states, actions):
        """
        Args:
            states: [B, T, Ch, H, W]
            actions: [B, T-1, 2]

        Output:
            predictions: [B, T, D]
        """
        B, T, C, H, W = states.shape
        
        # Encode states
        encoded_states = self.encoder(states.view(-1, C, H, W)).view(B, T, -1)
        
        # Pad actions to match T and concatenate with encoded states
        actions = F.pad(actions, (0, 0, 1, 0))  # Pad along time dimension
        inputs = torch.cat([encoded_states, actions], dim=-1)
        
        # Predict latent states
        predictions, _ = self.predictor(inputs)
        return predictions


class Prober(torch.nn.Module):
    def __init__(
        self,
        embedding: int,
        arch: str,
        output_shape: List[int],
    ):
        super().__init__()
        self.output_dim = np.prod(output_shape)
        self.output_shape = output_shape
        self.arch = arch

        arch_list = list(map(int, arch.split("-"))) if arch != "" else []
        f = [embedding] + arch_list + [self.output_dim]
        layers = []
        for i in range(len(f) - 2):
            layers.append(torch.nn.Linear(f[i], f[i + 1]))
            layers.append(torch.nn.ReLU(True))
        layers.append(torch.nn.Linear(f[-2], f[-1]))
        self.prober = torch.nn.Sequential(*layers)

    def forward(self, e):
        output = self.prober(e)
        return output


In [None]:
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
from typing import Tuple
from models import MockModel

# Custom Dataset Class
class TrajectoryDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir: Path to the dataset directory containing states, actions, and targets.
            transform: Optional image transformations for states.
        """
        self.root_dir = root_dir
        self.transform = transform

        # Load file paths
        self.state_paths = sorted([os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith("_state.npy")])
        self.action_paths = sorted([os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith("_action.npy")])
        self.target_paths = sorted([os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith("_target.npy")])

        assert len(self.state_paths) == len(self.action_paths) == len(self.target_paths), \
            "Mismatch in number of state, action, and target files."

    def __len__(self):
        return len(self.state_paths)

    def __getitem__(self, idx):
        state = np.load(self.state_paths[idx])  # [T, Ch, H, W]
        action = np.load(self.action_paths[idx])  # [T-1, 2]
        target = np.load(self.target_paths[idx])  # [T, repr_dim]

        if self.transform:
            state = torch.stack([self.transform(state[i]) for i in range(state.shape[0])])

        return torch.tensor(state, dtype=torch.float32), \
               torch.tensor(action, dtype=torch.float32), \
               torch.tensor(target, dtype=torch.float32)


# Metrics
def compute_accuracy(predictions: torch.Tensor, targets: torch.Tensor) -> float:
    """
    Computes accuracy as a measure of closeness between predictions and targets.
    """
    threshold = 0.1
    distances = torch.norm(predictions - targets, dim=-1)
    return (distances < threshold).float().mean().item()


# Training Loop
def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device):
    for epoch in range(epochs):
        # Training Phase
        model.train()
        train_loss, train_accuracy = 0.0, 0.0
        for states, actions, targets in train_loader:
            states, actions, targets = states.to(device), actions.to(device), targets.to(device)
            optimizer.zero_grad()
            
            # Forward pass
            predictions = model(states, actions)
            
            # Loss computation
            loss = criterion(predictions, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_accuracy += compute_accuracy(predictions, targets)

        # Validation Phase
        model.eval()
        val_loss, val_accuracy = 0.0, 0.0
        with torch.no_grad():
            for states, actions, targets in val_loader:
                states, actions, targets = states.to(device), actions.to(device), targets.to(device)
                predictions = model(states, actions)
                loss = criterion(predictions, targets)
                val_loss += loss.item()
                val_accuracy += compute_accuracy(predictions, targets)

        # Log epoch metrics
        print(f"Epoch {epoch + 1}/{epochs}")
        print(f"  Train Loss: {train_loss / len(train_loader):.4f}, Train Accuracy: {train_accuracy / len(train_loader):.4f}")
        print(f"  Val Loss: {val_loss / len(val_loader):.4f}, Val Accuracy: {val_accuracy / len(val_loader):.4f}")


In [None]:
# Example Training Script
if __name__ == "__main__":
    # Dataset paths
    train_path = "/scratch/DL24FA/probe_normal/train"
    val_path_normal = "/scratch/DL24FA/probe_normal/val"
    val_path_wall = "/scratch/DL24FA/probe_wall/val"

    # Transformations
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])  # Example normalization
    ])

    # Load datasets
    train_dataset = TrajectoryDataset(train_path, transform)
    val_dataset_normal = TrajectoryDataset(val_path_normal, transform)
    val_dataset_wall = TrajectoryDataset(val_path_wall, transform)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader_normal = DataLoader(val_dataset_normal, batch_size=8, shuffle=False)
    val_loader_wall = DataLoader(val_dataset_wall, batch_size=8, shuffle=False)

    # Initialize model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = MockModel().to(device)

    # Loss and Optimizer
    criterion = nn.MSELoss()  # Measure distance in representation space
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Train model
    print("Training with Normal Validation Set:")
    train_model(model, train_loader, val_loader_normal, epochs=10, criterion=criterion, optimizer=optimizer, device=device)

    print("\nEvaluating on Wall Validation Set:")
    train_model(model, train_loader, val_loader_wall, epochs=1, criterion=criterion, optimizer=optimizer, device=device)