# Multi-GPU Alignment Training (Jupyter Optimized)

This notebook implements a strictly isolated multi-GPU training loop compatible with `accelerate.notebook_launcher`. 

**RULE:** Do not move model loading or CUDA operations to the global scope. Everything must remain inside `training_function` to avoid `RuntimeError: Cannot re-initialize CUDA`.

In [1]:
import os
import math
import io
import random
from dataclasses import dataclass
from typing import Tuple, List, Dict, Any
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
from PIL import Image
import requests
from io import BytesIO
from tqdm.auto import tqdm

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from torchaudio import transforms as T_audio

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    CLIPVisionModel,
    CLIPImageProcessor,
    WhisperProcessor,
    WhisperModel,
    get_cosine_schedule_with_warmup
)
from accelerate import Accelerator, notebook_launcher
from accelerate.utils import set_seed
import wandb



In [2]:
# ============================================
# 1. Configuration (Pure Python, No CUDA init)
# ============================================

@dataclass
class Config:
    # --- Model names ---
    vision_model_name: str = "openai/clip-vit-base-patch32"
    audio_model_name: str = "openai/whisper-base"
    llm_model_name: str = "Qwen/Qwen2.5-7B-Instruct"

    # --- Dimensions ---
    encoder_dim_vision: int = 768     # CLIP-base dim
    encoder_dim_audio: int = 512      # Whisper-base dim
    llm_hidden_size: int = 3584       # Qwen 7B dim
    
    # --- Model Capacity ---
    perceiver_dim: int = 768
    num_latents: int = 64
    num_perceiver_layers: int = 6
    num_attn_heads: int = 8
    mlp_ratio: float = 4.0

    # --- Matryoshka loss ---
    mrl_dims: Tuple[int, ...] = (128, 256, 512, 768, 3584)
    mrl_temperature: float = 0.07

    # --- Training Dynamics ---
    batch_size_vision: int = 32 # Adjust based on GPU VRAM
    batch_size_audio: int = 32
    max_train_steps_vision: int = 200 # Short for demo
    max_train_steps_audio: int = 200
    learning_rate: float = 5e-4
    weight_decay: float = 0.01
    num_rounds: int = 1
    seed: int = 42
    
    # --- Data (Subsets for Demo) ---
    audio_samples: int = 200
    vision_samples: int = 200

cfg = Config()
print("Config loaded.")

Config loaded.


In [3]:
# ============================================
# 2. Module Definitions (Adapters & Perceiver)
# ============================================

class ModalityAdapter(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)

class FeedForward(nn.Module):
    def __init__(self, dim: int, mlp_ratio: float = 4.0):
        super().__init__()
        hidden_dim = int(dim * mlp_ratio)
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class PerceiverLayer(nn.Module):
    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.ln_latents_1 = nn.LayerNorm(dim)
        self.ln_tokens    = nn.LayerNorm(dim)
        self.ln_latents_2 = nn.LayerNorm(dim)
        self.ln_latents_3 = nn.LayerNorm(dim)
        self.mlp = FeedForward(dim, mlp_ratio=mlp_ratio)

    def forward(self, latents, tokens, token_mask=None):
        q = self.ln_latents_1(latents)
        kv = self.ln_tokens(tokens)
        key_padding_mask = ~token_mask.bool() if token_mask is not None else None

        attn_out, _ = self.cross_attn(q, kv, kv, key_padding_mask=key_padding_mask, need_weights=False)
        latents = latents + attn_out
        
        q2 = self.ln_latents_2(latents)
        self_attn_out, _ = self.self_attn(q2, q2, q2, need_weights=False)
        latents = latents + self_attn_out
        
        latents = latents + self.mlp(self.ln_latents_3(latents))
        return latents

class PerceiverResampler(nn.Module):
    def __init__(self, dim, num_latents, num_layers, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.dim = dim
        self.latents = nn.Parameter(torch.randn(num_latents, dim) / math.sqrt(dim))
        self.layers = nn.ModuleList([
            PerceiverLayer(dim, num_heads, mlp_ratio) for _ in range(num_layers)
        ])

    def forward(self, tokens, token_mask=None):
        B = tokens.shape[0]
        latents = self.latents.unsqueeze(0).expand(B, -1, -1)
        for layer in self.layers:
            latents = layer(latents, tokens, token_mask)
        return latents

In [4]:
# ============================================
# 3. Dataset Classes (Self-contained)
# ============================================

class PixmoVisionDataset(Dataset):
    def __init__(self, data_list, vision_model, vision_processor, device):
        self.data = data_list
        self.vision_model = vision_model
        self.vision_processor = vision_processor
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        ex = self.data[idx]
        # Handle image loading
        img = ex.get("image")
        if img is None:
            img = Image.open(Requests.get(ex["image_url"], stream=True).raw).convert("RGB")
        if not isinstance(img, Image.Image):
            img = img.convert("RGB")
            
        # Process
        # NOTE: In a real large-scale scenario, you might want to pre-compute features.
        # Here we compute on-the-fly, which is slower but simpler for the script.
        inputs = self.vision_processor(images=img, return_tensors="pt")
        pixel_values = inputs["pixel_values"].to(self.device)
        
        with torch.no_grad():
            out = self.vision_model(pixel_values=pixel_values)
            feats = out.last_hidden_state.squeeze(0).to("cpu") # Move to CPU to save GPU RAM in queue
            
        return {"features": feats, "text": ex["caption"]}

class SimpleAudioDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list # Expects pre-computed feature dicts

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_features_with_text(batch):
    feats = [b["features"] for b in batch]
    encoder_feats = pad_sequence(feats, batch_first=True)
    
    B, T_max, _ = encoder_feats.shape
    encoder_mask = torch.zeros(B, T_max, dtype=torch.bool)
    for i, f in enumerate(feats):
        encoder_mask[i, :f.shape[0]] = True
        
    texts = [b["text"] for b in batch]
    return {
        "encoder_feats": encoder_feats, 
        "encoder_mask": encoder_mask, 
        "texts": texts
    }

def matryoshka_contrastive_loss(z_mod, z_txt, trunc_dims):
    losses = []
    targets = torch.arange(z_mod.size(0), device=z_mod.device)
    for d in trunc_dims:
        zm = F.normalize(z_mod[:, :d], dim=-1)
        zt = F.normalize(z_txt[:, :d], dim=-1)
        logits = zm @ zt.T / cfg.mrl_temperature
        loss = 0.5 * (F.cross_entropy(logits, targets) + F.cross_entropy(logits.T, targets))
        losses.append(loss)
    return sum(losses) / len(losses)

In [5]:
# ============================================
# 4. THE TRAINING FUNCTION
# ============================================

def training_function():
    # --- A. Initialization ---
    accelerator = Accelerator(mixed_precision="bf16", log_with="wandb")
    device = accelerator.device
    set_seed(cfg.seed)

    if accelerator.is_main_process:
        print(f"Process {accelerator.process_index}: Initialized. Device: {device}")

    # --- B. Load Models (Local to Process) ---
    # 1. Vision
    vision_processor = CLIPImageProcessor.from_pretrained(cfg.vision_model_name)
    vision_model = CLIPVisionModel.from_pretrained(cfg.vision_model_name, torch_dtype=torch.bfloat16).to(device)
    vision_model.eval()

    # 2. Audio
    audio_processor = WhisperProcessor.from_pretrained(cfg.audio_model_name)
    audio_model = WhisperModel.from_pretrained(cfg.audio_model_name, torch_dtype=torch.float32).to(device)
    audio_model.eval()

    # 3. LLM (Qwen)
    qwen_tokenizer = AutoTokenizer.from_pretrained(cfg.llm_model_name, use_fast=True)
    if qwen_tokenizer.pad_token is None:
        qwen_tokenizer.pad_token = qwen_tokenizer.eos_token
    qwen_model = AutoModelForCausalLM.from_pretrained(
        cfg.llm_model_name, 
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2"
    ).to(device)
    qwen_model.eval()

    # 4. Trainable Modules
    vision_adapter = ModalityAdapter(cfg.encoder_dim_vision, cfg.perceiver_dim).to(device)
    audio_adapter = ModalityAdapter(cfg.encoder_dim_audio, cfg.perceiver_dim).to(device)
    perceiver = PerceiverResampler(
        cfg.perceiver_dim, cfg.num_latents, cfg.num_perceiver_layers, cfg.num_attn_heads
    ).to(device)
    projector = nn.Linear(cfg.perceiver_dim, cfg.llm_hidden_size).to(device)

    # --- C. Data Prep (Robust Loading) ---
    # To avoid API limits, main process loads list, broadcast isn't easy with lists in accel,
    # so we'll just load a small slice independently on each for this demo.
    
    # Vision Data
    if accelerator.is_main_process: print("Loading Vision Data...")
    pixmo_ds = load_dataset("allenai/pixmo-cap", split="train", streaming=True)
    pixmo_data = list(pixmo_ds.take(cfg.vision_samples))
    train_v_ds = PixmoVisionDataset(pixmo_data, vision_model, vision_processor, device)
    train_v_loader = DataLoader(train_v_ds, batch_size=cfg.batch_size_vision, collate_fn=collate_features_with_text)

    # Audio Data (Pre-process audio to features to save VRAM/Compute during loop)
    if accelerator.is_main_process: print("Loading Audio Data...")
    libri_ds = load_dataset("openslr/librispeech_asr", "all", split="train.clean.100", streaming=True)
    processed_audio = []
    
    # Helper for local processing
    def process_audio_entry(ex):
        wav_bytes = ex["audio"]["bytes"]
        wav, sr = librosa.load(io.BytesIO(wav_bytes), sr=16000)
        dur = len(wav)/sr
        if dur > 12.0: return None
        
        inputs = audio_processor(wav, sampling_rate=16000, return_tensors="pt")
        input_features = inputs["input_features"].to(device)
        with torch.no_grad():
            enc_out = audio_model.encoder(input_features)
            hidden = enc_out.last_hidden_state.squeeze(0)
            
        # Slice padding
        valid_frames = min(int(dur * 50), 1500)
        feats = hidden[:valid_frames, :].to("cpu")
        return {"features": feats, "text": ex["text"].lower().capitalize(), "duration": dur}

    count = 0
    for ex in libri_ds:
        if count >= cfg.audio_samples: break
        p = process_audio_entry(ex)
        if p: 
            processed_audio.append(p)
            count += 1
            
    train_a_ds = SimpleAudioDataset(processed_audio)
    train_a_loader = DataLoader(train_a_ds, batch_size=cfg.batch_size_audio, collate_fn=collate_features_with_text)

    # --- D. Optimizer & Prepare ---
    params = list(vision_adapter.parameters()) + list(audio_adapter.parameters()) + \
             list(perceiver.parameters()) + list(projector.parameters())
    
    optimizer = AdamW(params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
    
    total_steps = (len(train_v_loader) + len(train_a_loader)) * cfg.num_rounds
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(0.1*total_steps), total_steps)

    (vision_adapter, audio_adapter, perceiver, projector, optimizer, train_v_loader, train_a_loader, scheduler) = \
        accelerator.prepare(vision_adapter, audio_adapter, perceiver, projector, optimizer, train_v_loader, train_a_loader, scheduler)

    # --- E. Helper for Text Encoding (Local Qwen) ---
    def encode_text_local(texts):
        enc = qwen_tokenizer(texts, padding=True, truncation=True, max_length=64, return_tensors="pt").to(device)
        with torch.no_grad():
            emb = qwen_model.get_input_embeddings()(enc.input_ids)
        mask = enc.attention_mask.unsqueeze(-1)
        # Mean pool
        pooled = (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1)
        return pooled

    # --- F. Training Loop ---
    for round_idx in range(cfg.num_rounds):
        if accelerator.is_main_process:
            print(f"\n=== Round {round_idx+1} ===")

        # 1. Vision Loop
        for batch in tqdm(train_v_loader, disable=not accelerator.is_main_process, desc="Vision"):
            optimizer.zero_grad()
            
            tokens = vision_adapter(batch["encoder_feats"])
            latents = perceiver(tokens, batch["encoder_mask"])
            z_llm = projector(latents)
            h_mod = z_llm.mean(dim=1)
            
            h_txt = encode_text_local(batch["texts"])
            
            # Gather for Global Loss calculation
            h_mod_g = accelerator.gather(h_mod)
            h_txt_g = accelerator.gather(h_txt)
            
            loss = matryoshka_contrastive_loss(h_mod_g, h_txt_g, cfg.mrl_dims)
            accelerator.backward(loss)
            
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(params, 1.0)
            
            optimizer.step()
            scheduler.step()
            
            if accelerator.is_main_process:
                wandb.log({"vision_loss": loss.item()})

        # 2. Audio Loop
        for batch in tqdm(train_a_loader, disable=not accelerator.is_main_process, desc="Audio"):
            optimizer.zero_grad()
            
            tokens = audio_adapter(batch["encoder_feats"])
            latents = perceiver(tokens, batch["encoder_mask"])
            z_llm = projector(latents)
            h_mod = z_llm.mean(dim=1)
            
            h_txt = encode_text_local(batch["texts"])
            
            h_mod_g = accelerator.gather(h_mod)
            h_txt_g = accelerator.gather(h_txt)
            
            loss = matryoshka_contrastive_loss(h_mod_g, h_txt_g, cfg.mrl_dims)
            accelerator.backward(loss)
            
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(params, 1.0)

            optimizer.step()
            scheduler.step()
            
            if accelerator.is_main_process:
                wandb.log({"audio_loss": loss.item()})

    if accelerator.is_main_process:
        print("Training Complete. Saving models...")
        # Save logic here if needed (unwrapped models)
        # accelerator.save_state(cfg.save_dir)

In [6]:
# ============================================
# 5. LAUNCH
# ============================================

notebook_launcher(training_function, num_processes=torch.cuda.device_count())

Launching training on 2 CUDAs.


E1124 16:34:58.556000 451342 /storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py:742] failed (exitcode: 1) local_rank: 0 (pid: 453380) of fn: training_function (start_method: fork)
E1124 16:34:58.556000 451342 /storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py:742] Traceback (most recent call last):
E1124 16:34:58.556000 451342 /storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py:742]   File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 697, in _poll
E1124 16:34:58.556000 451342 /storage/ice1/1/0/vchopra37/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py:742

ChildFailedError: 
============================================================
training_function FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-11-24_16:34:58
  host      : atl1-1-03-017-23-0.pace.gatech.edu
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 453380)
  error_file: /home/hice1/vchopra37/scratch/models/tmp/torchelastic_1sp9p812/none_bc6ujcpc/attempt_0/0/error.json
  traceback : Traceback (most recent call last):
    File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
      return f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
    File "/home/hice1/vchopra37/scratch/models/tmp/ipykernel_451342/632741400.py", line 7, in training_function
      accelerator = Accelerator(mixed_precision="bf16", log_with="wandb")
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/accelerate/accelerator.py", line 461, in __init__
      self.state = AcceleratorState(
                   ^^^^^^^^^^^^^^^^^
    File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/accelerate/state.py", line 912, in __init__
      PartialState(cpu, **kwargs)
    File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/accelerate/state.py", line 301, in __init__
      self.set_device()
    File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/accelerate/state.py", line 838, in set_device
      device_module.set_device(self.device)
    File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/cuda/__init__.py", line 567, in set_device
      torch._C._cuda_setDevice(device)
    File "/home/hice1/vchopra37/scratch/projects/edge_glass/edge_glass_env/lib/python3.12/site-packages/torch/cuda/__init__.py", line 398, in _lazy_init
      raise RuntimeError(
  RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
  
============================================================