## Main Notebook -> For Running Alignment (Stage-1 Training)

This notebook is used for alignment training of the models

In [1]:
%load_ext autoreload
%autoreload 2

### Step-1: - Load the Encoded Dataset -> For Alignment

In [2]:
import json
import torch
from torch.utils.data import Dataset
from pathlib import Path

#### Step-1: - Load the Pixmo-Cap Dataset

In [3]:
import json
import torch
from torch.utils.data import Dataset
from pathlib import Path

class PixmoFeatureDataset(Dataset):
    """
    Loads the pre-extracted image features.
    It also fixes old paths like 'data/pixmo/...' -> 'data/data/pixmo/...'
    """
    def __init__(self, index_file: str | Path):
        index_file = Path(index_file)
        with open(index_file, "r") as f:
            self.index = json.load(f)

        # Base dir where your index lives (e.g. ./data/data/pixmo)
        self.base_dir = index_file.parent

    def _fix_path(self, raw_path: str) -> Path:
        p = Path(raw_path)

        # If it's already absolute and exists, just return
        if p.is_absolute() and p.exists():
            return p

        # Common case: path stored as "data/pixmo/features/xxx.pt"
        # but actual is "data/data/pixmo/features/xxx.pt"
        s = str(p)

        if "data/pixmo" in s and not p.exists():
            s = s.replace("data/pixmo", "data/data/pixmo")
            p2 = Path(s)
            if p2.exists():
                return p2

        # Otherwise, try resolving relative to the index directory
        p3 = (self.base_dir / p.name)  # fallback: same dir, same filename
        if p3.exists():
            return p3

        # Last resort: just return the original; will raise if missing
        return p

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        meta = self.index[idx]
        raw_path = meta["file"]
        file_path = self._fix_path(raw_path)

        blob = torch.load(file_path)

        return {
            "features": blob["features"],          # (num_patches, feat_dim)
            "caption": blob["caption"],            # raw caption text
            "num_patches": meta["num_patches"],
            "orig_idx": meta["orig_idx"],
        }


In [4]:
pixmo_train_ds = PixmoFeatureDataset("./data/data/pixmo/train_index.json")
pixmo_val_ds   = PixmoFeatureDataset("./data/data/pixmo/val_index.json")

print("PixMo train:", len(pixmo_train_ds))
print("PixMo val:", len(pixmo_val_ds))


PixMo train: 26366
PixMo val: 4376


In [5]:
# Inspect one example
ex = pixmo_train_ds[0]
print(ex["features"].shape, ex["caption"][:80])

torch.Size([256, 1536]) The image depicts a pixelated, retro-style space shooter game set against a nigh


#### Step-2: - Load the Audio Dataset

In [6]:
import json
import torch
from torch.utils.data import Dataset
from pathlib import Path

class LibriSpeechFeatureDataset(Dataset):
    """
    Loads pre-extracted Whisper features.
    Fixes broken paths like 'data/librispeech/...' -> 'data/data/librispeech/...'
    just like PixmoFeatureDataset does.
    """

    def __init__(self, index_file: str | Path):
        index_file = Path(index_file)
        with open(index_file, "r") as f:
            self.index = json.load(f)

        # Base directory where index.json lives
        self.base_dir = index_file.parent

    def _fix_path(self, raw_path: str) -> Path:
        """
        Try multiple strategies to fix incorrect dataset paths.
        1. Use raw path if absolute + exists
        2. Fix common 'data/librispeech' → 'data/data/librispeech'
        3. Try rewriting relative to index dir
        """
        p = Path(raw_path)

        # 1. If fully absolute and exists → OK
        if p.is_absolute() and p.exists():
            return p

        s = str(p)

        # 2. Common mismatch:
        #    raw: "data/librispeech/features/train_feat_123.pt"
        #    actual: "data/data/librispeech/features/train_feat_123.pt"
        if "data/librispeech" in s and not p.exists():
            s2 = s.replace("data/librispeech", "data/data/librispeech")
            p2 = Path(s2)
            if p2.exists():
                return p2

        # 3. Fallback: resolve relative to the index directory
        #    (useful if someone moved the index folder)
        p3 = self.base_dir / p.name
        if p3.exists():
            return p3

        # ❌ Last fallback → just return original (torch.load will raise)
        return p

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        meta = self.index[idx]

        raw_path = meta["file"]
        file_path = self._fix_path(raw_path)

        blob = torch.load(file_path)

        return {
            "features": blob["features"],      # (T_enc, d_audio)
            "text": blob["text"],
            "duration": blob["duration"],
            "sampling_rate": blob["sampling_rate"],
            "orig_idx": blob["orig_idx"],
        }


In [7]:
audio_train_ds = LibriSpeechFeatureDataset("./data/data/librispeech/train_index.json")

print("Loaded LibriSpeech feature dataset:", len(audio_train_ds))

ex = audio_train_ds[0]
print(ex["features"].shape, ex["text"][:100])


Loaded LibriSpeech feature dataset: 8677
torch.Size([1500, 512]) CONCERNING THE DISEASE THAT HEROD FELL INTO AND THE SEDITION WHICH THE JEWS RAISED THEREUPON


### Step-2:- Load the Architecture(Encoders + Alignment Layers)

In [8]:
from imports.encoders import VisionEncoder, AudioEncoder
from imports.perceiver import PerceiverLatentEncoder, ProjectorMLP
import torch, torch.nn as nn



In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16  # or bfloat16 / float32 depending on your env

In [10]:
# 1) Load frozen encoders
vision_enc = VisionEncoder(
    model_name="facebook/dinov2-base",
    device=device,
    dtype=dtype,
)

# Use a real image if available
try:
    img1 = Image.open("/mnt/data/sample1.jpg")
    img2 = Image.open("/mnt/data/sample2.jpg")
    images = [img1, img2]
except:
    # fallback: create dummy RGB images of size 224x224
    import numpy as np
    from PIL import Image

    dummy = (np.random.rand(224,224,3) * 255).astype('uint8')
    images = [Image.fromarray(dummy), Image.fromarray(dummy)]
    
    
vision_out = vision_enc.encode_images(images)
vision_feats, vision_mask = vision_out["feats"], vision_out["mask"]

#### Test the encoders with random image and audio

In [11]:
audio_enc = AudioEncoder(
    model_name="openai/whisper-base",
    device=device,
    dtype=torch.float16,
)

# 1) (B, T) tensor
B, T = 2, 16000 * 3
waveforms = torch.randn(B, T)
sr = 16000

In [12]:
out = audio_enc.encode_waveforms(waveforms, sample_rates=sr)
print("Case 1 feats:", out["feats"].shape, "mask:", out["mask"].shape)

Case 1 feats: torch.Size([2, 1500, 512]) mask: torch.Size([2, 1500])


In [13]:
# 2) (B, 1, T) tensor
waveforms_3d = waveforms.unsqueeze(1)
out2 = audio_enc.encode_waveforms(waveforms_3d, sample_rates=sr)
print("Case 2 feats:", out2["feats"].shape, "mask:", out2["mask"].shape)

# 3) list[Tensor] with slightly different shapes (simulating variable length)
waveforms_list = [torch.randn(T), torch.randn(T // 2)]
out3 = audio_enc.encode_waveforms(waveforms_list, sample_rates=sr)
print("Case 3 feats:", out3["feats"].shape, "mask:", out3["mask"].shape)

Case 2 feats: torch.Size([2, 1500, 512]) mask: torch.Size([2, 1500])
Case 3 feats: torch.Size([2, 1500, 512]) mask: torch.Size([2, 1500])


#### Load the multimodal model

In [23]:
from imports.model import MultiModalAlignmentModel  # or from current cell
from PIL import Image
import torch
from torchviz import make_dot

In [37]:
model = MultiModalAlignmentModel(
    d_shared=512,
    d_latent=512,
    d_align=1024,
    num_latents=32,   # smaller for viz
    num_layers=2,
    num_heads=4,
    use_perceiver=True,
    dtype=torch.float32,
    device=device,
)


In [33]:
from torchview import draw_graph

dummy_img = (np.random.rand(224, 224, 3) * 255).astype("uint8")
img1 = Image.fromarray(dummy_img)
img2 = Image.fromarray(dummy_img)
images = [img1, img2]


dummy_wave = torch.randn(2, 16000)
dummy_sr = 16000

graph = draw_graph(
    model, 
    input_data=( [None]*2, dummy_wave, dummy_sr ),  # images=None placeholder
    expand_nested=True,
    save_graph=True,
    directory="arch_plots",
    filename="alignment_model_torchview",
)
graph.visual_graph

RuntimeError: Failed to run torchgraph see error message