## Main Notebook -> For Running Alignment (Stage-1 Training)

This notebook is used for alignment training of the models

In [6]:
%load_ext autoreload
%autoreload 2

In [7]:
import torch

In [8]:
from imports.configs.config import setup_from_yaml, ModelsConfig
# ðŸ”¹ Phase 0 â€“ Global setup
cfg = setup_from_yaml("imports/configs/config.yaml")

device = cfg.torch_device
dtype = cfg.torch_dtype
audio_train_loader = True
print("Using device:", device)
print("Using dtype:", dtype)


[34m[1mwandb[0m: Currently logged in as: [33mvedaangchopra[0m ([33mvedaangchopra_gatech[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[Config] Device: mps, dtype: torch.float16
[Config] root_dir: /Users/vedaangchopra/all_data/complete_technical_work/all_projects_implemented/Edge Assistant/code_base/v1_code_base/edge_glass
[Config] features_dir: /Users/vedaangchopra/all_data/complete_technical_work/all_projects_implemented/Edge Assistant/code_base/v1_code_base/edge_glass/features
Using device: mps
Using dtype: torch.float16


### Step-1: - Load the Encoded Dataset -> For Alignment

In [9]:
from imports.dataset import PixmoFeatureDataset, LibriSpeechFeatureDataset
from imports.dataset import LibriSpeechFeatureDataset
from imports.dataset import collate_alignment

In [10]:
from functools import partial

# assuming cfg is already loaded and has cfg.datasets
device = cfg.torch_device


In [11]:
# Vision
pixmo_train_ds = PixmoFeatureDataset(cfg.datasets.pixmo_train_path)
pixmo_val_ds   = PixmoFeatureDataset(cfg.datasets.pixmo_val_path)

# Audio (optional)
audio_train_ds = None
if cfg.datasets.use_librispeech and cfg.datasets.librispeech_train_path is not None:
    audio_train_ds = LibriSpeechFeatureDataset(cfg.datasets.librispeech_train_path)


#### Step-1: - Load the Pixmo-Cap Dataset

In [12]:
pixmo_train_ds = PixmoFeatureDataset("./data/data/pixmo/train_index.json")
pixmo_val_ds   = PixmoFeatureDataset("./data/data/pixmo/val_index.json")

print("PixMo train:", len(pixmo_train_ds))
print("PixMo val:", len(pixmo_val_ds))


PixMo train: 873
PixMo val: 89


In [13]:
# Inspect one example
ex = pixmo_train_ds[0]
print(ex["features"].shape, ex["text"][:80])

torch.Size([256, 1536]) In this meme, a close-up photograph of a bald-headed Black man is prominently fe


#### Step-2: - Load the Audio Dataset

In [14]:
audio_train_ds = LibriSpeechFeatureDataset("./data/data/librispeech/train_index.json")

print("Loaded LibriSpeech feature dataset:", len(audio_train_ds))
ex = audio_train_ds[0]
print(ex["features"].shape, ex["text"][:100])


Loaded LibriSpeech feature dataset: 338
torch.Size([1500, 512]) IN THE EXERCISE OF THE EXECUTIVE POWER THE PRESIDENT OF THE UNITED STATES IS CONSTANTLY SUBJECT TO A


In [15]:
cfg

Config(paths=PathsConfig(root_dir='/Users/vedaangchopra/all_data/complete_technical_work/all_projects_implemented/Edge Assistant/code_base/v1_code_base/edge_glass', features_dir='/Users/vedaangchopra/all_data/complete_technical_work/all_projects_implemented/Edge Assistant/code_base/v1_code_base/edge_glass/features'), models=ModelsConfig(vision_model_name='openai/clip-vit-base-patch32', llm_model_name='qwen/Qwen2.5-3B-Instruct', audio_model_name='openai/whisper-base'), architecture=ArchitectureConfig(perceiver_dim=None, num_latents=64, num_perceiver_layers=4, num_attn_heads=8, mlp_ratio=4.0), training=TrainingConfig(batch_size=16, num_epochs=5, learning_rate='3e-4', weight_decay=0.01, warmup_steps=500, max_grad_norm=1.0, log_every_steps=50, train_subset_size=2000, val_subset_size=500), mrl=MRLConfig(mrl_dims=[1024, 512, 256], mrl_weight=1.0, mrl_temp=0.07), misc=MiscConfig(dtype='float16', seed=42, device='auto', use_wandb=True, wandb_project='edge_glass', wandb_run_name='phase0_global_

### Build the dataset

In [16]:
from transformers import AutoTokenizer

# Load LLM tokenizer based on config.yaml
llm_tokenizer = AutoTokenizer.from_pretrained(
    cfg.models.llm_model_name,
    trust_remote_code=True
)

llm_tokenizer.padding_side = "right"
llm_tokenizer.truncation_side = "right"

print("Loaded tokenizer:", cfg.models.llm_model_name)


Loaded tokenizer: qwen/Qwen2.5-3B-Instruct


In [17]:
from torch.utils.data import DataLoader
from functools import partial

In [18]:
device = cfg.torch_device
tokenizer = llm_tokenizer   # previously loaded

In [19]:
alignment_collate = partial(collate_alignment, tokenizer=llm_tokenizer)

In [20]:
# Choose workers & pin_memory based on device
if device.type == "cuda":
    num_workers = 4       # you can bump this if GPU + good CPU
    pin_memory = True
else:
    # macOS (CPU or MPS) and generic CPU: be conservative
    num_workers = 0       # safest in Jupyter / macOS
    pin_memory = False

print(f"DataLoader config â†’ num_workers={num_workers}, pin_memory={pin_memory}")



DataLoader config â†’ num_workers=0, pin_memory=False


In [21]:
pixmo_train_loader = DataLoader(
    pixmo_train_ds,
    batch_size=cfg.training.batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin_memory,
    collate_fn=alignment_collate,
)

pixmo_val_loader = DataLoader(
    pixmo_val_ds,
    batch_size=cfg.training.batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
    collate_fn=alignment_collate,
)

if audio_train_ds is not None:
    audio_train_loader = DataLoader(
        audio_train_ds,
        batch_size=cfg.training.batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=alignment_collate,
    )
else:
    audio_train_loader = None

In [22]:
batch_v = next(iter(pixmo_train_loader))
print("Vision batch encoder_feats:", batch_v["features"].shape)
print("Vision batch tokens:", batch_v["input_ids"].shape)
print("Vision modalities:", batch_v["modality_ids"])


Vision batch encoder_feats: torch.Size([16, 256, 1536])
Vision batch tokens: torch.Size([16, 322])
Vision modalities: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


In [23]:
if audio_train_loader:
    batch_a = next(iter(audio_train_loader))
    print("Audio batch encoder_feats:", batch_a["features"].shape)
    print("Audio batch tokens:", batch_a["input_ids"].shape)
    print("Audio modalities:", batch_a["modality_ids"])

Audio batch encoder_feats: torch.Size([16, 1500, 512])
Audio batch tokens: torch.Size([16, 54])
Audio modalities: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


### Step-2:- Load the Architecture(Encoders + Alignment Layers)

In [24]:
from imports.encoders import VisionEncoder, AudioEncoder
from imports.perceiver import PerceiverLatentEncoder, ProjectorMLP
import torch, torch.nn as nn

In [25]:
# 1) Load frozen encoders
vision_enc = VisionEncoder(
    model_name="facebook/dinov2-base",
    device=device,
    dtype=dtype,
)

# Use a real image if available
try:
    img1 = Image.open("/mnt/data/sample1.jpg")
    img2 = Image.open("/mnt/data/sample2.jpg")
    images = [img1, img2]
except:
    # fallback: create dummy RGB images of size 224x224
    import numpy as np
    from PIL import Image

    dummy = (np.random.rand(224,224,3) * 255).astype('uint8')
    images = [Image.fromarray(dummy), Image.fromarray(dummy)]
    
    
vision_out = vision_enc.encode_images(images)
vision_feats, vision_mask = vision_out["feats"], vision_out["mask"]

#### Test the encoders with random image and audio

In [26]:
audio_enc = AudioEncoder(
    model_name="openai/whisper-base",
    device=device,
    dtype=torch.float16,
)

# 1) (B, T) tensor
B, T = 2, 16000 * 3
waveforms = torch.randn(B, T)
sr = 16000

In [27]:
out = audio_enc.encode_waveforms(waveforms, sample_rates=sr)
print("Case 1 feats:", out["feats"].shape, "mask:", out["mask"].shape)

Case 1 feats: torch.Size([2, 1500, 512]) mask: torch.Size([2, 1500])


In [28]:
# 2) (B, 1, T) tensor
waveforms_3d = waveforms.unsqueeze(1)
out2 = audio_enc.encode_waveforms(waveforms_3d, sample_rates=sr)
print("Case 2 feats:", out2["feats"].shape, "mask:", out2["mask"].shape)

# 3) list[Tensor] with slightly different shapes (simulating variable length)
waveforms_list = [torch.randn(T), torch.randn(T // 2)]
out3 = audio_enc.encode_waveforms(waveforms_list, sample_rates=sr)
print("Case 3 feats:", out3["feats"].shape, "mask:", out3["mask"].shape)

Case 2 feats: torch.Size([2, 1500, 512]) mask: torch.Size([2, 1500])
Case 3 feats: torch.Size([2, 1500, 512]) mask: torch.Size([2, 1500])


#### Load the multimodal model

In [29]:
from imports.model import MultiModalAlignmentModel  # or from current cell
from PIL import Image
import torch
from torchviz import make_dot

In [30]:
model = MultiModalAlignmentModel(
    d_shared=512,
    d_latent=512,
    d_align=1024,
    num_latents=32,   # smaller for viz
    num_layers=2,
    num_heads=4,
    use_perceiver=True,
    dtype=torch.float32,
    device=device,
)
model.eval()


MultiModalAlignmentModel(
  (vision_encoder): VisionEncoder(
    (model): Dinov2Model(
      (embeddings): Dinov2Embeddings(
        (patch_embeddings): Dinov2PatchEmbeddings(
          (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): Dinov2Encoder(
        (layer): ModuleList(
          (0-11): 12 x Dinov2Layer(
            (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
            (attention): Dinov2Attention(
              (attention): Dinov2SelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
              )
              (output): Dinov2SelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace

In [31]:
def describe(name, t):
    if t is None:
        print(f"{name}: None")
    else:
        print(f"{name}: shape={tuple(t.shape)}, dtype={t.dtype}, "
              f"mean={t.mean().item():.4f}, std={t.std().item():.4f}")


In [32]:
B = 2

# Vision: dummy pixel values (already "preprocessed"-ish tensor)
dummy_images = torch.randn(B, 3, 224, 224, device=device)

# Audio: dummy raw waveforms (e.g. 3 seconds at 16kHz)
dummy_waveforms = torch.randn(B, 16000 * 3, device=device)
dummy_sr = 16000


In [33]:
with torch.no_grad():
    # ---- 1) Encoder ----
    enc_v = model.vision_encoder.encode_images(dummy_images)
    v_feats = enc_v["feats"].to(device)   # (B, T_v, D_v)
    v_mask  = enc_v["mask"]              # (B, T_v) bool

    describe("vision/encoder_feats", v_feats)
    describe("vision/mask", v_mask.float())

    # ---- 2) Adapter to shared dim ----
    model._ensure_vision_adapter(v_feats.size(-1))
    v_shared = model.vision_adapter(v_feats)   # (B, T_v, d_shared)
    describe("vision/shared_feats (after adapter)", v_shared)

    # ---- 3) Perceiver latents ----
    if model.use_perceiver:
        v_latents = model.perceiver(v_shared, encoder_mask=v_mask)  # (B, L, d_latent)
        describe("vision/perceiver_latents", v_latents)

        # ---- 4) Projector (token-level) ----
        v_tokens = model.projector(v_latents)                        # (B, L, d_align)
        describe("vision/tokens_after_projector", v_tokens)

        # ---- 5) Pooled alignment embedding ----
        v_pooled = v_tokens.mean(dim=1)                              # (B, d_align)
        describe("vision/pooled", v_pooled)
    else:
        print("Perceiver disabled; using pooled-only path.")
        v_pooled_in = model._pool_masked_mean(v_shared, v_mask)
        v_pooled = model.projector(v_pooled_in)
        v_latents = None
        v_tokens = None
        describe("vision/pooled", v_pooled)


vision/encoder_feats: shape=(2, 256, 1536), dtype=torch.float32, mean=-0.0202, std=1.1425
vision/mask: shape=(2, 256), dtype=torch.float32, mean=1.0000, std=0.0000
vision/shared_feats (after adapter): shape=(2, 256, 512), dtype=torch.float32, mean=-0.0095, std=0.6621
vision/perceiver_latents: shape=(2, 32, 512), dtype=torch.float32, mean=-0.0542, std=0.5422
vision/tokens_after_projector: shape=(2, 32, 1024), dtype=torch.float32, mean=0.0005, std=0.1944
vision/pooled: shape=(2, 1024), dtype=torch.float32, mean=0.0005, std=0.1938


In [34]:
with torch.no_grad():
    # ---- 1) Encoder ----
    enc_a = model.audio_encoder.encode_waveforms(dummy_waveforms, sample_rates=dummy_sr)
    a_feats = enc_a["feats"].to(device)   # (B, T_a, D_a)
    a_mask  = enc_a["mask"]              # (B, T_a) bool

    describe("audio/encoder_feats", a_feats)
    describe("audio/mask", a_mask.float())

    # ---- 2) Adapter to shared dim ----
    model._ensure_audio_adapter(a_feats.size(-1))
    a_shared = model.audio_adapter(a_feats)   # (B, T_a, d_shared)
    describe("audio/shared_feats (after adapter)", a_shared)

    # ---- 3) Perceiver latents ----
    if model.use_perceiver:
        a_latents = model.perceiver(a_shared, encoder_mask=a_mask)  # (B, L, d_latent)
        describe("audio/perceiver_latents", a_latents)

        # ---- 4) Projector (token-level) ----
        a_tokens = model.projector(a_latents)                        # (B, L, d_align)
        describe("audio/tokens_after_projector", a_tokens)

        # ---- 5) Pooled alignment embedding ----
        a_pooled = a_tokens.mean(dim=1)                              # (B, d_align)
        describe("audio/pooled", a_pooled)
    else:
        print("Perceiver disabled; using pooled-only path.")
        a_pooled_in = model._pool_masked_mean(a_shared, a_mask)
        a_pooled = model.projector(a_pooled_in)
        a_latents = None
        a_tokens = None
        describe("audio/pooled", a_pooled)


audio/encoder_feats: shape=(2, 1500, 512), dtype=torch.float32, mean=-0.0190, std=1.3893
audio/mask: shape=(2, 1500), dtype=torch.float32, mean=1.0000, std=0.0000
audio/shared_feats (after adapter): shape=(2, 1500, 512), dtype=torch.float32, mean=-0.0087, std=0.7757
audio/perceiver_latents: shape=(2, 32, 512), dtype=torch.float32, mean=-0.0570, std=0.5787
audio/tokens_after_projector: shape=(2, 32, 1024), dtype=torch.float32, mean=-0.0038, std=0.1963
audio/pooled: shape=(2, 1024), dtype=torch.float32, mean=-0.0038, std=0.1962


In [35]:
with torch.no_grad():
    v_enc_api = model.encode_vision(dummy_images)
    a_enc_api = model.encode_audio(dummy_waveforms, dummy_sr)

describe("encode_vision()['tokens']", v_enc_api["tokens"])
describe("encode_vision()['pooled']", v_enc_api["pooled"])

describe("encode_audio()['tokens']", a_enc_api["tokens"])
describe("encode_audio()['pooled']", a_enc_api["pooled"])


encode_vision()['tokens']: shape=(2, 32, 1024), dtype=torch.float32, mean=0.0005, std=0.1944
encode_vision()['pooled']: shape=(2, 1024), dtype=torch.float32, mean=0.0005, std=0.1938
encode_audio()['tokens']: shape=(2, 32, 1024), dtype=torch.float32, mean=-0.0038, std=0.1963
encode_audio()['pooled']: shape=(2, 1024), dtype=torch.float32, mean=-0.0038, std=0.1962


### Plotting the Architecture

In [36]:
from imports.model import FullAlignmentGraphWrapper


wrapper = FullAlignmentGraphWrapper(model).to(device)
wrapper.eval()


FullAlignmentGraphWrapper(
  (core): MultiModalAlignmentModel(
    (vision_encoder): VisionEncoder(
      (model): Dinov2Model(
        (embeddings): Dinov2Embeddings(
          (patch_embeddings): Dinov2PatchEmbeddings(
            (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
          )
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (encoder): Dinov2Encoder(
          (layer): ModuleList(
            (0-11): 12 x Dinov2Layer(
              (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
              (attention): Dinov2Attention(
                (attention): Dinov2SelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                )
                (output): Dinov2SelfOutput(
                  (dense): Linear(in_features=768, o

In [37]:
# Vision: dummy pixel_values (B, 3, H, W)
dummy_pixel_values = torch.randn(2, 3, 224, 224, device=device)

# Audio: dummy raw waveforms (B, T)
dummy_waveforms = torch.randn(2, 16000 * 3, device=device)  # 3 seconds @ 16kHz


In [38]:
with torch.no_grad():
    z = wrapper(dummy_pixel_values, dummy_waveforms)

print("Output shape:", z.shape)


Output shape: torch.Size([2, 1024])


In [42]:
# Verification: print devices and dtypes of model params and inputs
print('=== Model parameter devices & dtypes (first 20) ===')
count = 0
for name, p in model.named_parameters():
    print(name, p.device, p.dtype)
    count += 1
    if count >= 20:
        break

# Check specific submodules
if hasattr(model, 'vision_encoder') and hasattr(model.vision_encoder, 'model'):
    try:
        vp = next(model.vision_encoder.model.parameters())
        print('vision_encoder.model param ->', vp.device, vp.dtype)
    except StopIteration:
        print('vision_encoder.model has no parameters')

if hasattr(model, 'audio_encoder') and hasattr(model.audio_encoder, 'model'):
    try:
        ap = next(model.audio_encoder.model.parameters())
        print('audio_encoder.model param ->', ap.device, ap.dtype)
    except StopIteration:
        print('audio_encoder.model has no parameters')

# Inputs
print('dummy_pixel_values ->', dummy_pixel_values.device, dummy_pixel_values.dtype)
print('dummy_waveforms ->', dummy_waveforms.device, dummy_waveforms.dtype)

# Also ensure wrapper/core moved to device
try:
    cp = next(wrapper.core.parameters())
    print('wrapper.core param ->', cp.device, cp.dtype)
except StopIteration:
    print('wrapper.core has no parameters')


=== Model parameter devices & dtypes (first 20) ===
vision_encoder.model.embeddings.cls_token cpu torch.float32
vision_encoder.model.embeddings.mask_token cpu torch.float32
vision_encoder.model.embeddings.position_embeddings cpu torch.float32
vision_encoder.model.embeddings.patch_embeddings.projection.weight cpu torch.float32
vision_encoder.model.embeddings.patch_embeddings.projection.bias cpu torch.float32
vision_encoder.model.encoder.layer.0.norm1.weight cpu torch.float32
vision_encoder.model.encoder.layer.0.norm1.bias cpu torch.float32
vision_encoder.model.encoder.layer.0.attention.attention.query.weight cpu torch.float32
vision_encoder.model.encoder.layer.0.attention.attention.query.bias cpu torch.float32
vision_encoder.model.encoder.layer.0.attention.attention.key.weight cpu torch.float32
vision_encoder.model.encoder.layer.0.attention.attention.key.bias cpu torch.float32
vision_encoder.model.encoder.layer.0.attention.attention.value.weight cpu torch.float32
vision_encoder.model.en

In [45]:
# from torchview import draw_graph

# graph = draw_graph(
#     wrapper,
#     input_data=(dummy_pixel_values, dummy_waveforms),
#     graph_name="EdgeGlassAlignmentFull",
#     expand_nested=True,
#     depth=3,                 # increase to 4 for more detail
#     save_graph=True,
#     directory="arch_plots",
#     filename="edgeglass_alignment_full_graph",
# )

# # graph.visual_graph

### Step-3: - Alignment Training