In [1]:
import os
import math
import random
import requests
from io import BytesIO
from dataclasses import dataclass
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from PIL import Image

# --- Imports for models ---
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    CLIPVisionModel, 
    CLIPImageProcessor, 
    WhisperModel, 
    WhisperProcessor
)
from datasets import load_dataset
import librosa



In [2]:


# Set device for adapters (Qwen manages its own device map)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using adapter device: {device}")

# ============================================================
# 1. Configuration
# ============================================================
@dataclass
class Config:
    # Checkpoint path (Ensure this matches where you saved it!)
    ckpt_path: Path = Path("./runs_perceiver_mrl_qwen/multimodal_adapter_poc.pt")
    
    # Models
    vision_model_name: str = "openai/clip-vit-base-patch32"
    audio_model_name: str = "openai/whisper-base"
    llm_model_name: str = "Qwen/Qwen2.5-7B-Instruct"
    
    # Architecture Dims (Must match training!)
    perceiver_dim: int = 512
    num_latents: int = 64
    num_perceiver_layers: int = 2
    num_attn_heads: int = 8
    mlp_ratio: float = 4.0
    
    # Filled dynamically later
    encoder_dim_vision: int = 768 
    encoder_dim_audio: int = 512 
    llm_hidden_size: int = 3584
    
    # Data for Eval
    batch_size: int = 8
    vision_max_samples: int = 100  # Small subset for fast eval
    audio_max_samples: int = 100
    seed: int = 42

cfg = Config()


Using adapter device: cuda


In [3]:

# ============================================================
# 2. Architecture Definitions (Must match training code)
# ============================================================

class ModalityAdapter(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        return self.proj(x)

class FeedForward(nn.Module):
    def __init__(self, dim: int, mlp_ratio: float = 4.0):
        super().__init__()
        hidden_dim = int(dim * mlp_ratio)
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x): 
        return self.net(x)

class PerceiverLayer(nn.Module):
    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.ln_latents_1 = nn.LayerNorm(dim)
        self.ln_tokens = nn.LayerNorm(dim)
        self.ln_latents_2 = nn.LayerNorm(dim)
        self.ln_latents_3 = nn.LayerNorm(dim)
        self.mlp = FeedForward(dim, mlp_ratio)

    def forward(self, latents, tokens, token_mask=None):
        # Cross-Attention
        q = self.ln_latents_1(latents)
        kv = self.ln_tokens(tokens)
        # Create key_padding_mask (True for ignored positions)
        key_padding_mask = ~token_mask.bool() if token_mask is not None else None
        
        attn_out, _ = self.cross_attn(q, kv, kv, key_padding_mask=key_padding_mask, need_weights=False)
        latents = latents + attn_out
        
        # Self-Attention
        q2 = self.ln_latents_2(latents)
        self_out, _ = self.self_attn(q2, q2, q2, need_weights=False)
        latents = latents + self_out
        
        # MLP
        latents = latents + self.mlp(self.ln_latents_3(latents))
        return latents

class PerceiverResampler(nn.Module):
    def __init__(self, dim, num_latents, num_layers, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.dim = dim
        self.latents = nn.Parameter(torch.randn(num_latents, dim) / math.sqrt(dim))
        self.layers = nn.ModuleList([
            PerceiverLayer(dim, num_heads, mlp_ratio) for _ in range(num_layers)
        ])

    def forward(self, tokens, token_mask=None):
        B = tokens.shape[0]
        latents = self.latents.unsqueeze(0).expand(B, -1, -1)
        for layer in self.layers:
            latents = layer(latents, tokens, token_mask)
        return latents


In [4]:

# ============================================================
# 3. The Aligned Model Wrapper (Fixed for Device Map)
# ============================================================

class AlignedModel(nn.Module):
    def __init__(self, vision_adapter, audio_adapter, perceiver, projector, qwen_model, qwen_tokenizer):
        super().__init__()
        self.vision_adapter = vision_adapter
        self.audio_adapter = audio_adapter
        self.perceiver = perceiver
        self.projector = projector
        self.qwen_model = qwen_model
        self.qwen_tokenizer = qwen_tokenizer

    def encode_image_features(self, features, mask):
        # Move inputs to the same device as the adapter (which is on 'device')
        adapter_dev = next(self.vision_adapter.parameters()).device
        features = features.to(adapter_dev)
        mask = mask.to(adapter_dev)
        
        tokens = self.vision_adapter(features)
        latents = self.perceiver(tokens, mask)
        z_llm = self.projector(latents)
        return z_llm.mean(dim=1) # Pooling for retrieval

    def encode_audio_features(self, features, mask):
        adapter_dev = next(self.audio_adapter.parameters()).device
        features = features.to(adapter_dev)
        mask = mask.to(adapter_dev)
        
        tokens = self.audio_adapter(features)
        latents = self.perceiver(tokens, mask)
        z_llm = self.projector(latents)
        return z_llm.mean(dim=1)

    def encode_text_raw(self, texts: list[str]):
        # Find Qwen's device for embeddings
        qwen_dev = self.qwen_model.model.embed_tokens.weight.device
        
        enc = self.qwen_tokenizer(
            texts, padding=True, truncation=True, max_length=64, return_tensors="pt"
        ).to(qwen_dev)
        
        with torch.no_grad():
            token_embs = self.qwen_model.get_input_embeddings()(enc.input_ids)
            
        mask = enc.attention_mask.unsqueeze(-1)
        sum_embs = (token_embs * mask).sum(dim=1)
        count = mask.sum(dim=1).clamp(min=1)
        return sum_embs / count

    @torch.no_grad()
    def generate(self, modality_feats, modality_type, prompt_text, max_new_tokens=50):
        """
        Generates text from an image or audio feature input.
        """
        # 1. Project Modality (on adapter device)
        if modality_type == "vision":
            adapter = self.vision_adapter
        elif modality_type == "audio":
            adapter = self.audio_adapter
            
        adapter_dev = next(adapter.parameters()).device
        modality_feats = modality_feats.to(adapter_dev)
        
        # Create mask (B=1)
        mask = torch.ones(1, modality_feats.shape[1], dtype=torch.bool, device=adapter_dev)
        
        tokens = adapter(modality_feats)
        latents = self.perceiver(tokens, mask)
        inputs_embeds_modality = self.projector(latents) # (1, 64, 3584)

        # 2. Embed Text (on Qwen device)
        qwen_dev = self.qwen_model.model.embed_tokens.weight.device
        
        # Move modality embeds to Qwen device for concatenation
        inputs_embeds_modality = inputs_embeds_modality.to(qwen_dev)
        
        text_inputs = self.qwen_tokenizer(prompt_text, return_tensors="pt").to(qwen_dev)
        inputs_embeds_text = self.qwen_model.get_input_embeddings()(text_inputs.input_ids)

        # 3. Concatenate
        final_embeds = torch.cat([inputs_embeds_modality, inputs_embeds_text], dim=1)
        
        # 4. Generate
        out = self.qwen_model.generate(
            inputs_embeds=final_embeds,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            pad_token_id=self.qwen_tokenizer.eos_token_id
        )
        
        return self.qwen_tokenizer.batch_decode(out, skip_special_tokens=True)[0]


In [9]:
# ============================================================
# 4. Data Loading (Fixed: Manual Audio Decoding)
# ============================================================

import io
import librosa
from datasets import load_dataset

def build_eval_dataloaders(cfg):
    """
    Rebuilds minimal datasets for evaluation (PixMo Vision, LibriSpeech Audio).
    
    FIX APPLIED:
    - For Audio, we use .decode(False) to get raw bytes.
    - We manually load bytes via io.BytesIO and librosa to avoid 
      internal Hugging Face decoding errors (torchcodec/ffmpeg issues).
    """
    print("\n--- Building Eval Datasets ---")
    
    # --- 1. Load Encoders (CLIP + Whisper) ---
    print("Loading Encoders for Feature Extraction...")
    # Vision
    clip_proc = CLIPImageProcessor.from_pretrained(cfg.vision_model_name)
    clip_model = CLIPVisionModel.from_pretrained(cfg.vision_model_name).to(device).eval()
    
    # Audio
    whisper_proc = WhisperProcessor.from_pretrained(cfg.audio_model_name)
    whisper_model = WhisperModel.from_pretrained(cfg.audio_model_name).to(device).eval()
    
    # --- 2. Vision Dataset (PixMo) ---
    print("Loading PixMo subset...")
    # Note: PixMo handles images via URLs or PIL objects differently, usually strictly PIL or URLs.
    # We wrap in try/except to be robust.
    pixmo_ds = load_dataset("allenai/pixmo-cap", split="train", streaming=True)
    
    vision_data = []
    count = 0
    for ex in pixmo_ds:
        if count >= cfg.vision_max_samples: break
        
        # Handle diverse column names if needed
        url = ex.get("image_url")
        text = ex.get("caption")
        
        try:
            # Download image from URL
            resp = requests.get(url, timeout=2)
            img = Image.open(BytesIO(resp.content)).convert("RGB")
            
            # Extract Features
            inputs = clip_proc(images=img, return_tensors="pt").to(device)
            with torch.no_grad():
                out = clip_model(**inputs)
                feats = out.last_hidden_state.squeeze(0).cpu() # (T, D) to CPU
            
            vision_data.append({
                "features": feats, 
                "text": text,
                "image_obj": img  # Keep for visualization
            })
            count += 1
        except Exception as e:
            # Skip broken URLs
            continue
            
    print(f"Collected {len(vision_data)} Vision samples.")

    # --- 3. Audio Dataset (LibriSpeech) - MANUAL DECODE FIX ---
    print("Loading LibriSpeech subset (Manual Decoding)...")
    
    # 1. Load streaming with the specific config and split you requested
    librispeech = load_dataset(
        "openslr/librispeech_asr",
        "all",
        streaming=True,
        split="train.clean.100"
    )
    
    # 2. CRITICAL: Turn off automatic decoding to get raw bytes
    audio_stream = librispeech.decode(False)
    
    audio_data = []
    count = 0
    
    # 3. Iterate
    for ex in audio_stream:
        if count >= cfg.audio_max_samples: break
        
        try:
            # Extract raw bytes
            audio_bytes = ex["audio"]["bytes"]
            text = ex["text"]
            
            if audio_bytes is None:
                continue

            # 4. Decode bytes manually using librosa
            # librosa.load accepts a file-like object (BytesIO)
            # It automatically resamples to sr=16000 if specified
            wav, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000)
            
            # 5. Extract Features with Whisper
            inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt")
            input_features = inputs.input_features.to(device)
            
            with torch.no_grad():
                enc_out = whisper_model.encoder(input_features)
                feats = enc_out.last_hidden_state.squeeze(0).cpu() # (T, D)
            
            audio_data.append({
                "features": feats,
                "text": text,
                # We don't store raw audio array to save RAM, just features
            })
            count += 1
            
        except Exception as e:
            print(f"Skipping bad audio sample: {e}")
            continue
        
    print(f"Collected {len(audio_data)} Audio samples.")
    
    # --- 4. Collate Function ---
    def collate_fn(batch):
        feats = [ex["features"] for ex in batch]
        texts = [ex["text"] for ex in batch]
        
        # Pad features
        feats_padded = pad_sequence(feats, batch_first=True) # (B, T_max, D)
        
        # Mask
        B, T_max, _ = feats_padded.shape
        mask = torch.zeros(B, T_max, dtype=torch.bool)
        for i, f in enumerate(feats):
            mask[i, :f.shape[0]] = True
            
        return {
            "encoder_feats": feats_padded,
            "encoder_mask": mask,
            "texts": texts
        }

    # --- 5. DataLoaders ---
    # Drop last to avoid issues with batch norm layers if batch size=1
    vision_loader = DataLoader(vision_data, batch_size=cfg.batch_size, collate_fn=collate_fn, drop_last=False)
    audio_loader = DataLoader(audio_data, batch_size=cfg.batch_size, collate_fn=collate_fn, drop_last=False)
    
    return vision_loader, audio_loader, vision_data # Return raw vision data for viz

In [10]:

# ============================================================
# 5. Load Model Function
# ============================================================

def load_full_model(cfg_obj):
    print("\n--- Loading Full Inference System ---")
    print("1. Loading Frozen Qwen (Auto Device Map)...")
    tokenizer = AutoTokenizer.from_pretrained(cfg_obj.llm_model_name, use_fast=True)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
        
    # Load Qwen with device_map="auto" -> Allows splitting/offloading
    qwen = AutoModelForCausalLM.from_pretrained(
        cfg_obj.llm_model_name, 
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
        device_map="auto",
        trust_remote_code=True
    ).eval()

    print("2. Initializing Adapters...")
    v_adapt = ModalityAdapter(cfg_obj.encoder_dim_vision, cfg_obj.perceiver_dim)
    a_adapt = ModalityAdapter(cfg_obj.encoder_dim_audio, cfg_obj.perceiver_dim)
    
    perc = PerceiverResampler(
        dim=cfg_obj.perceiver_dim, 
        num_latents=cfg_obj.num_latents,
        num_layers=cfg_obj.num_perceiver_layers,
        num_heads=cfg_obj.num_attn_heads,
        mlp_ratio=cfg_obj.mlp_ratio
    )
    
    proj = nn.Linear(cfg_obj.perceiver_dim, cfg_obj.llm_hidden_size)

    # Load Weights
    print(f"3. Loading adapter weights from {cfg_obj.ckpt_path}...")
    try:
        ckpt = torch.load(cfg_obj.ckpt_path, map_location="cpu")
        v_adapt.load_state_dict(ckpt["vision_adapter"])
        a_adapt.load_state_dict(ckpt["audio_adapter"])
        perc.load_state_dict(ckpt["perceiver"])
        proj.load_state_dict(ckpt["projector"])
        print("✅ Weights loaded successfully.")
    except FileNotFoundError:
        print("❌ Checkpoint not found! Using random init (expect garbage results).")

    # MOVE ADAPTERS TO DEVICE (Manual)
    # Qwen handles its own split via device_map="auto"
    # We assume adapters fit on the main execution device (cuda:0 or cpu)
    adapter_device = "cuda" if torch.cuda.is_available() else "cpu"
    v_adapt.to(adapter_device).eval()
    a_adapt.to(adapter_device).eval()
    perc.to(adapter_device).eval()
    proj.to(adapter_device).eval()

    # Create Wrapper (Do NOT call .to() on this wrapper!)
    model = AlignedModel(v_adapt, a_adapt, perc, proj, qwen, tokenizer)
    return model


In [13]:

# ============================================================
# 6. Evaluation Functions
# ============================================================

@torch.no_grad()
def evaluate_retrieval(model, dataloader, modality="vision", num_batches=10):
    model.eval()
    all_mod_embs = []
    all_txt_embs = []
    
    print(f"\nStarting {modality.upper()} retrieval eval...")
    # Get device where adapters live
    adapter_device = next(model.vision_adapter.parameters()).device

    for i, batch in enumerate(tqdm(dataloader, total=min(num_batches, len(dataloader)))):
        if i >= num_batches: break
        
        feats = batch["encoder_feats"].to(adapter_device)
        mask  = batch["encoder_mask"].to(adapter_device)
        texts = batch["texts"]
        
        if modality == "vision":
            z_mod = model.encode_image_features(feats, mask)
        else:
            z_mod = model.encode_audio_features(feats, mask)
            
        z_txt = model.encode_text_raw(texts)
        
        all_mod_embs.append(z_mod.cpu().float()) # Ensure float32 for matmul
        all_txt_embs.append(z_txt.cpu().float())

    z_mod_all = torch.cat(all_mod_embs, dim=0)
    z_txt_all = torch.cat(all_txt_embs, dim=0)
    
    z_mod_all = F.normalize(z_mod_all, dim=-1)
    z_txt_all = F.normalize(z_txt_all, dim=-1)
    
    # Similarity: (N, N)
    sim_matrix = z_mod_all @ z_txt_all.T
    n = sim_matrix.shape[0]
    targets = torch.arange(n)
    
    # Modality -> Text Recall
    preds = sim_matrix.argsort(dim=1, descending=True)
    r1 = (preds[:, 0] == targets).float().mean().item()
    
    # Safe R@5 check (handle small batch sizes)
    k = min(5, n)
    r5 = (preds[:, :k] == targets.unsqueeze(1)).any(dim=1).float().mean().item()
    
    print(f"[{modality.upper()}] Results on {n} samples:")
    print(f"  R@1: {r1:.4f}")
    print(f"  R@5: {r5:.4f}")

def visualize_inference(model, raw_data_list, idx=None):
    """
    Picks an item from the raw dataset list (which contains img objects),
    runs the model, and shows the result.
    """
    if idx is None: idx = random.randint(0, len(raw_data_list) - 1)
    print(f"\n--- Qualitative Inference Sample {idx} ---")
    
    sample = raw_data_list[idx]
    
    # 1. Get Image and Text
    img = sample["image_obj"]
    text_gt = sample["text"]
    
    # 2. Get Features (already computed in dataset)
    # Add batch dim: (T, D) -> (1, T, D)
    feats = sample["features"].unsqueeze(0) 
    
    # 3. Generate
    prompt = "User: Describe this image.\nAssistant:"
    feats = feats.to(dtype=torch.float32)
    gen_text = model.generate(feats, "vision", prompt, max_new_tokens=60)
    
    # 4. Show
    plt.figure(figsize=(6,6))
    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Pred: {gen_text[:60]}...", fontsize=10)
    plt.show()
    
    print(f"Ground Truth: {text_gt}")
    print(f"Prediction:   {gen_text}")


In [14]:

# ============================================================
# 7. Main Execution
# ============================================================

# 1. Build Datasets & Loaders (fresh extract)
vision_loader, audio_loader, raw_vision_data = build_eval_dataloaders(cfg)

# 2. Load the Aligned System
eval_model = load_full_model(cfg)

# 3. Run Metrics
evaluate_retrieval(eval_model, vision_loader, modality="vision", num_batches=20)
evaluate_retrieval(eval_model, audio_loader, modality="audio", num_batches=20)

# 4. Run Visual Check
if len(raw_vision_data) > 0:
    visualize_inference(eval_model, raw_vision_data)


--- Building Eval Datasets ---
Loading Encoders for Feature Extraction...
Loading PixMo subset...
Collected 100 Vision samples.
Loading LibriSpeech subset (Manual Decoding)...


Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Collected 100 Audio samples.

--- Loading Full Inference System ---
1. Loading Frozen Qwen (Auto Device Map)...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

2. Initializing Adapters...
3. Loading adapter weights from runs_perceiver_mrl_qwen/multimodal_adapter_poc.pt...
✅ Weights loaded successfully.

Starting VISION retrieval eval...


  0%|          | 0/13 [00:00<?, ?it/s]

[VISION] Results on 100 samples:
  R@1: 0.0800
  R@5: 0.2700

Starting AUDIO retrieval eval...


  0%|          | 0/13 [00:00<?, ?it/s]

[AUDIO] Results on 100 samples:
  R@1: 0.0200
  R@5: 0.0500

--- Qualitative Inference Sample 11 ---


RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16