In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import Tuple, Dict
from dataset import WallSample, create_wall_dataloader
from tqdm.auto import tqdm
import time
from datetime import datetime
from pathlib import Path
from lightly.utils.scheduler import cosine_schedule

from evaluator import ProbingEvaluator

def load_data(device):
    data_path = "/drive_reader/as16386/DL24FA"

    probe_train_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_normal_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_ds = {"normal": probe_val_normal_ds, "wall": probe_val_wall_ds}

    return probe_train_ds, probe_val_ds

def evaluate_model(device, model, probe_train_ds, probe_val_ds):
    evaluator = ProbingEvaluator(
        device=device,
        model=model,
        probe_train_ds=probe_train_ds,
        probe_val_ds=probe_val_ds,
        quick_debug=False,
    )

    prober = evaluator.train_pred_prober()

    avg_losses = evaluator.evaluate_all(prober=prober)

    for probe_attr, loss in avg_losses.items():
        print(f"{probe_attr} loss: {loss}")

probe_train_ds, probe_val_ds = load_data("cuda")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import Tuple, Dict
from tqdm.auto import tqdm
import time
from datetime import datetime
from pathlib import Path
from dataset import create_wall_dataloader, WallSample

class Encoder(nn.Module):
    def __init__(self, input_channels=2):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, 8, kernel_size=5, stride=3, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(8)
        
        self.conv2 = nn.Conv2d(8, 32, kernel_size=3, stride=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        
        self.repr_dim = 32 * 8 * 8
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        return x

class TransitionModel(nn.Module):
    def __init__(self, hidden_dim=32):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.action_embed = nn.Sequential(
            nn.Conv2d(2, hidden_dim // 2, 1),
            nn.BatchNorm2d(hidden_dim // 2),
            nn.ReLU(),
            nn.Conv2d(hidden_dim // 2, hidden_dim, 1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU()
        )
        
        self.transition = nn.Sequential(
            nn.Conv2d(hidden_dim * 2, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim)
        )
        
    def forward(self, state, action):
        B, _, H, W = state.shape
        action = action.view(B, 2, 1, 1).expand(-1, -1, H, W)
        action_embedding = self.action_embed(action)
        combined = torch.cat([state, action_embedding], dim=1)
        delta = self.transition(combined)
        next_state = state + delta
        return next_state

class NormalizedCombinedLoss(nn.Module):
    def __init__(self, 
                 vicreg_sim_coef=25.0,
                 vicreg_std_coef=25.0,
                 vicreg_cov_coef=1.0,
                 barlow_lambda=0.005,
                 loss_weight=0.5):
        super().__init__()
        self.vicreg_sim_coef = vicreg_sim_coef
        self.vicreg_std_coef = vicreg_std_coef
        self.vicreg_cov_coef = vicreg_cov_coef
        self.barlow_lambda = barlow_lambda
        self.loss_weight = loss_weight
        
        # Running statistics for normalization
        self.register_buffer('vicreg_mean', torch.tensor(0.0))
        self.register_buffer('barlow_mean', torch.tensor(0.0))
        self.register_buffer('count', torch.tensor(0))
        self.momentum = 0.9
    
    def off_diagonal(self, x):
        n = x.shape[0]
        return x.flatten()[:-1].view(n-1, n+1)[:, 1:].flatten()
    
    def update_means(self, vicreg_loss, barlow_loss):
        with torch.no_grad():
            if self.count == 0:
                self.vicreg_mean = vicreg_loss.detach()
                self.barlow_mean = barlow_loss.detach()
            else:
                self.vicreg_mean = self.momentum * self.vicreg_mean + (1 - self.momentum) * vicreg_loss.detach()
                self.barlow_mean = self.momentum * self.barlow_mean + (1 - self.momentum) * barlow_loss.detach()
            self.count += 1
    
    def forward(self, z_a, z_b):
        N = z_a.shape[0]
        D = z_a.shape[1]
        
        # VICReg components
        sim_loss = F.mse_loss(z_a, z_b)
        
        std_z_a = torch.sqrt(z_a.var(dim=0) + 1e-04)
        std_z_b = torch.sqrt(z_b.var(dim=0) + 1e-04)
        std_loss = torch.mean(F.relu(1 - std_z_a)) + torch.mean(F.relu(1 - std_z_b))
        
        z_a_c = z_a - z_a.mean(dim=0)
        z_b_c = z_b - z_b.mean(dim=0)
        
        cov_z_a = (z_a_c.T @ z_a_c) / (N - 1)
        cov_z_b = (z_b_c.T @ z_b_c) / (N - 1)
        
        vicreg_cov_loss = (self.off_diagonal(cov_z_a).pow_(2).sum() / D +
                          self.off_diagonal(cov_z_b).pow_(2).sum() / D)
        
        # Compute full VICReg loss
        vicreg_loss = (self.vicreg_sim_coef * sim_loss +
                      self.vicreg_std_coef * std_loss +
                      self.vicreg_cov_coef * vicreg_cov_loss)
        
        # Barlow Twins components
        z_a_norm = (z_a - z_a.mean(dim=0)) / (z_a.std(dim=0) + 1e-6)
        z_b_norm = (z_b - z_b.mean(dim=0)) / (z_b.std(dim=0) + 1e-6)
        
        c = torch.mm(z_a_norm.T, z_b_norm) / N
        
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = torch.triu(c.pow(2), diagonal=1).sum() + torch.tril(c.pow(2), diagonal=-1).sum()
        barlow_loss = on_diag + self.barlow_lambda * off_diag
        
        # Update running means
        self.update_means(vicreg_loss, barlow_loss)
        
        # Normalize losses using running means
        eps = 1e-6
        if self.count > 0:
            vicreg_loss_norm = vicreg_loss / (self.vicreg_mean + eps)
            barlow_loss_norm = barlow_loss / (self.barlow_mean + eps)
        else:
            vicreg_loss_norm = vicreg_loss
            barlow_loss_norm = barlow_loss
        
        # Combine normalized losses
        total_loss = self.loss_weight * vicreg_loss_norm + (1 - self.loss_weight) * barlow_loss_norm
        
        component_losses = {
            'total_loss': total_loss.item(),
            'vicreg_loss': vicreg_loss.item(),
            'vicreg_norm': vicreg_loss_norm.item(),
            'barlow_loss': barlow_loss.item(),
            'barlow_norm': barlow_loss_norm.item(),
            'vicreg_mean': self.vicreg_mean.item(),
            'barlow_mean': self.barlow_mean.item()
        }
        
        return total_loss, component_losses

class WorldModelCombined(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder(input_channels=2)
        self.predictor = TransitionModel(hidden_dim=32)
        self.criterion = NormalizedCombinedLoss()
        self.repr_dim = self.encoder.repr_dim
    
    def forward_prediction(self, states, actions):
        B, _, _, H, W = states.shape
        T = actions.shape[1] + 1
        
        curr_state = self.encoder(states.squeeze(1))
        predictions = [curr_state]
        
        for t in range(T-1):
            curr_state = self.predictor(curr_state, actions[:, t])
            predictions.append(curr_state)
            
        predictions = torch.stack(predictions, dim=1)
        return predictions

    def forward(self, states, actions):
        init_states = states[:, 0:1]
        predictions = self.forward_prediction(init_states, actions)
        B, T, C, H, W = predictions.shape
        predictions = predictions.view(B, T, -1)
        return predictions

    def training_step(self, batch):
        states = batch.states
        actions = batch.actions
        
        init_states = states[:, 0:1]
        predictions = self.forward_prediction(init_states, actions)
        
        total_loss = 0.0
        accumulated_losses = {
            'total_loss': 0.0,
            'vicreg_loss': 0.0,
            'vicreg_norm': 0.0,
            'barlow_loss': 0.0,
            'barlow_norm': 0.0
        }
        
        for t in range(actions.shape[1]):
            pred_state = predictions[:, t+1]
            target_obs = states[:, t+1]
            
            target_state = self.encoder(target_obs)
            
            pred_flat = pred_state.flatten(start_dim=1)
            target_flat = target_state.flatten(start_dim=1)
            
            loss, component_losses = self.criterion(pred_flat, target_flat)
            
            total_loss += loss
            for k in accumulated_losses:
                accumulated_losses[k] += component_losses[k]
        
        total_loss = total_loss / actions.shape[1]
        for k in accumulated_losses:
            accumulated_losses[k] /= actions.shape[1]
        
        return total_loss, predictions, accumulated_losses

def train_model(model, train_loader, num_epochs=100, learning_rate=3e-4, 
                save_dir='checkpoints/combined_normalized', save_frequency=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"Starting training at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Training for {num_epochs} epochs")
    print(f"Checkpoints will be saved to {save_dir}")
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        
        progress_bar = tqdm(enumerate(train_loader), 
                          total=len(train_loader),
                          desc=f'Epoch {epoch}/{num_epochs}')
        
        start_time = time.time()
        
        for batch_idx, batch in progress_bar:
            # Move batch to GPU
            batch = batch._replace(
                states=batch.states.cuda(),
                actions=batch.actions.cuda(),
                locations=batch.locations.cuda() if batch.locations is not None else None
            )
            
            # Forward pass and compute loss
            optimizer.zero_grad()
            loss, _, component_losses = model.training_step(batch)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            # Update progress bar
            total_loss += loss.item()
            current_loss = total_loss / (batch_idx + 1)
            
            progress_bar.set_postfix({
                'loss': f'{current_loss:.4f}',
                'vicreg_mean': f'{model.criterion.vicreg_mean.item():.4f}',
                'barlow_mean': f'{model.criterion.barlow_mean.item():.4f}'
            })
            
        # Save checkpoint
        if (epoch + 1) % save_frequency == 0:
            checkpoint_path = save_dir / f"checkpoint_epoch_{epoch+1}.pt"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': current_loss,
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")
        
        evaluate_model("cuda", model, probe_train_ds, probe_val_ds)
        
        print(f"\nEpoch {epoch} Summary:")
        print(f"Loss: {current_loss:.4f}")
        print(f"VICReg Mean: {model.criterion.vicreg_mean.item():.4f}")
        print(f"Barlow Mean: {model.criterion.barlow_mean.item():.4f}")
        print("-" * 80)

if __name__ == "__main__":
    # Create model and data loader
    model = WorldModelCombined().cuda()
    
    train_loader = create_wall_dataloader(
        "/drive_reader/as16386/DL24FA/train",  # Update with your path
        batch_size=32,
        train=True
    )
    
    # Start training
    train_model(
        model=model,
        train_loader=train_loader,
        num_epochs=100,
        learning_rate=3e-4,
        save_dir='checkpoints/combined_barlow_vicreg_normalized',
        save_frequency=10
    )

Starting training at 2024-12-11 23:55:14
Training for 100 epochs
Checkpoints will be saved to checkpoints/combined_barlow_vicreg_normalized


Epoch 0/100: 100%|██████████| 1148/1148 [02:20<00:00,  8.18it/s, loss=0.9949, time/batch=0.077s, gpu_mem=1.4GB]



Epoch 0/100 Summary:
Train Loss: 0.9949
Component Losses:
  total_loss: 0.9949
  vicreg_sim_loss: 24.1968
  vicreg_std_loss: 1.1923
  vicreg_cov_loss: 4545.7000
  barlow_loss: 1.0009
Epoch Time: 140.3s
Avg Batch Time: 0.077s
GPU Memory: 1353MB


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

normalized pred locations loss 1.1327272653579712




normalized pred locations loss 0.33933916687965393


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 17.43it/s]
Probe prediction epochs:   5%|▌         | 1/20 [00:08<02:50,  8.95s/it]

normalized pred locations loss 0.16049417853355408




normalized pred locations loss 0.08976524323225021


Probe prediction step: 100%|██████████| 156/156 [00:08<00:00, 17.63it/s]
Probe prediction epochs:  10%|█         | 2/20 [00:17<02:40,  8.89s/it]

normalized pred locations loss 0.06290233880281448


Probe prediction step: 100%|██████████| 156/156 [00:16<00:00,  9.72it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:33<03:26, 12.16s/it]

normalized pred locations loss 0.047875333577394485


Probe prediction step:  31%|███       | 48/156 [00:05<00:11,  9.29it/s]
Probe prediction epochs:  15%|█▌        | 3/20 [00:39<03:41, 13.01s/it]


KeyboardInterrupt: 