
# 01 — Vision–Text Alignment on PixMo-Cap (Phase 1, Improved)

This notebook trains a **vision–text alignment model** on a **local PixMo-Cap Parquet subset**
(created by `00_build_alignment_datasets.ipynb`).

It incorporates the key Phase‑1 features from your original code:

- Uses the same **`AlignmentConfig` + `VisionTextAligner`** (CLIP-style model)
- Uses **Matryoshka (MRL) + CLIP** contrastive losses from `core.py`
- Adds a **warmup + cosine learning rate schedule** (Phase‑1 improvement)
- Uses **in‑memory image–text dataset** for fast training
- Evaluates **image ↔ text retrieval** on a validation split each epoch
- Saves **`last`** and **`best`** checkpoints into a common directory:
  - `artifacts/phase1_alignment/vision_text/`
- Optionally resumes from an existing checkpoint
- Logs training & validation metrics to **Weights & Biases (W&B)**
- Supports **multi‑GPU** via `torch.nn.DataParallel` (1–2 GPUs)

These checkpoints will be the **Phase‑1 encoder** artifacts used by Phase‑2 experiments
(normal LLM decoder, TRM decoder, MoE decoder).


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

import os
import math
import time
from dataclasses import asdict
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import LambdaLR

from datasets import load_dataset
import wandb

# Local modules
from imports.core import AlignmentConfig, VisionTextAligner, compute_retrieval_metrics, set_seed, get_device
from imports.in_memory_datasets import InMemoryImageTextDataset, collate_in_memory_images
from imports.train import save_checkpoint, load_checkpoint

# ---- Paths ----
PROJECT_ROOT = Path.cwd()
DATA_DIR = PROJECT_ROOT / "data" / "alignment_subsets"
ARTIFACTS_DIR = PROJECT_ROOT / "artifacts"
PHASE1_DIR = ARTIFACTS_DIR / "phase1_alignment"
VISION_TEXT_DIR = PHASE1_DIR / "vision_text"

VISION_TEXT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Project root      : {PROJECT_ROOT}")
print(f"Data dir          : {DATA_DIR}")
print(f"Phase 1 dir       : {PHASE1_DIR}")
print(f"Vision-text dir   : {VISION_TEXT_DIR}")




Project root      : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base
Data dir          : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/data/alignment_subsets
Phase 1 dir       : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/artifacts/phase1_alignment
Vision-text dir   : /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/v2_code_base/artifacts/phase1_alignment/vision_text


In [3]:

# =========================
# Dataset & training config
# =========================

class VTDataConfig:
    'Config for PixMo-Cap subset used in Phase-1 alignment.'
    # Path created by 00_build_alignment_datasets.ipynb
    pixmocap_parquet = DATA_DIR / 'pixmocap_train_subset_50000.parquet'  # adjust if needed

    # Subsampling and split
    max_train_samples = 40_000
    max_val_samples = 5_000
    val_ratio = 0.1

    # DataLoader
    num_workers = 8
    image_size = (224, 224)


class VTTrainConfig:
    'High-level training hyperparameters for Phase-1.'
    seed = 42
    batch_size = 32
    num_epochs = 5
    lr = 1e-4
    weight_decay = 0.01
    max_grad_norm = 1.0
    use_amp = True

    temperature = 0.07
    mrl_weight = 1.0
    clip_weight = 0.5

    # LR scheduler: warmup fraction of total steps
    warmup_fraction = 0.1

    # Logging
    use_wandb = True
    wandb_project = 'edgeglass_phase1'
    wandb_entity = None  # or your W&B username/org
    wandb_group = 'vision_text_alignment'
    wandb_run_name = 'vt_align_pixmocap_phase1'

    # Checkpointing / resume
    resume_from = None  # e.g., VISION_TEXT_DIR / 'vision_text_best.pt'
    save_every_epochs = 1


data_cfg = VTDataConfig()
train_cfg = VTTrainConfig()


In [4]:
print("VTDataConfig:")
for k, v in data_cfg.__dict__.items():
    print(f"  {k}: {v}")
print("\nVTTrainConfig:")
for k, v in train_cfg.__dict__.items():
    print(f"  {k}: {v}")


VTDataConfig:

VTTrainConfig:


In [6]:

# =========================
# AlignmentConfig (model)
# =========================

# You can adjust these to match your original Phase-1 experiments.
align_cfg = AlignmentConfig(
    # Keep your chosen encoders here (examples):
    # vision_model_name='openai/clip-vit-base-patch32',
    # text_model_name='sentence-transformers/all-MiniLM-L6-v2',
    batch_size=train_cfg.batch_size,
    num_epochs=train_cfg.num_epochs,
    learning_rate=train_cfg.lr,
    weight_decay=train_cfg.weight_decay,
    warmup_ratio=train_cfg.warmup_fraction,
    mrl_temperature=train_cfg.temperature,
    clip_temperature=train_cfg.temperature,
    # Temperature & MRL settings from your original experiments:
    mrl_dims=(64, 128, 256, 512),
    mrl_weight=train_cfg.mrl_weight,
    clip_weight=train_cfg.clip_weight,
)

print('AlignmentConfig:')
for k, v in asdict(align_cfg).items():
    print(f'  {k}: {v}')


TypeError: AlignmentConfig.__init__() got an unexpected keyword argument 'temperature'

In [None]:

# =========================
# Weights & Biases init
# =========================

run = None
if train_cfg.use_wandb:
    def _to_serializable(v):
        if isinstance(v, Path):
            return str(v)
        if 'torch' in globals() and (isinstance(v, torch.device) or isinstance(v, torch.dtype)):
            return str(v)
        return v

    wandb_kwargs = dict(
        project=train_cfg.wandb_project,
        name=train_cfg.wandb_run_name,
        group=train_cfg.wandb_group,
        config={
            'phase': 'phase1_alignment',
            'task': 'vision_text',
            'data_cfg': {k: _to_serializable(v) for k, v in data_cfg.__dict__.items()},
            'align_cfg': {k: _to_serializable(v) for k, v in asdict(align_cfg).items()},
            'train_cfg': {k: _to_serializable(v) for k, v in train_cfg.__dict__.items()},
        },
    )
    if train_cfg.wandb_entity is not None:
        wandb_kwargs['entity'] = train_cfg.wandb_entity

    run = wandb.init(**wandb_kwargs)
    print('✅ W&B run initialized:', run.name)
else:
    print('W&B logging disabled.')


In [None]:

# =========================
# Load PixMo-Cap subset from Parquet
# =========================

assert data_cfg.pixmocap_parquet.exists(), f"Parquet file not found: {data_cfg.pixmocap_parquet}"

print(f"📥 Loading PixMo-Cap subset from: {data_cfg.pixmocap_parquet}")
pixmo_ds = load_dataset(
    "parquet",
    data_files={"train": str(data_cfg.pixmocap_parquet)},
)["train"]

print(f"Total PixMo-Cap subset size: {len(pixmo_ds):,}")

# Train/val split
split = pixmo_ds.train_test_split(test_size=data_cfg.val_ratio, seed=train_cfg.seed)
train_ds = split["train"]
val_ds = split["test"]

if data_cfg.max_train_samples is not None and data_cfg.max_train_samples < len(train_ds):
    train_ds = train_ds.shuffle(seed=train_cfg.seed).select(range(data_cfg.max_train_samples))

if data_cfg.max_val_samples is not None and data_cfg.max_val_samples < len(val_ds):
    val_ds = val_ds.shuffle(seed=train_cfg.seed).select(range(data_cfg.max_val_samples))

print(f"Train split: {len(train_ds):,} samples")
print(f"Val split  : {len(val_ds):,} samples")


In [None]:

# =========================
# Build In-Memory Datasets & DataLoaders
# =========================

train_mem = InMemoryImageTextDataset(
    hf_dataset=train_ds,
    img_col="image_url",
    txt_col="caption",
    max_samples=None,  # already subselected
    image_size=data_cfg.image_size,
    num_workers=data_cfg.num_workers,
)

val_mem = InMemoryImageTextDataset(
    hf_dataset=val_ds,
    img_col="image_url",
    txt_col="caption",
    max_samples=None,
    image_size=data_cfg.image_size,
    num_workers=data_cfg.num_workers,
)

train_loader = DataLoader(
    train_mem,
    batch_size=train_cfg.batch_size,
    shuffle=True,
    num_workers=data_cfg.num_workers,
    collate_fn=collate_in_memory_images,
    pin_memory=True,
)

val_loader = DataLoader(
    val_mem,
    batch_size=train_cfg.batch_size,
    shuffle=False,
    num_workers=data_cfg.num_workers,
    collate_fn=collate_in_memory_images,
    pin_memory=True,
)

print("Train batches:", len(train_loader))
print("Val batches  :", len(val_loader))


In [None]:

# =========================
# Model, optimizer, scheduler, device
# =========================

set_seed(train_cfg.seed)
device = get_device()
align_cfg.device = device
align_cfg.dtype = torch.float16 if train_cfg.use_amp and device.type == 'cuda' else torch.float32
print('Using device:', device)
print('Using dtype:', align_cfg.dtype)

model = VisionTextAligner(align_cfg)
model.to(device)

# Multi-GPU (DataParallel) if 2 GPUs are visible
if torch.cuda.device_count() > 1:
    print(f'✅ Using DataParallel on {torch.cuda.device_count()} GPUs')
    model = nn.DataParallel(model)
else:
    print('Using single GPU or CPU.')

vt_model = model.module if isinstance(model, nn.DataParallel) else model

optimizer = AdamW(
    vt_model.get_trainable_params(),
    lr=train_cfg.lr,
    weight_decay=train_cfg.weight_decay,
)

# ----- Warmup + cosine scheduler -----
num_training_steps = train_cfg.num_epochs * len(train_loader)
warmup_steps = int(train_cfg.warmup_fraction * num_training_steps)

def lr_lambda(step: int):
    if step < warmup_steps:
        return float(step) / max(1, warmup_steps)
    progress = (step - warmup_steps) / max(1, num_training_steps - warmup_steps)
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = LambdaLR(optimizer, lr_lambda)

scaler = GradScaler(enabled=train_cfg.use_amp)

start_epoch = 0
global_step = 0

# Optionally resume
if train_cfg.resume_from is not None and Path(train_cfg.resume_from).exists():
    print(f'🔄 Resuming from checkpoint: {train_cfg.resume_from}')
    start_epoch = load_checkpoint(train_cfg.resume_from, vt_model, optimizer=optimizer, load_optimizer=True)
    # Note: scheduler/global_step are not restored here; adjust if you save them in the future.
else:
    print('No resume checkpoint specified.')


In [None]:

# =========================
# Validation / retrieval evaluation
# =========================

@torch.no_grad()
def evaluate_retrieval(model, loader, device):
    model.eval()
    vt_model = model.module if isinstance(model, nn.DataParallel) else model

    all_vision = []
    all_text = []

    for batch in loader:
        images = batch["images"]
        captions = batch["captions"]

        z_v = vt_model.encode_vision(images)
        z_t = vt_model.encode_text(captions)

        all_vision.append(z_v.cpu())
        all_text.append(z_t.cpu())

    z_v_all = torch.cat(all_vision, dim=0)
    z_t_all = torch.cat(all_text, dim=0)

    metrics_i2t = compute_retrieval_metrics(z_v_all, z_t_all)
    metrics_t2i = compute_retrieval_metrics(z_t_all, z_v_all)

    metrics = {}
    for k, v in metrics_i2t.items():
        metrics[f"i2t_{k}"] = v
    for k, v in metrics_t2i.items():
        metrics[f"t2i_{k}"] = v

    return metrics


In [None]:

# =========================
# Training loop (Phase-1)
# =========================

best_r1 = 0.0

print("Trainable parameter count:", vt_model.count_trainable_params())

for epoch in range(start_epoch, train_cfg.num_epochs):
    model.train()
    epoch_loss = 0.0
    epoch_mrl = 0.0
    epoch_clip = 0.0
    num_steps = 0

    t0 = time.time()

    for step, batch in enumerate(train_loader):
        images = batch["images"]
        captions = batch["captions"]

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type=device.type if device.type != "mps" else "cuda",
                      enabled=train_cfg.use_amp):
            outputs = model(images=images, texts=captions)
            loss = outputs["loss"]
            loss_mrl = outputs.get("loss_mrl", torch.tensor(0.0, device=device))
            loss_clip = outputs.get("loss_clip", torch.tensor(0.0, device=device))

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(vt_model.get_trainable_params(), train_cfg.max_grad_norm)
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()
        global_step += 1

        epoch_loss += loss.item()
        epoch_mrl += loss_mrl.item()
        epoch_clip += loss_clip.item()
        num_steps += 1

        if train_cfg.use_wandb and step % 10 == 0:
            wandb.log({
                "train/loss": loss.item(),
                "train/loss_mrl": loss_mrl.item(),
                "train/loss_clip": loss_clip.item(),
                "train/lr": optimizer.param_groups[0]["lr"],
                "train/epoch_progress": epoch + (step + 1) / len(train_loader),
                "global_step": global_step,
            })

        if step % 50 == 0:
            print(f"[Epoch {epoch+1}/{train_cfg.num_epochs}] "
                  f"Step {step}/{len(train_loader)} | "
                  f"Loss: {loss.item():.4f} | MRL: {loss_mrl.item():.4f} | CLIP: {loss_clip.item():.4f} | "
                  f"LR: {optimizer.param_groups[0]['lr']:.2e}")

    epoch_time = time.time() - t0
    avg_loss = epoch_loss / max(num_steps, 1)
    avg_mrl = epoch_mrl / max(num_steps, 1)
    avg_clip = epoch_clip / max(num_steps, 1)

    print(f"\n=== Epoch {epoch+1} finished in {epoch_time/60:.2f} min ===")
    print(f"Train avg loss: {avg_loss:.4f} | MRL: {avg_mrl:.4f} | CLIP: {avg_clip:.4f}")

    # Validation
    val_metrics = evaluate_retrieval(model, val_loader, device)
    print("Val retrieval metrics:")
    for k, v in val_metrics.items():
        print(f"  {k}: {v:.4f}")

    r1_mean = 0.5 * (val_metrics["i2t_r@1"] + val_metrics["t2i_r@1"])

    if train_cfg.use_wandb:
        wandb.log({
            "epoch": epoch + 1,
            "train/avg_loss": avg_loss,
            "train/avg_loss_mrl": avg_mrl,
            "train/avg_loss_clip": avg_clip,
            "val/i2t_r@1": val_metrics["i2t_r@1"],
            "val/i2t_r@5": val_metrics["i2t_r@5"],
            "val/i2t_r@10": val_metrics["i2t_r@10"],
            "val/t2i_r@1": val_metrics["t2i_r@1"],
            "val/t2i_r@5": val_metrics["t2i_r@5"],
            "val/t2i_r@10": val_metrics["t2i_r@10"],
            "val/r1_mean": r1_mean,
        })

    # Save last checkpoint
    save_checkpoint(
        vt_model,
        optimizer,
        epoch=epoch + 1,
        save_dir=str(VISION_TEXT_DIR),
        name="vision_text_last",
    )

    # Save best checkpoint
    if r1_mean > best_r1:
        best_r1 = r1_mean
        print(f"✅ New best mean R@1: {best_r1:.4f} — saving best checkpoint")
        save_checkpoint(
            vt_model,
            optimizer,
            epoch=epoch + 1,
            save_dir=str(VISION_TEXT_DIR),
            name="vision_text_best",
        )
    else:
        print(f"No improvement over best mean R@1: {best_r1:.4f}")

print("\nTraining complete.")
print(f"Best mean R@1 achieved: {best_r1:.4f}")

if run is not None:
    run.finish()
