## Main Notebook -> For Running Alignment (Stage-1 Training)

This notebook is used for alignment training of the models

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from imports.configs.config import setup_from_yaml
from imports.dataset import PixmoFeatureDataset, LibriSpeechFeatureDataset, collate_alignment
from imports.perceiver import PerceiverLatentEncoder, ProjectorMLP
from imports.align_training.text_encoder import HFTextEncoderConfig, HFTextEncoder
from imports.align_training.steps import AlignmentModules, AlignmentConfig
from imports.align_training.training import build_alignment_optimizer, train_alignment

In [None]:
cfg = setup_from_yaml("imports/configs/config.yaml")  # uses your Config + YAML loader
device = cfg.torch_device
dtype = cfg.torch_dtype
print("Device:", device, "dtype:", dtype)

### 6.2 Build datasets & dataloaders

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(cfg.models.llm_model_name)

# Vision (PixMo) features
pixmo_train = PixmoFeatureDataset(cfg.datasets.pixmo_train_index)
pixmo_val   = PixmoFeatureDataset(cfg.datasets.pixmo_val_index)

# Optional small subsets for quick testing
if cfg.training.train_subset_size and cfg.training.train_subset_size < len(pixmo_train):
    from torch.utils.data import Subset
    pixmo_train = Subset(pixmo_train, range(cfg.training.train_subset_size))

if cfg.training.val_subset_size and cfg.training.val_subset_size < len(pixmo_val):
    from torch.utils.data import Subset
    pixmo_val = Subset(pixmo_val, range(cfg.training.val_subset_size))


In [None]:
## Nvidia

vision_train_loader = DataLoader(
    pixmo_train,
    batch_size=cfg.training.batch_size,
    shuffle=True,
    num_workers=4,
    collate_fn=lambda b: collate_alignment(b, tokenizer),
)

vision_val_loader = DataLoader(
    pixmo_val,
    batch_size=cfg.training.batch_size,
    shuffle=False,
    num_workers=4,
    collate_fn=lambda b: collate_alignment(b, tokenizer),
)



In [None]:
# from functools import partial

# collate_fn = partial(collate_alignment, tokenizer=tokenizer)

# vision_train_loader = DataLoader(
#     pixmo_train,
#     batch_size=cfg.training.batch_size,
#     shuffle=True,
#     num_workers=0,  # <-- IMPORTANT for now on macOS
#     collate_fn=collate_fn,
# )

# vision_val_loader = DataLoader(
#     pixmo_val,
#     batch_size=cfg.training.batch_size,
#     shuffle=False,
#     num_workers=0,  # <-- same here
#     collate_fn=collate_fn,
# )




In [None]:
train_loaders = {
    "vision": vision_train_loader,
    # "audio": audio_train_loader,  # add later if you want
}


In [None]:
from imports.dataset import LibriSpeechFeatureDataset, collate_alignment
from functools import partial

# Only if you actually set librispeech_train_index in config.yaml
if cfg.datasets.use_librispeech and cfg.datasets.librispeech_train_index is not None:
    print("Loading LibriSpeech feature dataset from:", cfg.datasets.librispeech_train_index)
    audio_train = LibriSpeechFeatureDataset(cfg.datasets.librispeech_train_index)

    # Optional: subset for quick debugging
    if cfg.training.train_subset_size and cfg.training.train_subset_size < len(audio_train):
        from torch.utils.data import Subset
        audio_train = Subset(audio_train, range(cfg.training.train_subset_size))

    # Reuse the same collate fn as vision, with tokenizer bound
    audio_collate_fn = partial(collate_alignment, tokenizer=tokenizer)

    audio_train_loader = DataLoader(
        audio_train,
        batch_size=cfg.training.batch_size,
        shuffle=True,
        num_workers=0,             # IMPORTANT on macOS; avoids pickling issues
        collate_fn=audio_collate_fn,
    )

    train_loaders["audio"] = audio_train_loader
    print("Audio train samples:", len(audio_train))
else:
    print("LibriSpeech not enabled or librispeech_train_index missing in config.")


### 6.3 Build text encoder

In [None]:
txt_cfg = HFTextEncoderConfig(
    model_name=cfg.models.llm_model_name,
    max_length=128,
    trainable=False,  # Stage-1: keep frozen
)

text_encoder = HFTextEncoder(
    cfg=txt_cfg,
    device=device,
    dtype=dtype,
)

d_text = text_encoder.hidden_size
print("Text hidden size:", d_text)


In [None]:
def text_embed_fn(texts: list[str], max_length: int) -> torch.Tensor:
    # We could override max_length by rebuilding text_encoder, but usually
    # HFTextEncoderConfig.max_length is enough, so we ignore this arg.
    with torch.no_grad():
        return text_encoder.encode(texts).to(device)


### 6.4 Build adapters, Perceiver, projector

In [None]:
# Peek one example to get feature dim
sample = pixmo_train[0] if not isinstance(pixmo_train, torch.utils.data.Subset) else pixmo_train.dataset[pixmo_train.indices[0]]
d_feat_v = sample["features"].shape[-1]
print("Vision feature dim:", d_feat_v)


In [None]:
sample = audio_train[0] if not isinstance(audio_train, torch.utils.data.Subset) else audio_train.dataset[audio_train.indices[0]]
d_feat_a = sample["features"].shape[-1]
print("Audio feature dim:", d_feat_a)

In [None]:
d_perceiver = cfg.architecture.perceiver_dim or d_feat_v
print("Perceiver dim:", d_perceiver)


In [None]:
vision_adapter = nn.Linear(d_feat_v, d_perceiver).to(device=device, dtype=dtype)
audio_adapter = nn.Linear(d_feat_a, d_perceiver).to(device=device, dtype=dtype)

perceiver = PerceiverLatentEncoder(
    num_latents=cfg.architecture.num_latents,
    d_latent=d_perceiver,
    d_input=d_perceiver,
    num_layers=cfg.architecture.num_perceiver_layers,
    num_heads=cfg.architecture.num_attn_heads,
    mlp_ratio=cfg.architecture.mlp_ratio,
    dropout=0.1,
).to(device=device, dtype=dtype)

projector = ProjectorMLP(
    d_in=d_perceiver,
    d_out=d_text,
    hidden_factor=2.0,
    dropout=0.1,
).to(device=device, dtype=dtype)

modules = AlignmentModules(
    vision_adapter=vision_adapter,
    audio_adapter=None,  # add audio later if needed
    perceiver=perceiver,
    projector=projector,
)


In [None]:
vision_adapter = nn.Linear(d_feat_v, d_perceiver).to(device=device, dtype=dtype)
perceiver = PerceiverLatentEncoder(...)
projector = ProjectorMLP(...)
modules = AlignmentModules(
    vision_adapter=vision_adapter,
    audio_adapter=None,
    perceiver=perceiver,
    projector=projector,
)


### 6.5 Alignment config + optimizer

In [None]:
# MRL radii from config; default to full dim if None
mrl_dims = tuple(cfg.mrl.mrl_dims) if cfg.mrl.mrl_dims is not None else (d_text,)

align_cfg = AlignmentConfig(
    mrl_dims=mrl_dims,
    mrl_temperature=cfg.mrl.mrl_temp,
    max_text_length=64,  # arbitrary; text_encoder already truncates internally
)

optimizer = build_alignment_optimizer(
    modules=modules,
    learning_rate=float(cfg.training.learning_rate),
    weight_decay=float(cfg.training.weight_decay),
)
print("Trainable params:", sum(p.numel() for p in optimizer.param_groups[0]["params"]))


### 6.6 Run training

In [None]:
# mean_epochs = cfg.training.num_epochs

# train_alignment(
#     train_loaders=train_loaders,
#     modules=modules,
#     cfg=align_cfg,
#     text_embed_fn=text_embed_fn,
#     optimizer=optimizer,
#     device=device,
#     num_epochs=mean_epochs,
#     log_every=cfg.training.log_every_steps,
#     log_fn=None,   # or pass wandb.log
#     modalities=("vision",),  # add "audio" when you wire audio
# )


In [None]:
mean_epochs = cfg.training.num_epochs

train_alignment(
    train_loaders=train_loaders,
    modules=modules,
    cfg=align_cfg,
    text_embed_fn=text_embed_fn,
    optimizer=optimizer,
    device=device,
    num_epochs=mean_epochs,
    log_every=cfg.training.log_every_steps,
    log_fn=None,
    modalities=("vision",),
)
