In [12]:
# DID NOT REALLY WORK!

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import Tuple, Dict
from dataset import WallSample, create_wall_dataloader
from tqdm.auto import tqdm
import time
from datetime import datetime
from pathlib import Path
from lightly.utils.scheduler import cosine_schedule

from evaluator import ProbingEvaluator

def load_data(device):
    data_path = "/drive_reader/as16386/DL24FA"

    probe_train_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_normal_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_ds = {"normal": probe_val_normal_ds, "wall": probe_val_wall_ds}

    return probe_train_ds, probe_val_ds

def evaluate_model(device, model, probe_train_ds, probe_val_ds):
    evaluator = ProbingEvaluator(
        device=device,
        model=model,
        probe_train_ds=probe_train_ds,
        probe_val_ds=probe_val_ds,
        quick_debug=False,
    )

    prober = evaluator.train_pred_prober()

    avg_losses = evaluator.evaluate_all(prober=prober)

    for probe_attr, loss in avg_losses.items():
        print(f"{probe_attr} loss: {loss}")

probe_train_ds, probe_val_ds = load_data("cuda")

class Encoder(nn.Module):
    def __init__(self, input_channels=2):
        super().__init__()
        # First conv: 65x65 -> 22x22
        self.conv1 = nn.Conv2d(input_channels, 8, kernel_size=5, stride=3, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(8)
        
        # Second conv: 22x22 -> 8x8
        self.conv2 = nn.Conv2d(8, 32, kernel_size=3, stride=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        
        self.repr_dim = 32 * 8 * 8  # Full flattened representation
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # -> 22x22
        x = F.relu(self.bn2(self.conv2(x)))  # -> 8x8
        return x

class TransitionModel(nn.Module):
    def __init__(self, hidden_dim=32):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # Action embedding
        self.action_embed = nn.Sequential(
            nn.Conv2d(2, hidden_dim // 2, 1),
            nn.BatchNorm2d(hidden_dim // 2),
            nn.ReLU(),
            nn.Conv2d(hidden_dim // 2, hidden_dim, 1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU()
        )
        
        # Transition model
        self.transition = nn.Sequential(
            nn.Conv2d(hidden_dim * 2, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim)
        )
        
    def forward(self, state, action):
        B, _, H, W = state.shape
        action = action.view(B, 2, 1, 1).expand(-1, -1, H, W)
        action_embedding = self.action_embed(action)
        combined = torch.cat([state, action_embedding], dim=1)
        delta = self.transition(combined)
        next_state = state + delta
        return next_state

class ProjectionHead(nn.Module):
    def __init__(self, input_dim=32*8*8, hidden_dim=2048, output_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        if len(x.shape) > 2:
            x = x.reshape(x.shape[0], -1)
        return self.net(x)

class PredictionHead(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=512, output_dim=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

class BYOLWorldModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Online networks
        self.encoder = Encoder(input_channels=2)
        self.predictor = TransitionModel(hidden_dim=32)
        self.projection = ProjectionHead(input_dim=32*8*8)
        self.prediction = PredictionHead()
        
        # Target networks
        self.encoder_momentum = copy.deepcopy(self.encoder)
        self.projection_momentum = copy.deepcopy(self.projection)
        
        # Deactivate gradients for momentum networks
        for param in self.encoder_momentum.parameters():
            param.requires_grad = False
        for param in self.projection_momentum.parameters():
            param.requires_grad = False
            
        self.repr_dim = self.encoder.repr_dim
        
    def forward_prediction(self, states, actions):
        B, _, _, H, W = states.shape
        T = actions.shape[1] + 1
        
        curr_state = self.encoder(states.squeeze(1))
        predictions = [curr_state]
        
        for t in range(T-1):
            curr_state = self.predictor(curr_state, actions[:, t])
            predictions.append(curr_state)
            
        predictions = torch.stack(predictions, dim=1)
        return predictions

    def forward(self, states, actions):
        init_states = states[:, 0:1]
        predictions = self.forward_prediction(init_states, actions)
        B, T, C, H, W = predictions.shape
        predictions = predictions.view(B, T, -1)
        return predictions

    def compute_byol_loss(self, pred_flat, target_flat):
        pred_proj = self.projection(pred_flat)
        pred_out = self.prediction(pred_proj)
        
        with torch.no_grad():
            target_proj = self.projection_momentum(target_flat)
        
        pred_out = F.normalize(pred_out, dim=1)
        target_proj = F.normalize(target_proj, dim=1)
        
        loss = 2 - 2 * (pred_out * target_proj).sum(dim=1).mean()
        
        return loss, {
            'byol_loss': loss.item(),
            'total_loss': loss.item()
        }

    def training_step(self, batch):
        states = batch.states
        actions = batch.actions
        
        # Get initial state
        init_states = states[:, 0:1]
        
        # Get predictions for all steps
        predictions = self.forward_prediction(init_states, actions)
        
        # Initialize losses
        total_loss = 0.0
        accumulated_losses = {
            'total_loss': 0.0,
            'byol_loss': 0.0
        }
        
        # Compute loss for each timestep
        for t in range(actions.shape[1]):
            pred_state = predictions[:, t+1]
            target_obs = states[:, t+1]
            
            with torch.no_grad():
                target_state = self.encoder_momentum(target_obs)
            
            pred_flat = pred_state.flatten(start_dim=1)
            target_flat = target_state.flatten(start_dim=1)
            
            loss, component_losses = self.compute_byol_loss(pred_flat, target_flat)
            
            total_loss += loss
            for k in accumulated_losses:
                accumulated_losses[k] += component_losses[k]
        
        # Average losses over timesteps
        total_loss = total_loss / actions.shape[1]
        for k in accumulated_losses:
            accumulated_losses[k] /= actions.shape[1]
        
        return total_loss, predictions, accumulated_losses
    
    @torch.no_grad()
    def update_target(self, current_epoch, total_epochs):
        tau = cosine_schedule(current_epoch, total_epochs, 0.996, 1.0)
        
        for online, target in zip(self.encoder.parameters(), self.encoder_momentum.parameters()):
            target.data = tau * target.data + (1 - tau) * online.data
            
        for online, target in zip(self.projection.parameters(), self.projection_momentum.parameters()):
            target.data = tau * target.data + (1 - tau) * online.data

def train_epoch(model, dataloader, optimizer, epoch, total_epochs):
    """Train for one epoch"""
    model.train()
    total_loss = 0.0
    train_losses = {
        'total_loss': 0.0,
        'byol_loss': 0.0
    }
    
    # Progress bar for batches
    num_batches = len(dataloader)
    progress_bar = tqdm(enumerate(dataloader), 
                       total=num_batches,
                       desc=f'Epoch {epoch}/{total_epochs}',
                       leave=True)
    
    start_time = time.time()
    batch_times = []
    
    for batch_idx, batch in progress_bar:
        batch_start = time.time()
        
        # Move batch to GPU
        batch = batch._replace(
            states=batch.states.cuda(),
            actions=batch.actions.cuda(),
            locations=batch.locations.cuda() if batch.locations is not None else None
        )
        
        # Forward pass and compute loss
        optimizer.zero_grad()
        loss, _, component_losses = model.training_step(batch)  # Removed extra arguments
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Update momentum networks
        model.update_target(epoch, total_epochs)
        
        # Update timing
        batch_time = time.time() - batch_start
        batch_times.append(batch_time)
        
        # Update losses
        total_loss += loss.item()
        for k in train_losses:
            train_losses[k] += component_losses[k]
        
        # Update progress bar
        current_loss = total_loss / (batch_idx + 1)
        current_batch_time = sum(batch_times) / len(batch_times)
        
        progress_bar.set_postfix({
            'loss': f'{current_loss:.4f}',
            'time/batch': f'{current_batch_time:.3f}s',
            'gpu_mem': f'{torch.cuda.max_memory_allocated()/1e9:.1f}GB'
        })
    
    # Compute final metrics
    num_batches = len(dataloader)
    total_loss /= num_batches
    for k in train_losses:
        train_losses[k] /= num_batches
    
    epoch_time = time.time() - start_time
    avg_batch_time = sum(batch_times) / len(batch_times)
    
    return total_loss, train_losses, {
        'epoch_time': epoch_time,
        'avg_batch_time': avg_batch_time
    }

def train_model(model, train_loader, num_epochs=100, learning_rate=3e-4, 
                save_dir='checkpoints/byol', save_frequency=10):
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Create checkpoint directory
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"Starting training at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Training for {num_epochs} epochs")
    print(f"Checkpoints will be saved to {save_dir}")
    
    for epoch in range(num_epochs):
        # Train for one epoch
        train_loss, train_losses, timing_stats = train_epoch(
            model, train_loader, optimizer, epoch, num_epochs
        )
        
        # Print epoch summary
        print(f"\nEpoch {epoch}/{num_epochs} Summary:")
        print(f"Train Loss: {train_loss:.4f}")
        print("Component Losses:")
        for k, v in train_losses.items():
            print(f"  {k}: {v:.4f}")
        print(f"Epoch Time: {timing_stats['epoch_time']:.1f}s")
        print(f"Avg Batch Time: {timing_stats['avg_batch_time']:.3f}s")
        print(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1024**2:.0f}MB")
        evaluate_model("cuda", model, probe_train_ds, probe_val_ds)
        
        # Save checkpoint
        if (epoch + 1) % save_frequency == 0:
            checkpoint_path = save_dir / f"checkpoint_epoch_{epoch+1}.pt"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': train_loss,
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")
        
        print("-" * 80)

if __name__ == "__main__":
    # Create model and data loader
    model = BYOLWorldModel().cuda()
    train_loader = create_wall_dataloader(
        "/drive_reader/as16386/DL24FA/train",  # Update with your path
        batch_size=128,
        train=True
    )
    
    # Start training
    train_model(
        model=model,
        train_loader=train_loader,
        num_epochs=100,
        learning_rate=3e-5,
        save_dir='byol_try_checkpoints/',
        save_frequency=10
    )

Starting training at 2024-12-11 23:38:42
Training for 100 epochs
Checkpoints will be saved to byol_try_checkpoints


Epoch 0/100: 100%|██████████| 1148/1148 [01:02<00:00, 18.43it/s, loss=0.3036, time/batch=0.029s, gpu_mem=1.8GB]



Epoch 0/100 Summary:
Train Loss: 0.3036
Component Losses:
  total_loss: 0.3036
  byol_loss: 0.3036
Epoch Time: 62.3s
Avg Batch Time: 0.029s
GPU Memory: 1683MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 2.0951576232910156




normalized pred locations loss 1.1319832801818848


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 59.30it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:50,  2.63s/it]

normalized pred locations loss 0.9847484827041626


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.41it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:44,  2.47s/it]

normalized pred locations loss 0.883996307849884




normalized pred locations loss 1.112760066986084


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.64it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:40,  2.41s/it]

normalized pred locations loss 0.9842950105667114




normalized pred locations loss 0.9987833499908447


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.51it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:38,  2.39s/it]

normalized pred locations loss 0.9879862666130066


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.60it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:12<00:35,  2.37s/it]

normalized pred locations loss 1.0100244283676147




normalized pred locations loss 0.9937216639518738


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.88it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:33,  2.37s/it]

normalized pred locations loss 1.0210224390029907


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.69it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.34s/it]

normalized pred locations loss 0.9183034300804138




normalized pred locations loss 0.936065673828125


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.71it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:27,  2.32s/it]

normalized pred locations loss 0.9529276490211487


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.77it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:25,  2.31s/it]

normalized pred locations loss 1.028796672821045




normalized pred locations loss 0.955379068851471


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.94it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.32s/it]

normalized pred locations loss 1.0199918746948242




normalized pred locations loss 0.8607928156852722


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.91it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:25<00:20,  2.32s/it]

normalized pred locations loss 0.7730434536933899


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.19it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:18,  2.32s/it]

normalized pred locations loss 1.0162709951400757




normalized pred locations loss 0.8846650123596191


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.01it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.33s/it]

normalized pred locations loss 0.8317813873291016


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.15it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:32<00:13,  2.33s/it]

normalized pred locations loss 0.738357663154602




normalized pred locations loss 0.7896367311477661


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.21it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.32s/it]

normalized pred locations loss 0.7774302363395691


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.95it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.33s/it]

normalized pred locations loss 0.7009076476097107




normalized pred locations loss 0.6467087268829346


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.20it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:39<00:06,  2.33s/it]

normalized pred locations loss 0.7500486969947815


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.06it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.33s/it]

normalized pred locations loss 0.6175337433815002




normalized pred locations loss 0.655852198600769


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.87it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.34s/it]

normalized pred locations loss 0.853463351726532




normalized pred locations loss 0.7163749933242798


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.41it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:46<00:00,  2.35s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 73.26it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 73.49it/s]


normal loss: 197.3961181640625
wall loss: 177.8648223876953
--------------------------------------------------------------------------------


Epoch 1/100: 100%|██████████| 1148/1148 [00:57<00:00, 20.12it/s, loss=0.0917, time/batch=0.028s, gpu_mem=1.8GB]



Epoch 1/100 Summary:
Train Loss: 0.0917
Component Losses:
  total_loss: 0.0917
  byol_loss: 0.0917
Epoch Time: 57.1s
Avg Batch Time: 0.028s
GPU Memory: 1683MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 2.018658399581909




normalized pred locations loss 1.142062783241272


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.08it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:44,  2.36s/it]

normalized pred locations loss 1.0971879959106445


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.51it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:42,  2.38s/it]

normalized pred locations loss 1.009329915046692




normalized pred locations loss 0.9262480139732361


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.93it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:40,  2.37s/it]

normalized pred locations loss 0.9804626703262329




normalized pred locations loss 0.8319993615150452


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.96it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.36s/it]

normalized pred locations loss 1.0421959161758423


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.95it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.35s/it]

normalized pred locations loss 1.0353106260299683




normalized pred locations loss 1.1119009256362915


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.73it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:32,  2.35s/it]

normalized pred locations loss 0.9000351428985596


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.19it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.34s/it]

normalized pred locations loss 1.0146270990371704




normalized pred locations loss 1.0617952346801758


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.09it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:28,  2.33s/it]

normalized pred locations loss 0.9407896399497986


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.22it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:25,  2.33s/it]

normalized pred locations loss 1.0089417695999146




normalized pred locations loss 1.0493237972259521


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.07it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.33s/it]

normalized pred locations loss 1.0578699111938477




normalized pred locations loss 0.8196636438369751


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.46it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:25<00:21,  2.34s/it]

normalized pred locations loss 0.9281478524208069


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.65it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:18,  2.34s/it]

normalized pred locations loss 1.118178129196167




normalized pred locations loss 0.9972575902938843


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.97it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.34s/it]

normalized pred locations loss 0.9157006740570068


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.80it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:32<00:14,  2.34s/it]

normalized pred locations loss 0.8017618060112




normalized pred locations loss 0.9662711024284363


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.14it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.33s/it]

normalized pred locations loss 0.9154397249221802


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.17it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.33s/it]

normalized pred locations loss 1.0291247367858887




normalized pred locations loss 0.8889573216438293


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.99it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:39<00:06,  2.33s/it]

normalized pred locations loss 0.9811620116233826


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.99it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.33s/it]

normalized pred locations loss 0.9510461091995239




normalized pred locations loss 0.8131243586540222


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.96it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.33s/it]

normalized pred locations loss 0.8837419748306274




normalized pred locations loss 0.9298592805862427


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.01it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:46<00:00,  2.34s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.53it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.60it/s]


normal loss: 245.18235778808594
wall loss: 189.22793579101562
--------------------------------------------------------------------------------


Epoch 2/100: 100%|██████████| 1148/1148 [00:57<00:00, 20.09it/s, loss=0.0597, time/batch=0.028s, gpu_mem=1.8GB]



Epoch 2/100 Summary:
Train Loss: 0.0597
Component Losses:
  total_loss: 0.0597
  byol_loss: 0.0597
Epoch Time: 57.1s
Avg Batch Time: 0.028s
GPU Memory: 1683MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 2.363611936569214




normalized pred locations loss 1.1205543279647827


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.94it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:44,  2.33s/it]

normalized pred locations loss 1.1992921829223633


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.69it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:42,  2.34s/it]

normalized pred locations loss 1.078345537185669




normalized pred locations loss 0.9601853489875793


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.85it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:39,  2.34s/it]

normalized pred locations loss 1.111706256866455




normalized pred locations loss 1.034497857093811


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.67it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.34s/it]

normalized pred locations loss 0.9480718970298767


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.94it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.34s/it]

normalized pred locations loss 0.8817775845527649




normalized pred locations loss 1.0002883672714233


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.64it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:32,  2.34s/it]

normalized pred locations loss 0.9805789589881897


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.86it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.34s/it]

normalized pred locations loss 0.952487051486969




normalized pred locations loss 1.0097156763076782


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.02it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:28,  2.33s/it]

normalized pred locations loss 0.7870132327079773


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.06it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:25,  2.33s/it]

normalized pred locations loss 1.0243088006973267




normalized pred locations loss 0.8577398657798767


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 68.61it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.32s/it]

normalized pred locations loss 0.9811416864395142




normalized pred locations loss 0.9004069566726685


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.28it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:25<00:20,  2.32s/it]

normalized pred locations loss 0.8595185875892639


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.75it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:27<00:18,  2.32s/it]

normalized pred locations loss 1.060077428817749




normalized pred locations loss 0.9443929195404053


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.60it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.34s/it]

normalized pred locations loss 0.8778512477874756


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.24it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:32<00:14,  2.36s/it]

normalized pred locations loss 1.1440768241882324




normalized pred locations loss 0.9987502694129944


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.32it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.36s/it]

normalized pred locations loss 0.7767149806022644


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.82it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.35s/it]

normalized pred locations loss 0.9431481957435608




normalized pred locations loss 0.9805623292922974


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.65it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:39<00:07,  2.35s/it]

normalized pred locations loss 0.809899628162384


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.94it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.35s/it]

normalized pred locations loss 0.9213096499443054




normalized pred locations loss 0.9747066497802734


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.24it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.37s/it]

normalized pred locations loss 0.8831705451011658




normalized pred locations loss 0.7294135093688965


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.82it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:46<00:00,  2.34s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.70it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.51it/s]


normal loss: 249.85256958007812
wall loss: 188.7933349609375
--------------------------------------------------------------------------------


Epoch 3/100:  53%|█████▎    | 612/1148 [00:30<00:26, 19.96it/s, loss=0.0487, time/batch=0.029s, gpu_mem=1.8GB]


KeyboardInterrupt: 

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import Tuple, Dict, NamedTuple
from tqdm.auto import tqdm
import time
from datetime import datetime
from pathlib import Path

class Encoder(nn.Module):
    def __init__(self, input_channels=2):
        super().__init__()
        # First conv: 65x65 -> 22x22
        self.conv1 = nn.Conv2d(input_channels, 8, kernel_size=5, stride=3, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(8)
        
        # Second conv: 22x22 -> 8x8
        self.conv2 = nn.Conv2d(8, 32, kernel_size=3, stride=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        
        self.repr_dim = 32 * 8 * 8
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # -> 22x22
        x = F.relu(self.bn2(self.conv2(x)))  # -> 8x8
        return x

class TransitionModel(nn.Module):
    def __init__(self, hidden_dim=32):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # Action embedding
        self.action_embed = nn.Sequential(
            nn.Conv2d(2, hidden_dim // 2, 1),
            nn.BatchNorm2d(hidden_dim // 2),
            nn.ReLU(),
            nn.Conv2d(hidden_dim // 2, hidden_dim, 1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU()
        )
        
        # Transition model
        self.transition = nn.Sequential(
            nn.Conv2d(hidden_dim * 2, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim)
        )
        
    def forward(self, state, action):
        B, _, H, W = state.shape
        action = action.view(B, 2, 1, 1).expand(-1, -1, H, W)
        action_embedding = self.action_embed(action)
        combined = torch.cat([state, action_embedding], dim=1)
        delta = self.transition(combined)
        next_state = state + delta
        return next_state

class CombinedLoss(nn.Module):
    def __init__(self, 
                 vicreg_sim_coef=25.0,
                 vicreg_std_coef=25.0,
                 vicreg_cov_coef=1.0,
                 barlow_lambda=0.005,
                 loss_weight=0.5):  # Weight between VICReg and Barlow
        super().__init__()
        self.vicreg_sim_coef = vicreg_sim_coef
        self.vicreg_std_coef = vicreg_std_coef
        self.vicreg_cov_coef = vicreg_cov_coef
        self.barlow_lambda = barlow_lambda
        self.loss_weight = loss_weight
    
    def off_diagonal(self, x):
        n = x.shape[0]
        return x.flatten()[:-1].view(n-1, n+1)[:, 1:].flatten()
    
    def forward(self, z_a, z_b):
        N = z_a.shape[0]
        D = z_a.shape[1]
        
        # VICReg components
        sim_loss = F.mse_loss(z_a, z_b)
        
        std_z_a = torch.sqrt(z_a.var(dim=0) + 1e-04)
        std_z_b = torch.sqrt(z_b.var(dim=0) + 1e-04)
        std_loss = torch.mean(F.relu(1 - std_z_a)) + torch.mean(F.relu(1 - std_z_b))
        
        z_a_centered = z_a - z_a.mean(dim=0)
        z_b_centered = z_b - z_b.mean(dim=0)
        
        cov_z_a = (z_a_centered.T @ z_a_centered) / (N - 1)
        cov_z_b = (z_b_centered.T @ z_b_centered) / (N - 1)
        
        vicreg_cov_loss = (self.off_diagonal(cov_z_a).pow_(2).sum() / D +
                          self.off_diagonal(cov_z_b).pow_(2).sum() / D)
        
        # Barlow Twins components
        z_a_norm = (z_a - z_a.mean(dim=0)) / (z_a.std(dim=0) + 1e-6)
        z_b_norm = (z_b - z_b.mean(dim=0)) / (z_b.std(dim=0) + 1e-6)
        
        c = torch.mm(z_a_norm.T, z_b_norm) / N
        
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = torch.triu(c.pow(2), diagonal=1).sum() + torch.tril(c.pow(2), diagonal=-1).sum()
        barlow_loss = on_diag + self.barlow_lambda * off_diag
        
        # Combine VICReg losses
        vicreg_loss = (self.vicreg_sim_coef * sim_loss +
                      self.vicreg_std_coef * std_loss +
                      self.vicreg_cov_coef * vicreg_cov_loss)
        
        # Final weighted combination
        total_loss = self.loss_weight * vicreg_loss + (1 - self.loss_weight) * barlow_loss
        
        return total_loss, {
            'total_loss': total_loss.item(),
            'vicreg_sim_loss': sim_loss.item(),
            'vicreg_std_loss': std_loss.item(),
            'vicreg_cov_loss': vicreg_cov_loss.item(),
            'barlow_loss': barlow_loss.item()
        }

class WorldModelCombined(nn.Module):
    def __init__(self, vicreg_sim_coef=25.0, vicreg_std_coef=25.0, vicreg_cov_coef=1.0,
                 barlow_lambda=0.005, loss_weight=0.5):
        super().__init__()
        self.encoder = Encoder(input_channels=2)
        self.predictor = TransitionModel(hidden_dim=32)
        
        self.criterion = CombinedLoss(
            vicreg_sim_coef=vicreg_sim_coef,
            vicreg_std_coef=vicreg_std_coef,
            vicreg_cov_coef=vicreg_cov_coef,
            barlow_lambda=barlow_lambda,
            loss_weight=loss_weight
        )
        
        self.repr_dim = self.encoder.repr_dim
    
    def forward_prediction(self, states, actions):
        B, _, _, H, W = states.shape
        T = actions.shape[1] + 1
        
        curr_state = self.encoder(states.squeeze(1))
        predictions = [curr_state]
        
        for t in range(T-1):
            curr_state = self.predictor(curr_state, actions[:, t])
            predictions.append(curr_state)
            
        predictions = torch.stack(predictions, dim=1)
        return predictions

    def forward(self, states, actions):
        init_states = states[:, 0:1]
        predictions = self.forward_prediction(init_states, actions)
        B, T, C, H, W = predictions.shape
        predictions = predictions.view(B, T, -1)
        return predictions

    def training_step(self, batch):
        states = batch.states
        actions = batch.actions
        
        init_states = states[:, 0:1]
        predictions = self.forward_prediction(init_states, actions)
        
        total_loss = 0.0
        accumulated_losses = {
            'total_loss': 0.0,
            'vicreg_sim_loss': 0.0,
            'vicreg_std_loss': 0.0,
            'vicreg_cov_loss': 0.0,
            'barlow_loss': 0.0
        }
        
        for t in range(actions.shape[1]):
            pred_state = predictions[:, t+1]
            target_obs = states[:, t+1]
            
            target_state = self.encoder(target_obs)
            
            pred_flat = pred_state.flatten(start_dim=1)
            target_flat = target_state.flatten(start_dim=1)
            
            loss, component_losses = self.criterion(pred_flat, target_flat)
            
            total_loss += loss
            for k in accumulated_losses:
                accumulated_losses[k] += component_losses[k]
        
        total_loss = total_loss / actions.shape[1]
        for k in accumulated_losses:
            accumulated_losses[k] /= actions.shape[1]
        
        return total_loss, predictions, accumulated_losses

def train_epoch(model, dataloader, optimizer, epoch, total_epochs):
    model.train()
    total_loss = 0.0
    train_losses = {
        'total_loss': 0.0,
        'vicreg_sim_loss': 0.0,
        'vicreg_std_loss': 0.0,
        'vicreg_cov_loss': 0.0,
        'barlow_loss': 0.0
    }
    
    progress_bar = tqdm(enumerate(dataloader), 
                       total=len(dataloader),
                       desc=f'Epoch {epoch}/{total_epochs}',
                       leave=True)
    
    start_time = time.time()
    batch_times = []
    
    for batch_idx, batch in progress_bar:
        batch_start = time.time()
        
        batch = batch._replace(
            states=batch.states.cuda(),
            actions=batch.actions.cuda(),
            locations=batch.locations.cuda() if batch.locations is not None else None
        )
        
        optimizer.zero_grad()
        loss, _, component_losses = model.training_step(batch)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        batch_time = time.time() - batch_start
        batch_times.append(batch_time)
        
        total_loss += loss.item()
        for k in train_losses:
            train_losses[k] += component_losses[k]
        
        current_loss = total_loss / (batch_idx + 1)
        current_batch_time = sum(batch_times) / len(batch_times)
        
        progress_bar.set_postfix({
            'loss': f'{current_loss:.4f}',
            'time/batch': f'{current_batch_time:.3f}s',
            'gpu_mem': f'{torch.cuda.max_memory_allocated()/1e9:.1f}GB'
        })
    
    num_batches = len(dataloader)
    total_loss /= num_batches
    for k in train_losses:
        train_losses[k] /= num_batches
    
    epoch_time = time.time() - start_time
    avg_batch_time = sum(batch_times) / len(batch_times)
    
    return total_loss, train_losses, {
        'epoch_time': epoch_time,
        'avg_batch_time': avg_batch_time
    }

def train_model(model, train_loader, num_epochs=100, learning_rate=3e-4, 
                save_dir='checkpoints/combined', save_frequency=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"Starting training at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Training for {num_epochs} epochs")
    print(f"Checkpoints will be saved to {save_dir}")
    
    for epoch in range(num_epochs):
        train_loss, train_losses, timing_stats = train_epoch(
            model, train_loader, optimizer, epoch, num_epochs
        )
        
        print(f"\nEpoch {epoch}/{num_epochs} Summary:")
        print(f"Train Loss: {train_loss:.4f}")
        print("Component Losses:")
        for k, v in train_losses.items():
            print(f"  {k}: {v:.4f}")
        print(f"Epoch Time: {timing_stats['epoch_time']:.1f}s")
        print(f"Avg Batch Time: {timing_stats['avg_batch_time']:.3f}s")
        print(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1024**2:.0f}MB")
        
        if (epoch + 1) % save_frequency == 0:
            checkpoint_path = save_dir / f"checkpoint_epoch_{epoch+1}.pt"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': train_loss,
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")
        
        evaluate_model("cuda", model, probe_train_ds, probe_val_ds)
        
        print("-" * 80)

if __name__ == "__main__":
    from dataset import create_wall_dataloader, WallSample
    
    # Create model with combined losses
    model = WorldModelCombined(
        vicreg_sim_coef=25.0,
        vicreg_std_coef=25.0,
        vicreg_cov_coef=1.0,
        barlow_lambda=0.005,
        loss_weight=0.5  # Equal weight between VICReg and Barlow
    ).cuda()
    
    # Create data loader
    train_loader = create_wall_dataloader(
        "/drive_reader/as16386/DL24FA/train",  # Update with your path
        batch_size=128,
        train=True
    )
    
    # Start training
    train_model(
        model=model,
        train_loader=train_loader,
        num_epochs=100,
        learning_rate=3e-5,
        save_dir='checkpoints/combined_barlow_vicreg',
        save_frequency=10
    )

Starting training at 2024-12-11 23:46:02
Training for 100 epochs
Checkpoints will be saved to checkpoints/combined_barlow_vicreg


Epoch 0/100: 100%|██████████| 1148/1148 [01:38<00:00, 11.66it/s, loss=3203.5392, time/batch=0.060s, gpu_mem=2.1GB]



Epoch 0/100 Summary:
Train Loss: 3203.5392
Component Losses:
  total_loss: 3203.5393
  vicreg_sim_loss: 19.9145
  vicreg_std_loss: 1.1452
  vicreg_cov_loss: 4325.8279
  barlow_loss: 1554.7584
Epoch Time: 98.4s
Avg Batch Time: 0.060s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 0.9376960396766663




normalized pred locations loss 0.6371876001358032


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.41it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:45,  2.39s/it]

normalized pred locations loss 0.25403302907943726


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.43it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:42,  2.39s/it]

normalized pred locations loss 0.17966483533382416




normalized pred locations loss 0.1303916573524475


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.71it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:40,  2.38s/it]

normalized pred locations loss 0.07833853363990784




normalized pred locations loss 0.07247161120176315


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.07it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.36s/it]

normalized pred locations loss 0.06488312035799026


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.60it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.35s/it]

normalized pred locations loss 0.08267024159431458




normalized pred locations loss 0.051201559603214264


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.77it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:33,  2.36s/it]

normalized pred locations loss 0.04753510281443596


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.77it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.37s/it]

normalized pred locations loss 0.05208192020654678




normalized pred locations loss 0.057380467653274536


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.36it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:28,  2.37s/it]

normalized pred locations loss 0.04069685935974121


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.91it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:26,  2.37s/it]

normalized pred locations loss 0.06324310600757599




normalized pred locations loss 0.03087588958442211


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.61it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.36s/it]

normalized pred locations loss 0.04728852957487106




normalized pred locations loss 0.031887322664260864


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.07it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:25<00:21,  2.35s/it]

normalized pred locations loss 0.04078087583184242


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.04it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:18,  2.35s/it]

normalized pred locations loss 0.031744226813316345




normalized pred locations loss 0.03322982043027878


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.16it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.34s/it]

normalized pred locations loss 0.04131370782852173


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.06it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:32<00:14,  2.34s/it]

normalized pred locations loss 0.04916258901357651




normalized pred locations loss 0.057727642357349396


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.29it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.34s/it]

normalized pred locations loss 0.03016722947359085


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.19it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.35s/it]

normalized pred locations loss 0.03691316023468971




normalized pred locations loss 0.033752068877220154


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.31it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:40<00:07,  2.34s/it]

normalized pred locations loss 0.04532919079065323


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.50it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.34s/it]

normalized pred locations loss 0.02293107472360134




normalized pred locations loss 0.031723521649837494


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.29it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:44<00:02,  2.33s/it]

normalized pred locations loss 0.03360387310385704




normalized pred locations loss 0.02406018227338791


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 67.18it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:47<00:00,  2.35s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.05it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 82.15it/s]


normal loss: 9.413141250610352
wall loss: 13.63965892791748
--------------------------------------------------------------------------------


Epoch 1/100: 100%|██████████| 1148/1148 [01:35<00:00, 12.00it/s, loss=550.7443, time/batch=0.060s, gpu_mem=2.1GB]



Epoch 1/100 Summary:
Train Loss: 550.7443
Component Losses:
  total_loss: 550.7443
  vicreg_sim_loss: 1.0924
  vicreg_std_loss: 1.3119
  vicreg_cov_loss: 3.9262
  barlow_loss: 1037.4544
Epoch Time: 95.6s
Avg Batch Time: 0.060s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.0023376941680908




normalized pred locations loss 0.31415143609046936


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.68it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:45,  2.38s/it]

normalized pred locations loss 0.1653321236371994


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 63.62it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:43,  2.42s/it]

normalized pred locations loss 0.09025569260120392




normalized pred locations loss 0.05110526829957962


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.80it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:41,  2.42s/it]

normalized pred locations loss 0.04095038026571274




normalized pred locations loss 0.03888165205717087


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.79it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:38,  2.41s/it]

normalized pred locations loss 0.045417431741952896


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.66it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:12<00:36,  2.40s/it]

normalized pred locations loss 0.032841939479112625




normalized pred locations loss 0.03323875740170479


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 61.98it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:34,  2.44s/it]

normalized pred locations loss 0.019496385008096695


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.29it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:31,  2.42s/it]

normalized pred locations loss 0.023280518129467964




normalized pred locations loss 0.03377694636583328


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.28it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:19<00:28,  2.40s/it]

normalized pred locations loss 0.02455635741353035


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.82it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:26,  2.39s/it]

normalized pred locations loss 0.022906474769115448




normalized pred locations loss 0.019477728754281998


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.01it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:24<00:23,  2.38s/it]

normalized pred locations loss 0.02501334622502327




normalized pred locations loss 0.018419766798615456


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.84it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:26<00:21,  2.39s/it]

normalized pred locations loss 0.02357698231935501


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.79it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:19,  2.40s/it]

normalized pred locations loss 0.02113538421690464




normalized pred locations loss 0.014344219118356705


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.24it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:31<00:16,  2.40s/it]

normalized pred locations loss 0.025809047743678093


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.93it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:33<00:14,  2.40s/it]

normalized pred locations loss 0.019979365170001984




normalized pred locations loss 0.030396517366170883


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.35it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.38s/it]

normalized pred locations loss 0.025969764217734337


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.23it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:38<00:09,  2.38s/it]

normalized pred locations loss 0.019981272518634796




normalized pred locations loss 0.01695307157933712


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.32it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:40<00:07,  2.39s/it]

normalized pred locations loss 0.025704383850097656


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.56it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:43<00:04,  2.38s/it]

normalized pred locations loss 0.011587917804718018




normalized pred locations loss 0.01435818150639534


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.10it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:45<00:02,  2.37s/it]

normalized pred locations loss 0.0150483762845397




normalized pred locations loss 0.023518946021795273


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.41it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:47<00:00,  2.39s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 80.90it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 81.79it/s]


normal loss: 5.6003031730651855
wall loss: 9.006675720214844
--------------------------------------------------------------------------------


Epoch 2/100: 100%|██████████| 1148/1148 [01:35<00:00, 12.05it/s, loss=477.1179, time/batch=0.060s, gpu_mem=2.1GB]



Epoch 2/100 Summary:
Train Loss: 477.1179
Component Losses:
  total_loss: 477.1179
  vicreg_sim_loss: 0.8151
  vicreg_std_loss: 1.2941
  vicreg_cov_loss: 3.8491
  barlow_loss: 897.6561
Epoch Time: 95.3s
Avg Batch Time: 0.060s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 0.9230179190635681




normalized pred locations loss 0.32507023215293884


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.87it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:44,  2.33s/it]

normalized pred locations loss 0.1248072013258934


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.61it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:42,  2.34s/it]

normalized pred locations loss 0.06462077796459198




normalized pred locations loss 0.049397021532058716


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.91it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:39,  2.34s/it]

normalized pred locations loss 0.030735857784748077




normalized pred locations loss 0.0426497608423233


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.43it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.36s/it]

normalized pred locations loss 0.03147435560822487


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.24it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.37s/it]

normalized pred locations loss 0.028577011078596115




normalized pred locations loss 0.0260672178119421


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.25it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:33,  2.38s/it]

normalized pred locations loss 0.024219505488872528


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.03it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:31,  2.39s/it]

normalized pred locations loss 0.02350440062582493




normalized pred locations loss 0.02099692076444626


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.03it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:28,  2.39s/it]

normalized pred locations loss 0.025054344907402992


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.13it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:26,  2.39s/it]

normalized pred locations loss 0.018244845792651176




normalized pred locations loss 0.017944006249308586


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.94it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:23<00:23,  2.38s/it]

normalized pred locations loss 0.021033311262726784




normalized pred locations loss 0.01874559558928013


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.26it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:26<00:21,  2.38s/it]

normalized pred locations loss 0.023098798468708992


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.67it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:28<00:18,  2.37s/it]

normalized pred locations loss 0.016276801005005836




normalized pred locations loss 0.020195569843053818


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.79it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:30<00:16,  2.36s/it]

normalized pred locations loss 0.013011409901082516


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.19it/s]
Probe prediction epochs:  70%|███████   | 14/20 [00:33<00:14,  2.36s/it]

normalized pred locations loss 0.015155903063714504




normalized pred locations loss 0.01861622929573059


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.32it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [00:35<00:11,  2.37s/it]

normalized pred locations loss 0.014905309304594994


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.83it/s]
Probe prediction epochs:  80%|████████  | 16/20 [00:37<00:09,  2.36s/it]

normalized pred locations loss 0.010338465683162212




normalized pred locations loss 0.013317218981683254


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.30it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [00:40<00:07,  2.36s/it]

normalized pred locations loss 0.012805543839931488


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.87it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [00:42<00:04,  2.37s/it]

normalized pred locations loss 0.011336009949445724




normalized pred locations loss 0.02412687987089157


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.08it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [00:45<00:02,  2.38s/it]

normalized pred locations loss 0.014533217996358871




normalized pred locations loss 0.014635167084634304


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 65.20it/s]
Probe prediction epochs: 100%|██████████| 20/20 [00:47<00:00,  2.37s/it]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 81.87it/s]
Eval probe pred: 100%|██████████| 62/62 [00:00<00:00, 81.90it/s]


normal loss: 5.108015060424805
wall loss: 9.598150253295898
--------------------------------------------------------------------------------


Epoch 3/100: 100%|██████████| 1148/1148 [01:35<00:00, 12.05it/s, loss=422.8223, time/batch=0.060s, gpu_mem=2.1GB]



Epoch 3/100 Summary:
Train Loss: 422.8223
Component Losses:
  total_loss: 422.8223
  vicreg_sim_loss: 0.6356
  vicreg_std_loss: 1.2585
  vicreg_cov_loss: 4.0150
  barlow_loss: 794.2782
Epoch Time: 95.2s
Avg Batch Time: 0.060s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.0278675556182861




normalized pred locations loss 0.2906845211982727


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.40it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:02<00:44,  2.35s/it]

normalized pred locations loss 0.09027080237865448


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 64.09it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:04<00:43,  2.40s/it]

normalized pred locations loss 0.07862785458564758




normalized pred locations loss 0.04771977663040161


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.41it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:07<00:40,  2.38s/it]

normalized pred locations loss 0.051902759820222855




normalized pred locations loss 0.04585621505975723


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.19it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:09<00:37,  2.37s/it]

normalized pred locations loss 0.03556545451283455


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.40it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:11<00:35,  2.36s/it]

normalized pred locations loss 0.027025381103157997




normalized pred locations loss 0.025769881904125214


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.72it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:14<00:32,  2.36s/it]

normalized pred locations loss 0.027306167408823967


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.75it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:16<00:30,  2.35s/it]

normalized pred locations loss 0.023149756714701653




normalized pred locations loss 0.025721168145537376


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 66.70it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:18<00:28,  2.35s/it]

normalized pred locations loss 0.01777808554470539


Probe prediction step: 100%|██████████| 156/156 [00:02<00:00, 52.85it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:21<00:27,  2.54s/it]

normalized pred locations loss 0.019694821909070015




normalized pred locations loss 0.01640874333679676


Probe prediction step: 100%|██████████| 156/156 [00:07<00:00, 19.53it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:29<00:42,  4.22s/it]

normalized pred locations loss 0.013118133880198002




normalized pred locations loss 0.024187440052628517


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.47it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:38<00:49,  5.51s/it]

normalized pred locations loss 0.01482419017702341


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.69it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:46<00:51,  6.38s/it]

normalized pred locations loss 0.013843890279531479




normalized pred locations loss 0.011777295731008053


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.34it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [00:55<00:49,  7.02s/it]

normalized pred locations loss 0.014471795409917831


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.30it/s]
Probe prediction epochs:  70%|███████   | 14/20 [01:03<00:44,  7.48s/it]

normalized pred locations loss 0.01123635470867157




normalized pred locations loss 0.010566522367298603


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.50it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [01:12<00:38,  7.77s/it]

normalized pred locations loss 0.02067713811993599


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.43it/s]
Probe prediction epochs:  80%|████████  | 16/20 [01:20<00:31,  7.98s/it]

normalized pred locations loss 0.011646684259176254




normalized pred locations loss 0.020898059010505676


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.39it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [01:29<00:24,  8.13s/it]

normalized pred locations loss 0.013456928543746471




normalized pred locations loss 0.014427107758820057


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.59it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [01:37<00:16,  8.21s/it]

normalized pred locations loss 0.01452112477272749


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.24it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [01:45<00:08,  8.31s/it]

normalized pred locations loss 0.01448891032487154




normalized pred locations loss 0.012343911454081535


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.23it/s]
Probe prediction epochs: 100%|██████████| 20/20 [01:54<00:00,  5.73s/it]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 22.72it/s]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 22.72it/s]


normal loss: 5.2224931716918945
wall loss: 9.587729454040527
--------------------------------------------------------------------------------


Epoch 4/100: 100%|██████████| 1148/1148 [02:46<00:00,  6.88it/s, loss=387.3064, time/batch=0.081s, gpu_mem=2.1GB]



Epoch 4/100 Summary:
Train Loss: 387.3064
Component Losses:
  total_loss: 387.3064
  vicreg_sim_loss: 0.5763
  vicreg_std_loss: 1.2355
  vicreg_cov_loss: 4.1255
  barlow_loss: 725.1944
Epoch Time: 166.9s
Avg Batch Time: 0.081s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 0.9910621643066406




normalized pred locations loss 0.29826024174690247


Probe prediction step: 100%|██████████| 156/156 [00:15<00:00, 10.15it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:15<04:52, 15.38s/it]

normalized pred locations loss 0.10168688744306564




normalized pred locations loss 0.064509816467762


Probe prediction step: 100%|██████████| 156/156 [00:15<00:00,  9.99it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:30<04:39, 15.52s/it]

normalized pred locations loss 0.04079660773277283


Probe prediction step: 100%|██████████| 156/156 [00:15<00:00,  9.86it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:46<04:26, 15.66s/it]

normalized pred locations loss 0.04320589080452919




normalized pred locations loss 0.035622190684080124


Probe prediction step: 100%|██████████| 156/156 [00:15<00:00,  9.82it/s]
Probe prediction epochs:  20%|██        | 4/20 [01:02<04:12, 15.75s/it]

normalized pred locations loss 0.03190445154905319


Probe prediction step: 100%|██████████| 156/156 [00:12<00:00, 12.33it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [01:15<03:39, 14.63s/it]

normalized pred locations loss 0.026011331006884575




normalized pred locations loss 0.02423352561891079


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.26it/s]
Probe prediction epochs:  30%|███       | 6/20 [01:27<03:11, 13.66s/it]

normalized pred locations loss 0.02321905642747879


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.39it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [01:38<02:49, 13.00s/it]

normalized pred locations loss 0.016997694969177246




normalized pred locations loss 0.02190728671848774


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.57it/s]
Probe prediction epochs:  40%|████      | 8/20 [01:50<02:30, 12.52s/it]

normalized pred locations loss 0.013824081979691982




normalized pred locations loss 0.017275551334023476


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.44it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [02:01<02:14, 12.24s/it]

normalized pred locations loss 0.019892405718564987


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.68it/s]
Probe prediction epochs:  50%|█████     | 10/20 [02:13<01:59, 11.98s/it]

normalized pred locations loss 0.014713984914124012




normalized pred locations loss 0.014093809761106968


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.60it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [02:24<01:46, 11.83s/it]

normalized pred locations loss 0.014956986531615257


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 22.71it/s]
Probe prediction epochs:  60%|██████    | 12/20 [02:31<01:22, 10.32s/it]

normalized pred locations loss 0.011259238235652447




normalized pred locations loss 0.01218146737664938


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.38it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [02:38<01:04,  9.22s/it]

normalized pred locations loss 0.014063125476241112


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.30it/s]
Probe prediction epochs:  70%|███████   | 14/20 [02:45<00:50,  8.45s/it]

normalized pred locations loss 0.00822187028825283




normalized pred locations loss 0.010014520958065987


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.16it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [02:51<00:39,  7.94s/it]

normalized pred locations loss 0.014578735455870628


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.41it/s]
Probe prediction epochs:  80%|████████  | 16/20 [02:58<00:30,  7.55s/it]

normalized pred locations loss 0.010717449709773064




normalized pred locations loss 0.011882757768034935


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.21it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [03:05<00:21,  7.30s/it]

normalized pred locations loss 0.012373940087854862




normalized pred locations loss 0.01212928257882595


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.43it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [03:11<00:14,  7.11s/it]

normalized pred locations loss 0.010053812526166439


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.74it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [03:18<00:06,  6.95s/it]

normalized pred locations loss 0.010901552625000477




normalized pred locations loss 0.016516188159585


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.23it/s]
Probe prediction epochs: 100%|██████████| 20/20 [03:25<00:00, 10.25s/it]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 25.65it/s]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 25.51it/s]


normal loss: 4.544832706451416
wall loss: 8.753279685974121
--------------------------------------------------------------------------------


Epoch 5/100: 100%|██████████| 1148/1148 [02:30<00:00,  7.61it/s, loss=362.9658, time/batch=0.081s, gpu_mem=2.1GB]



Epoch 5/100 Summary:
Train Loss: 362.9658
Component Losses:
  total_loss: 362.9658
  vicreg_sim_loss: 0.5497
  vicreg_std_loss: 1.2138
  vicreg_cov_loss: 4.3485
  barlow_loss: 677.4955
Epoch Time: 150.9s
Avg Batch Time: 0.081s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.1006544828414917




normalized pred locations loss 0.21354421973228455


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.21it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:10<03:28, 10.98s/it]

normalized pred locations loss 0.0520792156457901




normalized pred locations loss 0.04799184575676918


Probe prediction step: 100%|██████████| 156/156 [00:15<00:00, 10.04it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:26<04:05, 13.66s/it]

normalized pred locations loss 0.04357154294848442


Probe prediction step: 100%|██████████| 156/156 [00:16<00:00,  9.67it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:42<04:11, 14.79s/it]

normalized pred locations loss 0.021923374384641647




normalized pred locations loss 0.027917806059122086


Probe prediction step: 100%|██████████| 156/156 [00:16<00:00,  9.59it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:58<04:05, 15.37s/it]

normalized pred locations loss 0.02478799782693386


Probe prediction step: 100%|██████████| 156/156 [00:15<00:00,  9.83it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [01:14<03:53, 15.55s/it]

normalized pred locations loss 0.021802369505167007




normalized pred locations loss 0.018761921674013138


Probe prediction step: 100%|██████████| 156/156 [00:12<00:00, 12.07it/s]
Probe prediction epochs:  30%|███       | 6/20 [01:27<03:25, 14.66s/it]

normalized pred locations loss 0.015689687803387642


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.25it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [01:39<02:58, 13.72s/it]

normalized pred locations loss 0.017353303730487823




normalized pred locations loss 0.015143471769988537


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.22it/s]
Probe prediction epochs:  40%|████      | 8/20 [01:51<02:37, 13.11s/it]

normalized pred locations loss 0.019355328753590584




normalized pred locations loss 0.01628505066037178


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.31it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [02:03<02:19, 12.67s/it]

normalized pred locations loss 0.019625751301646233


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.40it/s]
Probe prediction epochs:  50%|█████     | 10/20 [02:14<02:03, 12.35s/it]

normalized pred locations loss 0.013184839859604836




normalized pred locations loss 0.012028589844703674


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.28it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [02:26<01:49, 12.17s/it]

normalized pred locations loss 0.00930896308273077


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.44it/s]
Probe prediction epochs:  60%|██████    | 12/20 [02:38<01:35, 12.00s/it]

normalized pred locations loss 0.010739507153630257




normalized pred locations loss 0.010605070739984512


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.35it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [02:49<01:23, 11.90s/it]

normalized pred locations loss 0.008065637201070786


Probe prediction step: 100%|██████████| 156/156 [00:07<00:00, 20.05it/s]
Probe prediction epochs:  70%|███████   | 14/20 [02:57<01:03, 10.66s/it]

normalized pred locations loss 0.0062026637606322765




normalized pred locations loss 0.007453978061676025


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.43it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [03:04<00:47,  9.45s/it]

normalized pred locations loss 0.009587853215634823


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.16it/s]
Probe prediction epochs:  80%|████████  | 16/20 [03:10<00:34,  8.64s/it]

normalized pred locations loss 0.008752595633268356




normalized pred locations loss 0.007468451280146837


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.56it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [03:17<00:24,  8.03s/it]

normalized pred locations loss 0.005874605383723974




normalized pred locations loss 0.008625411428511143


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.24it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [03:24<00:15,  7.64s/it]

normalized pred locations loss 0.007429595571011305


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.03it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [03:30<00:07,  7.38s/it]

normalized pred locations loss 0.009378213435411453




normalized pred locations loss 0.008011561818420887


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.37it/s]
Probe prediction epochs: 100%|██████████| 20/20 [03:37<00:00, 10.88s/it]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 25.31it/s]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 25.17it/s]


normal loss: 3.7805418968200684
wall loss: 7.500100135803223
--------------------------------------------------------------------------------


Epoch 6/100: 100%|██████████| 1148/1148 [02:28<00:00,  7.75it/s, loss=348.8506, time/batch=0.078s, gpu_mem=2.1GB]



Epoch 6/100 Summary:
Train Loss: 348.8506
Component Losses:
  total_loss: 348.8506
  vicreg_sim_loss: 0.5259
  vicreg_std_loss: 1.1961
  vicreg_cov_loss: 4.6052
  barlow_loss: 650.0479
Epoch Time: 148.1s
Avg Batch Time: 0.078s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.0211379528045654




normalized pred locations loss 0.2021535485982895


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.09it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:11<03:30, 11.07s/it]

normalized pred locations loss 0.06836000084877014




normalized pred locations loss 0.035102225840091705


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.18it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:22<03:18, 11.03s/it]

normalized pred locations loss 0.03627360612154007


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.21it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:33<03:07, 11.01s/it]

normalized pred locations loss 0.03417915850877762




normalized pred locations loss 0.023671651259064674


Probe prediction step: 100%|██████████| 156/156 [00:15<00:00,  9.86it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:48<03:26, 12.91s/it]

normalized pred locations loss 0.028127722442150116


Probe prediction step: 100%|██████████| 156/156 [00:16<00:00,  9.57it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [01:05<03:32, 14.13s/it]

normalized pred locations loss 0.025632422417402267




normalized pred locations loss 0.019743965938687325


Probe prediction step: 100%|██████████| 156/156 [00:16<00:00,  9.66it/s]
Probe prediction epochs:  30%|███       | 6/20 [01:21<03:27, 14.82s/it]

normalized pred locations loss 0.01909780129790306


Probe prediction step: 100%|██████████| 156/156 [00:12<00:00, 12.90it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [01:33<03:01, 13.93s/it]

normalized pred locations loss 0.017466865479946136




normalized pred locations loss 0.01582260988652706


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.41it/s]
Probe prediction epochs:  40%|████      | 8/20 [01:45<02:38, 13.20s/it]

normalized pred locations loss 0.022482918575406075




normalized pred locations loss 0.014700101688504219


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.27it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [01:56<02:20, 12.75s/it]

normalized pred locations loss 0.017913933843374252


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.43it/s]
Probe prediction epochs:  50%|█████     | 10/20 [02:08<02:03, 12.40s/it]

normalized pred locations loss 0.01527661457657814




normalized pred locations loss 0.013658752664923668


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.19it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [02:20<01:50, 12.22s/it]

normalized pred locations loss 0.01462255697697401


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.12it/s]
Probe prediction epochs:  60%|██████    | 12/20 [02:32<01:36, 12.12s/it]

normalized pred locations loss 0.010301073081791401




normalized pred locations loss 0.012264547869563103


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.35it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [02:43<01:23, 11.99s/it]

normalized pred locations loss 0.009511919692158699


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.18it/s]
Probe prediction epochs:  70%|███████   | 14/20 [02:55<01:11, 11.94s/it]

normalized pred locations loss 0.009088568389415741




normalized pred locations loss 0.007246833760291338


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.31it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [03:07<00:59, 11.88s/it]

normalized pred locations loss 0.00748120341449976


Probe prediction step: 100%|██████████| 156/156 [00:07<00:00, 20.33it/s]
Probe prediction epochs:  80%|████████  | 16/20 [03:15<00:42, 10.61s/it]

normalized pred locations loss 0.012915411032736301




normalized pred locations loss 0.008456096053123474


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.65it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [03:21<00:28,  9.41s/it]

normalized pred locations loss 0.0067851124331355095




normalized pred locations loss 0.008644829504191875


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.38it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [03:28<00:17,  8.58s/it]

normalized pred locations loss 0.007297718431800604


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.24it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [03:35<00:08,  8.02s/it]

normalized pred locations loss 0.007748263888061047




normalized pred locations loss 0.0077883354388177395


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.22it/s]
Probe prediction epochs: 100%|██████████| 20/20 [03:41<00:00, 11.09s/it]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 25.28it/s]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 25.33it/s]


normal loss: 3.5636684894561768
wall loss: 7.645246505737305
--------------------------------------------------------------------------------


Epoch 7/100: 100%|██████████| 1148/1148 [02:25<00:00,  7.89it/s, loss=336.1240, time/batch=0.077s, gpu_mem=2.1GB]



Epoch 7/100 Summary:
Train Loss: 336.1240
Component Losses:
  total_loss: 336.1240
  vicreg_sim_loss: 0.5043
  vicreg_std_loss: 1.1811
  vicreg_cov_loss: 4.8279
  barlow_loss: 625.2845
Epoch Time: 145.4s
Avg Batch Time: 0.077s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.2009226083755493




normalized pred locations loss 0.21208137273788452


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.52it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:10<03:24, 10.74s/it]

normalized pred locations loss 0.08626123517751694




normalized pred locations loss 0.03508992865681648


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.07it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:21<03:17, 10.95s/it]

normalized pred locations loss 0.04014768451452255


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.15it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:32<03:06, 10.98s/it]

normalized pred locations loss 0.03663985803723335




normalized pred locations loss 0.036758679896593094


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.03it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:43<02:56, 11.04s/it]

normalized pred locations loss 0.023179326206445694


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.30it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:54<02:44, 10.99s/it]

normalized pred locations loss 0.02041974663734436




normalized pred locations loss 0.025276266038417816


Probe prediction step: 100%|██████████| 156/156 [00:15<00:00,  9.94it/s]
Probe prediction epochs:  30%|███       | 6/20 [01:10<02:56, 12.59s/it]

normalized pred locations loss 0.016823457553982735


Probe prediction step: 100%|██████████| 156/156 [00:14<00:00, 10.43it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [01:25<02:53, 13.37s/it]

normalized pred locations loss 0.022371506318449974




normalized pred locations loss 0.01757776364684105


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.30it/s]
Probe prediction epochs:  40%|████      | 8/20 [01:37<02:34, 12.84s/it]

normalized pred locations loss 0.02419067546725273




normalized pred locations loss 0.02241969108581543


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.32it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [01:48<02:17, 12.49s/it]

normalized pred locations loss 0.016823401674628258


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.34it/s]
Probe prediction epochs:  50%|█████     | 10/20 [02:00<02:02, 12.25s/it]

normalized pred locations loss 0.012926500290632248




normalized pred locations loss 0.017348723486065865


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.11it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [02:12<01:49, 12.14s/it]

normalized pred locations loss 0.013151814229786396


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.26it/s]
Probe prediction epochs:  60%|██████    | 12/20 [02:24<01:36, 12.03s/it]

normalized pred locations loss 0.010352770797908306




normalized pred locations loss 0.013182537630200386


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.26it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [02:36<01:23, 11.95s/it]

normalized pred locations loss 0.01105441153049469


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.25it/s]
Probe prediction epochs:  70%|███████   | 14/20 [02:47<01:11, 11.89s/it]

normalized pred locations loss 0.011669248342514038




normalized pred locations loss 0.009606442414224148


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.33it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [02:59<00:59, 11.84s/it]

normalized pred locations loss 0.008748154155910015


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.44it/s]
Probe prediction epochs:  80%|████████  | 16/20 [03:11<00:47, 11.77s/it]

normalized pred locations loss 0.01287667267024517




normalized pred locations loss 0.010554035194218159


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.36it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [03:22<00:35, 11.74s/it]

normalized pred locations loss 0.012675700709223747




normalized pred locations loss 0.011416669003665447


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 19.43it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [03:30<00:21, 10.63s/it]

normalized pred locations loss 0.010702389292418957


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 22.96it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [03:37<00:09,  9.48s/it]

normalized pred locations loss 0.00896525103598833




normalized pred locations loss 0.008025922812521458


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 23.37it/s]
Probe prediction epochs: 100%|██████████| 20/20 [03:44<00:00, 11.22s/it]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 25.15it/s]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 25.21it/s]


normal loss: 3.7678537368774414
wall loss: 7.248698711395264
--------------------------------------------------------------------------------


Epoch 8/100: 100%|██████████| 1148/1148 [02:22<00:00,  8.06it/s, loss=327.3117, time/batch=0.076s, gpu_mem=2.1GB]



Epoch 8/100 Summary:
Train Loss: 327.3117
Component Losses:
  total_loss: 327.3117
  vicreg_sim_loss: 0.4840
  vicreg_std_loss: 1.1750
  vicreg_cov_loss: 4.8546
  barlow_loss: 608.2954
Epoch Time: 142.4s
Avg Batch Time: 0.076s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.0482534170150757




normalized pred locations loss 0.15378132462501526


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.13it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:11<03:29, 11.04s/it]

normalized pred locations loss 0.06752688437700272




normalized pred locations loss 0.048029497265815735


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.23it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:22<03:17, 11.00s/it]

normalized pred locations loss 0.02828121930360794


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.08it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:33<03:07, 11.04s/it]

normalized pred locations loss 0.037834249436855316




normalized pred locations loss 0.0202847421169281


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.49it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:43<02:54, 10.93s/it]

normalized pred locations loss 0.022963209077715874


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.40it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:54<02:43, 10.90s/it]

normalized pred locations loss 0.01852954737842083




normalized pred locations loss 0.017999565228819847


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.35it/s]
Probe prediction epochs:  30%|███       | 6/20 [01:05<02:32, 10.89s/it]

normalized pred locations loss 0.01707381382584572


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.32it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [01:16<02:21, 10.89s/it]

normalized pred locations loss 0.01764286682009697




normalized pred locations loss 0.0181454885751009


Probe prediction step: 100%|██████████| 156/156 [00:13<00:00, 11.22it/s]
Probe prediction epochs:  40%|████      | 8/20 [01:30<02:22, 11.85s/it]

normalized pred locations loss 0.016440289095044136




normalized pred locations loss 0.013467704877257347


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.18it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [01:42<02:10, 11.85s/it]

normalized pred locations loss 0.013533852063119411


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.28it/s]
Probe prediction epochs:  50%|█████     | 10/20 [01:53<01:58, 11.82s/it]

normalized pred locations loss 0.01879545859992504




normalized pred locations loss 0.009867779910564423


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.20it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [02:05<01:46, 11.82s/it]

normalized pred locations loss 0.017523296177387238


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.43it/s]
Probe prediction epochs:  60%|██████    | 12/20 [02:17<01:34, 11.76s/it]

normalized pred locations loss 0.018665861338377




normalized pred locations loss 0.011452412232756615


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.43it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [02:28<01:21, 11.71s/it]

normalized pred locations loss 0.011179137043654919


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.32it/s]
Probe prediction epochs:  70%|███████   | 14/20 [02:40<01:10, 11.71s/it]

normalized pred locations loss 0.006431798916310072




normalized pred locations loss 0.009408630430698395


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.34it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [02:52<00:58, 11.71s/it]

normalized pred locations loss 0.007983547635376453


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.25it/s]
Probe prediction epochs:  80%|████████  | 16/20 [03:04<00:46, 11.73s/it]

normalized pred locations loss 0.009703127667307854




normalized pred locations loss 0.012154152616858482


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.40it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [03:15<00:35, 11.70s/it]

normalized pred locations loss 0.00623256666585803




normalized pred locations loss 0.005807817447930574


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.20it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [03:27<00:23, 11.74s/it]

normalized pred locations loss 0.009952612221240997


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.27it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [03:39<00:11, 11.74s/it]

normalized pred locations loss 0.009786785580217838




normalized pred locations loss 0.008031783625483513


Probe prediction step: 100%|██████████| 156/156 [00:07<00:00, 19.53it/s]
Probe prediction epochs: 100%|██████████| 20/20 [03:47<00:00, 11.37s/it]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 25.33it/s]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 25.29it/s]


normal loss: 3.4674692153930664
wall loss: 7.241121768951416
--------------------------------------------------------------------------------


Epoch 9/100: 100%|██████████| 1148/1148 [02:19<00:00,  8.21it/s, loss=318.7579, time/batch=0.076s, gpu_mem=2.1GB]



Epoch 9/100 Summary:
Train Loss: 318.7579
Component Losses:
  total_loss: 318.7579
  vicreg_sim_loss: 0.4656
  vicreg_std_loss: 1.1699
  vicreg_cov_loss: 4.8592
  barlow_loss: 591.7691
Epoch Time: 139.8s
Avg Batch Time: 0.076s
GPU Memory: 1988MB
Checkpoint saved to checkpoints/combined_barlow_vicreg/checkpoint_epoch_10.pt


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.0668574571609497




normalized pred locations loss 0.10879135131835938


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.20it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:10<03:28, 10.99s/it]

normalized pred locations loss 0.04865381866693497




normalized pred locations loss 0.033261243253946304


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.07it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:22<03:18, 11.05s/it]

normalized pred locations loss 0.0377991758286953


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.09it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:33<03:08, 11.06s/it]

normalized pred locations loss 0.032751183956861496




normalized pred locations loss 0.02940683253109455


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.11it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:44<02:56, 11.06s/it]

normalized pred locations loss 0.02367246523499489


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.91it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:54<02:42, 10.84s/it]

normalized pred locations loss 0.019910143688321114




normalized pred locations loss 0.021201591938734055


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.14it/s]
Probe prediction epochs:  30%|███       | 6/20 [01:05<02:32, 10.91s/it]

normalized pred locations loss 0.014675806276500225


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.12it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [01:16<02:22, 10.96s/it]

normalized pred locations loss 0.018343757838010788




normalized pred locations loss 0.022522352635860443


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.32it/s]
Probe prediction epochs:  40%|████      | 8/20 [01:27<02:11, 10.94s/it]

normalized pred locations loss 0.012902108952403069


Probe prediction step: 100%|██████████| 156/156 [00:07<00:00, 20.59it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [01:35<01:48,  9.89s/it]

normalized pred locations loss 0.013167341239750385




normalized pred locations loss 0.010864177718758583


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.23it/s]
Probe prediction epochs:  50%|█████     | 10/20 [01:47<01:44, 10.48s/it]

normalized pred locations loss 0.008942804299294949




normalized pred locations loss 0.0123186856508255


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.20it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [01:58<01:37, 10.89s/it]

normalized pred locations loss 0.01145927980542183


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.30it/s]
Probe prediction epochs:  60%|██████    | 12/20 [02:10<01:29, 11.14s/it]

normalized pred locations loss 0.010985251516103745




normalized pred locations loss 0.008406679145991802


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.20it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [02:22<01:19, 11.35s/it]

normalized pred locations loss 0.01074040774255991


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.35it/s]
Probe prediction epochs:  70%|███████   | 14/20 [02:34<01:08, 11.45s/it]

normalized pred locations loss 0.0070217326283454895




normalized pred locations loss 0.009316515177488327


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.30it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [02:45<00:57, 11.54s/it]

normalized pred locations loss 0.007269103545695543


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.47it/s]
Probe prediction epochs:  80%|████████  | 16/20 [02:57<00:46, 11.55s/it]

normalized pred locations loss 0.009873139671981335




normalized pred locations loss 0.010334819555282593


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.43it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [03:09<00:34, 11.57s/it]

normalized pred locations loss 0.007153197191655636




normalized pred locations loss 0.006821429822593927


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.24it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [03:20<00:23, 11.63s/it]

normalized pred locations loss 0.011238587088882923


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.22it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [03:32<00:11, 11.69s/it]

normalized pred locations loss 0.005495882127434015




normalized pred locations loss 0.006472495384514332


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 13.43it/s]
Probe prediction epochs: 100%|██████████| 20/20 [03:44<00:00, 11.21s/it]
Eval probe pred: 100%|██████████| 62/62 [00:04<00:00, 14.13it/s]
Eval probe pred: 100%|██████████| 62/62 [00:04<00:00, 14.09it/s]


normal loss: 3.3981287479400635
wall loss: 6.827298164367676
--------------------------------------------------------------------------------


Epoch 10/100: 100%|██████████| 1148/1148 [02:20<00:00,  8.20it/s, loss=311.6711, time/batch=0.074s, gpu_mem=2.1GB]



Epoch 10/100 Summary:
Train Loss: 311.6711
Component Losses:
  total_loss: 311.6711
  vicreg_sim_loss: 0.4498
  vicreg_std_loss: 1.1657
  vicreg_cov_loss: 4.8647
  barlow_loss: 578.0919
Epoch Time: 140.1s
Avg Batch Time: 0.074s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.0157663822174072




normalized pred locations loss 0.10308150947093964


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.09it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:11<03:30, 11.07s/it]

normalized pred locations loss 0.036600761115550995




normalized pred locations loss 0.037433698773384094


Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.34it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:21<03:17, 10.96s/it]

normalized pred locations loss 0.025931725278496742


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.10it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:33<03:07, 11.01s/it]

normalized pred locations loss 0.02330898679792881




normalized pred locations loss 0.023060329258441925


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.11it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:44<02:56, 11.03s/it]

normalized pred locations loss 0.02499234862625599


Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.16it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:55<02:45, 11.02s/it]

normalized pred locations loss 0.021769821643829346




normalized pred locations loss 0.012792025692760944


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.45it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:59<02:02,  8.73s/it]

normalized pred locations loss 0.01248762384057045


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.15it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [01:03<01:34,  7.29s/it]

normalized pred locations loss 0.011737323366105556




normalized pred locations loss 0.015755122527480125


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.89it/s]
Probe prediction epochs:  40%|████      | 8/20 [01:07<01:15,  6.32s/it]

normalized pred locations loss 0.014278591610491276


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 35.57it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [01:12<01:02,  5.71s/it]

normalized pred locations loss 0.014791339635848999




normalized pred locations loss 0.008142620325088501


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.24it/s]
Probe prediction epochs:  50%|█████     | 10/20 [01:16<00:52,  5.28s/it]

normalized pred locations loss 0.00954383797943592




normalized pred locations loss 0.0109332874417305


Probe prediction step: 100%|██████████| 156/156 [00:05<00:00, 30.32it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [01:21<00:47,  5.24s/it]

normalized pred locations loss 0.008598017506301403


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.73it/s]
Probe prediction epochs:  60%|██████    | 12/20 [01:30<00:49,  6.18s/it]

normalized pred locations loss 0.009986872784793377




normalized pred locations loss 0.006612148601561785


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.83it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [01:38<00:47,  6.82s/it]

normalized pred locations loss 0.0064377933740615845


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.67it/s]
Probe prediction epochs:  70%|███████   | 14/20 [01:46<00:43,  7.28s/it]

normalized pred locations loss 0.008437287993729115




normalized pred locations loss 0.008817914873361588


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.89it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [01:54<00:37,  7.58s/it]

normalized pred locations loss 0.006727051455527544


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.81it/s]
Probe prediction epochs:  80%|████████  | 16/20 [02:03<00:31,  7.79s/it]

normalized pred locations loss 0.007297968026250601




normalized pred locations loss 0.006364633794873953


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.85it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [02:11<00:23,  7.94s/it]

normalized pred locations loss 0.006291055586189032




normalized pred locations loss 0.005795670207589865


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.83it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [02:19<00:16,  8.04s/it]

normalized pred locations loss 0.007840332575142384


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.85it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [02:28<00:08,  8.11s/it]

normalized pred locations loss 0.004583344794809818




normalized pred locations loss 0.006607231218367815


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.87it/s]
Probe prediction epochs: 100%|██████████| 20/20 [02:36<00:00,  7.82s/it]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 22.00it/s]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 21.98it/s]


normal loss: 2.9325318336486816
wall loss: 6.417964458465576
--------------------------------------------------------------------------------


Epoch 11/100: 100%|██████████| 1148/1148 [02:16<00:00,  8.41it/s, loss=303.8026, time/batch=0.070s, gpu_mem=2.1GB]



Epoch 11/100 Summary:
Train Loss: 303.8026
Component Losses:
  total_loss: 303.8026
  vicreg_sim_loss: 0.4354
  vicreg_std_loss: 1.1638
  vicreg_cov_loss: 4.8367
  barlow_loss: 562.7897
Epoch Time: 136.6s
Avg Batch Time: 0.070s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.1130744218826294




normalized pred locations loss 0.10561424493789673


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.82it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:04<01:20,  4.24s/it]

normalized pred locations loss 0.043449580669403076




normalized pred locations loss 0.03518408164381981


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.61it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:08<01:16,  4.25s/it]

normalized pred locations loss 0.03996364027261734


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 37.02it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:12<01:12,  4.24s/it]

normalized pred locations loss 0.033840395510196686




normalized pred locations loss 0.022899135947227478


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 37.62it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:16<01:07,  4.20s/it]

normalized pred locations loss 0.02319304831326008


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 37.00it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:21<01:03,  4.21s/it]

normalized pred locations loss 0.016315165907144547




normalized pred locations loss 0.012710406444966793


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 35.21it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:25<00:59,  4.28s/it]

normalized pred locations loss 0.024172132834792137


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.13it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:29<00:55,  4.30s/it]

normalized pred locations loss 0.015023080632090569




normalized pred locations loss 0.013956958428025246


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 35.87it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:34<00:51,  4.31s/it]

normalized pred locations loss 0.016089562326669693


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 37.23it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:38<00:47,  4.28s/it]

normalized pred locations loss 0.012782882899045944




normalized pred locations loss 0.010631068609654903


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.72it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:42<00:42,  4.27s/it]

normalized pred locations loss 0.009423069655895233




normalized pred locations loss 0.007935437373816967


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.10it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:46<00:38,  4.28s/it]

normalized pred locations loss 0.011089473031461239


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 24.47it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:53<00:39,  4.92s/it]

normalized pred locations loss 0.00879609677940607




normalized pred locations loss 0.00984334759414196


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.75it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [01:01<00:41,  5.95s/it]

normalized pred locations loss 0.006700395606458187


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.76it/s]
Probe prediction epochs:  70%|███████   | 14/20 [01:09<00:39,  6.67s/it]

normalized pred locations loss 0.007196655962616205




normalized pred locations loss 0.007155961357057095


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.78it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [01:18<00:35,  7.16s/it]

normalized pred locations loss 0.011353385634720325


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.83it/s]
Probe prediction epochs:  80%|████████  | 16/20 [01:26<00:29,  7.50s/it]

normalized pred locations loss 0.0060706487856805325




normalized pred locations loss 0.006199805997312069


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.82it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [01:34<00:23,  7.74s/it]

normalized pred locations loss 0.007096145302057266




normalized pred locations loss 0.007647455669939518


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.78it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [01:43<00:15,  7.91s/it]

normalized pred locations loss 0.006580061744898558


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.74it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [01:51<00:08,  8.03s/it]

normalized pred locations loss 0.007776868529617786




normalized pred locations loss 0.006530216429382563


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.68it/s]
Probe prediction epochs: 100%|██████████| 20/20 [01:59<00:00,  5.99s/it]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 22.02it/s]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 21.88it/s]


normal loss: 3.094120740890503
wall loss: 6.57526159286499
--------------------------------------------------------------------------------


Epoch 12/100: 100%|██████████| 1148/1148 [02:19<00:00,  8.24it/s, loss=298.8013, time/batch=0.071s, gpu_mem=2.1GB]



Epoch 12/100 Summary:
Train Loss: 298.8013
Component Losses:
  total_loss: 298.8013
  vicreg_sim_loss: 0.4205
  vicreg_std_loss: 1.1642
  vicreg_cov_loss: 4.7536
  barlow_loss: 553.2330
Epoch Time: 139.4s
Avg Batch Time: 0.071s
GPU Memory: 1988MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.1258273124694824




normalized pred locations loss 0.2113478034734726


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 35.48it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:04<01:23,  4.40s/it]

normalized pred locations loss 0.060419321060180664




normalized pred locations loss 0.03342978283762932


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 37.42it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:08<01:16,  4.26s/it]

normalized pred locations loss 0.028061671182513237


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 37.09it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:12<01:12,  4.24s/it]

normalized pred locations loss 0.028676200658082962




normalized pred locations loss 0.02208709344267845


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.14it/s]
Probe prediction epochs:  20%|██        | 4/20 [00:17<01:08,  4.27s/it]

normalized pred locations loss 0.021509626880288124


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 35.29it/s]
Probe prediction epochs:  25%|██▌       | 5/20 [00:21<01:04,  4.32s/it]

normalized pred locations loss 0.020448319613933563




normalized pred locations loss 0.030230704694986343


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 35.27it/s]
Probe prediction epochs:  30%|███       | 6/20 [00:25<01:01,  4.36s/it]

normalized pred locations loss 0.013122377917170525


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.23it/s]
Probe prediction epochs:  35%|███▌      | 7/20 [00:30<00:56,  4.34s/it]

normalized pred locations loss 0.011573469266295433




normalized pred locations loss 0.016330337151885033


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 35.83it/s]
Probe prediction epochs:  40%|████      | 8/20 [00:34<00:52,  4.35s/it]

normalized pred locations loss 0.012184781022369862


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 36.33it/s]
Probe prediction epochs:  45%|████▌     | 9/20 [00:38<00:47,  4.33s/it]

normalized pred locations loss 0.01512442622333765




normalized pred locations loss 0.014222881756722927


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 31.65it/s]
Probe prediction epochs:  50%|█████     | 10/20 [00:43<00:45,  4.52s/it]

normalized pred locations loss 0.008252791129052639




normalized pred locations loss 0.01170950848609209


Probe prediction step: 100%|██████████| 156/156 [00:05<00:00, 30.23it/s]
Probe prediction epochs:  55%|█████▌    | 11/20 [00:48<00:42,  4.71s/it]

normalized pred locations loss 0.007515826728194952


Probe prediction step: 100%|██████████| 156/156 [00:04<00:00, 34.89it/s]
Probe prediction epochs:  60%|██████    | 12/20 [00:53<00:37,  4.64s/it]


normalized pred locations loss 0.0074342479929327965


Probe prediction step:  23%|██▎       | 36/156 [00:00<00:03, 35.58it/s][A

normalized pred locations loss 0.007769961375743151


Probe prediction step: 100%|██████████| 156/156 [00:06<00:00, 22.54it/s]
Probe prediction epochs:  65%|██████▌   | 13/20 [01:00<00:37,  5.33s/it]

normalized pred locations loss 0.00844247080385685


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.88it/s]
Probe prediction epochs:  70%|███████   | 14/20 [01:08<00:37,  6.22s/it]

normalized pred locations loss 0.00701413257047534




normalized pred locations loss 0.007947463542222977


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.88it/s]
Probe prediction epochs:  75%|███████▌  | 15/20 [01:16<00:34,  6.83s/it]

normalized pred locations loss 0.00826653465628624


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 19.07it/s]
Probe prediction epochs:  80%|████████  | 16/20 [01:25<00:28,  7.24s/it]

normalized pred locations loss 0.0070409285835921764




normalized pred locations loss 0.007457922678440809


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.77it/s]
Probe prediction epochs:  85%|████████▌ | 17/20 [01:33<00:22,  7.56s/it]

normalized pred locations loss 0.010132121853530407




normalized pred locations loss 0.0088250283151865


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.74it/s]
Probe prediction epochs:  90%|█████████ | 18/20 [01:41<00:15,  7.79s/it]

normalized pred locations loss 0.007649283390492201


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.74it/s]
Probe prediction epochs:  95%|█████████▌| 19/20 [01:50<00:07,  7.95s/it]

normalized pred locations loss 0.006874684244394302




normalized pred locations loss 0.007609501481056213


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 18.74it/s]
Probe prediction epochs: 100%|██████████| 20/20 [01:58<00:00,  5.92s/it]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 21.85it/s]
Eval probe pred: 100%|██████████| 62/62 [00:02<00:00, 21.79it/s]


normal loss: 2.9689748287200928
wall loss: 6.159099578857422
--------------------------------------------------------------------------------


Epoch 13/100:  65%|██████▍   | 744/1148 [02:07<01:09,  5.82it/s, loss=295.4510, time/batch=0.082s, gpu_mem=2.1GB]


KeyboardInterrupt: 

In [None]:
# TRYING LOSS NORMALIZATION

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import Tuple, Dict, NamedTuple
from tqdm.auto import tqdm
import time
from datetime import datetime
from pathlib import Path

class Encoder(nn.Module):
    def __init__(self, input_channels=2):
        super().__init__()
        # First conv: 65x65 -> 22x22
        self.conv1 = nn.Conv2d(input_channels, 8, kernel_size=5, stride=3, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(8)
        
        # Second conv: 22x22 -> 8x8
        self.conv2 = nn.Conv2d(8, 32, kernel_size=3, stride=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        
        self.repr_dim = 32 * 8 * 8
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # -> 22x22
        x = F.relu(self.bn2(self.conv2(x)))  # -> 8x8
        return x

class TransitionModel(nn.Module):
    def __init__(self, hidden_dim=32):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # Action embedding
        self.action_embed = nn.Sequential(
            nn.Conv2d(2, hidden_dim // 2, 1),
            nn.BatchNorm2d(hidden_dim // 2),
            nn.ReLU(),
            nn.Conv2d(hidden_dim // 2, hidden_dim, 1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU()
        )
        
        # Transition model
        self.transition = nn.Sequential(
            nn.Conv2d(hidden_dim * 2, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim)
        )
        
    def forward(self, state, action):
        B, _, H, W = state.shape
        action = action.view(B, 2, 1, 1).expand(-1, -1, H, W)
        action_embedding = self.action_embed(action)
        combined = torch.cat([state, action_embedding], dim=1)
        delta = self.transition(combined)
        next_state = state + delta
        return next_state

class CombinedLoss(nn.Module):
    def __init__(self, 
                 vicreg_sim_coef=25.0,
                 vicreg_std_coef=25.0,
                 vicreg_cov_coef=1.0,
                 barlow_lambda=0.005,
                 loss_weight=0.5,
                 normalize_losses=True):
        super().__init__()
        self.vicreg_sim_coef = vicreg_sim_coef
        self.vicreg_std_coef = vicreg_std_coef
        self.vicreg_cov_coef = vicreg_cov_coef
        self.barlow_lambda = barlow_lambda
        self.loss_weight = loss_weight
        self.normalize_losses = normalize_losses
        
        # Running statistics for normalization
        self.register_buffer('vicreg_running_mean', torch.tensor(0.0))
        self.register_buffer('barlow_running_mean', torch.tensor(0.0))
        self.momentum = 0.9  # For updating running means
        
    def off_diagonal(self, x):
        n = x.shape[0]
        return x.flatten()[:-1].view(n-1, n+1)[:, 1:].flatten()
    
    def update_running_means(self, vicreg_loss, barlow_loss):
        if self.training:
            self.vicreg_running_mean = self.momentum * self.vicreg_running_mean + \
                                     (1 - self.momentum) * vicreg_loss.detach()
            self.barlow_running_mean = self.momentum * self.barlow_running_mean + \
                                     (1 - self.momentum) * barlow_loss.detach()
    
    def normalize_loss(self, vicreg_loss, barlow_loss):
        # Prevent division by zero
        eps = 1e-6
        
        # Use running means to normalize if they're non-zero
        if self.vicreg_running_mean > eps and self.barlow_running_mean > eps:
            vicreg_norm = vicreg_loss / self.vicreg_running_mean
            barlow_norm = barlow_loss / self.barlow_running_mean
        else:
            # Initially, scale based on typical ranges
            vicreg_norm = vicreg_loss / 17.5  # Assuming typical VICReg total around 17.5
            barlow_norm = barlow_loss / 5.0   # Assuming typical Barlow total around 5.0
            
        return vicreg_norm, barlow_norm
    
    def forward(self, z_a, z_b):
        N = z_a.shape[0]
        D = z_a.shape[1]
        
        # VICReg components
        sim_loss = F.mse_loss(z_a, z_b)
        
        std_z_a = torch.sqrt(z_a.var(dim=0) + 1e-04)
        std_z_b = torch.sqrt(z_b.var(dim=0) + 1e-04)
        std_loss = torch.mean(F.relu(1 - std_z_a)) + torch.mean(F.relu(1 - std_z_b))
        
        z_a_centered = z_a - z_a.mean(dim=0)
        z_b_centered = z_b - z_b.mean(dim=0)
        
        cov_z_a = (z_a_centered.T @ z_a_centered) / (N - 1)
        cov_z_b = (z_b_centered.T @ z_b_centered) / (N - 1)
        
        vicreg_cov_loss = (self.off_diagonal(cov_z_a).pow_(2).sum() / D +
                          self.off_diagonal(cov_z_b).pow_(2).sum() / D)
        
        # Combine VICReg components
        vicreg_loss = (self.vicreg_sim_coef * sim_loss +
                      self.vicreg_std_coef * std_loss +
                      self.vicreg_cov_coef * vicreg_cov_loss)
        
        # Barlow Twins components
        z_a_norm = (z_a - z_a.mean(dim=0)) / (z_a.std(dim=0) + 1e-6)
        z_b_norm = (z_b - z_b.mean(dim=0)) / (z_b.std(dim=0) + 1e-6)
        
        c = torch.mm(z_a_norm.T, z_b_norm) / N
        
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = torch.triu(c.pow(2), diagonal=1).sum() + torch.tril(c.pow(2), diagonal=-1).sum()
        barlow_loss = on_diag + self.barlow_lambda * off_diag
        
        # Update running means
        self.update_running_means(vicreg_loss, barlow_loss)
        
        # Normalize losses if enabled
        if self.normalize_losses:
            vicreg_loss, barlow_loss = self.normalize_loss(vicreg_loss, barlow_loss)
        
        # Combine normalized losses
        total_loss = self.loss_weight * vicreg_loss + (1 - self.loss_weight) * barlow_loss
        
        return total_loss, {
            'total_loss': total_loss.item(),
            'vicreg_sim_loss': sim_loss.item(),
            'vicreg_std_loss': std_loss.item(),
            'vicreg_cov_loss': vicreg_cov_loss.item(),
            'barlow_loss': barlow_loss.item(),
            'vicreg_mean': self.vicreg_running_mean.item(),
            'barlow_mean': self.barlow_running_mean.item()
        }

class WorldModelCombined(nn.Module):
    def __init__(self, vicreg_sim_coef=25.0, vicreg_std_coef=25.0, vicreg_cov_coef=1.0,
                 barlow_lambda=0.005, loss_weight=0.5):
        super().__init__()
        self.encoder = Encoder(input_channels=2)
        self.predictor = TransitionModel(hidden_dim=32)
        
        self.criterion = CombinedLoss(
            vicreg_sim_coef=vicreg_sim_coef,
            vicreg_std_coef=vicreg_std_coef,
            vicreg_cov_coef=vicreg_cov_coef,
            barlow_lambda=barlow_lambda,
            loss_weight=loss_weight
        )
        
        self.repr_dim = self.encoder.repr_dim
    
    def forward_prediction(self, states, actions):
        B, _, _, H, W = states.shape
        T = actions.shape[1] + 1
        
        curr_state = self.encoder(states.squeeze(1))
        predictions = [curr_state]
        
        for t in range(T-1):
            curr_state = self.predictor(curr_state, actions[:, t])
            predictions.append(curr_state)
            
        predictions = torch.stack(predictions, dim=1)
        return predictions

    def forward(self, states, actions):
        init_states = states[:, 0:1]
        predictions = self.forward_prediction(init_states, actions)
        B, T, C, H, W = predictions.shape
        predictions = predictions.view(B, T, -1)
        return predictions

    def training_step(self, batch):
        states = batch.states
        actions = batch.actions
        
        init_states = states[:, 0:1]
        predictions = self.forward_prediction(init_states, actions)
        
        total_loss = 0.0
        accumulated_losses = {
            'total_loss': 0.0,
            'vicreg_sim_loss': 0.0,
            'vicreg_std_loss': 0.0,
            'vicreg_cov_loss': 0.0,
            'barlow_loss': 0.0
        }
        
        for t in range(actions.shape[1]):
            pred_state = predictions[:, t+1]
            target_obs = states[:, t+1]
            
            target_state = self.encoder(target_obs)
            
            pred_flat = pred_state.flatten(start_dim=1)
            target_flat = target_state.flatten(start_dim=1)
            
            loss, component_losses = self.criterion(pred_flat, target_flat)
            
            total_loss += loss
            for k in accumulated_losses:
                accumulated_losses[k] += component_losses[k]
        
        total_loss = total_loss / actions.shape[1]
        for k in accumulated_losses:
            accumulated_losses[k] /= actions.shape[1]
        
        return total_loss, predictions, accumulated_losses

def train_epoch(model, dataloader, optimizer, epoch, total_epochs):
    model.train()
    total_loss = 0.0
    train_losses = {
        'total_loss': 0.0,
        'vicreg_sim_loss': 0.0,
        'vicreg_std_loss': 0.0,
        'vicreg_cov_loss': 0.0,
        'barlow_loss': 0.0
    }
    
    progress_bar = tqdm(enumerate(dataloader), 
                       total=len(dataloader),
                       desc=f'Epoch {epoch}/{total_epochs}',
                       leave=True)
    
    start_time = time.time()
    batch_times = []
    
    for batch_idx, batch in progress_bar:
        batch_start = time.time()
        
        batch = batch._replace(
            states=batch.states.cuda(),
            actions=batch.actions.cuda(),
            locations=batch.locations.cuda() if batch.locations is not None else None
        )
        
        optimizer.zero_grad()
        loss, _, component_losses = model.training_step(batch)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        batch_time = time.time() - batch_start
        batch_times.append(batch_time)
        
        total_loss += loss.item()
        for k in train_losses:
            train_losses[k] += component_losses[k]
        
        current_loss = total_loss / (batch_idx + 1)
        current_batch_time = sum(batch_times) / len(batch_times)
        
        progress_bar.set_postfix({
            'loss': f'{current_loss:.4f}',
            'time/batch': f'{current_batch_time:.3f}s',
            'gpu_mem': f'{torch.cuda.max_memory_allocated()/1e9:.1f}GB'
        })
    
    num_batches = len(dataloader)
    total_loss /= num_batches
    for k in train_losses:
        train_losses[k] /= num_batches
    
    epoch_time = time.time() - start_time
    avg_batch_time = sum(batch_times) / len(batch_times)
    
    return total_loss, train_losses, {
        'epoch_time': epoch_time,
        'avg_batch_time': avg_batch_time
    }

def train_model(model, train_loader, num_epochs=100, learning_rate=3e-4, 
                save_dir='checkpoints/combined', save_frequency=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"Starting training at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Training for {num_epochs} epochs")
    print(f"Checkpoints will be saved to {save_dir}")
    
    for epoch in range(num_epochs):
        train_loss, train_losses, timing_stats = train_epoch(
            model, train_loader, optimizer, epoch, num_epochs
        )
        
        print(f"\nEpoch {epoch}/{num_epochs} Summary:")
        print(f"Train Loss: {train_loss:.4f}")
        print("Component Losses:")
        for k, v in train_losses.items():
            print(f"  {k}: {v:.4f}")
        print(f"Epoch Time: {timing_stats['epoch_time']:.1f}s")
        print(f"Avg Batch Time: {timing_stats['avg_batch_time']:.3f}s")
        print(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1024**2:.0f}MB")
        
        if (epoch + 1) % save_frequency == 0:
            checkpoint_path = save_dir / f"checkpoint_epoch_{epoch+1}.pt"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': train_loss,
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")
        
        evaluate_model("cuda", model, probe_train_ds, probe_val_ds)
        
        print("-" * 80)

if __name__ == "__main__":
    from dataset import create_wall_dataloader, WallSample
    
    # Create model with combined losses
    model = WorldModelCombined(
        vicreg_sim_coef=25.0,
        vicreg_std_coef=25.0,
        vicreg_cov_coef=1.0,
        barlow_lambda=0.005,
        loss_weight=0.5  # Equal weight between VICReg and Barlow
    ).cuda()
    
    # Create data loader
    train_loader = create_wall_dataloader(
        "/drive_reader/as16386/DL24FA/train",  # Update with your path
        batch_size=128,
        train=True
    )
    
    # Start training
    train_model(
        model=model,
        train_loader=train_loader,
        num_epochs=100,
        learning_rate=3e-5,
        save_dir='checkpoints/combined_barlow_vicreg_normalized',
        save_frequency=10
    )