In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, input_channels=2):
        super().__init__()
        # First conv: 65x65 -> 22x22
        self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=7, stride=3, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        
        # Second conv: 22x22 -> 8x8
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3, padding=2, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        \
        # Final conv: maintain 8x8 but process features
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(32)
        
        self.repr_dim = 32
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # -> 22x22
        x = F.relu(self.bn2(self.conv2(x)))  # -> 8x8
        x = F.relu(self.bn3(self.conv3(x)))  # -> 8x8
        return x  # Output shape: [B, 32, 8, 8]

class TransitionModel(nn.Module):
    def __init__(self, hidden_dim=32):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # Action embedding
        self.action_embed = nn.Sequential(
            nn.Conv2d(2, 16, 1),  # First go to intermediate dimension
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, hidden_dim, 1),  # Then to full hidden_dim
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU()
        )
        
        # Transition model
        self.transition = nn.Sequential(
            nn.Conv2d(hidden_dim * 2, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim)
        )
        
    def forward(self, state, action):
        """
        Args:
            state: [B, hidden_dim, 8, 8] - Current state representation
            action: [B, 2] - (dx, dy) action
        """
        B, _, H, W = state.shape
        
        # Expand action to spatial dimensions and embed
        action = action.view(B, 2, 1, 1).expand(-1, -1, H, W)
        action_embedding = self.action_embed(action)
        
        # Combine state and action
        combined = torch.cat([state, action_embedding], dim=1)
        
        # Predict state change
        delta = self.transition(combined)
        
        # Residual connection
        next_state = state + delta
        
        return next_state

class WorldModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder(input_channels=2)
        self.predictor = TransitionModel(hidden_dim=32)
        
    def forward(self, states, actions):
        """
        Args:
            states: [B, 1, 2, 65, 65] - Initial state only
            actions: [B, T-1, 2] - Sequence of T-1 actions
        Returns:
            predictions: [B, T, 32, 8, 8] - Predicted representations
        """
        B, _, _, H, W = states.shape
        T = actions.shape[1] + 1
        
        # Get initial state encoding
        curr_state = self.encoder(states.squeeze(1))
        predictions = [curr_state]
        
        # Predict future states
        for t in range(T-1):
            curr_state = self.predictor(curr_state, actions[:, t])
            predictions.append(curr_state)
            
        predictions = torch.stack(predictions, dim=1)
        return predictions

In [7]:
import torch
from dataset import create_wall_dataloader

def test_model_with_real_data():
    # Initialize models with new dimensions
    encoder = Encoder(input_channels=2).to('cuda')
    predictor = TransitionModel(hidden_dim=32).to('cuda')
    world_model = WorldModel().to('cuda')
    
    # Create data loader with actual data
    data_path = "/drive_reader/as16386/DL24FA/train"
    dataloader = create_wall_dataloader(
        data_path=data_path,
        probing=False,
        device="cuda" if torch.cuda.is_available() else "cpu",
        batch_size=256,
        train=True
    )
    
    # Get a batch of data
    batch = next(iter(dataloader))
    states = batch.states  # [B, T, C, H, W]
    actions = batch.actions  # [B, T-1, 2]
    
    print("\nInput Data Shapes:")
    print(f"States shape: {states.shape}")  # Should be [4, T, 2, 65, 65]
    print(f"Actions shape: {actions.shape}")  # Should be [4, T-1, 2]
    print(f"States min/max: {states.min():.2f}, {states.max():.2f}")
    print(f"Actions min/max: {actions.min():.2f}, {actions.max():.2f}")
    
    # Test encoder with initial state
    init_state = states[:, 0]  # Take first timestep
    encoded = encoder(init_state)
    print(f"\nEncoder output shape: {encoded.shape}")  # Should be [4, 32, 8, 8]
    
    # Test predictor with a single action
    single_action = actions[:, 0]  # Take first action
    print(f"\nSingle action sample:")
    print(f"Action values: {single_action[0]}")  # Print first batch's action
    
    # Test action processing
    B, _, h, w = encoded.shape
    action_spatial = single_action.view(B, 2, 1, 1).expand(-1, -1, h, w)
    print(f"\nAction processing:")
    print(f"Action after spatial expansion: {action_spatial.shape}")  # Should be [4, 2, 8, 8]
    
    # Test action embedding
    action_embedded = predictor.action_embed(action_spatial)
    print(f"Action after embedding: {action_embedded.shape}")  # Should be [4, 32, 8, 8]
    
    # Test single prediction step
    next_state = predictor(encoded, single_action)
    print(f"\nPrediction shapes:")
    print(f"Single step prediction: {next_state.shape}")  # Should be [4, 32, 8, 8]
    
    # Test full sequence prediction
    init_states = states[:, 0:1]  # [B, 1, C, H, W]
    sequence_actions = actions  # [B, T-1, 2]
    predictions = world_model(init_states, sequence_actions)
    print(f"Full sequence prediction: {predictions.shape}")
    # Should be [B, T, 32, 8, 8]
    
    # Print compression stats
    input_size = states.shape[-2] * states.shape[-1] * states.shape[-3]
    encoded_size = encoded.shape[-2] * encoded.shape[-1] * encoded.shape[-3]
    compression_ratio = input_size / encoded_size
    print(f"\nCompression Statistics:")
    print(f"Input size per frame: {input_size} values")
    print(f"Encoded size per frame: {encoded_size} values")
    print(f"Compression ratio: {compression_ratio:.2f}x")
    
    # Verify shapes
    assert encoded.shape[1] == 32, "Channel dimension mismatch"
    assert encoded.shape[2] == encoded.shape[3] == 8, "Spatial dimensions mismatch"
    assert predictions.shape[1] == actions.shape[1] + 1, "Sequence length mismatch"
    assert predictions.shape[2] == 32, "Output channel mismatch"
    assert predictions.shape[3] == predictions.shape[4] == 8, "Output spatial dimension mismatch"
    
    print("\nAll dimension tests passed!")
    
    # Test multiple forward passes
    print("\nTesting multiple forward passes...")
    for i, batch in enumerate(dataloader):
        if i >= 3:  # Test with 3 batches
            break
        states = batch.states
        actions = batch.actions
        init_states = states[:, 0:1]
        with torch.no_grad():
            predictions = world_model(init_states, actions)
        print(f"Batch {i+1} processed successfully.")
        
        # Print some statistics about the predictions
        if i == 0:
            print(f"Prediction stats for batch {i+1}:")
            print(f"Mean: {predictions.mean():.3f}")
            print(f"Std: {predictions.std():.3f}")
            print(f"Min: {predictions.min():.3f}")
            print(f"Max: {predictions.max():.3f}")
    
    return world_model, dataloader

if __name__ == "__main__":
    print("Running model tests with real data...")
    model, dataloader = test_model_with_real_data()

Running model tests with real data...

Input Data Shapes:
States shape: torch.Size([256, 17, 2, 65, 65])
Actions shape: torch.Size([256, 16, 2])
States min/max: 0.00, 0.09
Actions min/max: -1.79, 1.77

Encoder output shape: torch.Size([256, 32, 8, 8])

Single action sample:
Action values: tensor([-0.6135, -0.4304], device='cuda:0')

Action processing:
Action after spatial expansion: torch.Size([256, 2, 8, 8])
Action after embedding: torch.Size([256, 32, 8, 8])

Prediction shapes:
Single step prediction: torch.Size([256, 32, 8, 8])
Full sequence prediction: torch.Size([256, 17, 32, 8, 8])

Compression Statistics:
Input size per frame: 8450 values
Encoded size per frame: 2048 values
Compression ratio: 4.13x

All dimension tests passed!

Testing multiple forward passes...
Batch 1 processed successfully.
Prediction stats for batch 1:
Mean: 0.391
Std: 8.042
Min: -131.621
Max: 139.406
Batch 2 processed successfully.
Batch 3 processed successfully.


In [8]:
# Necessary functions to test prober too
from dataset import create_wall_dataloader
from evaluator import ProbingEvaluator
import torch
import glob

def get_device():
    """Check for GPU availability."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    return device


def load_data(device):
    data_path = "/drive_reader/as16386/DL24FA"

    probe_train_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_normal_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_ds = {"normal": probe_val_normal_ds, "wall": probe_val_wall_ds}

    return probe_train_ds, probe_val_ds

def create_train_val_loaders(data_path, train_samples=10000, val_samples=2000):
    """Helper function to create train and validation loaders"""
    train_loader = create_wall_dataloader(
        data_path=data_path,
        probing=False,
        device="cuda",
        batch_size=128,
        train=True,
        num_samples=train_samples
    )
    
    val_loader = create_wall_dataloader(
        data_path=data_path,
        probing=False,
        device="cuda",
        batch_size=128,
        train=False,
        num_samples=val_samples
    )
    
    return train_loader, val_loader

def evaluate_model(device, model, probe_train_ds, probe_val_ds):
    evaluator = ProbingEvaluator(
        device=device,
        model=model,
        probe_train_ds=probe_train_ds,
        probe_val_ds=probe_val_ds,
        quick_debug=False,
    )

    prober = evaluator.train_pred_prober()

    avg_losses = evaluator.evaluate_all(prober=prober)

    for probe_attr, loss in avg_losses.items():
        print(f"{probe_attr} loss: {loss}")

device = "cuda" if torch.cuda.is_available() else "cpu"
probe_train_ds, probe_val_ds = load_data(device)


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from datetime import datetime
from tqdm import tqdm
from dataset import WallSample

class VICRegLoss(nn.Module):
    def __init__(self, lambda_param=25.0, mu_param=25.0, nu_param=1.0):
        super().__init__()
        self.lambda_param = lambda_param  # invariance loss coefficient
        self.mu_param = mu_param         # variance loss coefficient
        self.nu_param = nu_param         # covariance loss coefficient
    
    def off_diagonal(self, x):
        """Return off-diagonal elements of a square matrix"""
        n = x.shape[0]
        return x.flatten()[:-1].view(n-1, n+1)[:, 1:].flatten()
    
    def forward(self, z_a, z_b):
        """
        Args:
            z_a, z_b: Batch of representations [N, D]
        Returns:
            total_loss: Combined VICReg loss
            losses: Dictionary containing individual loss components
        """
        N = z_a.shape[0]  # batch size
        D = z_a.shape[1]  # dimension
        
        # Invariance loss (MSE)
        sim_loss = F.mse_loss(z_a, z_b)
        
        # Variance loss
        std_z_a = torch.sqrt(z_a.var(dim=0) + 1e-04)
        std_z_b = torch.sqrt(z_b.var(dim=0) + 1e-04)
        std_loss = (torch.mean(F.relu(1 - std_z_a)) + 
                   torch.mean(F.relu(1 - std_z_b)))
        
        # Covariance loss
        z_a = z_a - z_a.mean(dim=0)
        z_b = z_b - z_b.mean(dim=0)
        
        cov_z_a = (z_a.T @ z_a) / (N - 1)
        cov_z_b = (z_b.T @ z_b) / (N - 1)
        
        cov_loss = (self.off_diagonal(cov_z_a).pow_(2).sum() / D +
                   self.off_diagonal(cov_z_b).pow_(2).sum() / D)
        
        # Combine losses
        total_loss = (self.lambda_param * sim_loss + 
                     self.mu_param * std_loss + 
                     self.nu_param * cov_loss)
        
        # Return individual losses for logging
        losses = {
            'sim_loss': sim_loss.item(),
            'std_loss': std_loss.item(),
            'cov_loss': cov_loss.item(),
            'total_loss': total_loss.item()
        }
        
        return total_loss, losses

class WorldModelVICReg(nn.Module):
    def __init__(self, lambda_param=25.0, mu_param=25.0, nu_param=1.0):
        super().__init__()
        self.encoder = Encoder(input_channels=2)
        self.predictor = TransitionModel(hidden_dim=32)
        self.criterion = VICRegLoss(lambda_param, mu_param, nu_param)
        self.repr_dim = 32 * 8 * 8
    
    def compute_vicreg_loss(self, pred_state, target_obs):
        """Compute VICReg loss between predicted and encoded target states"""
        # Get target encoding
        target_state = self.encoder(target_obs)
        
        # Flatten spatial dimensions: [B, 32, 8, 8] -> [B, 2048]
        pred_flat = pred_state.flatten(start_dim=1)
        target_flat = target_state.flatten(start_dim=1)
        
        # Compute VICReg losses
        total_loss, component_losses = self.criterion(pred_flat, target_flat)
        return total_loss, component_losses
    
    def training_step(self, batch):
        states = batch.states
        actions = batch.actions
        
        # Get initial state
        init_states = states[:, 0:1]
        
        # Get predictions for all steps
        predictions = self.forward_prediction(init_states, actions)
        
        # Initialize losses
        total_loss = 0.0
        accumulated_losses = {
            'sim_loss': 0.0,
            'std_loss': 0.0,
            'cov_loss': 0.0,
            'total_loss': 0.0
        }
        
        # Compute VICReg loss for each timestep
        for t in range(actions.shape[1]):
            pred_state = predictions[:, t+1]
            target_obs = states[:, t+1]
            
            loss, component_losses = self.compute_vicreg_loss(pred_state, target_obs)
            total_loss += loss
            
            # Accumulate component losses
            for k in accumulated_losses:
                accumulated_losses[k] += component_losses[k]
        
        # Average losses over timesteps
        for k in accumulated_losses:
            accumulated_losses[k] /= actions.shape[1]
        
        return total_loss / actions.shape[1], predictions, accumulated_losses


    def forward_prediction(self, states, actions):
        """
        Forward pass for prediction of future states
        Args:
            states: [B, 1, 2, 65, 65] - Initial state only
            actions: [B, T-1, 2] - Sequence of T-1 actions
        Returns:
            predictions: [B, T, 32, 8, 8] - Predicted representations
        """
        B, _, _, H, W = states.shape
        T = actions.shape[1] + 1
        
        # Get initial state encoding
        curr_state = self.encoder(states.squeeze(1))
        predictions = [curr_state]
        
        # Predict future states
        for t in range(T-1):
            curr_state = self.predictor(curr_state, actions[:, t])
            predictions.append(curr_state)
            
        predictions = torch.stack(predictions, dim=1)
        return predictions


In [10]:
import os 


class WorldModelTrainer:
    def __init__(self, model, train_loader, val_loader, learning_rate=3e-5, 
                 device='cuda', log_dir='runs', probe_train_data = probe_train_ds, probe_val_data = probe_val_ds):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
        self.probe_train_data = probe_train_data
        self.probe_val_data = probe_val_data
        
        # Initialize tensorboard writer
        current_time = datetime.now().strftime('%Y%m%d-%H%M%S')
        self.writer = SummaryWriter(f'{log_dir}/{current_time}')
        
        # Save hyperparameters
        self.writer.add_text('hyperparameters', f'''
        learning_rate: {learning_rate}
        batch_size: {train_loader.batch_size}
        model_channels: {model.encoder.repr_dim}
        ''')

    def validate(self, epoch):
        self.model.eval()
        total_val_loss = 0.0
        val_losses = {
            'total_loss': 0.0,
            'sim_loss': 0.0,
            'std_loss': 0.0,
            'cov_loss': 0.0
        }
        
        with torch.no_grad():
            for batch in self.val_loader:
                states = batch.states.to(self.device)
                actions = batch.actions.to(self.device)
                
                loss, _, component_losses = self.model.training_step(
                    WallSample(states=states, actions=actions, locations=batch.locations)
                )
                
                total_val_loss += loss.item()
                for k in val_losses:
                    val_losses[k] += component_losses[k]
        
        # Average the losses
        num_batches = len(self.val_loader)
        total_val_loss /= num_batches
        for k in val_losses:
            val_losses[k] /= num_batches
            
        return total_val_loss, val_losses


    def train_epoch(self, epoch):
        self.model.train()
        total_train_loss = 0.0  # Changed variable name for clarity
        train_losses = {
            'total_loss': 0.0,
            'sim_loss': 0.0,
            'std_loss': 0.0,
            'cov_loss': 0.0
        }
        
        for batch_idx, batch in enumerate(tqdm(self.train_loader, desc=f"Epoch {epoch}")):
            self.optimizer.zero_grad()
            
            states = batch.states.to(self.device)
            actions = batch.actions.to(self.device)
            
            loss, _, component_losses = self.model.training_step(
                WallSample(states=states, actions=actions, locations=batch.locations)
            )
            
            # Add backpropagation steps
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()

            # Update running losses
            total_train_loss += loss.item()
            for k in train_losses:
                train_losses[k] += component_losses[k]
        
        # Average the losses
        num_batches = len(self.train_loader)
        total_train_loss /= num_batches
        for k in train_losses:
            train_losses[k] /= num_batches
        
        evaluate_model(self.device, self.model, self.probe_train_data, self.probe_val_data)
            
        return total_train_loss, train_losses  # Return both total loss and component losses

    
    def save_checkpoint(self, epoch, loss):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': loss,
        }
        path = f'cnn_based_checkpoints/checkpoint_epoch_{epoch}.pt'
        os.makedirs('checkpoints', exist_ok=True)
        torch.save(checkpoint, path)
    
    def train(self, num_epochs):
        best_val_loss = float('inf')
        
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            
            # Train
            train_loss, train_losses = self.train_epoch(epoch)
            
            # Validate
            val_loss, val_losses = self.validate(epoch)
            
            # Print epoch summary
            print("\nEpoch Summary:")
            print("Training Losses:")
            for k, v in train_losses.items():
                print(f"{k}: {v:.4f}")
            print("\nValidation Losses:")
            for k, v in val_losses.items():
                print(f"{k}: {v:.4f}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_checkpoint(epoch, val_loss)
                print("New best model saved!")

In [11]:
def create_train_val_loaders(data_path, train_samples=10000, val_samples=2000):
    """Helper function to create train and validation loaders"""
    train_loader = create_wall_dataloader(
        data_path=data_path,
        probing=False,
        device="cuda",
        batch_size=128,
        train=True,
        num_samples=train_samples
    )
    
    val_loader = create_wall_dataloader(
        data_path=data_path,
        probing=False,
        device="cuda",
        batch_size=128,
        train=False,
        num_samples=val_samples
    )
    
    return train_loader, val_loader

In [12]:
# Initialize model and dataloaders
model = WorldModelVICReg(
    lambda_param=25.0,
    mu_param=25.0,
    nu_param=1.0
)

train_loader, val_loader = create_train_val_loaders(
    data_path="/drive_reader/as16386/DL24FA/train",
    train_samples=50000,
    val_samples=None
)

# Initialize trainer and train
trainer = WorldModelTrainer(model, train_loader, val_loader)
trainer.train(num_epochs=100)


Epoch 1/100


Epoch 0: 100%|██████████| 390/390 [03:52<00:00,  1.68it/s]
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.097769856452942


Probe prediction step:  55%|█████▌    | 86/156 [00:55<00:44,  1.56it/s]
Probe prediction epochs:   0%|          | 0/20 [00:55<?, ?it/s]


KeyboardInterrupt: 

: 

In [4]:
# BEST MODEL SO FAR!
model = WorldModelVICReg().to('cuda')
model.load_state_dict(torch.load('checkpoints/checkpoint_epoch_0.pt')['model_state_dict'])

  model.load_state_dict(torch.load('checkpoints/checkpoint_epoch_0.pt')['model_state_dict'])


<All keys matched successfully>

In [48]:
from dataset import create_wall_dataloader
from evaluator import ProbingEvaluator
import torch
import glob

def get_device():
    """Check for GPU availability."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    return device


def load_data(device):
    data_path = "/drive_reader/as16386/DL24FA"

    probe_train_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_normal_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_ds = {"normal": probe_val_normal_ds, "wall": probe_val_wall_ds}

    return probe_train_ds, probe_val_ds

In [49]:
def evaluate_model(device, model, probe_train_ds, probe_val_ds):
    evaluator = ProbingEvaluator(
        device=device,
        model=model,
        probe_train_ds=probe_train_ds,
        probe_val_ds=probe_val_ds,
        quick_debug=False,
    )

    prober = evaluator.train_pred_prober()

    avg_losses = evaluator.evaluate_all(prober=prober)

    for probe_attr, loss in avg_losses.items():
        print(f"{probe_attr} loss: {loss}")

In [50]:
device = "cuda" if torch.cuda.is_available() else "cpu"
probe_train_ds, probe_val_ds = load_data(device)

In [51]:
model = WorldModelVICReg().to('cuda')
model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_0.pt')['model_state_dict'])
evaluate_model(device, model, probe_train_ds, probe_val_ds)

  model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_0.pt')['model_state_dict'])
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s][A

normalized pred locations loss 1.061932921409607




normalized pred locations loss 1.007510781288147


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 60.93it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:48,  2.56s/it]

normalized pred locations loss 0.9798067808151245


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.56it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:43,  2.43s/it]

normalized pred locations loss 1.1555085182189941




normalized pred locations loss 0.9349393248558044


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.62it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:40,  2.39s/it]

normalized pred locations loss 0.8490764498710632




normalized pred locations loss 0.8910062313079834


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.84it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.37s/it]

normalized pred locations loss 0.8979362845420837


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.88it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.36s/it]

normalized pred locations loss 0.7030884623527527




normalized pred locations loss 0.7919564843177795


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.80it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:32,  2.35s/it]

normalized pred locations loss 0.812245786190033


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.77it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.34s/it]

normalized pred locations loss 0.8951931595802307





normalized pred locations loss 0.8323168158531189


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.27it/s][A
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:28,  2.34s/it]

normalized pred locations loss 0.812424898147583


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.28it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:25,  2.36s/it]

normalized pred locations loss 0.8847749829292297




normalized pred locations loss 0.8874053955078125


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.15it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.37s/it]

normalized pred locations loss 0.6922993659973145




normalized pred locations loss 0.7016913890838623


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.97it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:26<00:21,  2.37s/it]

normalized pred locations loss 0.8183857798576355


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.67it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:18,  2.37s/it]

normalized pred locations loss 0.6932435631752014




normalized pred locations loss 0.7029843926429749


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.46it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.36s/it]

normalized pred locations loss 0.7675111889839172


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.60it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:33<00:14,  2.37s/it]

normalized pred locations loss 0.6667055487632751




normalized pred locations loss 0.6519219279289246


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.63it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.37s/it]

normalized pred locations loss 0.6280651092529297


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.41it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.37s/it]

normalized pred locations loss 0.593088686466217




normalized pred locations loss 0.6688989996910095


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.83it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:40<00:07,  2.37s/it]

normalized pred locations loss 0.7038297057151794


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.30it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.37s/it]

normalized pred locations loss 0.7995834350585938




normalized pred locations loss 0.5831136107444763


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.46it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:45<00:02,  2.38s/it]

normalized pred locations loss 0.7613988518714905




normalized pred locations loss 0.6997306942939758


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.34it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:47<00:00,  2.37s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 74.48it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 74.38it/s]

normal loss: 191.7706298828125
wall loss: 173.2938690185547





In [52]:
model = WorldModelVICReg().to('cuda')
model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_1.pt')['model_state_dict'])
evaluate_model(device, model, probe_train_ds, probe_val_ds)

  model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_1.pt')['model_state_dict'])
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.0985841751098633




normalized pred locations loss 0.8447808623313904


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.56it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:45,  2.38s/it]

normalized pred locations loss 0.2625228762626648


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.37it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:42,  2.39s/it]

normalized pred locations loss 0.10173250734806061




normalized pred locations loss 0.0986194759607315


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.56it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:40,  2.38s/it]

normalized pred locations loss 0.062243182212114334




normalized pred locations loss 0.052059151232242584


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.48it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.37s/it]

normalized pred locations loss 0.05946718901395798


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.84it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.37s/it]

normalized pred locations loss 0.04416849836707115




normalized pred locations loss 0.03912874311208725


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.56it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:33,  2.37s/it]

normalized pred locations loss 0.043851181864738464


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.90it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.37s/it]

normalized pred locations loss 0.03816947340965271




normalized pred locations loss 0.04594571515917778


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.61it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:19<00:28,  2.37s/it]

normalized pred locations loss 0.04236822575330734


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.34it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:26,  2.37s/it]

normalized pred locations loss 0.04284005984663963




normalized pred locations loss 0.03473106399178505


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.09it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.36s/it]

normalized pred locations loss 0.047896623611450195




normalized pred locations loss 0.03343567997217178


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.35it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:26<00:21,  2.37s/it]

normalized pred locations loss 0.03401721641421318


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.83it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:18,  2.37s/it]
Probe prediction step:  27%|██▋       | 42/156 [00:00<00:01, 65.50it/s]

normalized pred locations loss 0.04117041826248169


[A

normalized pred locations loss 0.035276222974061966


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.44it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.37s/it]

normalized pred locations loss 0.050069648772478104


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.46it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:33<00:14,  2.38s/it]

normalized pred locations loss 0.02914666011929512




normalized pred locations loss 0.03822823241353035


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.27it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.38s/it]

normalized pred locations loss 0.02574198506772518


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.43it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.38s/it]

normalized pred locations loss 0.02502267435193062




normalized pred locations loss 0.030432967469096184


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.10it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:40<00:07,  2.38s/it]

normalized pred locations loss 0.027060315012931824


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.21it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.38s/it]

normalized pred locations loss 0.030977550894021988




normalized pred locations loss 0.028223983943462372


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.40it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:45<00:02,  2.38s/it]

normalized pred locations loss 0.027337120845913887




normalized pred locations loss 0.026407107710838318


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.58it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:47<00:00,  2.37s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.31it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 84.47it/s]

normal loss: 10.164274215698242
wall loss: 14.979148864746094





In [53]:
model = WorldModelVICReg().to('cuda')
model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_2.pt')['model_state_dict'])
evaluate_model(device, model, probe_train_ds, probe_val_ds)

  model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_2.pt')['model_state_dict'])
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]



normalized pred locations loss 0.973574161529541




normalized pred locations loss 0.7800882458686829


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.41it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:45,  2.39s/it]

normalized pred locations loss 0.2721656858921051


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.39it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:42,  2.39s/it]

normalized pred locations loss 0.11064087599515915




normalized pred locations loss 0.09030066430568695


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.95it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:40,  2.39s/it]

normalized pred locations loss 0.06960202753543854




normalized pred locations loss 0.062197986990213394


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.01it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:38,  2.40s/it]

normalized pred locations loss 0.03570416942238808


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.23it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.40s/it]

normalized pred locations loss 0.04206952080130577




normalized pred locations loss 0.04054495319724083


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.33it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:33,  2.39s/it]

normalized pred locations loss 0.04624347761273384


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.79it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:31,  2.39s/it]

normalized pred locations loss 0.04633660987019539




normalized pred locations loss 0.048230305314064026


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.11it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:19<00:28,  2.37s/it]

normalized pred locations loss 0.038325726985931396


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.29it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:25,  2.36s/it]

normalized pred locations loss 0.0410887636244297




normalized pred locations loss 0.03796772286295891


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.43it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.37s/it]

normalized pred locations loss 0.03060382977128029




normalized pred locations loss 0.04682473838329315


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.86it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:26<00:21,  2.36s/it]

normalized pred locations loss 0.04052279144525528


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.58it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:18,  2.35s/it]

normalized pred locations loss 0.02687322534620762




normalized pred locations loss 0.03498085215687752


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.53it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.35s/it]

normalized pred locations loss 0.04850255697965622


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.73it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:33<00:14,  2.35s/it]

normalized pred locations loss 0.028824012726545334




normalized pred locations loss 0.031474318355321884


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.86it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.34s/it]

normalized pred locations loss 0.031141133978962898


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.72it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.34s/it]

normalized pred locations loss 0.026927253231406212




normalized pred locations loss 0.028328124433755875


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.81it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:40<00:07,  2.34s/it]

normalized pred locations loss 0.03938893973827362


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.58it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.34s/it]

normalized pred locations loss 0.03112061880528927




normalized pred locations loss 0.031095854938030243


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.92it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.34s/it]

normalized pred locations loss 0.025024792179465294




normalized pred locations loss 0.03577831760048866


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.85it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:47<00:00,  2.36s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.49it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.42it/s]

normal loss: 10.20442008972168
wall loss: 15.43889331817627





In [54]:
model = WorldModelVICReg().to('cuda')
model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_3.pt')['model_state_dict'])
evaluate_model(device, model, probe_train_ds, probe_val_ds)

  model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_3.pt')['model_state_dict'])
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]



normalized pred locations loss 0.998798131942749




normalized pred locations loss 0.8595345616340637


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.59it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:44,  2.34s/it]

normalized pred locations loss 0.27136605978012085


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.63it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:42,  2.34s/it]

normalized pred locations loss 0.1563265472650528




normalized pred locations loss 0.10792218893766403


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.23it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:39,  2.33s/it]

normalized pred locations loss 0.09370145946741104




normalized pred locations loss 0.06394800543785095


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.97it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.32s/it]

normalized pred locations loss 0.07776868343353271


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.04it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:34,  2.32s/it]

normalized pred locations loss 0.0527845062315464




normalized pred locations loss 0.06921539455652237


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.02it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:13<00:32,  2.32s/it]

normalized pred locations loss 0.052291274070739746


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.98it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.33s/it]

normalized pred locations loss 0.0575057677924633




normalized pred locations loss 0.03936103358864784


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.99it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:27,  2.33s/it]

normalized pred locations loss 0.048771291971206665


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.94it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:20<00:25,  2.33s/it]

normalized pred locations loss 0.05901194363832474




normalized pred locations loss 0.04044530168175697


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.84it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.33s/it]

normalized pred locations loss 0.04892488941550255




normalized pred locations loss 0.035884857177734375


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.96it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:25<00:20,  2.33s/it]

normalized pred locations loss 0.04128429666161537


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.37it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:27<00:18,  2.33s/it]

normalized pred locations loss 0.02971440926194191




normalized pred locations loss 0.03413544222712517


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.75it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.33s/it]

normalized pred locations loss 0.04855824261903763


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.12it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:32<00:13,  2.33s/it]

normalized pred locations loss 0.04119688645005226




normalized pred locations loss 0.028863398358225822


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.10it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:34<00:11,  2.33s/it]

normalized pred locations loss 0.0366673618555069


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.87it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.33s/it]

normalized pred locations loss 0.02444310486316681




normalized pred locations loss 0.036303166300058365


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.73it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:39<00:06,  2.33s/it]

normalized pred locations loss 0.028238678351044655


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.54it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:41<00:04,  2.34s/it]

normalized pred locations loss 0.03246837854385376




normalized pred locations loss 0.03216493874788284


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.96it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.34s/it]

normalized pred locations loss 0.03197283297777176




normalized pred locations loss 0.04770766943693161


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.98it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:46<00:00,  2.33s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.40it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.56it/s]

normal loss: 11.363409042358398
wall loss: 16.40474510192871





In [55]:
model = WorldModelVICReg().to('cuda')
model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_4.pt')['model_state_dict'])
evaluate_model(device, model, probe_train_ds, probe_val_ds)

  model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_4.pt')['model_state_dict'])
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]
Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

[A

normalized pred locations loss 1.067325234413147




normalized pred locations loss 0.7300690412521362


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.92it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:44,  2.33s/it]

normalized pred locations loss 0.3180791139602661


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.19it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:41,  2.33s/it]

normalized pred locations loss 0.17186228930950165




normalized pred locations loss 0.10247805714607239


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.26it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:06<00:39,  2.32s/it]

normalized pred locations loss 0.08096766471862793




normalized pred locations loss 0.09119147807359695


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.54it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.33s/it]

normalized pred locations loss 0.06866149604320526


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.72it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.34s/it]

normalized pred locations loss 0.06514804810285568




normalized pred locations loss 0.06466719508171082


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.74it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:32,  2.34s/it]

normalized pred locations loss 0.06344388425350189


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.73it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.34s/it]

normalized pred locations loss 0.060466304421424866




normalized pred locations loss 0.047825150191783905


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.65it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:28,  2.34s/it]

normalized pred locations loss 0.060211315751075745


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.64it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:25,  2.34s/it]

normalized pred locations loss 0.04813283309340477




normalized pred locations loss 0.04015157371759415


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.04it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.34s/it]

normalized pred locations loss 0.051623061299324036





normalized pred locations loss 0.041874952614307404


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.33it/s][A
Probe prediction epochs:  55%|█████▌    | 11/20 [00:25<00:21,  2.34s/it]

normalized pred locations loss 0.05795025825500488


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.89it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:27<00:18,  2.32s/it]

normalized pred locations loss 0.02807670459151268




normalized pred locations loss 0.033789005130529404


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.51it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.33s/it]

normalized pred locations loss 0.039586398750543594


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.97it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:32<00:14,  2.34s/it]

normalized pred locations loss 0.037925101816654205




normalized pred locations loss 0.031027868390083313


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.33it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.35s/it]

normalized pred locations loss 0.030266599729657173


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.33it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.36s/it]

normalized pred locations loss 0.031436268240213394




normalized pred locations loss 0.030447259545326233


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.50it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:39<00:07,  2.37s/it]

normalized pred locations loss 0.03314489126205444


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.18it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.36s/it]

normalized pred locations loss 0.03226040303707123




normalized pred locations loss 0.03413183614611626


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.68it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.35s/it]

normalized pred locations loss 0.02660958096385002




normalized pred locations loss 0.03621548041701317


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.64it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:46<00:00,  2.34s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.40it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.54it/s]

normal loss: 11.86609172821045
wall loss: 17.112180709838867





In [56]:
model = WorldModelVICReg().to('cuda')
model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_5.pt')['model_state_dict'])
evaluate_model(device, model, probe_train_ds, probe_val_ds)

  model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_5.pt')['model_state_dict'])
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]



normalized pred locations loss 1.2110536098480225




normalized pred locations loss 0.7883133292198181


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.98it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:44,  2.33s/it]

normalized pred locations loss 0.32565274834632874


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.96it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:41,  2.33s/it]

normalized pred locations loss 0.1992124617099762




normalized pred locations loss 0.10047836601734161


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.09it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:06<00:39,  2.33s/it]

normalized pred locations loss 0.11903891712427139




normalized pred locations loss 0.0835399180650711


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.94it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.33s/it]

normalized pred locations loss 0.1258830428123474


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.01it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:34,  2.33s/it]

normalized pred locations loss 0.07436065375804901




normalized pred locations loss 0.07101467996835709


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.02it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:13<00:32,  2.33s/it]

normalized pred locations loss 0.05736861005425453


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.08it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.33s/it]

normalized pred locations loss 0.07594674080610275




normalized pred locations loss 0.06071389466524124


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.94it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:27,  2.33s/it]

normalized pred locations loss 0.0503939688205719


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.91it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:20<00:25,  2.33s/it]

normalized pred locations loss 0.05820417404174805




normalized pred locations loss 0.039852578192949295


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.88it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.33s/it]

normalized pred locations loss 0.0396280400454998




normalized pred locations loss 0.04038191959261894


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 69.84it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:25<00:20,  2.30s/it]

normalized pred locations loss 0.04401538148522377


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 70.44it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:27<00:18,  2.28s/it]

normalized pred locations loss 0.03363671153783798




normalized pred locations loss 0.03823761269450188


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.77it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:15,  2.27s/it]

normalized pred locations loss 0.034851882606744766


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.27it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:32<00:13,  2.29s/it]

normalized pred locations loss 0.04031161591410637




normalized pred locations loss 0.04649718105792999


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.58it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:34<00:11,  2.31s/it]

normalized pred locations loss 0.03695002570748329


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.49it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.33s/it]

normalized pred locations loss 0.04565272107720375




normalized pred locations loss 0.03448766842484474


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.55it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:39<00:07,  2.34s/it]

normalized pred locations loss 0.028735538944602013


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.84it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:41<00:04,  2.35s/it]

normalized pred locations loss 0.03974789008498192




normalized pred locations loss 0.03584674000740051


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.09it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.35s/it]

normalized pred locations loss 0.034124765545129776




normalized pred locations loss 0.036533504724502563


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.78it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:46<00:00,  2.32s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 83.06it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.62it/s]

normal loss: 12.959648132324219
wall loss: 17.97338104248047





In [57]:
model = WorldModelVICReg().to('cuda')
model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_6.pt')['model_state_dict'])
evaluate_model(device, model, probe_train_ds, probe_val_ds)

  model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_6.pt')['model_state_dict'])
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.1338393688201904




normalized pred locations loss 0.7276270389556885


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.88it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:45,  2.37s/it]

normalized pred locations loss 0.33604541420936584


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.98it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:42,  2.35s/it]

normalized pred locations loss 0.19993677735328674




normalized pred locations loss 0.157821387052536


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.55it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:39,  2.35s/it]

normalized pred locations loss 0.12850509583950043




normalized pred locations loss 0.133024200797081


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.68it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.33s/it]

normalized pred locations loss 0.08875823020935059


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.51it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:34,  2.31s/it]

normalized pred locations loss 0.08229460567235947




normalized pred locations loss 0.06706434488296509


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.37it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:13<00:32,  2.31s/it]

normalized pred locations loss 0.11083176732063293


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.53it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.32s/it]

normalized pred locations loss 0.06925390660762787




normalized pred locations loss 0.0567888505756855


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.26it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:28,  2.35s/it]

normalized pred locations loss 0.07352180033922195


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.47it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:25,  2.35s/it]

normalized pred locations loss 0.0704197883605957




normalized pred locations loss 0.05126423016190529


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.71it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.36s/it]

normalized pred locations loss 0.07010426372289658





normalized pred locations loss 0.06929467618465424


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.56it/s][A
Probe prediction epochs:  55%|█████▌    | 11/20 [00:25<00:21,  2.36s/it]


normalized pred locations loss 0.04148884490132332


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.58it/s][A
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:18,  2.37s/it]

normalized pred locations loss 0.03877044841647148




normalized pred locations loss 0.03516725078225136


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.67it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.36s/it]

normalized pred locations loss 0.038466405123472214


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.22it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:32<00:14,  2.36s/it]

normalized pred locations loss 0.0493745282292366




normalized pred locations loss 0.0530044324696064


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.78it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.33s/it]

normalized pred locations loss 0.04887118935585022


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.88it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.31s/it]

normalized pred locations loss 0.03837314993143082




normalized pred locations loss 0.04601423069834709


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.45it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:39<00:06,  2.31s/it]

normalized pred locations loss 0.042729754000902176


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 69.07it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:41<00:04,  2.30s/it]

normalized pred locations loss 0.032650724053382874




normalized pred locations loss 0.03543461114168167


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 70.58it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.27s/it]

normalized pred locations loss 0.049194999039173126




normalized pred locations loss 0.042543649673461914


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.67it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:46<00:00,  2.32s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.13it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.66it/s]

normal loss: 13.814647674560547
wall loss: 18.868379592895508





In [58]:
model = WorldModelVICReg().to('cuda')
model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_7.pt')['model_state_dict'])
evaluate_model(device, model, probe_train_ds, probe_val_ds)

  model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_7.pt')['model_state_dict'])
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.0737444162368774




normalized pred locations loss 0.9109867811203003


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.02it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:45,  2.40s/it]

normalized pred locations loss 0.39509084820747375


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.58it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:42,  2.37s/it]

normalized pred locations loss 0.19365723431110382




normalized pred locations loss 0.19022312760353088


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.90it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:39,  2.35s/it]

normalized pred locations loss 0.13980799913406372




normalized pred locations loss 0.11988360434770584


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.57it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.35s/it]

normalized pred locations loss 0.11538535356521606


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.60it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.35s/it]

normalized pred locations loss 0.07834357768297195




normalized pred locations loss 0.06455153971910477


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.60it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:32,  2.35s/it]

normalized pred locations loss 0.09472496807575226


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.31it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.35s/it]

normalized pred locations loss 0.06447761505842209




normalized pred locations loss 0.0626414492726326


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.61it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:28,  2.35s/it]

normalized pred locations loss 0.062335770577192307


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.90it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:25,  2.34s/it]

normalized pred locations loss 0.07613563537597656




normalized pred locations loss 0.06897138059139252


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.73it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.34s/it]

normalized pred locations loss 0.05767959728837013




normalized pred locations loss 0.05408622696995735


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.32it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:25<00:21,  2.35s/it]

normalized pred locations loss 0.06066587567329407


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.52it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:18,  2.35s/it]

normalized pred locations loss 0.046086542308330536




normalized pred locations loss 0.042363908141851425


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.11it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.35s/it]

normalized pred locations loss 0.04819279536604881


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.25it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:32<00:14,  2.36s/it]

normalized pred locations loss 0.044070638716220856




normalized pred locations loss 0.04252490773797035


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.10it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.39s/it]

normalized pred locations loss 0.05031981319189072


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 63.44it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.41s/it]

normalized pred locations loss 0.04350661486387253




normalized pred locations loss 0.04019466042518616


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.26it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:40<00:07,  2.40s/it]

normalized pred locations loss 0.03855203092098236


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.05it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.39s/it]

normalized pred locations loss 0.032743506133556366




normalized pred locations loss 0.03749482333660126


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.57it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.39s/it]

normalized pred locations loss 0.03974064439535141




normalized pred locations loss 0.062215663492679596


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.59it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:47<00:00,  2.37s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.68it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 81.55it/s]

normal loss: 14.75185489654541
wall loss: 20.491424560546875





In [59]:
model = WorldModelVICReg().to('cuda')
model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_8.pt')['model_state_dict'])
evaluate_model(device, model, probe_train_ds, probe_val_ds)

  model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_8.pt')['model_state_dict'])
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.2150782346725464




normalized pred locations loss 0.8451034426689148


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.76it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:45,  2.37s/it]

normalized pred locations loss 0.42177751660346985


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.03it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:42,  2.37s/it]

normalized pred locations loss 0.28123000264167786




normalized pred locations loss 0.21254976093769073


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.83it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:40,  2.37s/it]

normalized pred locations loss 0.14432843029499054




normalized pred locations loss 0.125068798661232


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.39it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.36s/it]

normalized pred locations loss 0.1239105686545372


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.15it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.35s/it]

normalized pred locations loss 0.09253949671983719




normalized pred locations loss 0.09293122589588165


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.91it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:32,  2.36s/it]

normalized pred locations loss 0.07692209631204605


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.61it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.36s/it]

normalized pred locations loss 0.08245932310819626




normalized pred locations loss 0.08492577821016312


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.57it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:28,  2.37s/it]

normalized pred locations loss 0.08811011165380478


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.95it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:26,  2.37s/it]

normalized pred locations loss 0.05527731403708458




normalized pred locations loss 0.05568254739046097


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.80it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.37s/it]

normalized pred locations loss 0.04876500740647316




normalized pred locations loss 0.04643278196454048


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.74it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:26<00:21,  2.37s/it]

normalized pred locations loss 0.051755085587501526


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.89it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:18,  2.36s/it]

normalized pred locations loss 0.034623876214027405




normalized pred locations loss 0.05019689351320267


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.33it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.36s/it]

normalized pred locations loss 0.04677186161279678


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.93it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:33<00:14,  2.35s/it]

normalized pred locations loss 0.04510131850838661




normalized pred locations loss 0.05384405329823494


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.36it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.34s/it]

normalized pred locations loss 0.04763156920671463


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.27it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.32s/it]

normalized pred locations loss 0.06331208348274231




normalized pred locations loss 0.056542906910181046


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.95it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:39<00:06,  2.32s/it]

normalized pred locations loss 0.046564195305109024


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.20it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.32s/it]

normalized pred locations loss 0.03560130298137665




normalized pred locations loss 0.03232751786708832


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.34it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.34s/it]

normalized pred locations loss 0.04722057282924652




normalized pred locations loss 0.03814183175563812


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.42it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:47<00:00,  2.35s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 83.03it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 83.36it/s]

normal loss: 15.588173866271973
wall loss: 21.308504104614258





In [60]:
model = WorldModelVICReg().to('cuda')
model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_9.pt')['model_state_dict'])
evaluate_model(device, model, probe_train_ds, probe_val_ds)

  model.load_state_dict(torch.load('cnn_based_checkpoints/checkpoint_epoch_9.pt')['model_state_dict'])
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.0515559911727905




normalized pred locations loss 0.7334337830543518


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.85it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:45,  2.37s/it]

normalized pred locations loss 0.2747701406478882


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.08it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:42,  2.39s/it]

normalized pred locations loss 0.23857583105564117




normalized pred locations loss 0.15911516547203064


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.67it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:40,  2.38s/it]

normalized pred locations loss 0.15004290640354156




normalized pred locations loss 0.15738315880298615


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.65it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.37s/it]

normalized pred locations loss 0.09789498150348663


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.55it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.37s/it]

normalized pred locations loss 0.11750100553035736




normalized pred locations loss 0.10049953311681747


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.10it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:33,  2.37s/it]

normalized pred locations loss 0.07995134592056274


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.52it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.34s/it]

normalized pred locations loss 0.08791694790124893




normalized pred locations loss 0.0699029490351677


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 69.73it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:27,  2.31s/it]

normalized pred locations loss 0.08443207293748856


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.27it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:25,  2.31s/it]

normalized pred locations loss 0.07242199033498764




normalized pred locations loss 0.05707472562789917


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.23it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.30s/it]

normalized pred locations loss 0.06585739552974701




normalized pred locations loss 0.0562778115272522


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.92it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:25<00:21,  2.34s/it]
Probe prediction step:  63%|██████▎   | 98/156 [00:01<00:00, 65.29it/s]

normalized pred locations loss 0.06038006767630577


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.17it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:18,  2.35s/it]

normalized pred locations loss 0.05878392979502678




normalized pred locations loss 0.04478451982140541


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.83it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.35s/it]

normalized pred locations loss 0.04139050468802452


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.80it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:32<00:14,  2.36s/it]

normalized pred locations loss 0.061122164130210876




normalized pred locations loss 0.05034835264086723


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 62.69it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.40s/it]

normalized pred locations loss 0.056024592369794846


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.99it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.40s/it]

normalized pred locations loss 0.0397929847240448




normalized pred locations loss 0.0439738854765892


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.12it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:40<00:07,  2.38s/it]

normalized pred locations loss 0.044183652848005295


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.91it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.39s/it]

normalized pred locations loss 0.05098721757531166




normalized pred locations loss 0.03805529698729515


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.29it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.39s/it]

normalized pred locations loss 0.05697118863463402




normalized pred locations loss 0.037072405219078064


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.20it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:47<00:00,  2.36s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 84.40it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 84.45it/s]

normal loss: 15.973650932312012
wall loss: 21.906774520874023



