In [7]:
import os
import math
import random
import requests
from io import BytesIO
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

# Imports for models
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    CLIPVisionModel, 
    CLIPImageProcessor, 
    WhisperModel, 
    WhisperProcessor
)


In [8]:

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# ============================================================
# 1. Configuration
# ============================================================
@dataclass
class Config:
    # Checkpoint path (Ensure this matches where you saved it!)
    ckpt_path: str = "./runs_perceiver_mrl_qwen/multimodal_adapter_poc.pt"
    
    # Models
    vision_model_name: str = "openai/clip-vit-base-patch32"
    audio_model_name: str = "openai/whisper-base"
    llm_model_name: str = "Qwen/Qwen2.5-7B-Instruct"
    
    # Architecture Dims (Must match training!)
    perceiver_dim: int = 512
    num_latents: int = 64
    num_perceiver_layers: int = 2
    num_attn_heads: int = 8
    mlp_ratio: float = 4.0
    
    # Filled dynamically
    encoder_dim_vision: int = 768 
    encoder_dim_audio: int = 512 
    llm_hidden_size: int = 3584
    
    # Data
    batch_size: int = 8
    seed: int = 42

cfg = Config()


Using device: cuda


In [9]:

# ============================================================
# 2. Architecture Definitions (Must match training code)
# ============================================================

class ModalityAdapter(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        return self.proj(x)

class FeedForward(nn.Module):
    def __init__(self, dim: int, mlp_ratio: float = 4.0):
        super().__init__()
        hidden_dim = int(dim * mlp_ratio)
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x): 
        return self.net(x)

class PerceiverLayer(nn.Module):
    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.ln_latents_1 = nn.LayerNorm(dim)
        self.ln_tokens = nn.LayerNorm(dim)
        self.ln_latents_2 = nn.LayerNorm(dim)
        self.ln_latents_3 = nn.LayerNorm(dim)
        self.mlp = FeedForward(dim, mlp_ratio)

    def forward(self, latents, tokens, token_mask=None):
        # Cross-Attention
        q = self.ln_latents_1(latents)
        kv = self.ln_tokens(tokens)
        key_padding_mask = ~token_mask.bool() if token_mask is not None else None
        attn_out, _ = self.cross_attn(q, kv, kv, key_padding_mask=key_padding_mask, need_weights=False)
        latents = latents + attn_out
        
        # Self-Attention
        q2 = self.ln_latents_2(latents)
        self_out, _ = self.self_attn(q2, q2, q2, need_weights=False)
        latents = latents + self_out
        
        # MLP
        latents = latents + self.mlp(self.ln_latents_3(latents))
        return latents

class PerceiverResampler(nn.Module):
    def __init__(self, dim, num_latents, num_layers, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.dim = dim
        self.latents = nn.Parameter(torch.randn(num_latents, dim) / math.sqrt(dim))
        self.layers = nn.ModuleList([
            PerceiverLayer(dim, num_heads, mlp_ratio) for _ in range(num_layers)
        ])

    def forward(self, tokens, token_mask=None):
        B = tokens.shape[0]
        latents = self.latents.unsqueeze(0).expand(B, -1, -1)
        for layer in self.layers:
            latents = layer(latents, tokens, token_mask)
        return latents


In [10]:

# ============================================================
# 3. The Aligned Model Wrapper
# ============================================================

class AlignedModel(nn.Module):
    def __init__(self, vision_adapter, audio_adapter, perceiver, projector, qwen_model, qwen_tokenizer):
        super().__init__()
        self.vision_adapter = vision_adapter
        self.audio_adapter = audio_adapter
        self.perceiver = perceiver
        self.projector = projector
        self.qwen_model = qwen_model
        self.qwen_tokenizer = qwen_tokenizer

    def encode_image_features(self, features, mask):
        # Move inputs to the same device as the adapter
        device = next(self.vision_adapter.parameters()).device
        features = features.to(device)
        mask = mask.to(device)
        
        tokens = self.vision_adapter(features)
        latents = self.perceiver(tokens, mask)
        z_llm = self.projector(latents)
        return z_llm.mean(dim=1) # Pooling for retrieval

    def encode_audio_features(self, features, mask):
        device = next(self.audio_adapter.parameters()).device
        features = features.to(device)
        mask = mask.to(device)
        
        tokens = self.audio_adapter(features)
        latents = self.perceiver(tokens, mask)
        z_llm = self.projector(latents)
        return z_llm.mean(dim=1)

    def encode_text_raw(self, texts: list[str]):
        # Find Qwen's device (it might be spread across GPUs, get input embedding device)
        qwen_device = self.qwen_model.model.embed_tokens.weight.device
        
        enc = self.qwen_tokenizer(
            texts, padding=True, truncation=True, max_length=64, return_tensors="pt"
        ).to(qwen_device)
        
        with torch.no_grad():
            token_embs = self.qwen_model.get_input_embeddings()(enc.input_ids)
            
        mask = enc.attention_mask.unsqueeze(-1)
        sum_embs = (token_embs * mask).sum(dim=1)
        count = mask.sum(dim=1).clamp(min=1)
        return sum_embs / count

    @torch.no_grad()
    def generate(self, modality_feats, modality_type, prompt_text, max_new_tokens=50):
        """
        Generates text from an image or audio feature input.
        """
        # 1. Project Modality
        if modality_type == "vision":
            adapter = self.vision_adapter
        elif modality_type == "audio":
            adapter = self.audio_adapter
            
        device = next(adapter.parameters()).device
        modality_feats = modality_feats.to(device)
        
        # Create mask (B=1)
        mask = torch.ones(1, modality_feats.shape[1], dtype=torch.bool, device=device)
        
        # Forward pass through adapters
        tokens = adapter(modality_feats)
        latents = self.perceiver(tokens, mask)
        inputs_embeds_modality = self.projector(latents) # (1, 64, 3584)

        # 2. Embed Text
        qwen_device = self.qwen_model.model.embed_tokens.weight.device
        
        # Move modality embeds to Qwen device for concatenation
        inputs_embeds_modality = inputs_embeds_modality.to(qwen_device)
        
        text_inputs = self.qwen_tokenizer(prompt_text, return_tensors="pt").to(qwen_device)
        inputs_embeds_text = self.qwen_model.get_input_embeddings()(text_inputs.input_ids)

        # 3. Concatenate
        final_embeds = torch.cat([inputs_embeds_modality, inputs_embeds_text], dim=1)
        
        # 4. Generate
        out = self.qwen_model.generate(
            inputs_embeds=final_embeds,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            pad_token_id=self.qwen_tokenizer.eos_token_id
        )
        
        return self.qwen_tokenizer.batch_decode(out, skip_special_tokens=True)[0]


In [11]:

# ============================================================
# 4. Data Loading Logic (Placeholder / Re-used from your nb)
# ============================================================

# Assuming 'vision_loader' and 'audio_loader' exist from your previous cells.
# If running this as a standalone script, you would need to recreate the datasets here.
# For the purpose of this "Eval Only" block, we will assume variables exist.


### Phase-3: - Load the Dataset

#### Load the Audio Dataset

In [12]:
# ============================================
# Part 3 – LibriSpeech (Streaming) Audio–Text Dataset
# ============================================

from datasets import load_dataset
import io
import librosa
import numpy as np

print("\nLoading LibriSpeech ASR (streaming mode)...")

# Load only train.clean.100 from the giant 124GB dataset
librispeech_raw = load_dataset(
    "openslr/librispeech_asr",
    "all",
    streaming=True,
    split="train.clean.100"
)

print("Loaded streaming dataset:", librispeech_raw)

# Disable automatic decoding → we want raw bytes for librosa
audio_stream = librispeech_raw.decode(False)

# We will collect up to cfg.librispeech_max_samples
max_samples = cfg.librispeech_max_samples  # rename in your config if needed
subset = []

print(f"\nTaking up to {max_samples} examples in streaming mode...")

for ex in audio_stream:
    subset.append(ex)
    if len(subset) >= max_samples:
        break

print("\nSubset collected:", len(subset))
print("Keys:", subset[0].keys())
print("Example 0:", subset[0])



Loading LibriSpeech ASR (streaming mode)...


Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Loaded streaming dataset: IterableDataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_shards: 14
})


AttributeError: 'Config' object has no attribute 'librispeech_max_samples'

In [None]:
# Helper: convert LibriSpeech streaming example → waveform
def load_waveform_from_streaming_example(example, target_sr=16000):
    audio_info = example["audio"]

    audio_bytes = audio_info["bytes"]
    if audio_bytes is None:
        raise ValueError("No audio bytes in example.")

    # Convert raw bytes → file-like object
    audio_file = io.BytesIO(audio_bytes)

    # librosa loads PCM data and resamples to target_sr
    wav, sr = librosa.load(audio_file, sr=target_sr)

    return wav, sr


# Helper: compute duration in seconds
def compute_duration(wav, sr):
    return len(wav) / float(sr)


In [None]:
# We'll filter to keep only clips <= cfg.max_audio_duration_s
filtered = []

print("\nFiltering by duration ≤", cfg.max_audio_duration_s, "seconds...")

for ex in subset:
    wav, sr = load_waveform_from_streaming_example(ex, cfg.audio_sample_rate)
    dur = compute_duration(wav, sr)

    if dur <= cfg.max_audio_duration_s:
        filtered.append({
            "waveform": wav,
            "sampling_rate": sr,
            "duration": dur,
            "text": ex["text"]
        })

print("After duration filtering:", len(filtered), "examples")



Filtering by duration ≤ 12.0 seconds...
After duration filtering: 726 examples


In [None]:
print("\nShowing a few filtered samples...")

for i in range(min(5, len(filtered))):
    ex = filtered[i]
    print(f"\nSample {i}:")
    print("  Duration:", round(ex["duration"], 2), "s")
    print("  Transcript:", ex["text"])
    print("  Waveform shape:", ex["waveform"].shape)



Showing a few filtered samples...

Sample 0:
  Duration: 11.12 s
  Transcript: ASSUMED ALL AT ONCE AN APPEARANCE OF NOISE AND DISORDER NEVER BELIEVE HOWEVER DISINTERESTED THE LOVE OF A KEPT WOMAN MAY BE THAT IT WILL COST ONE NOTHING
  Waveform shape: (178000,)

Sample 1:
  Duration: 8.3 s
  Transcript: WHOSE ONLY DEFECT IS THAT THEY HAVE NOT TWO HUNDRED THOUSAND FRANCS A YEAR I NEED NOT TELL YOU OF THOSE WHO CHEAT AT PLAY
  Waveform shape: (132880,)

Sample 2:
  Duration: 2.96 s
  Transcript: IT WAS IMPOSSIBLE TO RESIST AN EXISTENCE
  Waveform shape: (47360,)

Sample 3:
  Duration: 8.61 s
  Transcript: WHEN AN ADROIT GAMBLER WOULD HAVE LEFT IT SETTLING ONE THING AGAINST ANOTHER I FOUND MYSELF IN POSSESSION OF SOME TEN THOUSAND FRANCS
  Waveform shape: (137680,)

Sample 4:
  Duration: 9.14 s
  Transcript: MARGUERITE WAS AWAKENED BY THE SUNLIGHT POURING INTO HER ROOM AND JUMPING OUT OF BED ASKED ME IF I WOULD TAKE HER INTO THE COUNTRY FOR THE WHOLE DAY
  Waveform shape: (146320,)


In [None]:

# ============================================
# New PixmoVisionDataset (uses HF 'image' column if available)
# ============================================

from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import requests
from io import BytesIO
import random

print("\nLoading PixMo-Cap vision–text dataset (allenai/pixmo-cap)...")

pixmo_raw = load_dataset("allenai/pixmo-cap", split="train")
print("PixMo-Cap split size:", len(pixmo_raw))
print("PixMo columns:", pixmo_raw.column_names)

# We only need a small subset for the POC
vision_max = getattr(cfg, "vision_max_samples", 2048)
if len(pixmo_raw) > vision_max:
    pixmo_subset = pixmo_raw.shuffle(seed=cfg.seed).select(range(vision_max))
else:
    pixmo_subset = pixmo_raw

print("PixMo subset size:", len(pixmo_subset))

# Fields from the dataset card:
#  - "image_url": URL to the image
#  - "caption": long caption text
img_col = "image_url"
txt_col = "caption"

cols = pixmo_raw.column_names
HAS_IMAGE_COL = "image" in cols

if HAS_IMAGE_COL:
    img_col = "image"
else:
    img_col = "image_url"

txt_col = "caption"

print(f"Using image column: {img_col}")




Loading PixMo-Cap vision–text dataset (allenai/pixmo-cap)...
PixMo-Cap split size: 717042
PixMo columns: ['image_url', 'caption', 'transcripts']
PixMo subset size: 2048
Using image column: image_url


In [None]:

class PixmoVisionDataset(Dataset):
    """
    On-the-fly image loading + CLIP feature extraction.

    If 'image' column exists: uses HF-managed images (no manual HTTP).
    Else: falls back to 'image_url' with robust skipping of bad URLs.

    Returns:
        {
          "features": Tensor(T, d_vision),
          "text": str
        }
    """
    def __init__(self, hf_dataset, vision_model, vision_processor, max_retries: int = 5):
        self.ds = hf_dataset
        self.vision_model = vision_model
        self.vision_processor = vision_processor
        self.max_retries = max_retries

    def __len__(self):
        return len(self.ds)

    def _load_image_from_url(self, url: str) -> Image.Image:
        resp = requests.get(url, timeout=10)
        # do NOT let this propagate; we'll catch in __getitem__
        resp.raise_for_status()
        img = Image.open(BytesIO(resp.content)).convert("RGB")
        return img

    def _encode_image(self, img: Image.Image):
        proc = self.vision_processor(images=img, return_tensors="pt")
        pixel_values = proc["pixel_values"].to(device)

        with torch.no_grad():
            out = self.vision_model(pixel_values=pixel_values)
            # (1, T, d_vision)
            feats = out.last_hidden_state.squeeze(0).to("cpu")  # (T, d_vision)
        return feats

    def _get_example(self, idx: int):
        ex = self.ds[idx]
        caption = ex[txt_col]

        if HAS_IMAGE_COL:
            # HF has already downloaded/cached images; this is usually a PIL.Image
            img = ex[img_col]
            if not isinstance(img, Image.Image):
                img = img.convert("RGB")
        else:
            url = ex[img_col]
            img = self._load_image_from_url(url)

        feats = self._encode_image(img)
        return {
            "features": feats,
            "text": caption,
        }

    def __getitem__(self, idx: int):
        """
        Try up to max_retries times with different indices if something fails
        (HTTP error, decoding error, etc).
        """
        n = len(self.ds)
        attempt = 0
        cur_idx = idx

        while attempt < self.max_retries:
            try:
                return self._get_example(cur_idx)
            except Exception as e:
                # print(f"[PixmoVisionDataset] Failed idx={cur_idx}, attempt={attempt+1}, err={e}")
                attempt += 1
                cur_idx = (cur_idx + 1) % n

        # Final fallback: try random indices
        for _ in range(self.max_retries):
            j = random.randint(0, n - 1)
            try:
                return self._get_example(j)
            except Exception:
                continue

        raise RuntimeError("PixmoVisionDataset: could not load any valid images after multiple retries.")


### Part-4:- 

In [None]:
# ============================================
# Part 4 – Audio features dataset (LibriSpeech + Whisper)
# ============================================

from torch.utils.data import Dataset

# We assume:
#  - `filtered` has been built in Part 3 (streaming LibriSpeech)
#  - Each entry: {"waveform": np.ndarray, "sampling_rate": int, "duration": float, "text": str}
print("\nBuilding LibriSpeech audio–text dataset from filtered streaming subset...")
print("Filtered LibriSpeech examples:", len(filtered))


def whisper_encode_sequence(wav: np.ndarray, sr: int):
    """
    wav: 1D numpy array (time,)
    sr:  sampling rate (expected 16k)
    Returns:
        feats: Tensor(T_enc, d_audio) on CPU (float16)
    """
    # WhisperProcessor: raw waveform -> log-Mel spectrogram features
    inputs = audio_processor(
        wav,
        sampling_rate=sr,
        return_tensors="pt",
    )
    input_features = inputs["input_features"].to(device)  # (1, T_mel, 80)

    with torch.no_grad():
        enc_out = audio_model.encoder(input_features)
        hidden = enc_out.last_hidden_state  # (1, T_enc, d_audio)

    feats = hidden.squeeze(0).to(torch.float16).cpu()  # (T_enc, d_audio)
    return feats


class LibriSpeechAudioDataset(Dataset):
    """
    Dataset over the in-memory filtered LibriSpeech examples.
    Returns:
        {
          "features": Tensor(T_enc, d_audio),
          "text": str,
          "duration": float
        }
    """
    def __init__(self, examples, max_len: int | None = None):
        self.examples = examples
        if max_len is not None and max_len < len(examples):
            # Optionally cut down further for faster experiments
            self.examples = examples[:max_len]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx: int):
        ex = self.examples[idx]
        wav = ex["waveform"]
        sr = ex["sampling_rate"]
        text = ex["text"]
        dur = ex["duration"]

        feats = whisper_encode_sequence(wav, sr)  # (T_enc, d_audio)

        return {
            "features": feats,
            "text": text,
            "duration": dur,
        }


audio_max = getattr(cfg, "librispeech_max_samples", len(filtered))
audio_dataset = LibriSpeechAudioDataset(filtered, max_len=audio_max)

print("Audio dataset ready. Example:")
sample_a = audio_dataset[0]
print("  features shape:", sample_a["features"].shape)
print("  duration:", round(sample_a["duration"], 2), "s")
print("  text:", sample_a["text"])



Building LibriSpeech audio–text dataset from filtered streaming subset...
Filtered LibriSpeech examples: 726
Audio dataset ready. Example:
  features shape: torch.Size([1500, 512])
  duration: 11.12 s
  text: ASSUMED ALL AT ONCE AN APPEARANCE OF NOISE AND DISORDER NEVER BELIEVE HOWEVER DISINTERESTED THE LOVE OF A KEPT WOMAN MAY BE THAT IT WILL COST ONE NOTHING


In [None]:

# ============================================================
# 5. Load Model Function (THE FIX IS HERE)
# ============================================================

def load_full_model(cfg_obj):
    print("Loading Frozen Qwen...")
    tokenizer = AutoTokenizer.from_pretrained(cfg_obj.llm_model_name, use_fast=True)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
        
    # Load Qwen with device_map="auto"
    qwen = AutoModelForCausalLM.from_pretrained(
        cfg_obj.llm_model_name, 
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
        device_map="auto"
    ).eval()

    print("Initializing Adapters...")
    v_adapt = ModalityAdapter(cfg_obj.encoder_dim_vision, cfg_obj.perceiver_dim)
    a_adapt = ModalityAdapter(cfg_obj.encoder_dim_audio, cfg_obj.perceiver_dim)
    
    perc = PerceiverResampler(
        dim=cfg_obj.perceiver_dim, 
        num_latents=cfg_obj.num_latents,
        num_layers=cfg_obj.num_perceiver_layers,
        num_heads=cfg_obj.num_attn_heads,
        mlp_ratio=cfg_obj.mlp_ratio
    )
    
    proj = nn.Linear(cfg_obj.perceiver_dim, cfg_obj.llm_hidden_size)

    # Load Weights
    print(f"Loading weights from {cfg_obj.ckpt_path}...")
    try:
        ckpt = torch.load(cfg_obj.ckpt_path, map_location="cpu")
        v_adapt.load_state_dict(ckpt["vision_adapter"])
        a_adapt.load_state_dict(ckpt["audio_adapter"])
        perc.load_state_dict(ckpt["perceiver"])
        proj.load_state_dict(ckpt["projector"])
        print("✅ Weights loaded successfully.")
    except FileNotFoundError:
        print("❌ Checkpoint not found! Using random init for testing.")

    # MOVE ADAPTERS TO DEVICE (Manual)
    # We assume standard single-GPU inference for the adapters for simplicity
    # Qwen handles its own split via device_map="auto"
    adapter_device = "cuda" if torch.cuda.is_available() else "cpu"
    v_adapt.to(adapter_device).eval()
    a_adapt.to(adapter_device).eval()
    perc.to(adapter_device).eval()
    proj.to(adapter_device).eval()

    # Create Wrapper (Do NOT call .to() on this wrapper!)
    model = AlignedModel(v_adapt, a_adapt, perc, proj, qwen, tokenizer)
    return model


In [None]:

# ============================================================
# 6. Evaluation Functions
# ============================================================

@torch.no_grad()
def evaluate_retrieval(model, dataloader, modality="vision", num_batches=10):
    model.eval()
    all_mod_embs = []
    all_txt_embs = []
    
    print(f"\nStarting {modality.upper()} retrieval eval...")
    adapter_device = next(model.vision_adapter.parameters()).device

    for i, batch in enumerate(tqdm(dataloader, total=num_batches)):
        if i >= num_batches: break
        
        feats = batch["encoder_feats"].to(adapter_device)
        mask  = batch["encoder_mask"].to(adapter_device)
        texts = batch["texts"]
        
        if modality == "vision":
            z_mod = model.encode_image_features(feats, mask)
        else:
            z_mod = model.encode_audio_features(feats, mask)
            
        z_txt = model.encode_text_raw(texts)
        
        all_mod_embs.append(z_mod.cpu().float()) # Ensure float32 for matmul
        all_txt_embs.append(z_txt.cpu().float())

    z_mod_all = torch.cat(all_mod_embs, dim=0)
    z_txt_all = torch.cat(all_txt_embs, dim=0)
    
    z_mod_all = F.normalize(z_mod_all, dim=-1)
    z_txt_all = F.normalize(z_txt_all, dim=-1)
    
    sim_matrix = z_mod_all @ z_txt_all.T
    n = sim_matrix.shape[0]
    targets = torch.arange(n)
    
    preds = sim_matrix.argsort(dim=1, descending=True)
    r1 = (preds[:, 0] == targets).float().mean().item()
    r5 = (preds[:, :5] == targets.unsqueeze(1)).any(dim=1).float().mean().item()
    
    print(f"[{modality.upper()}] R@1: {r1:.4f} | R@5: {r5:.4f}")

def visualize_inference(model, dataset, raw_dataset, idx=None):
    if idx is None: idx = random.randint(0, len(dataset) - 1)
    print(f"\n--- Qualitative Inference Sample {idx} ---")
    
    # 1. Get Image
    raw_sample = raw_dataset[idx]
    img = None
    if "image" in raw_sample and raw_sample["image"]:
        img = raw_sample["image"]
    elif "image_url" in raw_sample:
        try:
            resp = requests.get(raw_sample["image_url"], timeout=2)
            img = Image.open(BytesIO(resp.content)).convert("RGB")
        except: pass
            
    # 2. Get Features
    model_input = dataset[idx]
    feats = model_input["features"].unsqueeze(0) # (1, T, D)
    
    # 3. Generate
    prompt = "User: Describe this image.\nAssistant:"
    gen_text = model.generate(feats, "vision", prompt)
    
    # 4. Show
    if img:
        plt.figure(figsize=(5,5))
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"Pred: {gen_text[:50]}...", fontsize=10)
        plt.show()
    
    print(f"GT:   {model_input['text']}")
    print(f"Pred: {gen_text}")


In [None]:

# ============================================================
# 7. Main Execution
# ============================================================

# 1. Load the system
eval_model = load_full_model(cfg)

# 2. Run Qualitative Check (Vision)
# Requires 'vision_dataset' and 'pixmo_subset' from previous cells
if 'vision_dataset' in globals():
    visualize_inference(eval_model, vision_dataset, pixmo_subset)

# 3. Run Metrics
# Requires 'vision_loader' and 'audio_loader'
if 'vision_loader' in globals():
    evaluate_retrieval(eval_model, vision_loader, modality="vision", num_batches=20)
if 'audio_loader' in globals():
    evaluate_retrieval(eval_model, audio_loader, modality="audio", num_batches=20)