In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import torch

### Phase-0: - Load Configurations

In [3]:
from configs.config import setup_from_yaml, ModelsConfig

In [None]:

# ðŸ”¹ Phase 0 â€“ Global setup
cfg = setup_from_yaml("configs/config.yaml")

device = cfg.torch_device
dtype = cfg.torch_dtype

print("Using device:", device)
print("Using dtype:", dtype)


[34m[1mwandb[0m: Currently logged in as: [33mvedaangchopra[0m ([33mvedaangchopra_gatech[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[Config] Device: cuda, dtype: torch.float16
[Config] root_dir: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/edge_glass
[Config] features_dir: /storage/ice1/1/0/vchopra37/projects/edge_glass/code_base/edge_glass/features
Using device: cuda
Using dtype: torch.float16


In [5]:
# Later, when you build encoders / perceiver:
# from models_oop import ModelConfig, ImageEncoder, AudioEncoder, DecoderLLM

# vision_model_cfg = ModelConfig(
#     model_name=cfg.models.vision_model_name,
#     device=str(cfg.torch_device),
#     dtype=cfg.torch_dtype,
# )


### Phase-1: - Loading the Encoders

In [6]:
from architecture.audio_encoder import load_audio_encoder
from architecture.image_encoder import load_image_encoder
from architecture.text_decoder import (
    load_decoder_llm,
    TextDecoder,
    AudioDecoder,
    VideoDecoder,
)




In [7]:
# Image Encoders
# Phase-1: Vision encoder
vision_model_name = cfg.models.vision_model_name

vision_encoder = load_image_encoder(
    model_name=vision_model_name,
    device=str(device),
    dtype=dtype,
    feature_strategy="layers_concat",  # or "auto"/"last_hidden"
    layer_indices=[2, -2],            # 2nd & 2nd-to-last layer like your spec
    pool="mean",                      # mean over patches
)

print("Vision encoder loaded:", vision_model_name)


Vision encoder loaded: openai/clip-vit-base-patch32


In [8]:
# Phase-1: Audio encoder (Whisper)
audio_model_name = cfg.models.audio_model_name

audio_encoder = None
if audio_model_name is not None:
    audio_encoder = load_audio_encoder(
        model_name=audio_model_name,
        device=str(device),
        dtype=dtype,
        target_sr=16000,
        feature_strategy="layers_concat",  # or "auto"
        layer_indices=[2, -2],
        pool="mean",
    )
    print("Audio encoder loaded:", audio_model_name)
else:
    print("No audio model specified in config.")


`torch_dtype` is deprecated! Use `dtype` instead!


Audio encoder loaded: openai/whisper-base


In [9]:
# Phase-1: Decoder-only LLM
llm_model_name = cfg.models.llm_model_name

llm = load_decoder_llm(
    model_name=llm_model_name,
    device=str(device),
    dtype=dtype,
    device_map="auto",
    add_image_special_tokens=True,   # set False if you don't need them yet
)

print("LLM loaded:", llm_model_name)
print("Image token IDs on LLM:", getattr(llm, "image_token_ids", {}))

# Task-specific decoders wrapping the same LLM
text_decoder  = TextDecoder(llm)
audio_decoder = AudioDecoder(llm)
video_decoder = VideoDecoder(llm)


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LLM loaded: qwen/Qwen2.5-3B-Instruct
Image token IDs on LLM: {'image_start': 151665, 'image_patch': 151666, 'image_end': 151667}


In [17]:
# ---------------------------------------------------------
# Phase-1.4: Freeze base encoders (we only train bottleneck)
# ---------------------------------------------------------
def maybe_freeze(module, name: str):
    if module is None:
        return
    if hasattr(module, "parameters"):
        for p in module.parameters():
            p.requires_grad = False
        print(f"[Freeze] All params in {name} set to requires_grad=False")
    else:
        print(f"[Warn] {name} has no .parameters(), skipping freeze")

maybe_freeze(vision_encoder, "vision_encoder")
maybe_freeze(audio_encoder, "audio_encoder")

# For the LLM we typically also freeze all weights for this POC
# (later you can add LoRA/PEFT on top)
if hasattr(llm, "model"):
    base_llm = llm.model
else:
    base_llm = llm  # if load_decoder_llm returns the raw HF model

maybe_freeze(base_llm, "decoder_llm")


[Freeze] All params in vision_encoder set to requires_grad=False
[Freeze] All params in audio_encoder set to requires_grad=False
[Freeze] All params in decoder_llm set to requires_grad=False


In [18]:
# ---------------------------------------------------------
# Phase-1.5: Infer feature dims & Perceiver config hooks
# ---------------------------------------------------------
def get_feature_dim(encoder, name: str) -> int | None:
    """
    Try a few common patterns to get the output feature dim.
    Adjust this if your encoder abstraction exposes a different attribute.
    """
    if encoder is None:
        return None

    for attr in ("output_dim", "feature_dim", "hidden_size", "embed_dim"):
        if hasattr(encoder, attr):
            dim = getattr(encoder, attr)
            print(f"[Dim] {name} {attr} = {dim}")
            return dim

    # If your encoders donâ€™t expose a dim attribute,
    # you can later swap this to run a tiny dummy forward pass instead.
    raise AttributeError(
        f"Could not infer feature dim for {name}. "
        "Expose `.output_dim` or `.feature_dim` on your encoder class."
    )

In [19]:
vision_feat_dim = get_feature_dim(vision_encoder, "vision_encoder")

audio_feat_dim = None
if audio_encoder is not None:
    audio_feat_dim = get_feature_dim(audio_encoder, "audio_encoder")

# LLM hidden size â€“ used by the projector & MRL head
llm_hidden_dim = getattr(base_llm.config, "hidden_size", None)
if llm_hidden_dim is None:
    raise ValueError("Could not find `hidden_size` on LLM config.")

print(f"[Dim] LLM hidden size = {llm_hidden_dim}")


AttributeError: Could not infer feature dim for vision_encoder. Expose `.output_dim` or `.feature_dim` on your encoder class.

In [16]:
# (Optional) convenience flags / sub-configs youâ€™ll use later
models_cfg   = cfg.models
perc_cfg     = getattr(cfg, "perceiver", None)
loss_cfg     = getattr(cfg, "loss", None)


### Phase-2: - Preparing datasets