In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from torch.utils.data import DataLoader
from typing import Tuple
from tqdm import tqdm
from lightly.models.utils import deactivate_requires_grad
from lightly.utils.scheduler import cosine_schedule
from dataset import create_wall_dataloader

# Configurations
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 32
STATE_CHANNELS = 2
STATE_H = 65
STATE_W = 65
ACTION_DIM = 2
SPATIAL_DIM = 8        # Spatial dimension for latent features
OUTPUT_CHANNELS = 64    # Output channels for latent representation
PROJ_DIM = 256
INIT_MOMENTUM = 0.996
LR = 3e-4
NUM_EPOCHS = 100

# Dataloader creation (assuming create_wall_dataloader returns a loader with (current_states, next_states, actions)
train_loader = create_wall_dataloader(
    "/scratch/an3854/DL24FA/train",
    batch_size=BATCH_SIZE,
    train=True, 
)

class SpatialStateEncoder(nn.Module):
    """
    Encodes an input state into a spatial feature map of size [C, spatial_dim, spatial_dim].
    """
    def __init__(self, input_channels=2, output_channels=64, spatial_dim=8):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        # After these layers: 65x65 -> ~4x4
        # Map 4x4x256 -> 4x4x64 and then upsample to 8x8 if needed
        self.spatial_predictor = nn.Sequential(
            nn.Conv2d(256, output_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(spatial_dim, spatial_dim), mode='bilinear', align_corners=False) 
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)  
        x = self.spatial_predictor(x)  
        return x

class ConvDetTransition(nn.Module):
    """
    Convolutional deterministic transition model:
    Given current spatial latent and action, predicts next spatial latent.
    """
    def __init__(self, state_channels=64, action_dim=2, hidden_channels=128):
        super().__init__()
        self.action_proj = nn.Sequential(
            nn.Linear(action_dim, hidden_channels),
            nn.ReLU()
        )
        self.transition = nn.Sequential(
            nn.Conv2d(state_channels + hidden_channels, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, state_channels, kernel_size=3, padding=1)
        )
        
    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        B, C, H, W = state.shape
        action_feat = self.action_proj(action)  # [B, hidden_channels]
        action_feat = action_feat.unsqueeze(-1).unsqueeze(-1)  # [B, hidden_channels, 1, 1]
        action_feat = action_feat.expand(B, action_feat.shape[1], H, W)  # broadcast over spatial dims
        x = torch.cat([state, action_feat], dim=1)
        next_state = self.transition(x)
        return next_state

class ProjectionHead(nn.Module):
    """
    BYOL projection head:
    Flattens spatial representation and projects to a lower-dimensional space.
    """
    def __init__(self, input_channels=64, spatial_dim=8, proj_dim=256):
        super().__init__()
        input_dim = input_channels * spatial_dim * spatial_dim
        self.net = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, proj_dim)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.flatten(start_dim=1)  # [B, C*H*W]
        return self.net(x)

class PredictionHead(nn.Module):
    """
    BYOL prediction head:
    Takes the projection and predicts a representation closer to the target projection.
    """
    def __init__(self, proj_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(proj_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, proj_dim)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class BYOLConvDet(nn.Module):
    """
    BYOL model using a convolutional deterministic state transition.
    """
    def __init__(self, 
                 state_channels=STATE_CHANNELS, 
                 spatial_dim=SPATIAL_DIM, 
                 output_channels=OUTPUT_CHANNELS, 
                 action_dim=ACTION_DIM, 
                 proj_dim=PROJ_DIM,
                 momentum=INIT_MOMENTUM):
        super().__init__()
        
        # Online networks
        self.online_encoder = SpatialStateEncoder(input_channels=state_channels, output_channels=output_channels, spatial_dim=spatial_dim)
        self.online_transition = ConvDetTransition(state_channels=output_channels, action_dim=action_dim)
        self.online_projection = ProjectionHead(input_channels=output_channels, spatial_dim=spatial_dim, proj_dim=proj_dim)
        self.online_prediction = PredictionHead(proj_dim=proj_dim)
        
        # Target networks
        self.target_encoder = copy.deepcopy(self.online_encoder)
        self.target_projection = copy.deepcopy(self.online_projection)
        deactivate_requires_grad(self.target_encoder)
        deactivate_requires_grad(self.target_projection)
        
        self.momentum = momentum

    @torch.no_grad()
    def update_target(self, m: float):
        """Momentum update for target networks."""
        for o, t in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            t.data = t.data * m + o.data * (1. - m)
        for o, t in zip(self.online_projection.parameters(), self.target_projection.parameters()):
            t.data = t.data * m + o.data * (1. - m)

    def forward(self, current_state: torch.Tensor, next_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Encode current state
        z1 = self.online_encoder(current_state)  # [B, C, H, W]

        # Predict next latent
        z2_pred_spatial = self.online_transition(z1, action)  # [B, C, H, W]

        # Online pipeline
        z2_pred_proj = self.online_projection(z2_pred_spatial)  # [B, proj_dim]
        p1 = self.online_prediction(z2_pred_proj)               # [B, proj_dim]

        # Target pipeline (no grad)
        with torch.no_grad():
            z3_spatial = self.target_encoder(next_state)
            z3_proj = self.target_projection(z3_spatial)  # [B, proj_dim]
        
        return p1, z3_proj

def byol_loss(p: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
    p = F.normalize(p, dim=-1)
    z = F.normalize(z, dim=-1)
    return 2 - 2 * (p * z).sum(dim=-1).mean()

model = BYOLConvDet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

def train_epoch(model: BYOLConvDet, 
                dataloader: DataLoader, 
                optimizer: torch.optim.Optimizer,
                epoch: int,
                total_epochs: int) -> float:
    model.train()
    total_loss = 0.0
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        # batch = (current_states, next_states, actions)
        # current_states: [B, T, C, H, W]
        # next_states:    [B, T, C, H, W]
        # actions:        [B, T, action_dim]
        current_states, next_states, actions = batch
        current_states = current_states.to(DEVICE)
        next_states = next_states.to(DEVICE)
        actions = actions.to(DEVICE)
        
        # Compute current momentum
        m = cosine_schedule(epoch, total_epochs, INIT_MOMENTUM, 1.0)
        
        # We now loop over all time steps T and sum the losses.
        B, T, C, H, W = current_states.shape
        step_loss = 0.0
        for t in range(T):
            cs = current_states[:, t]  # [B, C, H, W]
            ns = next_states[:, t]     # [B, C, H, W]
            a = actions[:, t]          # [B, action_dim]
            
            p1, z3 = model(cs, ns, a)
            l = byol_loss(p1, z3)
            step_loss += l
        
        # Average loss over T steps
        step_loss = step_loss / T
        
        optimizer.zero_grad()
        step_loss.backward()
        optimizer.step()
        
        # Update target networks
        model.update_target(m)
        
        total_loss += step_loss.item()
        
    return total_loss / len(dataloader)

# Training loop
num_epochs = NUM_EPOCHS
for epoch in range(num_epochs):
    avg_loss = train_epoch(model, train_loader, optimizer, epoch, num_epochs)
    print(f"Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

# Optionally save the model after training
torch.save(model.state_dict(), "model_weights.pth")
print("Training complete. model_weights.pth saved.")


  0%|          | 0/4594 [00:01<?, ?it/s]


IndexError: index 0 is out of bounds for dimension 1 with size 0

Batch type: <class 'dataset.WallSample'>

Batch is a named tuple with fields: ('states', 'locations', 'actions')

States tensor:
- Shape: torch.Size([32, 17, 2, 65, 65])
- Type: torch.float32
- Device: cuda:0

Actions tensor:
- Shape: torch.Size([32, 16, 2])
- Type: torch.float32
- Device: cuda:0


tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.2752e-43, 2.8403e-39,
        3.4870e-35, 2.3690e-31, 8.9062e-28, 1.8529e-24, 2.1331e-21, 1.3590e-18,
        4.7909e-16, 9.3466e-14, 1.0090e-11, 6.0282e-10, 1.9929e-08, 3.6459e-07,
        3.6910e-06, 2.0678e-05, 6.4104e-05, 1.0997e-04, 1.0440e-04, 5.4847e-05,
        1.5945e-05, 2.5651e-06, 2.2836e-07, 1.1250e-08, 3.0669e-10, 4.6267e-12,
        3.8624e-14, 1.7843e-16, 4.5615e-19, 6.4530e-22, 5.0517e-25, 2.1884e-28,
        5.2463e-32, 6.9598e-36, 5.1091e-40, 2.1019e-44, 0.0000e+00],
       device='cuda:0')

In [25]:
# PROBE
from dataset import create_wall_dataloader
from evaluator import ProbingEvaluator
import torch
import glob

def get_device():
    """Check for GPU availability."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    return device


def load_data(device):
    data_path = "/scratch/an3854/DL24FA"

    probe_train_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_normal_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_ds = {"normal": probe_val_normal_ds, "wall": probe_val_wall_ds}

    return probe_train_ds, probe_val_ds

In [26]:
def evaluate_model(device, model, probe_train_ds, probe_val_ds):
    evaluator = ProbingEvaluator(
        device=device,
        model=model,
        probe_train_ds=probe_train_ds,
        probe_val_ds=probe_val_ds,
        quick_debug=False,
    )

    prober = evaluator.train_pred_prober()

    avg_losses = evaluator.evaluate_all(prober=prober)

    for probe_attr, loss in avg_losses.items():
        print(f"{probe_attr} loss: {loss}")

In [31]:
device = "cuda" if torch.cuda.is_available() else "cpu"
probe_train_ds, probe_val_ds = load_data(device)

In [32]:
class MockModel(torch.nn.Module):
    """
    Does nothing. Just for testing.
    """

    def __init__(self, device="cuda", bs=64, n_steps=17, output_dim=256):
        super().__init__()
        self.device = device
        self.bs = bs
        self.n_steps = n_steps
        self.repr_dim = 256

    def forward(self, states, actions):
        """
        Args:
            states: [B, 1, Ch, H, W]
            actions: [B, T-1, 2]

        Output:
            predictions: [B, T, D]
        """
        return torch.randn((self.bs, self.n_steps, self.repr_dim)).to(self.device)

In [35]:
model = MockModel().to(device)
evaluate_model(device, model, probe_train_ds, probe_val_ds)

Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.1928168535232544
normalized pred locations loss 1.043358564376831


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.8633273839950562


KeyboardInterrupt: 