
# 03 · Multimodal LLM Decoder Training on VALOR (Phase 2a)

This notebook implements **Notebook 3 – Aligned Model + “Normal” Decoder (standard LLM)**,
adapted to your **VALOR multi-modal instruction-tuning data** stored in multiple parquet shards.

**What this notebook does:**

1. **Loads the Phase‑1 aligned model** (`VisionTextAligner`), using the best checkpoint.
2. Wraps it with `MultimodalLLM` (vision → alignment → LLM prefix → decoder).
3. Loads **all VALOR shards** (multi-frame, audio+caption) and groups rows by `video_id`.
4. Builds a **clip-level dataset**:
   - Each sample = one `video_id` with multiple `image_jpeg` frames, one `audio_wav`, one `caption`.
5. Trains only the **vision→LLM projector** (Phase‑2a), with:
   - **Single-frame mode** by default: randomly pick a frame per clip each step.
   - Easy to extend later to **multi-frame pooling** / Perceiver usage.
6. Uses **multi‑GPU** (`DataParallel` if 2 GPUs visible).
7. Logs to **Weights & Biases (wandb)**:
   - Training / validation loss
   - Learning rate
   - Qualitative generations (GT caption vs generated caption)
8. Saves:
   - `best_phase3_valor.pt` (by validation loss)
   - `final_phase3_valor.pt`

> ⚠️ You will need to edit **paths** to match your environment:
> - `VALOR_SHARDS_DIR`
> - `PHASE1_CKPT_PATH`
> - `W&B` project / run names


## 1. Imports & Environment

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:


import os
import io
import math
import random
import json
import time
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Dict, Any, List, Optional

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

import pandas as pd
from PIL import Image as PILImage
from tqdm.auto import tqdm

import wandb

# Project imports
from imports.core import AlignmentConfig, VisionTextAligner, set_seed, count_parameters, 
from imports.train import load_checkpoint
from imports.llm_integration import LLMConfig, MultimodalLLM


SyntaxError: trailing comma not allowed without surrounding parentheses (3391005567.py, line 23)

## 2. Paths & High-Level Configuration

In [None]:

# Root directory of your project (edit if needed)
ROOT_DIR = Path.cwd()

# === Phase 1 (alignment) checkpoint ===
# This should be the *best* alignment checkpoint (Phase 1, tri-modal or vision-text only).
PHASE1_CKPT_PATH = ROOT_DIR / "checkpoints" / "phase1_multimodal" / "perceiver_mrl" / "best.pt"  # <-- EDIT

# === VALOR parquet shards (train split) ===
# Directory containing multiple parquet shards like:
#   valor32k_train_batch000_shard000.parquet
VALOR_SHARDS_DIR = ROOT_DIR / "data" / "alignment_subsets" / "valor32k_train_shards"  # <-- EDIT

assert PHASE1_CKPT_PATH.is_file(), f"Phase-1 checkpoint not found: {PHASE1_CKPT_PATH}"
assert VALOR_SHARDS_DIR.is_dir(), f"VALOR shards dir not found: {VALOR_SHARDS_DIR}"

print("ROOT_DIR       :", ROOT_DIR)
print("PHASE1_CKPT    :", PHASE1_CKPT_PATH)
print("VALOR_SHARDS   :", VALOR_SHARDS_DIR)

# === Phase 3 (LLM decoder alignment on VALOR) output dir ===
PHASE3_RUN_NAME = "valor_qwen2p5_phase3_v1"  # change per experiment
PHASE3_OUT_DIR = ROOT_DIR / "checkpoints" / "phase3_llm_valor" / PHASE3_RUN_NAME
PHASE3_OUT_DIR.mkdir(parents=True, exist_ok=True)
print("PHASE3_OUT_DIR :", PHASE3_OUT_DIR)


ROOT_DIR       : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base
PHASE1_CKPT    : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1_multimodal/perceiver_mrl/best.pt
VALOR_SHARDS   : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/alignment_subsets/valor32k_train_shards
PHASE3_OUT_DIR : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase3_llm_valor/valor_qwen2p5_phase3_v1


## 3. Training Hyperparameters & Weights & Biases

In [None]:

@dataclass
class Phase3ValorConfig:
    # Wandb
    wandb_project: str = "edgeglass_model_training"
    wandb_run_name: str = PHASE3_RUN_NAME
    wandb_entity: Optional[str] = None  # set if you use a team account
    
    # LLM
    llm_model_name: str = "Qwen/Qwen2.5-3B-Instruct"
    freeze_llm: bool = True
    num_prefix_tokens: int = 8
    
    # Data
    max_clips: Optional[int] = None  # cap for debugging; None = use all
    val_ratio: float = 0.05
    max_caption_len: int = 96
    frame_limit_per_clip: Optional[int] = None  # e.g. 8; None = all frames
    
    # Training
    num_epochs: int = 3
    batch_size_per_gpu: int = 2
    learning_rate: float = 5e-5
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    warmup_ratio: float = 0.1
    
    # Logging
    log_every: int = 50
    eval_every: int = 500
    num_eval_samples: int = 8
    
    # Device / precision
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    dtype: str = "bfloat16" if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else "float16"

cfg = Phase3ValorConfig()

print(cfg)


Phase3ValorConfig(wandb_project='edgeglass_model_training', wandb_run_name='valor_qwen2p5_phase3_v1', wandb_entity=None, llm_model_name='Qwen/Qwen2.5-3B-Instruct', freeze_llm=True, num_prefix_tokens=8, max_clips=None, val_ratio=0.05, max_caption_len=96, frame_limit_per_clip=None, num_epochs=3, batch_size_per_gpu=2, learning_rate=5e-05, weight_decay=0.01, max_grad_norm=1.0, warmup_ratio=0.1, log_every=50, eval_every=500, num_eval_samples=8, seed=42, device='cuda', dtype='bfloat16')


## 4. Device & Seeding

In [None]:

def get_device() -> torch.device:
    return torch.device(cfg.device)

def get_dtype() -> torch.dtype:
    if cfg.dtype == "bfloat16":
        return torch.bfloat16
    if cfg.dtype == "float16":
        return torch.float16
    return torch.float32

set_seed(cfg.seed)
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()

device = get_device()
dtype = get_dtype()

print(f"Using device: {device}, dtype: {dtype}")
print(f"GPUs available: {torch.cuda.device_count()}")
if torch.cuda.device_count() > 0:
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")


Using device: cuda, dtype: torch.bfloat16
GPUs available: 2
  GPU 0: NVIDIA H200
  GPU 1: NVIDIA H200


## 5. Load Phase‑1 Aligner

In [None]:

# NOTE: This assumes a VisionTextAligner-based Phase 1.
# If you used a Perceiver-based tri-modal aligner, you can adapt this cell to
# import and instantiate `MultimodalAlignmentModel` instead.

phase1_cfg = MultimodalAlignmentModel()
phase1_cfg.device = device
phase1_cfg.dtype = dtype

aligner = VisionTextAligner(phase1_cfg).to(device)
print("Loading Phase‑1 checkpoint from:", PHASE1_CKPT_PATH)
_ = load_checkpoint(
    model=aligner,
    checkpoint_path=str(PHASE1_CKPT_PATH),
    load_optimizer=False,
    optimizer=None,
    device=cfg.device,
)

# Freeze aligner in Phase 2a
for p in aligner.parameters():
    p.requires_grad = False
aligner.eval()

print("Aligner trainable params:", count_parameters(aligner, trainable_only=True))
print("Aligner total params    :", count_parameters(aligner))


[VisionEncoder] Loaded openai/clip-vit-base-patch32, hidden_size=768
[TextEncoder] Loaded sentence-transformers/all-MiniLM-L6-v2, hidden_size=384
[VisionTextAligner] d_vision=768, d_text=384, d_align=512
Loading Phase‑1 checkpoint from: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1_multimodal/perceiver_mrl/best.pt


TypeError: load_checkpoint() got an unexpected keyword argument 'device'

## 6. Build Multimodal LLM (Aligned Model + Qwen)

In [None]:

llm_cfg = LLMConfig(
    model_name=cfg.llm_model_name,
    max_new_tokens=128,
    temperature=0.7,
    top_p=0.9,
    num_prefix_tokens=cfg.num_prefix_tokens,
    freeze_llm=cfg.freeze_llm,
)

multimodal_model = MultimodalLLM(
    aligner=aligner,
    llm_config=llm_cfg,
)

multimodal_model.to(device)
multimodal_model = multimodal_model.to(dtype=dtype)

# Wrap with DataParallel if multiple GPUs are available
if torch.cuda.device_count() > 1:
    print(f"Wrapping multimodal model in DataParallel over {torch.cuda.device_count()} GPUs")
    multimodal_model = nn.DataParallel(multimodal_model)

mm_module = multimodal_model.module if isinstance(multimodal_model, nn.DataParallel) else multimodal_model

print("Multimodal model ready.")
print("Trainable params (Phase‑2a projector):", sum(p.numel() for p in mm_module.get_trainable_params()))


## 7. Load VALOR Parquet Shards

In [None]:

shard_paths = sorted(VALOR_SHARDS_DIR.glob("*.parquet"))
assert shard_paths, f"No parquet shards found in {VALOR_SHARDS_DIR}"

print("Found", len(shard_paths), "shards:")
for p in shard_paths[:5]:
    print(" -", p.name)
if len(shard_paths) > 5:
    print(" ...")

dfs = []
for p in shard_paths:
    print("Loading shard:", p)
    df_shard = pd.read_parquet(p)
    dfs.append(df_shard)

df_all = pd.concat(dfs, ignore_index=True)
print("Total VALOR rows (frames):", len(df_all))
print("Columns:", list(df_all.columns))

# Sanity: expected columns from 00_inspect_valor.ipynb
for col in ["video_id", "caption", "image_jpeg", "audio_wav"]:
    assert col in df_all.columns, f"Missing column in VALOR data: {col}"

# Optional cap for debugging
if cfg.max_clips is not None:
    # we'll apply it after grouping by video_id
    print("max_clips specified; will cap after grouping by video_id")


## 8. Clip-Level Dataset (Multi-Frame VALOR)

In [None]:

class ValorMultiFrameDataset(Dataset):
    """
    Clip-level dataset for VALOR multi-frame, audio+caption data.
    
    Each row in df_all is a single frame from a video (`video_id`).
    We group by `video_id` so that each sample is:
      - video_id
      - list[image_jpeg] bytes (frames)
      - audio_wav bytes (first occurrence)
      - caption (first occurrence)
    
    For Phase‑3 (LLM decoder alignment), we'll use only image+caption by default,
    but audio is kept for potential future use.
    """
    def __init__(
        self,
        df_all: pd.DataFrame,
        tokenizer,
        max_length: int = 96,
        frame_limit_per_clip: Optional[int] = None,
        max_clips: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.frame_limit_per_clip = frame_limit_per_clip

        grouped = df_all.groupby("video_id")
        video_ids = list(grouped.groups.keys())
        random.shuffle(video_ids)

        if max_clips is not None:
            video_ids = video_ids[:max_clips]

        self.clips = []
        for vid in tqdm(video_ids, desc="Building clip index"):
            group = grouped.get_group(vid)
            frames = list(group["image_jpeg"])
            audio = group["audio_wav"].iloc[0]
            caption = group["caption"].iloc[0]

            if frame_limit_per_clip is not None:
                frames = frames[:frame_limit_per_clip]

            self.clips.append(
                {
                    "video_id": vid,
                    "frames": frames,   # list[bytes]
                    "audio_wav": audio,
                    "caption": str(caption),
                }
            )

        # Pre-tokenize captions for efficiency
        self._tok = []
        for clip in tqdm(self.clips, desc="Tokenizing captions"):
            text = clip["caption"]
            toks = tokenizer(
                text,
                padding="max_length",
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
            )
            input_ids = toks["input_ids"][0]
            attn = toks["attention_mask"][0]
            labels = input_ids.clone()
            labels[attn == 0] = -100
            self._tok.append(
                {
                    "input_ids": input_ids,
                    "attention_mask": attn,
                    "labels": labels,
                    "text": text,
                }
            )

        print("Total clips (videos):", len(self.clips))

    def __len__(self) -> int:
        return len(self.clips)

    @staticmethod
    def _decode_image(jpeg_bytes: bytes) -> PILImage.Image:
        return PILImage.open(io.BytesIO(jpeg_bytes)).convert("RGB")

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        clip = self.clips[idx]
        toks = self._tok[idx]

        frames = [self._decode_image(b) for b in clip["frames"]]

        return {
            "video_id": clip["video_id"],
            "frames": frames,  # list[PIL.Image]
            "audio_wav": clip["audio_wav"],
            "input_ids": toks["input_ids"],
            "attention_mask": toks["attention_mask"],
            "labels": toks["labels"],
            "caption": toks["text"],
        }


def valor_multiframe_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    return {
        "video_ids": [b["video_id"] for b in batch],
        "frames": [b["frames"] for b in batch],  # list[list[Image]]
        "audio_wav": [b["audio_wav"] for b in batch],
        "input_ids": torch.stack([b["input_ids"] for b in batch], dim=0),
        "attention_mask": torch.stack([b["attention_mask"] for b in batch], dim=0),
        "labels": torch.stack([b["labels"] for b in batch], dim=0),
        "captions": [b["caption"] for b in batch],
    }


## 9. Build Train / Val Datasets & DataLoaders

In [None]:

tokenizer = mm_module.llm.tokenizer

# Build a clip-level dataset from all rows
full_dataset = ValorMultiFrameDataset(
    df_all=df_all,
    tokenizer=tokenizer,
    max_length=cfg.max_caption_len,
    frame_limit_per_clip=cfg.frame_limit_per_clip,
    max_clips=cfg.max_clips,
)

num_clips = len(full_dataset)
num_val = max(1, int(num_clips * cfg.val_ratio))
num_train = num_clips - num_val
print(f"Total clips: {num_clips} -> train={num_train}, val={num_val}")

# Simple split by index (clips already shuffled)
train_indices = list(range(num_train))
val_indices = list(range(num_train, num_clips))

train_subset = torch.utils.data.Subset(full_dataset, train_indices)
val_subset = torch.utils.data.Subset(full_dataset, val_indices)

global_batch_size = cfg.batch_size_per_gpu * max(1, torch.cuda.device_count())

train_loader = DataLoader(
    train_subset,
    batch_size=global_batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=torch.cuda.is_available(),
    collate_fn=valor_multiframe_collate,
)

val_loader = DataLoader(
    val_subset,
    batch_size=global_batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=torch.cuda.is_available(),
    collate_fn=valor_multiframe_collate,
)

print(f"Batches per epoch: train={len(train_loader)}, val={len(val_loader)}")


## 10. Optimizer & LR Schedule

In [None]:

trainable_params = list(mm_module.get_trainable_params())
print("Trainable parameter count:", sum(p.numel() for p in trainable_params))

optimizer = AdamW(
    trainable_params,
    lr=cfg.learning_rate,
    weight_decay=cfg.weight_decay,
)

num_training_steps = len(train_loader) * cfg.num_epochs
warmup_steps = int(num_training_steps * cfg.warmup_ratio)
print(f"Total training steps: {num_training_steps}, warmup steps: {warmup_steps}")

# We'll use a simple linear warmup + cosine decay implemented manually
def get_lr(step: int) -> float:
    if step < warmup_steps:
        return cfg.learning_rate * (step + 1) / max(1, warmup_steps)
    progress = (step - warmup_steps) / max(1, num_training_steps - warmup_steps)
    progress = min(max(progress, 0.0), 1.0)
    # Cosine from LR -> 0
    return cfg.learning_rate * 0.5 * (1.0 + math.cos(math.pi * progress))


## 11. Weights & Biases Init

In [None]:

wandb_run = wandb.init(
    project=cfg.wandb_project,
    name=cfg.wandb_run_name,
    entity=cfg.wandb_entity,
    config=asdict(cfg),
)

wandb.watch(mm_module.projector, log="all", log_freq=100)


## 12. Training & Evaluation Helpers

In [None]:

def sample_single_frame_per_clip(frames_batch: List[List[PILImage.Image]]) -> List[PILImage.Image]:
    """
    Given batch of clips (each clip is a list of frames),
    randomly pick ONE frame from each clip.
    """
    picked = []
    for frames in frames_batch:
        if len(frames) == 0:
            raise RuntimeError("Clip has zero frames; check VALOR data.")
        idx = random.randint(0, len(frames) - 1)
        picked.append(frames[idx])
    return picked


@torch.no_grad()
def evaluate_epoch(step: int) -> Dict[str, float]:
    mm_module.eval()
    total_loss = 0.0
    n_batches = 0

    all_samples = []

    for batch in tqdm(val_loader, desc=f"Eval @ step {step}", leave=False):
        # single-frame mode
        images = sample_single_frame_per_clip(batch["frames"])

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = mm_module(
            images=images,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs["loss"]
        total_loss += loss.item()
        n_batches += 1

        # qualitative samples
        if len(all_samples) < cfg.num_eval_samples:
            for img, cap in zip(images, batch["captions"]):
                if len(all_samples) >= cfg.num_eval_samples:
                    break
                try:
                    gen = mm_module.generate(
                        images=img,
                        prompt="Describe this video frame in one or two sentences:",
                        max_new_tokens=64,
                        temperature=0.7,
                    )
                except Exception as e:
                    gen = f"[GEN_ERROR: {e}]"
                all_samples.append(
                    {
                        "image": wandb.Image(img),
                        "gt_caption": cap,
                        "gen_caption": gen,
                    }
                )

    avg_loss = total_loss / max(1, n_batches)
    metrics = {"val_loss": avg_loss}

    if all_samples:
        table = wandb.Table(columns=["image", "gt_caption", "gen_caption"])
        for s in all_samples:
            table.add_data(s["image"], s["gt_caption"], s["gen_caption"])
        wandb.log({"eval_samples": table, "global_step": step})

    return metrics


## 13. Main Training Loop

In [None]:

global_step = 0
best_val_loss = math.inf

history = []

for epoch in range(cfg.num_epochs):
    print(f"\n===== Epoch {epoch+1}/{cfg.num_epochs} =====")
    mm_module.train()
    epoch_loss = 0.0
    n_batches = 0

    for batch in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"):
        images = sample_single_frame_per_clip(batch["frames"])

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # update LR
        lr = get_lr(global_step)
        for pg in optimizer.param_groups:
            pg["lr"] = lr

        optimizer.zero_grad(set_to_none=True)

        outputs = mm_module(
            images=images,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs["loss"]
        loss.backward()

        torch.nn.utils.clip_grad_norm_(trainable_params, cfg.max_grad_norm)
        optimizer.step()

        step_loss = loss.item()
        epoch_loss += step_loss
        n_batches += 1
        global_step += 1

        if global_step % cfg.log_every == 0:
            avg_loss = epoch_loss / max(1, n_batches)
            wandb.log(
                {
                    "train_loss": step_loss,
                    "train_loss_avg": avg_loss,
                    "lr": lr,
                    "epoch": epoch + 1,
                    "global_step": global_step,
                }
            )

        if global_step % cfg.eval_every == 0:
            print(f"\n>>> Running evaluation at step {global_step} ...")
            val_metrics = evaluate_epoch(global_step)
            val_loss = val_metrics["val_loss"]
            wandb.log({**val_metrics, "global_step": global_step})

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_path = PHASE3_OUT_DIR / "best_phase3_valor.pt"
                torch.save(
                    {
                        "step": global_step,
                        "epoch": epoch,
                        "model_state_dict": mm_module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "cfg": asdict(cfg),
                    },
                    best_path,
                )
                print(f"[BEST] Saved new best checkpoint to {best_path} (val_loss={val_loss:.4f})")

    avg_epoch_loss = epoch_loss / max(1, n_batches)
    history.append({"epoch": epoch + 1, "avg_train_loss": avg_epoch_loss})
    print(f"Epoch {epoch+1} complete | avg_train_loss={avg_epoch_loss:.4f}")
    wandb.log({"epoch_train_loss": avg_epoch_loss, "epoch": epoch + 1, "global_step": global_step})

# Final evaluation
print("\n>>> Final evaluation after training ...")
final_metrics = evaluate_epoch(global_step)
wandb.log({**final_metrics, "global_step": global_step})

final_path = PHASE3_OUT_DIR / "final_phase3_valor.pt"
torch.save(
    {
        "step": global_step,
        "epoch": cfg.num_epochs,
        "model_state_dict": mm_module.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "cfg": asdict(cfg),
    },
    final_path,
)
print(f"Saved final Phase‑3 checkpoint to {final_path}")


## 14. Save History & Quick Inference Helper

In [None]:

# Save training history
history_path = PHASE3_OUT_DIR / "train_history_phase3_valor.json"
with open(history_path, "w") as f:
    json.dump(history, f, indent=2)
print("Saved training history to:", history_path)

if wandb_run is not None:
    wandb_run.finish()
    print("Closed W&B run.")

# Quick helper for manual inspection in the notebook
def debug_generate_for_clip(idx: int = 0, prompt: str = "Describe this video frame:") -> Dict[str, Any]:
    mm_module.eval()
    clip = full_dataset.clips[idx]
    frames = [ValorMultiFrameDataset._decode_image(b) for b in clip["frames"]]
    img = random.choice(frames)

    gen = mm_module.generate(
        images=img,
        prompt=prompt,
        max_new_tokens=96,
        temperature=0.7,
    )
    return {
        "video_id": clip["video_id"],
        "image": img,
        "caption_gt": clip["caption"],
        "caption_gen": gen,
    }

print("You can now call `debug_generate_for_clip(0)` in the notebook to inspect a sample.")
