In [1]:
# =============================
# 0. Setup
# =============================
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from typing import List, Dict
import random
import time

import os
import math
import random
import requests
from io import BytesIO
from dataclasses import dataclass
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from PIL import Image

# --- Imports for models ---
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    CLIPVisionModel, 
    CLIPImageProcessor, 
    WhisperModel, 
    WhisperProcessor
)
from datasets import load_dataset
import librosa



In [5]:


# Set device for adapters (Qwen manages its own device map)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using adapter device: {device}")

# ============================================================
# 1. Configuration
# ============================================================
@dataclass
class Config:
    # Checkpoint path (Ensure this matches where you saved it!)
    ckpt_path: Path = Path("./runs_perceiver_mrl_qwen/multimodal_adapter_poc_1.pt")
    
    # Models
    vision_model_name: str = "openai/clip-vit-base-patch32"
    audio_model_name: str = "openai/whisper-base"
    llm_model_name: str = "Qwen/Qwen2.5-7B-Instruct"
    
    # Architecture Dims (Must match training!)
    perceiver_dim: int = 512
    num_latents: int = 64
    num_perceiver_layers: int = 2
    num_attn_heads: int = 8
    mlp_ratio: float = 4.0
    
    # Filled dynamically later
    encoder_dim_vision: int = 768 
    encoder_dim_audio: int = 512 
    llm_hidden_size: int = 3584
    
    # Data for Eval
    batch_size: int = 8
    vision_max_samples: int = 100  # Small subset for fast eval
    audio_max_samples: int = 100
    seed: int = 42

cfg = Config()


Using adapter device: cuda


In [2]:

# ============================================================
# 2. Architecture Definitions (Must match training code)
# ============================================================

class ModalityAdapter(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        return self.proj(x)

class FeedForward(nn.Module):
    def __init__(self, dim: int, mlp_ratio: float = 4.0):
        super().__init__()
        hidden_dim = int(dim * mlp_ratio)
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x): 
        return self.net(x)

class PerceiverLayer(nn.Module):
    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.ln_latents_1 = nn.LayerNorm(dim)
        self.ln_tokens = nn.LayerNorm(dim)
        self.ln_latents_2 = nn.LayerNorm(dim)
        self.ln_latents_3 = nn.LayerNorm(dim)
        self.mlp = FeedForward(dim, mlp_ratio)

    def forward(self, latents, tokens, token_mask=None):
        # Cross-Attention
        q = self.ln_latents_1(latents)
        kv = self.ln_tokens(tokens)
        # Create key_padding_mask (True for ignored positions)
        key_padding_mask = ~token_mask.bool() if token_mask is not None else None
        
        attn_out, _ = self.cross_attn(q, kv, kv, key_padding_mask=key_padding_mask, need_weights=False)
        latents = latents + attn_out
        
        # Self-Attention
        q2 = self.ln_latents_2(latents)
        self_out, _ = self.self_attn(q2, q2, q2, need_weights=False)
        latents = latents + self_out
        
        # MLP
        latents = latents + self.mlp(self.ln_latents_3(latents))
        return latents

class PerceiverResampler(nn.Module):
    def __init__(self, dim, num_latents, num_layers, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.dim = dim
        self.latents = nn.Parameter(torch.randn(num_latents, dim) / math.sqrt(dim))
        self.layers = nn.ModuleList([
            PerceiverLayer(dim, num_heads, mlp_ratio) for _ in range(num_layers)
        ])

    def forward(self, tokens, token_mask=None):
        B = tokens.shape[0]
        latents = self.latents.unsqueeze(0).expand(B, -1, -1)
        for layer in self.layers:
            latents = layer(latents, tokens, token_mask)
        return latents


In [3]:

# ============================================================
# 3. The Aligned Model Wrapper (Fixed for Device Map)
# ============================================================

class AlignedModel(nn.Module):
    def __init__(self, vision_adapter, audio_adapter, perceiver, projector, qwen_model, qwen_tokenizer):
        super().__init__()
        self.vision_adapter = vision_adapter
        self.audio_adapter = audio_adapter
        self.perceiver = perceiver
        self.projector = projector
        self.qwen_model = qwen_model
        self.qwen_tokenizer = qwen_tokenizer

    def encode_image_features(self, features, mask):
        # Move inputs to the same device as the adapter (which is on 'device')
        adapter_dev = next(self.vision_adapter.parameters()).device
        features = features.to(adapter_dev)
        mask = mask.to(adapter_dev)
        
        tokens = self.vision_adapter(features)
        latents = self.perceiver(tokens, mask)
        z_llm = self.projector(latents)
        return z_llm.mean(dim=1) # Pooling for retrieval

    def encode_audio_features(self, features, mask):
        adapter_dev = next(self.audio_adapter.parameters()).device
        features = features.to(adapter_dev)
        mask = mask.to(adapter_dev)
        
        tokens = self.audio_adapter(features)
        latents = self.perceiver(tokens, mask)
        z_llm = self.projector(latents)
        return z_llm.mean(dim=1)

    def encode_text_raw(self, texts: list[str]):
        # Find Qwen's device for embeddings
        qwen_dev = self.qwen_model.model.embed_tokens.weight.device
        
        enc = self.qwen_tokenizer(
            texts, padding=True, truncation=True, max_length=64, return_tensors="pt"
        ).to(qwen_dev)
        
        with torch.no_grad():
            token_embs = self.qwen_model.get_input_embeddings()(enc.input_ids)
            
        mask = enc.attention_mask.unsqueeze(-1)
        sum_embs = (token_embs * mask).sum(dim=1)
        count = mask.sum(dim=1).clamp(min=1)
        return sum_embs / count

    @torch.no_grad()
    def generate(self, modality_feats, modality_type, prompt_text, max_new_tokens=50):
        """
        Generates text from an image or audio feature input.
        """
        # 1. Project Modality (on adapter device)
        if modality_type == "vision":
            adapter = self.vision_adapter
        elif modality_type == "audio":
            adapter = self.audio_adapter
            
        adapter_dev = next(adapter.parameters()).device
        modality_feats = modality_feats.to(adapter_dev)
        
        # Create mask (B=1)
        mask = torch.ones(1, modality_feats.shape[1], dtype=torch.bool, device=adapter_dev)
        
        tokens = adapter(modality_feats)
        latents = self.perceiver(tokens, mask)
        inputs_embeds_modality = self.projector(latents) # (1, 64, 3584)

        # 2. Embed Text (on Qwen device)
        qwen_dev = self.qwen_model.model.embed_tokens.weight.device
        qwen_dtype = self.qwen_model.model.embed_tokens.weight.dtype
        
        # Move modality embeds to Qwen device for concatenation
        inputs_embeds_modality = inputs_embeds_modality.to(qwen_dev)
        
        text_inputs = self.qwen_tokenizer(prompt_text, return_tensors="pt").to(qwen_dev)
        inputs_embeds_text = self.qwen_model.get_input_embeddings()(text_inputs.input_ids)

        # 3. Concatenate
        final_embeds = torch.cat([inputs_embeds_modality, inputs_embeds_text], dim=1)
        # CRITICAL FIX: Force the combined embeddings to match Qwen's weight dtype (BF16)
        if final_embeds.dtype != qwen_dtype:
            final_embeds = final_embeds.to(qwen_dtype)
            
        # 4. Generate
        out = self.qwen_model.generate(
            inputs_embeds=final_embeds,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            pad_token_id=self.qwen_tokenizer.eos_token_id
        )
        
        return self.qwen_tokenizer.batch_decode(out, skip_special_tokens=True)[0]


In [4]:

# ============================================================
# 5. Load Model Function
# ============================================================

def load_full_model(cfg_obj):
    print("\n--- Loading Full Inference System ---")
    print("1. Loading Frozen Qwen (Auto Device Map)...")
    tokenizer = AutoTokenizer.from_pretrained(cfg_obj.llm_model_name, use_fast=True)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
        
    # Load Qwen with device_map="auto" -> Allows splitting/offloading
    qwen = AutoModelForCausalLM.from_pretrained(
        cfg_obj.llm_model_name, 
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
        device_map="auto",
        trust_remote_code=True
    ).eval()

    print("2. Initializing Adapters...")
    v_adapt = ModalityAdapter(cfg_obj.encoder_dim_vision, cfg_obj.perceiver_dim)
    a_adapt = ModalityAdapter(cfg_obj.encoder_dim_audio, cfg_obj.perceiver_dim)
    
    perc = PerceiverResampler(
        dim=cfg_obj.perceiver_dim, 
        num_latents=cfg_obj.num_latents,
        num_layers=cfg_obj.num_perceiver_layers,
        num_heads=cfg_obj.num_attn_heads,
        mlp_ratio=cfg_obj.mlp_ratio
    )
    
    proj = nn.Linear(cfg_obj.perceiver_dim, cfg_obj.llm_hidden_size)

    # Load Weights
    print(f"3. Loading adapter weights from {cfg_obj.ckpt_path}...")
    try:
        ckpt = torch.load(cfg_obj.ckpt_path, map_location="cpu")
        v_adapt.load_state_dict(ckpt["vision_adapter"])
        a_adapt.load_state_dict(ckpt["audio_adapter"])
        perc.load_state_dict(ckpt["perceiver"])
        proj.load_state_dict(ckpt["projector"])
        print("✅ Weights loaded successfully.")
    except FileNotFoundError:
        print("❌ Checkpoint not found! Using random init (expect garbage results).")

    # MOVE ADAPTERS TO DEVICE (Manual)
    # Qwen handles its own split via device_map="auto"
    # We assume adapters fit on the main execution device (cuda:0 or cpu)
    adapter_device = "cuda" if torch.cuda.is_available() else "cpu"
    v_adapt.to(adapter_device).eval()
    a_adapt.to(adapter_device).eval()
    perc.to(adapter_device).eval()
    proj.to(adapter_device).eval()

    # Create Wrapper (Do NOT call .to() on this wrapper!)
    model = AlignedModel(v_adapt, a_adapt, perc, proj, qwen, tokenizer)
    return model


In [6]:
eval_model = load_full_model(cfg)


--- Loading Full Inference System ---
1. Loading Frozen Qwen (Auto Device Map)...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

2. Initializing Adapters...
3. Loading adapter weights from runs_perceiver_mrl_qwen/multimodal_adapter_poc_1.pt...
✅ Weights loaded successfully.


In [None]:
def compute_embeddings(modality_fn, items):
    all_embeds = []
    with torch.no_grad():
        for x in tqdm(items):
            e = modality_fn(x)
            all_embeds.append(e.cpu().numpy())
    return np.vstack(all_embeds)


In [None]:
def retrieval_rank(query_embed, gallery_embed):
    sims = cosine_similarity(query_embed[None, :], gallery_embed)[0]
    sorted_idx = np.argsort(-sims)  # descending
    return sorted_idx


In [None]:
def compute_recall_at_k(query_embeds, gallery_embeds, true_indices, k=1):
    hits = 0
    for i in range(len(query_embeds)):
        idxs = retrieval_rank(query_embeds[i], gallery_embeds)[:k]
        if true_indices[i] in idxs:
            hits += 1
    return hits / len(query_embeds)


In [None]:
def linear_cka(X, Y):
    X = X - X.mean(0, keepdims=True)
    Y = Y - Y.mean(0, keepdims=True)
    numerator = np.linalg.norm(X.T @ Y, ord='fro') ** 2
    denom = np.linalg.norm(X.T @ X, ord='fro') * np.linalg.norm(Y.T @ Y, ord='fro')
    return numerator / denom


In [None]:
def embedding_stats(E, name=""):
    norms = np.linalg.norm(E, axis=1)
    print(f"\n{name} Embedding stats:")
    print("  mean norm:", norms.mean())
    print("  std:", norms.std())
    print("  min:", norms.min())
    print("  max:", norms.max())
