In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm
from typing import Tuple, Dict
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.utils.scheduler import cosine_schedule
from dataset import create_wall_dataloader

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_SHAPE = (2, 65, 65)  # 2 channel 65x65 images
ACTION_DIM = 2
STATE_DIM = 256  # Encoded state dimension
HIDDEN_DIM = 600  # GRU hidden dimension
BATCH_SIZE = 32
MOMENTUM = 0.996

In [5]:
train_loader = create_wall_dataloader(
    "/scratch/an3854/DL24FA/train",
    batch_size=BATCH_SIZE,
    train=True, 
)   

In [6]:
# Get one batch from the dataloader
train_iter = iter(train_loader)
batch = next(train_iter)

# Print batch type and contents
print("Batch type:", type(batch))
print("\nBatch is a named tuple with fields:", batch._fields)

print("\nStates tensor:")
print("- Shape:", batch.states.shape)
print("- Type:", batch.states.dtype)
print("- Device:", batch.states.device)

print("\nActions tensor:")
print("- Shape:", batch.actions.shape)
print("- Type:", batch.actions.dtype)
print("- Device:", batch.actions.device)

Batch type: <class 'dataset.WallSample'>

Batch is a named tuple with fields: ('states', 'locations', 'actions')

States tensor:
- Shape: torch.Size([32, 17, 2, 65, 65])
- Type: torch.float32
- Device: cuda:0

Actions tensor:
- Shape: torch.Size([32, 16, 2])
- Type: torch.float32
- Device: cuda:0


In [36]:
# PROBE
from dataset import create_wall_dataloader
from evaluator import ProbingEvaluator
import torch
import glob

def get_device():
    """Check for GPU availability."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    return device


def load_data(device):
    data_path = "/scratch/an3854/DL24FA"

    probe_train_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_normal_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_ds = {"normal": probe_val_normal_ds, "wall": probe_val_wall_ds}

    return probe_train_ds, probe_val_ds

In [37]:
def evaluate_model(device, model, probe_train_ds, probe_val_ds):
    evaluator = ProbingEvaluator(
        device=device,
        model=model,
        probe_train_ds=probe_train_ds,
        probe_val_ds=probe_val_ds,
        quick_debug=False,
    )

    prober = evaluator.train_pred_prober()

    avg_losses = evaluator.evaluate_all(prober=prober)

    for probe_attr, loss in avg_losses.items():
        print(f"{probe_attr} loss: {loss}")

In [38]:
device = "cuda" if torch.cuda.is_available() else "cpu"
probe_train_ds, probe_val_ds = load_data(device)

In [39]:
class MockModel(torch.nn.Module):
    """
    Does nothing. Just for testing.
    """

    def __init__(self, device="cuda", bs=64, n_steps=17, output_dim=256):
        super().__init__()
        self.device = device
        self.bs = bs
        self.n_steps = n_steps
        self.repr_dim = 256

    def forward(self, states, actions):
        """
        Args:
            states: [B, 1, Ch, H, W]
            actions: [B, T-1, 2]

        Output:
            predictions: [B, T, D]
        """
        return torch.randn((self.bs, self.n_steps, self.repr_dim)).to(self.device)

In [40]:
model = MockModel().to(device)
evaluate_model(device, model, probe_train_ds, probe_val_ds)

Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.1027911901474
normalized pred locations loss 0.986847460269928


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.042397141456604


KeyboardInterrupt: 