## Main Notebook -> For Running Alignment (Stage-1 Training)

This notebook is used for alignment training of the models

In [12]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from imports.configs.config import setup_from_yaml
from imports.dataset import PixmoFeatureDataset, LibriSpeechFeatureDataset, collate_alignment
from imports.perceiver import PerceiverLatentEncoder, ProjectorMLP
from imports.align_training.text_encoder import HFTextEncoderConfig, HFTextEncoder
from imports.align_training.steps import AlignmentModules, AlignmentConfig
from imports.align_training.training import build_alignment_optimizer, train_alignment

In [14]:
cfg = setup_from_yaml("imports/configs/config.yaml")  # uses your Config + YAML loader
device = cfg.torch_device
dtype = cfg.torch_dtype
print("Device:", device, "dtype:", dtype)

[Config] Device: cuda, dtype: torch.float32
[Config] root_dir: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v1_code_base/edge_glass
[Config] features_dir: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v1_code_base/edge_glass/features
Device: cuda dtype: torch.float32


### 6.2 Build datasets & dataloaders

In [15]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(cfg.models.llm_model_name)

# Vision (PixMo) features
pixmo_train = PixmoFeatureDataset(cfg.datasets.pixmo_train_index)
pixmo_val   = PixmoFeatureDataset(cfg.datasets.pixmo_val_index)

# Optional small subsets for quick testing
if cfg.training.train_subset_size and cfg.training.train_subset_size < len(pixmo_train):
    from torch.utils.data import Subset
    pixmo_train = Subset(pixmo_train, range(cfg.training.train_subset_size))

if cfg.training.val_subset_size and cfg.training.val_subset_size < len(pixmo_val):
    from torch.utils.data import Subset
    pixmo_val = Subset(pixmo_val, range(cfg.training.val_subset_size))


[PixmoFeatureDataset] Loaded 872 valid entries from 873 total.
[PixmoFeatureDataset] Loaded 80 valid entries from 89 total.


In [16]:
from imports.dataset import PixmoFeatureDataset

ds_test = PixmoFeatureDataset(cfg.datasets.pixmo_train_index)
len(ds_test)  # should be <= original, but all entries valid now

# sanity check a few random samples
for i in [0, 10, 100]:
    if i >= len(ds_test):
        break
    rec = ds_test.index[i]
    print(i, rec["resolved_path"])


[PixmoFeatureDataset] Loaded 872 valid entries from 873 total.
0 data/data/pixmo/features/train_feat_64.pt
10 data/data/pixmo/features/train_feat_77.pt
100 data/data/pixmo/features/train_feat_20.pt


In [17]:
# The code snippet you provided is setting up data loaders for training and validation data using the `DataLoader` class. These data loaders are used to load batches of data for training a machine learning model.
# Nvidia

vision_train_loader = DataLoader(
    pixmo_train,
    batch_size=cfg.training.batch_size,
    shuffle=True,
    num_workers=4,
    collate_fn=lambda b: collate_alignment(b, tokenizer),
)

vision_val_loader = DataLoader(
    pixmo_val,
    batch_size=cfg.training.batch_size,
    shuffle=False,
    num_workers=4,
    collate_fn=lambda b: collate_alignment(b, tokenizer),
)



In [18]:
# from functools import partial

# collate_fn = partial(collate_alignment, tokenizer=tokenizer)

# vision_train_loader = DataLoader(
#     pixmo_train,
#     batch_size=cfg.training.batch_size,
#     shuffle=True,
#     num_workers=0,  # <-- IMPORTANT for now on macOS
#     collate_fn=collate_fn,
# )

# vision_val_loader = DataLoader(
#     pixmo_val,
#     batch_size=cfg.training.batch_size,
#     shuffle=False,
#     num_workers=0,  # <-- same here
#     collate_fn=collate_fn,
# )




In [19]:
train_loaders = {
    "vision": vision_train_loader,
    # "audio": audio_train_loader,  # add later if you want
}
val_loaders = {
    "vision": vision_val_loader,
}

### 6.3 Build text encoder

In [20]:
txt_cfg = HFTextEncoderConfig(
    model_name=cfg.models.llm_model_name,
    max_length=128,
    trainable=False,  # Stage-1: keep frozen
)

text_encoder = HFTextEncoder(
    cfg=txt_cfg,
    device=device,
    dtype=dtype,
)

d_text = text_encoder.hidden_size
print("Text hidden size:", d_text)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Text hidden size: 2048


In [21]:
def text_embed_fn(texts: list[str], max_length: int) -> torch.Tensor:
    # We could override max_length by rebuilding text_encoder, but usually
    # HFTextEncoderConfig.max_length is enough, so we ignore this arg.
    with torch.no_grad():
        return text_encoder.encode(texts).to(device)


### 6.4 Build adapters, Perceiver, projector

In [22]:
# Peek one example to get feature dim
sample = pixmo_train[0] if not isinstance(pixmo_train, torch.utils.data.Subset) else pixmo_train.dataset[pixmo_train.indices[0]]
d_feat_v = sample["features"].shape[-1]
print("Vision feature dim:", d_feat_v)


Vision feature dim: 1536


In [23]:
d_perceiver = cfg.architecture.perceiver_dim or d_feat_v
print("Perceiver dim:", d_perceiver)


Perceiver dim: 1536


In [24]:
from imports.model import MultiModalAlignmentModel  # or from current cell
from PIL import Image
import torch
from torchviz import make_dot

In [25]:
model = MultiModalAlignmentModel(
    d_shared=512,
    d_latent=512,
    d_align=1024,
    num_latents=32,   # smaller for viz
    num_layers=2,
    num_heads=4,
    use_perceiver=True,
    dtype=torch.float32,
    device=device,
)
model.eval()


MultiModalAlignmentModel(
  (vision_encoder): VisionEncoder(
    (model): Dinov2Model(
      (embeddings): Dinov2Embeddings(
        (patch_embeddings): Dinov2PatchEmbeddings(
          (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): Dinov2Encoder(
        (layer): ModuleList(
          (0-11): 12 x Dinov2Layer(
            (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
            (attention): Dinov2Attention(
              (attention): Dinov2SelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
              )
              (output): Dinov2SelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace

In [None]:
# vision_adapter = nn.Linear(d_feat_v, d_perceiver).to(device=device, dtype=dtype)

# perceiver = PerceiverLatentEncoder(
#     num_latents=cfg.architecture.num_latents,
#     d_latent=d_perceiver,
#     d_input=d_perceiver,
#     num_layers=cfg.architecture.num_perceiver_layers,
#     num_heads=cfg.architecture.num_attn_heads,
#     mlp_ratio=cfg.architecture.mlp_ratio,
#     dropout=0.1,
# ).to(device=device, dtype=dtype)

# projector = ProjectorMLP(
#     d_in=d_perceiver,
#     d_out=d_text,
#     hidden_factor=2.0,
#     dropout=0.1,
# ).to(device=device, dtype=dtype)

# modules = AlignmentModules(
#     vision_adapter=vision_adapter,
#     audio_adapter=None,  # add audio later if needed
#     perceiver=perceiver,
#     projector=projector,
# )


### 6.5 Alignment config + optimizer

In [None]:
# MRL radii from config; default to full dim if None
mrl_dims = tuple(cfg.mrl.mrl_dims) if cfg.mrl.mrl_dims is not None else (d_text,)

align_cfg = AlignmentConfig(
    mrl_dims=mrl_dims,
    mrl_temperature=cfg.mrl.mrl_temp,
    max_text_length=64,  # arbitrary; text_encoder already truncates internally
)

optimizer = build_alignment_optimizer(
    modules=modules,
    learning_rate=float(cfg.training.learning_rate),
    weight_decay=float(cfg.training.weight_decay),
)
print("Trainable params:", sum(p.numel() for p in optimizer.param_groups[0]["params"]))


### 6.6 Run training

In [None]:
import wandb
from dataclasses import asdict
from typing import Dict

if cfg.misc.use_wandb:
    wandb.init(
        project=cfg.misc.wandb_project,
        name=cfg.misc.wandb_run_name,
        config=asdict(cfg),
    )


In [None]:
print("MRL dims:", cfg.mrl.mrl_dims)
print("Text dim:", d_text)

mean_epochs = cfg.training.num_epochs

def wb_log_fn(stats: Dict[str, float]):
    wandb.log(stats)


In [None]:
train_alignment(
    train_loaders=train_loaders,
    val_loaders=val_loaders,
    modules=modules,
    cfg=align_cfg,
    text_embed_fn=text_embed_fn,
    optimizer=optimizer,
    device=device,
    num_epochs=mean_epochs,
    log_every=cfg.training.log_every_steps,
    log_fn=wb_log_fn,
    modalities=("vision",),
)

In [None]:
# mean_epochs = cfg.training.num_epochs

# train_alignment(
#     train_loaders=train_loaders,
#     modules=modules,
#     cfg=align_cfg,
#     text_embed_fn=text_embed_fn,
#     optimizer=optimizer,
#     device=device,
#     num_epochs=mean_epochs,
#     log_every=cfg.training.log_every_steps,
#     log_fn=None,
#     modalities=("vision",),
# )


### Dataset Test

In [None]:
# from imports.dataset import PixmoFeatureDataset
# import torch

# def test_pixmo_dataset(index_path, num_samples=32):
#     print("=== Testing PixMo Dataset ===")
#     ds = PixmoFeatureDataset(index_path)
    
#     print(f"Total valid samples: {len(ds)}\n")
    
#     num_checked = min(num_samples, len(ds))
#     print(f"Checking first {num_checked} samples...\n")

#     for i in range(num_checked):
#         rec = ds[i]
#         feats = rec["features"]          # Tensor (T, D_feat)
#         caption = rec["text"]
#         mod = rec["modality"]
#         file = rec["file"]

#         print(f"--- Sample {i} ---")
#         print("file:", file)
#         print("caption:", caption)
#         print("modality:", mod)
#         print("features shape:", tuple(feats.shape))

#         # NaN / Inf checks
#         if not torch.isfinite(feats).all():
#             print("❌ NON-FINITE VALUES FOUND in features! NaNs:", 
#                   torch.isnan(feats).sum().item(),
#                   "Infs:", torch.isinf(feats).sum().item())
#         else:
#             print("✓ features finite")

#         # Basic statistics
#         print("feature mean/std:", feats.mean().item(), feats.std().item())
        
#         # Check if caption is empty
#         if caption is None or caption.strip() == "":
#             print("⚠️ Empty caption!")
#         else:
#             print("✓ caption OK")
        
#         print("----------------------------------")

#     print("\nDataset test completed.\n")


# # Run test
# test_pixmo_dataset(cfg.datasets.pixmo_train_index)


In [None]:
# def scan_entire_dataset(index_path):
#     ds = PixmoFeatureDataset(index_path)
    
#     nan_files = []
#     inf_files = []
    
#     for i in range(len(ds)):
#         rec = ds.index[i]
#         path = rec["resolved_path"]
#         blob = torch.load(path, map_location="cpu")
#         feats = blob["features"]
        
#         if torch.isnan(feats).any():
#             nan_files.append(path)
#         if torch.isinf(feats).any():
#             inf_files.append(path)

#     print("=== FULL SCAN RESULTS ===")
#     print("Files with NaNs:", len(nan_files))
#     print("Files with Infs:", len(inf_files))

#     if nan_files:
#         print("NaN-containing files:")
#         for f in nan_files[:15]:
#             print("  ", f)
#     if inf_files:
#         print("Inf-containing files:")
#         for f in inf_files[:15]:
#             print("  ", f)

#     return nan_files, inf_files

# nan_list, inf_list = scan_entire_dataset(cfg.datasets.pixmo_train_index)


In [None]:
from imports.dataset import PixmoFeatureDataset
import torch

def test_pixmo_dataset(index_path, num_samples=32):
    print("=== Testing PixMo Dataset ===")
    ds = PixmoFeatureDataset(index_path)
    
    print(f"Total valid samples (after path filtering): {len(ds)}\n")
    
    num_checked = min(num_samples, len(ds))
    print(f"Checking first {num_checked} samples...\n")

    for i in range(num_checked):
        ex = ds[i]
        feats = ex["features"]          # Tensor (T, D_feat)
        text  = ex["text"]
        file  = ex["file"]

        print(f"--- Sample {i} ---")
        print("file:", file)
        print("caption:", repr(text))
        print("features shape:", tuple(feats.shape))

        # 1) zero-length check
        if feats.shape[0] == 0:
            print("❌ zero-length features!")

        # 2) finite check
        finite = torch.isfinite(feats).all()
        print("finite:", bool(finite))

        # 3) basic stats
        print("mean/std:", feats.mean().item(), feats.std().item())
        print("----------------------------------")

    print("\nDataset test completed.\n")

test_pixmo_dataset(cfg.datasets.pixmo_train_index)


In [None]:
import torch

batch = next(iter(vision_train_loader))
print("Batch keys:", batch.keys())

# Move tensors to device
features = batch["features"].to(device)        # (B, T, D_feat)
feat_mask = batch["feature_mask"].to(device)   # (B, T)
texts    = batch["raw_text"]

print("features shape:", features.shape)
print("mask shape:", feat_mask.shape)
print("any zero-length seqs? lengths:", feat_mask.sum(dim=1))

with torch.no_grad():
    # 1) Adapter
    tokens = modules.vision_adapter(features)   # (B, T, d_perceiver)
    print("tokens finite:", bool(torch.isfinite(tokens).all()))
    print("tokens mean/std:",
          tokens.mean().item(), tokens.std().item())

    # 2) Perceiver
    latents = modules.perceiver(tokens, encoder_mask=feat_mask)  # (B, L, d_perceiver)
    print("latents finite:", bool(torch.isfinite(latents).all()))
    print("latents mean/std:",
          latents.mean().item(), latents.std().item())

    # 3) Projector
    lat_llm = modules.projector(latents)        # (B, L, d_text)
    print("lat_llm finite:", bool(torch.isfinite(lat_llm).all()))
    print("lat_llm mean/std:",
          lat_llm.mean().item(), lat_llm.std().item())

    # 4) Pooled vision embedding
    h_mod = lat_llm.mean(dim=1)                 # (B, d_text)
    print("h_mod finite:", bool(torch.isfinite(h_mod).all()))
    print("h_mod mean/std:",
          h_mod.mean().item(), h_mod.std().item())


In [None]:
import torch

# Grab one batch from the train loader
batch = next(iter(vision_train_loader))
features = batch["features"].to(device)        # (B, 256, 1536)
feat_mask = batch["feature_mask"].to(device)   # (B, 256)
texts    = batch["raw_text"]

print("=== Batch Sanity ===")
print("features shape:", features.shape)
print("mask shape:", feat_mask.shape)
print("seq lengths:", feat_mask.sum(dim=1))

print("features finite:", bool(torch.isfinite(features).all()))
print("features mean/std:",
      features.mean().item(), features.std().item())

# Now just test the adapter alone
with torch.no_grad():
    va = modules.vision_adapter
    print("\n=== Vision Adapter Params ===")
    for name, p in va.named_parameters():
        print(
            name, p.shape,
            "finite:", bool(torch.isfinite(p).all()),
            "mean/std:", p.mean().item(), p.std().item(),
        )

    print("\n=== Forward through vision_adapter ===")
    tokens = va(features)        # (B, 256, d_perceiver)
    print("tokens shape:", tokens.shape)
    print("tokens finite:", bool(torch.isfinite(tokens).all()))
    print("tokens mean/std:",
          tokens.mean().item(), tokens.std().item())
