# 05 — Multimodal LLM Inference & Evaluation (VALOR + Perceiver Alignment)

This notebook performs **pure inference** with the trained multimodal LLM (image–text via Perceiver alignment, audio available in VALOR clips), using the Phase‑3 checkpoint trained on VALOR.

It will:
- Load the **Phase‑1 Perceiver alignment** checkpoint and **Phase‑3 multimodal LLM** checkpoint.
- Build lightweight VALOR clip‑level dataset and sample random **image + caption + audio** clips.
- Run **qualitative generations** (single‑sample debug view).
- Run **batched evaluation** on random clips and compute simple text similarity metrics.
- Produce a few **plots for explainability**, such as:
  - Distribution of generated caption lengths vs ground‑truth.
  - Simple lexical similarity scores between generated and ground‑truth captions.
  - A small table of qualitative examples.
  
> **NOTE:** This notebook assumes the same project layout as the Phase‑3 training notebook
> (`04_multimodal_llm_decoder_training_valor_perceiver.ipynb`). Adjust paths marked with `# <-- EDIT` if needed.


In [1]:
import os
import io
import math
import random
import json
from pathlib import Path
from dataclasses import asdict
from typing import Dict, Any, List, Optional

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import pandas as pd
import numpy as np
from PIL import Image as PILImage
from PIL import UnidentifiedImageError
from tqdm.auto import tqdm

import matplotlib.pyplot as plt

# Optional: enable inline plots if running in a notebook
# %matplotlib inline

# Project imports (must match your codebase)
from imports.core import VisionEncoder, set_seed, count_parameters
from imports.llm_integration import LLMConfig, MultimodalLLM
from imports.multimodal_alignment_perceiver import MultimodalAlignmentConfig, MultimodalAlignmentModel




In [2]:
# === Paths & run configuration (mirror Phase-3 training) ===

ROOT_DIR = Path.cwd()   # <-- EDIT if you typically run notebooks from a different root

# Phase‑1 Perceiver alignment checkpoint (used to rebuild aligner)
PHASE1_PERCEIVER_CKPT_PATH = ROOT_DIR / "checkpoints" / "phase1_multimodal" / "perceiver_mrl" / "best.pt"  # <-- EDIT

# VALOR shards directory (same as training)
VALOR_SHARDS_DIR = ROOT_DIR / "data" / "alignment_subsets" / "valor32k_train_shards"  # <-- EDIT

PHASE3_RUN_NAME = "valor_qwen2p5_phase3_perceiver_v1"  # <-- EDIT if you changed run name
PHASE3_OUT_DIR = ROOT_DIR / "checkpoints" / "phase3_llm_valor_perceiver" / PHASE3_RUN_NAME

BEST_CKPT_PATH = PHASE3_OUT_DIR / "best_phase3_valor_perceiver.pt"   # saved during training loop
FINAL_CKPT_PATH = PHASE3_OUT_DIR / "final_phase3_valor_perceiver.pt" # saved at the end

print("ROOT_DIR                     :", ROOT_DIR)
print("PHASE1_PERCEIVER_CKPT_PATH   :", PHASE1_PERCEIVER_CKPT_PATH)
print("VALOR_SHARDS_DIR             :", VALOR_SHARDS_DIR)
print("PHASE3_OUT_DIR               :", PHASE3_OUT_DIR)
print("BEST_CKPT_PATH               :", BEST_CKPT_PATH)
print("FINAL_CKPT_PATH              :", FINAL_CKPT_PATH)

assert PHASE1_PERCEIVER_CKPT_PATH.is_file(), f"Phase‑1 Perceiver checkpoint not found: {PHASE1_PERCEIVER_CKPT_PATH}"
assert VALOR_SHARDS_DIR.is_dir(), f"VALOR shards directory not found: {VALOR_SHARDS_DIR}"
assert BEST_CKPT_PATH.is_file() or FINAL_CKPT_PATH.is_file(), "No Phase‑3 checkpoint found (best or final)."

# Device / dtype (same heuristic as in training)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
    dtype = torch.bfloat16
elif torch.cuda.is_available():
    dtype = torch.float16
else:
    dtype = torch.float32

print("Using device:", device)
print("Using dtype :", dtype)


ROOT_DIR                     : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base
PHASE1_PERCEIVER_CKPT_PATH   : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase1_multimodal/perceiver_mrl/best.pt
VALOR_SHARDS_DIR             : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/alignment_subsets/valor32k_train_shards
PHASE3_OUT_DIR               : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase3_llm_valor_perceiver/valor_qwen2p5_phase3_perceiver_v1
BEST_CKPT_PATH               : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase3_llm_valor_perceiver/valor_qwen2p5_phase3_perceiver_v1/best_phase3_valor_perceiver.pt
FINAL_CKPT_PATH              : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase3_llm_valor_perceiver/valor_qwen2p5_phase3_perceiver_v1/final_phase3_valor_perceiver.pt
Using dev

In [3]:
class ValorMultiFrameDataset(Dataset):
    """Clip-level VALOR dataset (multi-frame, audio + caption).

    This version is identical in spirit to the training notebook, but we only
    use it for **inference** and keep options for subsampling for speed.
    """
    def __init__(
        self,
        df_all: pd.DataFrame,
        tokenizer,
        max_length: int = 96,
        frame_limit_per_clip: Optional[int] = None,
        max_clips: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Group by video_id just like training
        grouped = df_all.groupby("video_id", sort=False)
        self.clips: List[Dict[str, Any]] = []
        total_frames = 0

        for vid, group in grouped:
            frames = group["image_jpegs"].iloc[0]
            audio = group["audio_wav"].iloc[0]
            caption = group["caption"].iloc[0]

            if frame_limit_per_clip is not None and frames is not None:
                frames = frames[:frame_limit_per_clip]

            total_frames += len(frames) if frames is not None else 0

            self.clips.append(
                {
                    "video_id": vid,
                    "frames": frames,   # list of jpeg bytes (or lists)
                    "audio_wav": audio,
                    "caption": str(caption),
                }
            )

            if max_clips is not None and len(self.clips) >= max_clips:
                break

        print("Total clips (videos):", len(self.clips))
        print("Total raw frames stored:", total_frames)

        # Pre-tokenize captions for potential text-based baselines
        self._tok: List[Dict[str, Any]] = []
        for clip in tqdm(self.clips, desc="Tokenizing captions"):
            text = clip["caption"]
            toks = tokenizer(
                text,
                padding="max_length",
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
            )
            self._tok.append({k: v.squeeze(0) for k, v in toks.items()})

    def __len__(self) -> int:
        return len(self.clips)

    @staticmethod
    def _decode_image(jpeg_bytes: bytes) -> PILImage.Image:
        return PILImage.open(io.BytesIO(jpeg_bytes)).convert("RGB")

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        clip = self.clips[idx]
        toks = self._tok[idx]

        frames: List[PILImage.Image] = []

        for cell in clip["frames"]:
            if cell is None:
                continue

            # Handle possible nested lists from parquet
            if isinstance(cell, (list, tuple)):
                for sub in cell:
                    if sub is None:
                        continue
                    try:
                        frames.append(self._decode_image(sub))
                    except UnidentifiedImageError:
                        continue
            else:
                try:
                    frames.append(self._decode_image(cell))
                except UnidentifiedImageError:
                    continue

        if len(frames) == 0:
            # Fallback gray image if all decodes fail
            frames = [PILImage.new("RGB", (224, 224), color=(128, 128, 128))]

        return {
            "video_id": clip["video_id"],
            "frames": frames,
            "audio_wav": clip["audio_wav"],
            "caption": clip["caption"],
            "input_ids": toks["input_ids"],
            "attention_mask": toks["attention_mask"],
        }


def collate_valor(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Simple collate: keep Python lists for frames, stack token tensors."""
    video_ids = [b["video_id"] for b in batch]
    frames = [b["frames"] for b in batch]
    captions = [b["caption"] for b in batch]
    audio_wav = [b["audio_wav"] for b in batch]

    input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)
    attention_mask = torch.stack([b["attention_mask"] for b in batch], dim=0)

    return {
        "video_id": video_ids,
        "frames": frames,
        "captions": captions,
        "audio_wav": audio_wav,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
    }


def sample_single_frame_per_clip(frames_batch: List[List[PILImage.Image]]) -> List[PILImage.Image]:
    """Pick exactly one random frame per clip for image‑conditioned generation."""
    images: List[PILImage.Image] = []
    for frames in frames_batch:
        if len(frames) == 0:
            images.append(PILImage.new("RGB", (224, 224), color=(128, 128, 128)))
        else:
            images.append(random.choice(frames))
    return images


In [14]:
set_seed(42)
perceiver_cfg = MultimodalAlignmentConfig()
perceiver_model = MultimodalAlignmentModel(perceiver_cfg)

# Attach cfg so MultimodalLLM can use aligner.cfg.*
perceiver_cfg.device = device
perceiver_cfg.dtype = dtype
perceiver_model.cfg = perceiver_cfg

ckpt = torch.load(PHASE1_PERCEIVER_CKPT_PATH, map_location=device, weights_only=False)

# Correct: your training script stored the real weights under "model_state"
if isinstance(ckpt, dict) and "model_state" in ckpt and isinstance(ckpt["model_state"], dict):
    state_dict = ckpt["model_state"]
    print("Using 'model_state' from checkpoint as state_dict.")
else:
    # Fallback heuristics if you ever change the saving format
    state_dict = None
    if isinstance(ckpt, dict):
        for key in ["model_state_dict", "alignment_model", "state_dict", "model"]:
            if key in ckpt and isinstance(ckpt[key], dict):
                state_dict = ckpt[key]
                print(f"Using '{key}' from checkpoint as state_dict.")
                break
        if state_dict is None:
            print("No wrapper key found — treating entire checkpoint as state_dict.")
            state_dict = ckpt
    else:
        print("Checkpoint is not a dict with wrapper keys — treating as state_dict.")
        state_dict = ckpt

missing, unexpected = perceiver_model.load_state_dict(state_dict, strict=False)
print("Loaded Perceiver alignment checkpoint.")
print("Missing keys   :", len(missing))
print("Unexpected keys:", len(unexpected))


Using 'model_state' from checkpoint as state_dict.
Loaded Perceiver alignment checkpoint.
Missing keys   : 0
Unexpected keys: 0


In [15]:
missing, unexpected = perceiver_model.load_state_dict(state_dict, strict=False)
print("Loaded Perceiver alignment checkpoint.")
print("Missing keys   :", len(missing))
print("Unexpected keys:", len(unexpected))

# Print full lists for debugging
print("\n=== Missing Keys ===")
for k in missing:
    print(k)

print("\n=== Unexpected Keys ===")
for k in unexpected:
    print(k)


Loaded Perceiver alignment checkpoint.
Missing keys   : 0
Unexpected keys: 0

=== Missing Keys ===

=== Unexpected Keys ===


In [16]:

perceiver_model.to(device)
perceiver_model = perceiver_model.to(dtype=dtype)
perceiver_model.eval()


# Vision backbone (CLIP) – same as in Phase‑1
vision_backbone = VisionEncoder(
    model_name=perceiver_cfg.vision_model_name,
    device=device,
    dtype=dtype,
)
vision_backbone.eval()
for p in vision_backbone.parameters():
    p.requires_grad = False

print("Perceiver params total :", count_parameters(perceiver_model))
print("Vision backbone hidden :", vision_backbone.hidden_size)


[VisionEncoder] Loaded openai/clip-vit-base-patch32, hidden_size=768
Perceiver params total : {'total': 21621760, 'trainable': 21621760, 'frozen': 0}
Vision backbone hidden : 768


In [18]:
# === Rebuild Multimodal LLM (Perceiver aligner + Qwen2.5) and load Phase‑3 checkpoint ===

# Keep cfg small and focused on inference
llm_cfg = LLMConfig(
    model_name="Qwen/Qwen2.5-3B-Instruct",   # <-- EDIT if you changed this in training
    max_new_tokens=128,
    temperature=0.7,
    top_p=0.9,
    num_prefix_tokens=8,                     # must match training
    freeze_llm=True,                         # we are only doing inference
    
)

# Wrap the Perceiver aligner as the "aligner" used by MultimodalLLM
aligner = perceiver_model

mm = MultimodalLLM(
    aligner=aligner,
    llm_config=llm_cfg,
)

mm.to(device)
mm = mm.to(dtype=dtype)

if torch.cuda.device_count() > 1:
    print(f"Wrapping multimodal model in DataParallel over {torch.cuda.device_count()} GPUs")
    mm = nn.DataParallel(mm)

mm_module = mm.module if isinstance(mm, nn.DataParallel) else mm

# Load best checkpoint if available, else fall back to final
ckpt_path = BEST_CKPT_PATH if BEST_CKPT_PATH.is_file() else FINAL_CKPT_PATH
print("Loading Phase‑3 checkpoint from:", ckpt_path)

ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
state_dict = ckpt.get("model_state_dict", ckpt)
missing, unexpected = mm_module.load_state_dict(state_dict, strict=False)
print("Loaded Phase‑3 state_dict.")
print("Missing keys   :", len(missing))
print("Unexpected keys:", len(unexpected))

mm_module.eval()

def count_parameters(model, trainable_only=False):
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())


print("Multimodal params (trainable):", count_parameters(mm_module, trainable_only=True))


[LLMDecoder] Loading Qwen/Qwen2.5-3B-Instruct...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[LLMDecoder] hidden_size=2048, frozen=True
[MultimodalLLM] Projector: 512 → 8 × 2048
Wrapping multimodal model in DataParallel over 2 GPUs
Loading Phase‑3 checkpoint from: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/checkpoints/phase3_llm_valor_perceiver/valor_qwen2p5_phase3_perceiver_v1/best_phase3_valor_perceiver.pt
Loaded Phase‑3 state_dict.
Missing keys   : 149
Unexpected keys: 348
Multimodal params (trainable): 8406016


In [21]:
# === Load VALOR shards into a single DataFrame (lightweight subset) ===

shard_paths = sorted(VALOR_SHARDS_DIR.glob("*.parquet"))
assert len(shard_paths) > 0, f"No parquet shards found in {VALOR_SHARDS_DIR}"

print("Found", len(shard_paths), "VALOR shards.")

dfs = []
for p in shard_paths:
    try:
        print("Loading shard:", p)
        df_shard = pd.read_parquet(p)
        dfs.append(df_shard)
    except Exception as e:
        print(f"  Error loading shard {p}: {e}")

df_all = pd.concat(dfs, ignore_index=True)
print("Full VALOR rows:", len(df_all))
print(df_all.head())

# Optional: subsample for faster inference
MAX_ROWS_FOR_INFERENCE = 100  # <-- EDIT; None = all rows
if MAX_ROWS_FOR_INFERENCE is not None and len(df_all) > MAX_ROWS_FOR_INFERENCE:
    df_all = df_all.sample(n=MAX_ROWS_FOR_INFERENCE, random_state=42)
    print("Subsampled VALOR rows:", len(df_all))



Found 19 VALOR shards.
Loading shard: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/alignment_subsets/valor32k_train_shards/valor32k_train_batch000_shard000.parquet
Loading shard: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/alignment_subsets/valor32k_train_shards/valor32k_train_batch000_shard001.parquet
Loading shard: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/alignment_subsets/valor32k_train_shards/valor32k_train_batch000_shard002.parquet
  Error loading shard /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/alignment_subsets/valor32k_train_shards/valor32k_train_batch000_shard002.parquet: Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.
Loading shard: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/alignment_subsets/valor32k_train_shards/v

In [23]:
# Build tokenizer via the LLM module
tokenizer = mm_module.llm.tokenizer

dataset = ValorMultiFrameDataset(
    df_all=df_all,
    tokenizer=tokenizer,
    max_length=96,
    frame_limit_per_clip=8,
    max_clips=512,     # cap at a few hundred clips for evaluation
)

eval_loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_valor,
)

print("Eval loader ready with", len(dataset), "clips.")


Total clips (videos): 100
Total raw frames stored: 277


Tokenizing captions:   0%|          | 0/100 [00:00<?, ?it/s]

Eval loader ready with 100 clips.


In [28]:
def debug_generate_for_clip(idx: int = 0, prompt: str = "Describe this video frame:") -> Dict[str, Any]:
    """Qualitative inspection: pick one clip, one random frame, and generate a caption."""
    mm_module.eval()

    clip = dataset.clips[idx]
    frames = [ValorMultiFrameDataset._decode_image(b) for b in clip["frames"]]
    img = random.choice(frames) if len(frames) > 0 else PILImage.new("RGB", (224, 224), color=(128, 128, 128))

    # Display the image
    plt.figure(figsize=(4, 4))
    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Video ID: {clip['video_id']}")
    plt.show()

    print("Ground‑truth caption:")
    print(clip["caption"])
    print("\nGenerated caption:")

    gen_text = mm_module.generate(
        images=img,
        prompt=prompt,
        max_new_tokens=96,
        temperature=0.7,
    )

    print(gen_text)

    return {
        "video_id": clip["video_id"],
        "image": img,
        "caption_gt": clip["caption"],
        "caption_gen": gen_text,
    }

print("You can now call:")
print("  debug_generate_for_clip(0)")
print("  debug_generate_for_clip(10, prompt='Give a detailed description:')")

from PIL import Image as PILImage
from typing import Union

@torch.no_grad()
def generate_from_pil_image(
    img: Union[PILImage.Image, str],
    prompt: str = "Describe this image:",
    max_new_tokens: int = 96,
    temperature: float = 0.7,
    top_p: float = 0.9,
    do_sample: bool = True,
) -> str:
    """
    Image-conditioned generation:

    PIL image -> CLIP vision encoder -> Perceiver aligner -> projector -> LLM.
    """
    mm_module.eval()

    # 0. Load image if path
    if isinstance(img, str):
        img = PILImage.open(img).convert("RGB")
    else:
        img = img.convert("RGB")

    # 1. Vision backbone: get vision features (same as in training)
    # If your API differs, adapt this line to whatever you used during Perceiver training:
    feats = vision_backbone.encode_images([img])   # (B, T, D) or (B, D)
    feats = feats.to(device=device, dtype=dtype)

    # 2. Perceiver: features -> aligned vision embedding
    aligner = mm_module.aligner
    z_align = aligner.encode_vision(feats)         # (B, d_align)

    # 3. Project to LLM prefix tokens
    projector = mm_module.projector
    prefix = projector(z_align)                    # (B, num_tokens, d_llm)

    # 4. LLM generate with prefix
    tokenizer = mm_module.llm.tokenizer
    llm_model = mm_module.llm.model

    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    outputs = llm_model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
        pad_token_id=tokenizer.eos_token_id,
        prefix_embeds=prefix,
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)


You can now call:
  debug_generate_for_clip(0)
  debug_generate_for_clip(10, prompt='Give a detailed description:')


In [41]:
from PIL import Image as PILImage
from typing import Union, Optional
from collections import Counter
import torch
import pandas as pd
import random
import matplotlib.pyplot as plt

# ---------- 1. Single-image generator (PIL → features → Perceiver → LLM) ----------

@torch.no_grad()
def generate_from_pil_image(
    img: Union[PILImage.Image, str],
    prompt: str = "Describe this image:",
    max_new_tokens: int = 96,
    temperature: float = 0.7,
    top_p: float = 0.9,
    do_sample: bool = True,
) -> str:
    """
    Image-conditioned generation:

    PIL image/path -> VisionEncoder -> Perceiver aligner -> projector -> LLM.
    """
    mm_module.eval()

    # 0. Load image if path
    if isinstance(img, str):
        img = PILImage.open(img).convert("RGB")
    else:
        img = img.convert("RGB")

    # 1. Vision backbone: get vision features (same as in training)
    # Your VisionEncoder is an nn.Module, so we just call it.
    # It may return either a tensor or a dict — handle both.
    vision_backbone.eval()
    with torch.no_grad():
        feats = vision_backbone([img])  # (B, T, D) or (B, D) or dict

    if isinstance(feats, dict):
        # Try common keys used in encoders; adapt if your VisionEncoder uses a specific one.
        for key in ["image_embeds", "features", "last_hidden_state", "pooler_output"]:
            if key in feats:
                feats = feats[key]
                break

    if not isinstance(feats, torch.Tensor):
        raise TypeError(
            f"VisionEncoder returned type {type(feats)}; expected tensor or dict of tensors. "
            "Please adapt generate_from_pil_image to match your VisionEncoder API."
        )

    feats = feats.to(device=device, dtype=dtype)

    # 2. Perceiver: features -> aligned vision embedding
    aligner = mm_module.aligner                    # MultimodalAlignmentModel
    z_align = aligner.encode_vision(feats)         # (B, d_align)

    # 3. Project to LLM prefix tokens
    projector = mm_module.projector                # VisionToLLMProjector
    prefix = projector(z_align)                    # (B, num_tokens, d_llm)

    # 4. LLM generate with prefix
    tokenizer = mm_module.llm.tokenizer
    llm_model = mm_module.llm.model

    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    outputs = llm_model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
        pad_token_id=tokenizer.eos_token_id,
        prefix_embeds=prefix,
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)


# ---------- 2. Qualitative single-clip debug ----------

def debug_generate_for_clip(idx: int = 0, prompt: str = "Describe this video frame:"):
    """Qualitative inspection: pick one clip, one random frame, and generate a caption."""
    mm_module.eval()

    clip = dataset.clips[idx]

    # Decode stored JPEG bytes to PIL images
    frames = []
    for cell in clip["frames"]:
        if cell is None:
            continue
        if isinstance(cell, (list, tuple)):
            for sub in cell:
                if sub is None:
                    continue
                try:
                    frames.append(ValorMultiFrameDataset._decode_image(sub))
                except Exception:
                    continue
        else:
            try:
                frames.append(ValorMultiFrameDataset._decode_image(cell))
            except Exception:
                continue

    img = random.choice(frames) if len(frames) > 0 else PILImage.new(
        "RGB", (224, 224), color=(128, 128, 128)
    )

    # Display the image
    plt.figure(figsize=(4, 4))
    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Video ID: {clip['video_id']}")
    plt.show()

    print("Ground-truth caption:")
    print(clip["caption"])
    print("\nGenerated caption:")

    gen_text = generate_from_pil_image(
        img,
        prompt=prompt,
        max_new_tokens=96,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
    )

    print(gen_text)

    return {
        "video_id": clip["video_id"],
        "image": img,
        "caption_gt": clip["caption"],
        "caption_gen": gen_text,
    }

print("You can now call:")
print("  debug_generate_for_clip(0)")
print("  debug_generate_for_clip(10, prompt='Give a detailed description:')")


# ---------- 3. Simple lexical similarity helpers ----------

def _normalize_text(s: str) -> str:
    return " ".join(s.lower().strip().split())

def _bag_of_words_similarity(a: str, b: str) -> float:
    """Very simple BoW Jaccard similarity between two strings."""
    a_toks = _normalize_text(a).split()
    b_toks = _normalize_text(b).split()
    if not a_toks or not b_toks:
        return 0.0
    ca = Counter(a_toks)
    cb = Counter(b_toks)
    inter = sum((ca & cb).values())
    union = sum((ca | cb).values())
    return inter / union if union > 0 else 0.0


# ---------- 4. Batch eval over VALOR clips using the fixed image pipeline ----------

@torch.no_grad()
def run_batch_eval(num_batches: int = 32, max_samples: Optional[int] = None) -> pd.DataFrame:
    """
    Run generation on a subset of VALOR clips and compute simple metrics.

    Metrics per sample:
    - `len_gt`, `len_gen`: character lengths
    - `bow_jaccard`: simple bag-of-words Jaccard similarity
    """
    mm_module.eval()

    rows = []
    n_seen = 0

    for b_idx, batch in enumerate(eval_loader):
        if b_idx >= num_batches:
            break

        images = sample_single_frame_per_clip(batch["frames"])  # list[PIL.Image]
        captions_gt = batch["captions"]

        batch_gen = []
        for img in images:
            gen = generate_from_pil_image(
                img,
                prompt="Describe this video frame:",
                max_new_tokens=96,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
            )
            batch_gen.append(gen)

        for vid, gt, gen, img in zip(batch["video_id"], captions_gt, batch_gen, images):
            n_seen += 1
            len_gt = len(gt)
            len_gen = len(gen)
            sim = _bag_of_words_similarity(gt, gen)

            rows.append(
                {
                    "video_id": vid,
                    "caption_gt": gt,
                    "caption_gen": gen,
                    "len_gt": len_gt,
                    "len_gen": len_gen,
                    "bow_jaccard": sim,
                }
            )

            if max_samples is not None and n_seen >= max_samples:
                break

        if max_samples is not None and n_seen >= max_samples:
            break

    df_eval = pd.DataFrame(rows)
    print("Collected eval samples:", len(df_eval))
    return df_eval

# Example call:
# eval_results = run_batch_eval(num_batches=16, max_samples=128)
# eval_results.head()


You can now call:
  debug_generate_for_clip(0)
  debug_generate_for_clip(10, prompt='Give a detailed description:')


In [43]:
# Run a small eval by default (edit parameters)
eval_results = run_batch_eval(num_batches=16, max_samples=128)
eval_results.head()


NotImplementedError: Module [VisionEncoder] is missing the required "forward" function

In [39]:
# === Plots: length distributions and BoW similarity ===

if len(eval_results) > 0:
    plt.figure()
    plt.hist(eval_results["len_gt"], bins=20, alpha=0.5, label="GT length")
    plt.hist(eval_results["len_gen"], bins=20, alpha=0.5, label="Gen length")
    plt.xlabel("Caption length (characters)")
    plt.ylabel("Count")
    plt.title("Ground‑truth vs Generated caption lengths")
    plt.legend()
    plt.show()

    plt.figure()
    plt.hist(eval_results["bow_jaccard"], bins=20)
    plt.xlabel("BoW Jaccard similarity")
    plt.ylabel("Count")
    plt.title("Distribution of lexical similarity (GT vs Gen)")
    plt.show()

    print("Basic statistics:")
    print(eval_results[["len_gt", "len_gen", "bow_jaccard"]].describe())

    # Show a few best and worst examples by similarity
    print("\nTop 5 highest similarity examples:")
    display(eval_results.sort_values("bow_jaccard", ascending=False).head(5))

    print("\nTop 5 lowest similarity examples:")
    display(eval_results.sort_values("bow_jaccard", ascending=True).head(5))
else:
    print("No eval results available to plot.")


NameError: name 'eval_results' is not defined

In [None]:
# === Optional: inspect raw audio for a random clip ===

import matplotlib.pyplot as plt

def debug_audio_for_clip(idx: int = 0):
    """Visualize raw audio waveform (and optionally spectrogram) for one clip.

    NOTE: The current Phase‑3 decoder uses **vision + text**; audio is part of
    the aligned embedding space from Phase‑1, but is not directly consumed by
    the decoder in this notebook. This cell is purely for sanity‑checking the
    audio modality.
    """
    clip = dataset.clips[idx]
    audio_bytes = clip["audio_wav"]
    video_id = clip["video_id"]

    print("Video ID:", video_id)
    print("Caption :", clip["caption"])

    # If audio is stored as raw bytes of a WAV/OGG file, users can decode it with soundfile/torchaudio.
    # Here we just visualize the raw byte length as a placeholder; adapt decoding to your format.
    print("Audio byte length:", len(audio_bytes))

    # If you know the exact format, replace this section with proper decoding + waveform plot.

print("You can inspect audio with: debug_audio_for_clip(0)")


In [26]:
# === Text-only Q&A helper using the Phase-3 LLM decoder ===

import torch

def ask_text(
    prompt: str,
    max_new_tokens: int = 256,
    temperature: float = 0.7,
    top_p: float = 0.9,
    do_sample: bool = True,
):
    """
    Run a pure text-only generation through the underlying decoder LLM.

    This ignores images/audio and just uses the language model that was
    used in Phase-3 (e.g., Qwen2.5-3B-Instruct).
    """
    mm_module.eval()

    tokenizer = mm_module.llm.tokenizer          # already set up in the notebook
    llm_model = mm_module.llm.model          # HF causal LM inside LLMDecoder

    # You can wrap your prompt for chat-style models if you like:
    full_prompt = prompt

    inputs = tokenizer(
        full_prompt,
        return_tensors="pt",
    ).to(device)

    with torch.no_grad():
        outputs = llm_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=tokenizer.eos_token_id,
        )

    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return text

print("You can now call: ask_text('Explain what this model is doing.')")


You can now call: ask_text('Explain what this model is doing.')


In [None]:
resp = ask_text("You are a helpful AI assistant. Explain VALOR dataset in simple terms.")
print(resp)


In [33]:
import torch
import torch.nn as nn
from pathlib import Path
from typing import Optional, Union, List
from PIL import Image as PILImage

from imports.core import VisionEncoder, set_seed
from imports.llm_integration import LLMConfig, MultimodalLLM
from imports.multimodal_alignment_perceiver import (
    MultimodalAlignmentConfig,
    MultimodalAlignmentModel,
)


class EdgeGlassEngine:
    """
    vLLM-style inference wrapper around your Perceiver + VisionEncoder + LLM stack.

    Supports:
      - text-only generation
      - image-conditioned generation (image -> CLIP -> Perceiver -> projector -> LLM)
    """

    def __init__(
        self,
        root_dir: Union[str, Path],
        phase1_perceiver_ckpt: Union[str, Path],
        phase3_ckpt: Union[str, Path],
        llm_name: str = "Qwen/Qwen2.5-3B-Instruct",   # match your training
        num_prefix_tokens: int = 8,                    # match your training
        seed: int = 42,
        device: Optional[torch.device] = None,
    ) -> None:
        set_seed(seed)

        self.root_dir = Path(root_dir)

        # --- device / dtype ---
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device

        if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
            self.dtype = torch.bfloat16
        elif torch.cuda.is_available():
            self.dtype = torch.float16
        else:
            self.dtype = torch.float32

        print(f"[Engine] device = {self.device}, dtype = {self.dtype}")

        # --- 1. Perceiver alignment model ---
        self.align_cfg = MultimodalAlignmentConfig()
        self.aligner = MultimodalAlignmentModel(self.align_cfg)

        # attach cfg like training expects
        self.align_cfg.device = self.device
        self.align_cfg.dtype = self.dtype
        self.aligner.cfg = self.align_cfg

        phase1_perceiver_ckpt = Path(phase1_perceiver_ckpt)
        ckpt1 = torch.load(phase1_perceiver_ckpt, map_location=self.device, weights_only=False)

        # your tri-modal trainer uses `model_state` key
        if isinstance(ckpt1, dict) and "model_state" in ckpt1:
            state_dict = ckpt1["model_state"]
            print("[Engine] Using 'model_state' from Phase-1 checkpoint.")
        else:
            state_dict = ckpt1
            print("[Engine] Phase-1 checkpoint has no 'model_state' wrapper; using full dict.")

        missing, unexpected = self.aligner.load_state_dict(state_dict, strict=False)
        print(f"[Engine] Perceiver loaded. missing={len(missing)}, unexpected={len(unexpected)}")

        self.aligner.to(self.device, dtype=self.dtype)
        self.aligner.eval()

        # --- 2. Vision backbone (CLIP / SigLIP etc.) ---
        self.vision = VisionEncoder(
            model_name=self.align_cfg.vision_model_name,
            device=self.device,
            dtype=self.dtype,
        )
        self.vision.eval()
        for p in self.vision.parameters():
            p.requires_grad = False

        # --- 3. LLM + projector ---
        self.llm_cfg = LLMConfig(
            model_name=llm_name,
            max_new_tokens=128,
            temperature=0.7,
            top_p=0.9,
            num_prefix_tokens=num_prefix_tokens,
            freeze_llm=True,
        )

        mm = MultimodalLLM(
            aligner=self.aligner,
            llm_config=self.llm_cfg,
        ).to(self.device, dtype=self.dtype)

        if torch.cuda.device_count() > 1:
            print(f"[Engine] Using DataParallel over {torch.cuda.device_count()} GPUs")
            mm = nn.DataParallel(mm)

        self.mm = mm
        self.mm_module = mm.module if isinstance(mm, nn.DataParallel) else mm

        # --- 4. Load Phase-3 checkpoint (decoder training) ---
        phase3_ckpt = Path(phase3_ckpt)
        ckpt3 = torch.load(phase3_ckpt, map_location=self.device, weights_only=False)

        if isinstance(ckpt3, dict) and "model_state_dict" in ckpt3:
            mm_state = ckpt3["model_state_dict"]
            print("[Engine] Using 'model_state_dict' from Phase-3 checkpoint.")
        else:
            mm_state = ckpt3
            print("[Engine] Phase-3 checkpoint has no 'model_state_dict'; using full dict.")

        missing, unexpected = self.mm_module.load_state_dict(mm_state, strict=False)
        print(f"[Engine] Multimodal LLM loaded. missing={len(missing)}, unexpected={len(unexpected)}")

        self.mm_module.eval()

        # shortcuts
        self.tokenizer = self.mm_module.llm.tokenizer
        self.llm_model = self.mm_module.llm.model
        self.projector = self.mm_module.projector  # VisionToLLMProjector

    # ------------------------------------------------------------------
    # Text-only generation (no Perceiver / vision needed)
    # ------------------------------------------------------------------
    @torch.no_grad()
    def generate_text(
        self,
        prompt: str,
        max_new_tokens: int = 256,
        temperature: float = 0.7,
        top_p: float = 0.9,
        do_sample: bool = True,
    ) -> str:
        self.mm_module.eval()
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        outputs = self.llm_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    # ------------------------------------------------------------------
    # Image-conditioned generation: PIL / path → CLIP → Perceiver → LLM
    # ------------------------------------------------------------------
    @torch.no_grad()
    def generate_image(
        self,
        image: Union[PILImage.Image, str],
        question: str = "Describe this image.",
        max_new_tokens: int = 128,
        temperature: float = 0.7,
        top_p: float = 0.9,
        do_sample: bool = True,
    ) -> str:
        self.mm_module.eval()

        # 0. Load / normalize image
        if isinstance(image, str):
            img = PILImage.open(image).convert("RGB")
        else:
            img = image.convert("RGB")

        # 1. Vision encoder forward (same as in training)
        self.vision.eval()
        feats = self.vision([img])   # may be tensor or dict

        if isinstance(feats, dict):
            # Adapt key choice to however VisionEncoder returns features
            for key in ["image_embeds", "features", "last_hidden_state", "pooler_output"]:
                if key in feats:
                    feats = feats[key]
                    break

        if not isinstance(feats, torch.Tensor):
            raise TypeError(
                f"VisionEncoder returned {type(feats)}; "
                "update EdgeGlassEngine.generate_image to match your VisionEncoder API."
            )

        feats = feats.to(self.device, dtype=self.dtype)

        # 2. Perceiver aligner: features -> aligned embedding
        z_align = self.aligner.encode_vision(feats)      # (B, d_align)

        # 3. Project to LLM prefix tokens
        prefix = self.projector(z_align)                 # (B, num_tokens, d_llm)

        # 4. LLM generate with prefix
        inputs = self.tokenizer(question, return_tensors="pt").to(self.device)

        outputs = self.llm_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=self.tokenizer.eos_token_id,
            prefix_embeds=prefix,
        )

        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    # ------------------------------------------------------------------
    # Unified API like vLLM: choose modality by arguments
    # ------------------------------------------------------------------
    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        image: Optional[Union[PILImage.Image, str]] = None,
        # audio_features: Optional[torch.Tensor] = None,  # TODO: once you wire audio→LLM
        **gen_kwargs,
    ) -> str:
        """
        If only prompt is given -> text-only.
        If image is given -> image-conditioned generation.
        Later you can extend for audio_features as well.
        """
        if image is None:
            return self.generate_text(prompt, **gen_kwargs)
        else:
            return self.generate_image(image, question=prompt, **gen_kwargs)


In [35]:
ROOT_DIR = "/home/hice1/vchopra37/scratch/projects/edge_glass/code_base/v2_code_base"
PHASE1 = f"{ROOT_DIR}/checkpoints/phase1_multimodal/perceiver_mrl/best.pt"
PHASE3 = f"{ROOT_DIR}/checkpoints/phase3_llm_valor_perceiver/valor_qwen2p5_phase3_perceiver_v1/best_phase3_valor_perceiver.pt"

engine = EdgeGlassEngine(
    root_dir=ROOT_DIR,
    phase1_perceiver_ckpt=PHASE1,
    phase3_ckpt=PHASE3,
)

# 1) Text-only question
print(engine.generate("Explain what the VALOR dataset is in simple terms."))

# 2) Image question
# from PIL import Image
# img = Image.open("/path/to/some/frame.jpg")
# print(engine.generate("What is happening in this scene?", image=img))


[Engine] device = cuda, dtype = torch.bfloat16
[Engine] Using 'model_state' from Phase-1 checkpoint.
[Engine] Perceiver loaded. missing=0, unexpected=0
[VisionEncoder] Loaded openai/clip-vit-base-patch32, hidden_size=768
[LLMDecoder] Loading Qwen/Qwen2.5-3B-Instruct...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[LLMDecoder] hidden_size=2048, frozen=True
[MultimodalLLM] Projector: 512 → 8 × 2048
[Engine] Using DataParallel over 2 GPUs
[Engine] Using 'model_state_dict' from Phase-3 checkpoint.
[Engine] Multimodal LLM loaded. missing=149, unexpected=348
Explain what the VALOR dataset is in simple terms. The VALOR dataset is a collection of images used for training computer vision models, particularly those that can recognize and understand objects or scenes in photographs. Think of it like a big photo album with lots of pictures of different things, but these aren't just any photos - they're specifically chosen to help machines learn to identify and describe various objects or situations accurately.

Imagine you have a robot that needs to recognize toys in a room. The VALOR dataset would provide it with many photos of toys in different positions and settings, helping the robot to better understand and recognize toys in real-world scenarios. This way, when it sees a toy, it can more reliably tell

In [37]:
# 1) Text-only question
print(engine.generate("Hi"))

Hi, I have a question. I need to find the derivative of the function f(x) = x^2 * e^x using the product rule. How can I do that? Sure! To find the derivative of the function \( f(x) = x^2 \cdot e^x \) using the product rule, you'll follow these steps:

1. Identify the two functions being multiplied together. In this case, we have:
   \[
   u(x) = x^2 \quad \text{and} \quad v(x) = e^x
   \]

2. Find the derivatives of these two functions:
   \[
   u'(x) = \frac{d}{dx}(x^2) = 2x
   \]
   \[
   v'(x) = \frac{d}{dx}(e^x) = e^x
   \]

3. Apply the product rule, which states:
   \[
   (u \cdot v)' = u' \cdot v + u \cdot v'
   \]

4. Substitute \( u(x) \), \( u'(x) \), \( v(x) \), and \( v'(x) \) into the product rule formula


In [38]:
img = "/home/hice1/vchopra37/scratch/projects/edge_glass/temp.jpg"
print(engine.generate("What is happening here?", image=img))


NotImplementedError: Module [VisionEncoder] is missing the required "forward" function