In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, input_channels=2):
        super().__init__()
        # First conv: 65x65 -> 22x22
        self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=7, stride=3, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        
        # Second conv: 22x22 -> 8x8
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3, padding=2, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        
        # Final conv: maintain 8x8 but process features
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(32)
        
        self.repr_dim = 32
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # -> 22x22
        x = F.relu(self.bn2(self.conv2(x)))  # -> 8x8
        x = F.relu(self.bn3(self.conv3(x)))  # -> 8x8
        return x  # Output shape: [B, 32, 8, 8]

class TransitionModel(nn.Module):
    def __init__(self, hidden_dim=32):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # Action embedding
        self.action_embed = nn.Sequential(
            nn.Conv2d(2, 16, 1),  # First go to intermediate dimension
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, hidden_dim, 1),  # Then to full hidden_dim
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU()
        )
        
        # Transition model
        self.transition = nn.Sequential(
            nn.Conv2d(hidden_dim * 2, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim)
        )
        
    def forward(self, state, action):
        """
        Args:
            state: [B, hidden_dim, 8, 8] - Current state representation
            action: [B, 2] - (dx, dy) action
        """
        B, _, H, W = state.shape
        
        # Expand action to spatial dimensions and embed
        action = action.view(B, 2, 1, 1).expand(-1, -1, H, W)
        action_embedding = self.action_embed(action)
        
        # Combine state and action
        combined = torch.cat([state, action_embedding], dim=1)
        
        # Predict state change
        delta = self.transition(combined)
        
        # Residual connection
        next_state = state + delta
        
        return next_state

class WorldModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder(input_channels=2)
        self.predictor = TransitionModel(hidden_dim=32)
        
    def forward(self, states, actions):
        """
        Args:
            states: [B, 1, 2, 65, 65] - Initial state only
            actions: [B, T-1, 2] - Sequence of T-1 actions
        Returns:
            predictions: [B, T, 32, 8, 8] - Predicted representations
        """
        B, _, _, H, W = states.shape
        T = actions.shape[1] + 1
        
        # Get initial state encoding
        curr_state = self.encoder(states.squeeze(1))
        predictions = [curr_state]
        
        # Predict future states
        for t in range(T-1):
            curr_state = self.predictor(curr_state, actions[:, t])
            predictions.append(curr_state)
            
        predictions = torch.stack(predictions, dim=1)
        return predictions

In [None]:
import torch
from dataset import create_wall_dataloader

def test_model_with_real_data():
    # Initialize models with new dimensions
    encoder = Encoder(input_channels=2).to('cuda')
    predictor = TransitionModel(hidden_dim=32).to('cuda')
    world_model = WorldModel().to('cuda')
    
    # Create data loader with actual data
    data_path = "/scratch/an3854/DL24FA/train"
    dataloader = create_wall_dataloader(
        data_path=data_path,
        probing=False,
        device="cuda" if torch.cuda.is_available() else "cpu",
        batch_size=256,
        train=True
    )
    
    # Get a batch of data
    batch = next(iter(dataloader))
    states = batch.states  # [B, T, C, H, W]
    actions = batch.actions  # [B, T-1, 2]
    
    print("\nInput Data Shapes:")
    print(f"States shape: {states.shape}")  # Should be [4, T, 2, 65, 65]
    print(f"Actions shape: {actions.shape}")  # Should be [4, T-1, 2]
    print(f"States min/max: {states.min():.2f}, {states.max():.2f}")
    print(f"Actions min/max: {actions.min():.2f}, {actions.max():.2f}")
    
    # Test encoder with initial state
    init_state = states[:, 0]  # Take first timestep
    encoded = encoder(init_state)
    print(f"\nEncoder output shape: {encoded.shape}")  # Should be [4, 32, 8, 8]
    
    # Test predictor with a single action
    single_action = actions[:, 0]  # Take first action
    print(f"\nSingle action sample:")
    print(f"Action values: {single_action[0]}")  # Print first batch's action
    
    # Test action processing
    B, _, h, w = encoded.shape
    action_spatial = single_action.view(B, 2, 1, 1).expand(-1, -1, h, w)
    print(f"\nAction processing:")
    print(f"Action after spatial expansion: {action_spatial.shape}")  # Should be [4, 2, 8, 8]
    
    # Test action embedding
    action_embedded = predictor.action_embed(action_spatial)
    print(f"Action after embedding: {action_embedded.shape}")  # Should be [4, 32, 8, 8]
    
    # Test single prediction step
    next_state = predictor(encoded, single_action)
    print(f"\nPrediction shapes:")
    print(f"Single step prediction: {next_state.shape}")  # Should be [4, 32, 8, 8]
    
    # Test full sequence prediction
    init_states = states[:, 0:1]  # [B, 1, C, H, W]
    sequence_actions = actions  # [B, T-1, 2]
    predictions = world_model(init_states, sequence_actions)
    print(f"Full sequence prediction: {predictions.shape}")
    # Should be [B, T, 32, 8, 8]
    
    # Print compression stats
    input_size = states.shape[-2] * states.shape[-1] * states.shape[-3]
    encoded_size = encoded.shape[-2] * encoded.shape[-1] * encoded.shape[-3]
    compression_ratio = input_size / encoded_size
    print(f"\nCompression Statistics:")
    print(f"Input size per frame: {input_size} values")
    print(f"Encoded size per frame: {encoded_size} values")
    print(f"Compression ratio: {compression_ratio:.2f}x")
    
    # Verify shapes
    assert encoded.shape[1] == 32, "Channel dimension mismatch"
    assert encoded.shape[2] == encoded.shape[3] == 8, "Spatial dimensions mismatch"
    assert predictions.shape[1] == actions.shape[1] + 1, "Sequence length mismatch"
    assert predictions.shape[2] == 32, "Output channel mismatch"
    assert predictions.shape[3] == predictions.shape[4] == 8, "Output spatial dimension mismatch"
    
    print("\nAll dimension tests passed!")
    
    # Test multiple forward passes
    print("\nTesting multiple forward passes...")
    for i, batch in enumerate(dataloader):
        if i >= 3:  # Test with 3 batches
            break
        states = batch.states
        actions = batch.actions
        init_states = states[:, 0:1]
        with torch.no_grad():
            predictions = world_model(init_states, actions)
        print(f"Batch {i+1} processed successfully.")
        
        # Print some statistics about the predictions
        if i == 0:
            print(f"Prediction stats for batch {i+1}:")
            print(f"Mean: {predictions.mean():.3f}")
            print(f"Std: {predictions.std():.3f}")
            print(f"Min: {predictions.min():.3f}")
            print(f"Max: {predictions.max():.3f}")
    
    return world_model, dataloader

if __name__ == "__main__":
    print("Running model tests with real data...")
    model, dataloader = test_model_with_real_data()

Running model tests with real data...

Input Data Shapes:
States shape: torch.Size([256, 17, 2, 65, 65])
Actions shape: torch.Size([256, 16, 2])
States min/max: 0.00, 0.09
Actions min/max: -1.76, 1.76

Encoder output shape: torch.Size([256, 32, 8, 8])

Single action sample:
Action values: tensor([-0.4032, -1.0003], device='cuda:0')

Action processing:
Action after spatial expansion: torch.Size([256, 2, 8, 8])
Action after embedding: torch.Size([256, 32, 8, 8])

Prediction shapes:
Single step prediction: torch.Size([256, 32, 8, 8])
Full sequence prediction: torch.Size([256, 17, 32, 8, 8])

Compression Statistics:
Input size per frame: 8450 values
Encoded size per frame: 2048 values
Compression ratio: 4.13x

All dimension tests passed!

Testing multiple forward passes...
Batch 1 processed successfully.
Prediction stats for batch 1:
Mean: 0.394
Std: 7.925
Min: -77.734
Max: 99.804
Batch 2 processed successfully.
Batch 3 processed successfully.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from datetime import datetime
from tqdm import tqdm
from dataset import WallSample

class VICRegLoss(nn.Module):
    def __init__(self, lambda_param=25.0, mu_param=25.0, nu_param=1.0):
        super().__init__()
        self.lambda_param = lambda_param  # invariance loss coefficient
        self.mu_param = mu_param         # variance loss coefficient
        self.nu_param = nu_param         # covariance loss coefficient
    
    def off_diagonal(self, x):
        """Return off-diagonal elements of a square matrix"""
        n = x.shape[0]
        return x.flatten()[:-1].view(n-1, n+1)[:, 1:].flatten()
    
    def forward(self, z_a, z_b):
        """
        Args:
            z_a, z_b: Batch of representations [N, D]
        Returns:
            total_loss: Combined VICReg loss
            losses: Dictionary containing individual loss components
        """
        N = z_a.shape[0]  # batch size
        D = z_a.shape[1]  # dimension
        
        # Invariance loss (MSE)
        sim_loss = F.mse_loss(z_a, z_b)
        
        # Variance loss
        std_z_a = torch.sqrt(z_a.var(dim=0) + 1e-04)
        std_z_b = torch.sqrt(z_b.var(dim=0) + 1e-04)
        std_loss = (torch.mean(F.relu(1 - std_z_a)) + 
                   torch.mean(F.relu(1 - std_z_b)))
        
        # Covariance loss
        z_a = z_a - z_a.mean(dim=0)
        z_b = z_b - z_b.mean(dim=0)
        
        cov_z_a = (z_a.T @ z_a) / (N - 1)
        cov_z_b = (z_b.T @ z_b) / (N - 1)
        
        cov_loss = (self.off_diagonal(cov_z_a).pow_(2).sum() / D +
                   self.off_diagonal(cov_z_b).pow_(2).sum() / D)
        
        # Combine losses
        total_loss = (self.lambda_param * sim_loss + 
                     self.mu_param * std_loss + 
                     self.nu_param * cov_loss)
        
        # Return individual losses for logging
        losses = {
            'sim_loss': sim_loss.item(),
            'std_loss': std_loss.item(),
            'cov_loss': cov_loss.item(),
            'total_loss': total_loss.item()
        }
        
        return total_loss, losses

class WorldModelVICReg(nn.Module):
    def __init__(self, lambda_param=25.0, mu_param=25.0, nu_param=1.0):
        super().__init__()
        self.encoder = Encoder(input_channels=2)
        self.predictor = TransitionModel(hidden_dim=32)
        self.criterion = VICRegLoss(lambda_param, mu_param, nu_param)
    
    def compute_vicreg_loss(self, pred_state, target_obs):
        """Compute VICReg loss between predicted and encoded target states"""
        # Get target encoding
        target_state = self.encoder(target_obs)
        
        # Flatten spatial dimensions: [B, 32, 8, 8] -> [B, 2048]
        pred_flat = pred_state.flatten(start_dim=1)
        target_flat = target_state.flatten(start_dim=1)
        
        # Compute VICReg losses
        total_loss, component_losses = self.criterion(pred_flat, target_flat)
        return total_loss, component_losses
    
    def training_step(self, batch):
        states = batch.states
        actions = batch.actions
        
        # Get initial state
        init_states = states[:, 0:1]
        
        # Get predictions for all steps
        predictions = self.forward_prediction(init_states, actions)
        
        # Initialize losses
        total_loss = 0.0
        accumulated_losses = {
            'sim_loss': 0.0,
            'std_loss': 0.0,
            'cov_loss': 0.0,
            'total_loss': 0.0
        }
        
        # Compute VICReg loss for each timestep
        for t in range(actions.shape[1]):
            pred_state = predictions[:, t+1]
            target_obs = states[:, t+1]
            
            loss, component_losses = self.compute_vicreg_loss(pred_state, target_obs)
            total_loss += loss
            
            # Accumulate component losses
            for k in accumulated_losses:
                accumulated_losses[k] += component_losses[k]
        
        # Average losses over timesteps
        for k in accumulated_losses:
            accumulated_losses[k] /= actions.shape[1]
        
        return total_loss / actions.shape[1], predictions, accumulated_losses


    def forward_prediction(self, states, actions):
        """
        Forward pass for prediction of future states
        Args:
            states: [B, 1, 2, 65, 65] - Initial state only
            actions: [B, T-1, 2] - Sequence of T-1 actions
        Returns:
            predictions: [B, T, 32, 8, 8] - Predicted representations
        """
        B, _, _, H, W = states.shape
        T = actions.shape[1] + 1
        
        # Get initial state encoding
        curr_state = self.encoder(states.squeeze(1))
        predictions = [curr_state]
        
        # Predict future states
        for t in range(T-1):
            curr_state = self.predictor(curr_state, actions[:, t])
            predictions.append(curr_state)
            
        predictions = torch.stack(predictions, dim=1)
        return predictions


In [None]:
class WorldModelTrainer:
    def __init__(self, model, train_loader, val_loader, learning_rate=1e-3, 
                 device='cuda', log_dir='runs'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        
        # Initialize tensorboard writer
        current_time = datetime.now().strftime('%Y%m%d-%H%M%S')
        self.writer = SummaryWriter(f'{log_dir}/{current_time}')
        
        # Save hyperparameters
        self.writer.add_text('hyperparameters', f'''
        learning_rate: {learning_rate}
        batch_size: {train_loader.batch_size}
        model_channels: {model.encoder.repr_dim}
        ''')

    def validate(self, epoch):
        self.model.eval()
        total_losses = {
            'total_loss': 0.0,
            'sim_loss': 0.0,
            'std_loss': 0.0,
            'cov_loss': 0.0
        }
        
        with torch.no_grad():
            for batch in self.val_loader:
                # Move batch to device without modifying the tuple
                states = batch.states.to(self.device)
                actions = batch.actions.to(self.device)
                
                # Forward pass
                _, _, component_losses = self.model.training_step(
                    WallSample(states=states, actions=actions, locations=batch.locations)
                )

    def train_epoch(self, epoch):
        self.model.train()
        total_losses = {
            'total_loss': 0.0,
            'sim_loss': 0.0,
            'std_loss': 0.0,
            'cov_loss': 0.0
        }
        
        for batch_idx, batch in enumerate(tqdm(self.train_loader, desc=f"Epoch {epoch}")):
            self.optimizer.zero_grad()
            
            # Move batch to device without modifying the tuple
            states = batch.states.to(self.device)
            actions = batch.actions.to(self.device)
            
            # Forward pass
            loss, predictions, component_losses = self.model.training_step(
                WallSample(states=states, actions=actions, locations=batch.locations)
            )

    
    def save_checkpoint(self, epoch, loss):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': loss,
        }
        path = f'checkpoints/checkpoint_epoch_{epoch}.pt'
        os.makedirs('checkpoints', exist_ok=True)
        torch.save(checkpoint, path)
    
    def train(self, num_epochs):
        """Complete training loop"""
        best_val_loss = float('inf')
        
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            
            # Train
            train_losses = self.train_epoch(epoch)
            
            # Validate
            val_loss = self.validate(epoch)
            
            # Print epoch summary
            print("\nEpoch Summary:")
            print("Training Losses:")
            for k, v in train_losses.items():
                print(f"{k}: {v:.4f}")
            print(f"Validation Loss: {val_loss:.4f}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_checkpoint(epoch, val_loss)
                print("New best model saved!")
            
            # Regular checkpoint every 5 epochs
            if (epoch + 1) % 5 == 0:
                self.save_checkpoint(epoch, val_loss)
                print(f"Regular checkpoint saved at epoch {epoch + 1}")

In [None]:
# Initialize model with custom coefficients
model = WorldModelVICReg(
    lambda_param=25.0,  # invariance loss coefficient
    mu_param=25.0,      # variance loss coefficient
    nu_param=1.0        # covariance loss coefficient
)

# Create dataloaders
train_loader = create_wall_dataloader(
    data_path="/scratch/an3854/DL24FA/train",
    probing=False,
    device="cuda",
    batch_size=128,
    train=True,  
    num_samples=10000
)

val_loader = create_wall_dataloader(
    data_path="/scratch/an3854/DL24FA/train",
    probing=False,
    device="cuda",
    batch_size=128,
    train=False, 
    num_samples=2000
)

# Initialize trainer and train
trainer = WorldModelTrainer(model, train_loader, val_loader)
trainer.train(num_epochs=2)