In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from typing import NamedTuple, Optional
import torch
import numpy as np


class WallSample(NamedTuple):
    states: torch.Tensor
    locations: torch.Tensor
    actions: torch.Tensor


class WallDataset:
    def __init__(
        self,
        data_path,
        probing=False,
        device="cuda",
    ):
        self.device = device
        self.states = np.load(f"{data_path}/states.npy", mmap_mode="r")
        self.actions = np.load(f"{data_path}/actions.npy")

        if probing:
            self.locations = np.load(f"{data_path}/locations.npy")
        else:
            self.locations = None

    def __len__(self):
        return len(self.states)

    def __getitem__(self, i):
        states = torch.from_numpy(self.states[i]).float().to(self.device)
        actions = torch.from_numpy(self.actions[i]).float().to(self.device)

        if self.locations is not None:
            locations = torch.from_numpy(self.locations[i]).float().to(self.device)
        else:
            locations = torch.empty(0).to(self.device)

        return WallSample(states=states, locations=locations, actions=actions)


def create_wall_dataloader(
    data_path,
    probing=False,
    device="cuda",
    batch_size=64,
    train=True,
):
    ds = WallDataset(
        data_path=data_path,
        probing=probing,
        device=device,
    )

    loader = torch.utils.data.DataLoader(
        ds,
        batch_size,
        shuffle=train,
        drop_last=True,
        pin_memory=False,
    )

    return loader


In [None]:
from typing import NamedTuple, List, Any, Optional, Dict
from itertools import chain
from dataclasses import dataclass
import itertools
import os
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import numpy as np
from matplotlib import pyplot as plt

from schedulers import Scheduler, LRSchedule
from models import Prober, build_mlp
from configs import ConfigBase

from dataset import WallDataset
from normalizer import Normalizer


@dataclass
class ProbingConfig(ConfigBase):
    probe_targets: str = "locations"
    lr: float = 0.00002
    epochs: int = 20
    schedule: LRSchedule = LRSchedule.Cosine
    sample_timesteps: int = 30
    prober_arch: str = "256"


class ProbeResult(NamedTuple):
    model: torch.nn.Module
    average_eval_loss: float
    eval_losses_per_step: List[float]
    plots: List[Any]


default_config = ProbingConfig()


def location_losses(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    assert pred.shape == target.shape
    mse = (pred - target).pow(2).mean(dim=0)
    return mse


class ProbingEvaluator:
    def __init__(
        self,
        device: "cuda",
        model: torch.nn.Module,
        probe_train_ds,
        probe_val_ds: dict,
        config: ProbingConfig = default_config,
        quick_debug: bool = False,
    ):
        self.device = device
        self.config = config

        self.model = model
        self.model.eval()

        self.quick_debug = quick_debug

        self.ds = probe_train_ds
        self.val_ds = probe_val_ds

        self.normalizer = Normalizer()

    def train_pred_prober(self):
        """
        Probes whether the predicted embeddings capture the future locations
        """
        repr_dim = self.model.repr_dim
        dataset = self.ds
        model = self.model

        config = self.config
        epochs = config.epochs

        if self.quick_debug:
            epochs = 1
        test_batch = next(iter(dataset))

        prober_output_shape = getattr(test_batch, "locations")[0, 0].shape
        prober = Prober(
            repr_dim,
            config.prober_arch,
            output_shape=prober_output_shape,
        ).to(self.device)

        all_parameters = []
        all_parameters += list(prober.parameters())

        optimizer_pred_prober = torch.optim.Adam(all_parameters, config.lr)

        step = 0

        batch_size = dataset.batch_size
        batch_steps = None

        scheduler = Scheduler(
            schedule=self.config.schedule,
            base_lr=config.lr,
            data_loader=dataset,
            epochs=epochs,
            optimizer=optimizer_pred_prober,
            batch_steps=batch_steps,
            batch_size=batch_size,
        )

        for epoch in tqdm(range(epochs), desc=f"Probe prediction epochs"):
            for batch in tqdm(dataset, desc="Probe prediction step"):
                ################################################################################
                # TODO: Forward pass through your model
                init_states = batch.states[:, 0:1]  # BS, 1, C, H, W
                pred_encs, _ = model(init_states, batch.actions)
                pred_encs = pred_encs.transpose(0, 1)  # # BS, T, D --> T, BS, D

                # Make sure pred_encs has shape (T, BS, D) at this point
                ################################################################################

                pred_encs = pred_encs.detach()

                n_steps = pred_encs.shape[0]
                bs = pred_encs.shape[1]

                losses_list = []

                target = getattr(batch, "locations").cuda()
                target = self.normalizer.normalize_location(target)

                if (
                    config.sample_timesteps is not None
                    and config.sample_timesteps < n_steps
                ):
                    sample_shape = (config.sample_timesteps,) + pred_encs.shape[1:]
                    # we only randomly sample n timesteps to train prober.
                    # we most likely do this to avoid OOM
                    sampled_pred_encs = torch.empty(
                        sample_shape,
                        dtype=pred_encs.dtype,
                        device=pred_encs.device,
                    )

                    sampled_target_locs = torch.empty(bs, config.sample_timesteps, 2)

                    for i in range(bs):
                        indices = torch.randperm(n_steps)[: config.sample_timesteps]
                        sampled_pred_encs[:, i, :] = pred_encs[indices, i, :]
                        sampled_target_locs[i, :] = target[i, indices]

                    pred_encs = sampled_pred_encs
                    target = sampled_target_locs.cuda()

                pred_locs = torch.stack([prober(x) for x in pred_encs], dim=1)
                losses = location_losses(pred_locs, target)
                per_probe_loss = losses.mean()

                if step % 100 == 0:
                    print(f"normalized pred locations loss {per_probe_loss.item()}")

                losses_list.append(per_probe_loss)
                optimizer_pred_prober.zero_grad()
                loss = sum(losses_list)
                loss.backward()
                optimizer_pred_prober.step()

                lr = scheduler.adjust_learning_rate(step)

                step += 1

                if self.quick_debug and step > 2:
                    break

        return prober

    @torch.no_grad()
    def evaluate_all(
        self,
        prober,
    ):
        """
        Evaluates on all the different validation datasets
        """
        avg_losses = {}

        for prefix, val_ds in self.val_ds.items():
            avg_losses[prefix] = self.evaluate_pred_prober(
                prober=prober,
                val_ds=val_ds,
                prefix=prefix,
            )

        return avg_losses

    @torch.no_grad()
    def evaluate_pred_prober(
        self,
        prober,
        val_ds,
        prefix="",
    ):
        quick_debug = self.quick_debug
        config = self.config

        model = self.model
        probing_losses = []
        prober.eval()

        for idx, batch in enumerate(tqdm(val_ds, desc="Eval probe pred")):
            ################################################################################
            # TODO: Forward pass through your model
            init_states = batch.states[:, 0:1]  # BS, 1 C, H, W
            pred_encs, _ = model(init_states, batch.actions)
            # # BS, T, D --> T, BS, D
            pred_encs = pred_encs.transpose(0, 1)

            # Make sure pred_encs has shape (T, BS, D) at this point
            ################################################################################

            target = getattr(batch, "locations").cuda()
            target = self.normalizer.normalize_location(target)

            pred_locs = torch.stack([prober(x) for x in pred_encs], dim=1)
            losses = location_losses(pred_locs, target)
            probing_losses.append(losses.cpu())

        losses_t = torch.stack(probing_losses, dim=0).mean(dim=0)
        losses_t = self.normalizer.unnormalize_mse(losses_t)

        losses_t = losses_t.mean(dim=-1)
        average_eval_loss = losses_t.mean().item()

        return average_eval_loss


In [None]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import math
import torch.nn.functional as F
import copy
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
from copy import deepcopy
import torchvision.models as models
from torchvision.models.resnet import BasicBlock


class CNNEncoder(nn.Module):
    def __init__(self, embedding_dim=64):
        super().__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=7, stride=3, padding=1)
        
        # Residual layers
        self.layer1 = self._make_layer(BasicBlock, 64, 1)  # Single block
        
        # # Adaptive pooling and fully connected layer
        # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(6400, embedding_dim)
    
    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels))
    
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        # x = self.layer2(x)
        # x = self.layer3(x)
        # x = self.layer4(x)
        
        # x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


class Predictor(nn.Module):
    def __init__(self, representation_dim=64, action_dim=2):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(representation_dim + action_dim, representation_dim),
            nn.ReLU(),
            nn.Linear(representation_dim, representation_dim)
        )
    
    def forward(self, prev_rep, action):
        prev_rep, action = prev_rep, action
        # Concatenate previous representation and action
        input_combined = torch.cat([prev_rep, action], dim=-1)
        return self.network(input_combined)


class JEPAWorldModel(nn.Module):
    """
    Joint Embedding Predictive Architecture World Model with ViT
    """
    def __init__(self, device, representation_dim=64, action_dim=2, training=False):
        super().__init__()
        self.encoder = CNNEncoder()
        self.predictor = Predictor(representation_dim, action_dim)
        
        # Use same encoder for target encoder (similar to VicReg)
        self.target_encoder = deepcopy(self.encoder)
        
        # Synchronize target encoder with main encoder
        self.update_target_encoder()
        self.repr_dim = representation_dim
        self.device = device
        self.training = training
    
    def update_target_encoder(self, tau=0.995):
        """
        Exponential Moving Average (EMA) update of target encoder
        """
        for param_q, param_k in zip(self.encoder.parameters(), self.target_encoder.parameters()):
            param_k.data = param_k.data * tau + param_q.data * (1. - tau)


    def forward(self, observations, actions):
        # Move observations and actions to device
        # return self._reshape(observations, actions), None
        observations, actions = observations, actions
                
        # Encode all observations at once using the encoder
        # encoded_all_states = self.encoder(observations.view(-1, *observations.shape[2:]))
        batch_size, seq_len, channels, height, width = observations.shape
        flat_observations = observations.view(-1, channels, height, width)
        encoded_all_states = self.encoder(flat_observations)
        encoded_all_states = encoded_all_states.view(*observations.shape[:2], -1)  # Reshape back to (batch, sequence, features)
        
        # Initialize storage for predicted and target states
        predicted_states = []
        target_states = []
    
        # Shift actions to align with the sequence (actions at t predict state at t+1)
        prev_states = encoded_all_states[:, :-1]  # Remove the last state
        next_states = encoded_all_states[:, 1:]   # Remove the first state
        # curr_actions = actions[:, :-1]           # Align actions with prediction        
        
        # Encode target states with target encoder
        
        with torch.no_grad():
            target_states = self.target_encoder(flat_observations)  # Skip the first observation for alignment
            target_states = target_states.view(*observations.shape[:2], -1)
            target_states = target_states[:, 1:]
            target_states = target_states
        # Predict future representations in parallel
        if not self.training:
            predicted_states = self._reshape(observations, actions)
        else:
            predicted_states = self.predictor(prev_states, actions)
        return predicted_states, target_states
    
    def compute_loss(self, predicted_states, target_states):
        """
        Multi-objective loss to prevent representation collapse
        """
        predicted_states, target_states = predicted_states, target_states
        # 1. Prediction Loss: Minimize distance between predicted and target states
        # pred_loss = F.mse_loss(torch.stack(predicted_states), torch.stack(target_states))
        pred_loss = F.mse_loss(predicted_states, target_states)
        
        # 2. Variance Loss: Encourage representations to have non-zero variance
        std_loss = self.variance_loss(predicted_states)
        
        # 3. Covariance Loss: Decorrelate representation dimensions
        cov_loss = self.covariance_loss(predicted_states)
        
        # Weighted combination of losses
        total_loss = pred_loss + 1e-3 * (std_loss + cov_loss)
        return total_loss
    
    def variance_loss(self, representations, min_std=0.1):
        """Encourage each feature to have non-zero variance"""
        # repr_tensor = torch.stack(representations)
        representations = representations
        std_loss = torch.relu(min_std - representations.std(dim=0)).mean()
        return std_loss
    
    def covariance_loss(self, representations):
        """Decorrelate representation dimensions"""
        # repr_tensor = torch.stack(representations)
        representations = representations
        repr_tensor = representations
        repr_tensor = repr_tensor
        
        # Center the representations
        repr_tensor = repr_tensor - repr_tensor.mean(dim=0)
        
        # Flatten tensor (keep batch dimension intact)
        repr_tensor = repr_tensor.view(repr_tensor.shape[0], -1)
        
        # Compute covariance matrix
        cov_matrix = (repr_tensor.T @ repr_tensor) / (repr_tensor.shape[0] - 1)
        
        # Decorrelate dimensions (set diagonal to zero)
        cov_matrix.fill_diagonal_(0)
        
        # Compute loss
        cov_loss = (cov_matrix ** 2).sum()
        return cov_loss

    def _get_next(self, tensor1, tensor2):
        # Step 1: Slice tensor1 to exclude the first element along dimension=1
        tensor1_sliced = tensor1
        if self.training:
            tensor1_sliced = tensor1[:, 1:, :]  # Shape [64, 16, 2]
        
        # Step 2: Add tensor1_sliced and tensor2
        result_sliced = tensor1_sliced + tensor2  # Shape [64, 16, 2]
        
        # Step 3: Insert the first element back along dimension=1
        result = result_sliced
        if self.training:
            first_element = tensor1[:, :1, :]  # Shape [64, 1, 2]
            result = torch.cat([first_element, result_sliced], dim=1)  # Shape [64, 17, 2]
        return result

    def _reshape(self, observations, actions):
        ans = None
        if not self.training:
            for i in range(actions.shape[1]):
                action = actions[:, i:i+1]
                channel_1 = observations[:, :, 0, :, :]
                channel_1 = channel_1.squeeze(1)
                # max_indices = channel_1.view(64, 17, -1).argmax(dim=-1)
                max_indices = channel_1.view(observations.shape[0], -1).argmax(dim=-1)
                max_coords = torch.stack([max_indices % 65, max_indices // 65], dim=-1)
                max_coords = max_coords.unsqueeze(1)
                if ans is None:
                    ans = torch.cat((max_coords, torch.zeros(max_coords.shape[0], max_coords.shape[1],
                                                             self.repr_dim - max_coords.shape[2]).to(self.device)), dim=2)
                next_coords = self._get_next(max_coords, action)
                next_coords = torch.cat((next_coords, 
                                         torch.zeros(next_coords.shape[0], next_coords.shape[1],
                                                     self.repr_dim - next_coords.shape[2]).to(self.device)), dim=2)
                ans = torch.cat([ans, next_coords], dim=1)
            return ans
        else:
            channel_1 = observations[:, :, 0, :, :]
            # max_indices = channel_1.view(64, 17, -1).argmax(dim=-1)
            max_indices = channel_1.view(observations.shape[0], observations.shape[1], -1).argmax(dim=-1)
            max_coords = torch.stack([max_indices % 65, max_indices // 65], dim=-1)
            next_coords = self._get_next(max_coords, actions)
            next_coords = torch.cat((next_coords, torch.zeros(next_coords.shape[0], next_coords.shape[1],
                                                        self.repr_dim - next_coords.shape[2]).to(self.device)), dim=2)
            return next_coords
        
        return next_coords

class DataTransforms:
    """
    Image augmentations and preprocessing for JEPA training
    """
    @staticmethod
    def get_train_transforms():
        return transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])
        ])


In [None]:
# from dataset import create_wall_dataloader
# from evaluator import ProbingEvaluator
# import torch
# from models import MockModel
import glob
# from src.models.new_model import JEPAWorldModel


def get_device():
    """Check for GPU availability."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    return device


def load_data(device):
    data_path = "/scratch/DL24FA"

    probe_train_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_normal_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_other_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall_other/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_ds = {
        "normal": probe_val_normal_ds,
        "wall": probe_val_wall_ds,
        "wall_other": probe_val_wall_other_ds,
    }

    return probe_train_ds, probe_val_ds


def load_expert_data(device):
    data_path = "/scratch/DL24FA"

    probe_train_expert_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_expert/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_expert_ds = {
        "expert": create_wall_dataloader(
            data_path=f"{data_path}/probe_expert/val",
            probing=True,
            device=device,
            train=False,
        )
    }

    return probe_train_expert_ds, probe_val_expert_ds


def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

def load_model():
    """Load or initialize the model."""
    # TODO: Replace MockModel with your trained model
    model = JEPAWorldModel(device=get_device(), representation_dim=64, action_dim=2).to(get_device())
    model.load_state_dict(torch.load("./weights/jepa_world_model_cnn.pth"))
    # model.load_state_dict(torch.load("./latest_repo/DL_Final_Proj/weights/jepa_world_model_cnn.pth"))
    seed_everything(40)
    return model


def evaluate_model(device, model, probe_train_ds, probe_val_ds):
    evaluator = ProbingEvaluator(
        device=device,
        model=model,
        probe_train_ds=probe_train_ds,
        probe_val_ds=probe_val_ds,
        quick_debug=False,
    )

    prober = evaluator.train_pred_prober()

    avg_losses = evaluator.evaluate_all(prober=prober)

    for probe_attr, loss in avg_losses.items():
        print(f"{probe_attr} loss: {loss}")



In [None]:
device = get_device()
model = load_model()

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {total_params:,}")

probe_train_ds, probe_val_ds = load_data(device)
evaluate_model(device, model, probe_train_ds, probe_val_ds)

probe_train_expert_ds, probe_val_expert_ds = load_expert_data(device)
evaluate_model(device, model, probe_train_expert_ds, probe_val_expert_ds)