In [None]:
# 04_latent_dynamics.ipynb
# =======================================
# Action-Conditioned Latent Dynamics Training
# Uses pretrained VAE → predicts next latent from current latent + action

# ────────────────────────────────────────────────
# 1. Imports & Configuration
# ────────────────────────────────────────────────

import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

from models.data_loader import RoboNetDataset

# Paths
PROJECT_ROOT = r"E:\NVIDIA_PROJECTS\Neural-World-Model-for-Embodied-AI-Robotics"
DATA_ROOT = os.path.join(PROJECT_ROOT, "data", "raw", "robonet", "hdf5")
SPLITS_PATH = os.path.join(PROJECT_ROOT, "data", "splits.json")
VAE_CHECKPOINT = os.path.join(PROJECT_ROOT, "checkpoints", "vae_pretrained.pth")
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

BATCH_SIZE = 4
NUM_WORKERS = 0
EPOCHS = 20
LR = 1e-3
LATENT_DIM = 128  # must match your pretrained VAE

# ────────────────────────────────────────────────
# 2. Load splits
# ────────────────────────────────────────────────

with open(SPLITS_PATH, "r") as f:
    splits = json.load(f)

train_files = splits["train"]
val_files   = splits["val"]

print(f"Train: {len(train_files)} | Val: {len(val_files)}")

# ────────────────────────────────────────────────
# 3. Datasets & Loaders
# ────────────────────────────────────────────────

train_dataset = RoboNetDataset(train_files, DATA_ROOT)
val_dataset   = RoboNetDataset(val_files,   DATA_ROOT)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")

# ────────────────────────────────────────────────
# 4. Load Pretrained VAE & Freeze
# ────────────────────────────────────────────────

class VAE(nn.Module):
    # Copy your VAE class definition here (encoder, reparameterize, decoder, forward)
    # ... paste the exact VAE class you used in pretraining ...

vae = VAE(latent_dim=LATENT_DIM).to(DEVICE)
vae.load_state_dict(torch.load(VAE_CHECKPOINT))
vae.eval()
for param in vae.parameters():
    param.requires_grad = False  # freeze VAE

print("Pretrained VAE loaded and frozen")

# ────────────────────────────────────────────────
# 5. Latent Dynamics Model (RSSM-style)
# ────────────────────────────────────────────────

class LatentDynamics(nn.Module):
    """
    Simple recurrent latent predictor:
    h_t = GRU(h_{t-1}, z_{t-1}, a_t)
    z_t ~ N(mu, std) from h_t
    """
    def __init__(self, latent_dim, hidden_dim=256):
        super().__init__()
        self.gru = nn.GRUCell(latent_dim * 2 + 4, hidden_dim)  # z_prev + a + h_prev
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, z_prev, action, h_prev=None):
        if h_prev is None:
            h_prev = torch.zeros(z_prev.size(0), self.gru.hidden_size, device=z_prev.device)

        inp = torch.cat([z_prev, action, h_prev], dim=-1)
        h = self.gru(inp, h_prev)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        z_next = self.reparameterize(mu, logvar)
        return z_next, mu, logvar, h

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std


dynamics = LatentDynamics(LATENT_DIM).to(DEVICE)
optimizer = optim.Adam(dynamics.parameters(), lr=LR)

def dynamics_loss(z_pred, z_true, mu, logvar):
    recon_loss = nn.functional.mse_loss(z_pred, z_true, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + 0.1 * kl_loss  # beta=0.1 for KL

print("Latent dynamics model created. Parameters:", sum(p.numel() for p in dynamics.parameters()))

# ────────────────────────────────────────────────
# 6. Training Loop
# ────────────────────────────────────────────────

best_val_loss = float('inf')
checkpoint_path = os.path.join(CHECKPOINT_DIR, "latent_dynamics_best.pth")

for epoch in range(EPOCHS):
    dynamics.train()
    train_loss = 0.0
    num_batches = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        frames = batch['frames'].to(DEVICE)   # (B, seq_len, C, H, W)
        actions = batch['actions'].to(DEVICE) # (B, seq_len-1, 4)

        B, T = frames.shape[0], frames.shape[1]

        # Get latents for all frames
        with torch.no_grad():
            x_flat = frames.view(-1, 3, 128, 128)
            z_all, _, _ = vae.encode(x_flat)  # (B*T, latent_dim)
            z_all = z_all.view(B, T, LATENT_DIM)

        loss = 0.0
        h = None
        for t in range(T-1):
            z_curr = z_all[:, t]
            a_t = actions[:, t]
            z_next_pred, mu, logvar, h = dynamics(z_curr, a_t, h)

            z_next_true = z_all[:, t+1]
            loss += dynamics_loss(z_next_pred, z_next_true, mu, logvar)

        loss = loss / (T-1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        num_batches += 1

    avg_train_loss = train_loss / num_batches

    # Validation (simplified)
    dynamics.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            frames = batch['frames'].to(DEVICE)
            actions = batch['actions'].to(DEVICE)
            B, T = frames.shape[0], frames.shape[1]

            x_flat = frames.view(-1, 3, 128, 128)
            z_all = vae.encode(x_flat)[0].view(B, T, LATENT_DIM)

            h = None
            batch_loss = 0.0
            for t in range(T-1):
                z_curr = z_all[:, t]
                a_t = actions[:, t]
                z_next_pred, mu, logvar, h = dynamics(z_curr, a_t, h)
                z_next_true = z_all[:, t+1]
                batch_loss += dynamics_loss(z_next_pred, z_next_true, mu, logvar).item()

            val_loss += batch_loss / (T-1)

    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch+1} | Train loss: {avg_train_loss:.4f} | Val loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(dynamics.state_dict(), checkpoint_path)
        print(f"Best dynamics model saved (val loss: {best_val_loss:.4f})")

print("Latent dynamics training finished!")
print(f"Best checkpoint: {checkpoint_path}")

# ────────────────────────────────────────────────
# 7. Quick Rollout Visualization (1-step prediction)
# ────────────────────────────────────────────────

dynamics.eval()
with torch.no_grad():
    batch = next(iter(val_loader))
    frames = batch['frames'].to(DEVICE)
    actions = batch['actions'].to(DEVICE)

    # Encode all frames
    x_flat = frames.view(-1, 3, 128, 128)
    z_all = vae.encode(x_flat)[0].view(frames.shape[0], -1, LATENT_DIM)

    # Predict next latent from middle
    t = frames.shape[1] // 2
    z_curr = z_all[:, t]
    a_t = actions[:, t]
    z_next_pred, _, _, _ = dynamics(z_curr, a_t)

    # Decode predicted latent
    recon_pred = vae.decode(z_next_pred)

    # Ground truth next frame
    target = frames[:, t+1]

    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(frames[0, t].cpu().permute(1, 2, 0).numpy())
    axes[0].set_title("Current (t)")
    axes[1].imshow(recon_pred[0].cpu().permute(1, 2, 0).numpy())
    axes[1].set_title("Predicted (t+1)")
    axes[2].imshow(target[0].cpu().permute(1, 2, 0).numpy())
    axes[2].set_title("Actual (t+1)")
    for ax in axes:
        ax.axis('off')
    plt.tight_layout()
    plt.show()