In [1]:
import os, random, zipfile, numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt

from transformers import CLIPProcessor, CLIPModel
from safetensors.torch import load_file
from collections import defaultdict
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
import torchvision.models as models
import torchvision.transforms as T

CONFIG = {
    "base_seed": 171717,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),

    # trials
    "num_trials": 5,
    "trial_seed_stride": 1000,

    # extraction
    "batch_size": 32,

    # TF-IDF
    "tfidf_max_features": 512,

    # finetuned model folders in Drive
    "clip_finetuned_drive_dir": "clip_finetuned_softmax",
    "sigclip_finetuned_drive_dir": "sigmoidclip_finetuned_bce",
    "llama_sigclip_drive_dir": "llamasigclip_assets",

    # ckpts if you still need
    "vae_ckpt_relpath": "multimodal_vae_400.pth",
    "llamavae_ckpt_relpath": "llamavae_text_vae.pth",
    
    # Fixed preference profiles
    "num_profiles": 1500,
    "interaction_k": 5,
    "pref_threshold": 0.2,

    # evaluation
    "top_k": 5,
    "kfold_splits": 5,
    "rec_std_mode": "profiles",
    
    "llama_prompt_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "llama_prompt_max_new_tokens": 96,
}


DEVICE = CONFIG["device"]
print("Using device:", DEVICE)


# REPRODUCIBILITY HELPERS
def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_global_seed(CONFIG["base_seed"])

Using device: cpu


In [2]:
CSV_PATH   = "all_labels.csv"

In [3]:
df = pd.read_csv(CSV_PATH)
df["id"] = df["scroll_id"].astype(str) + "_" + df["panel_id"].astype(str)
label_cols = ["animal_label", "myth_label", "tree_label"]
print("Data shape:", df.shape)
display(df.head())
train_df, test_df = train_test_split(
    df,
    test_size=0.2,
    random_state=CONFIG["base_seed"],
)
train_df = train_df.reset_index(drop=True)
test_df  = test_df.reset_index(drop=True)
print("Train size:", len(train_df), " Test size:", len(test_df))

Data shape: (189, 10)


Unnamed: 0.1,Unnamed: 0,identifier,scroll_id,panel_id,animal_label,myth_label,tree_label,image_path,text,id
0,0,scroll_001-pulin-panel_01,1_1,1,1,1,0,/workspace/folk/all_data/s1_1/img/s1.jpg,"Hail Durga, Ma Tara, destroyer of sorrows. Inv...",1_1_1
1,1,scroll_001-pulin-panel_02,1_1,2,0,1,1,/workspace/folk/all_data/s1_1/img/s2.jpg,"One day, the goddess Durga bestowed her grace ...",1_1_2
2,2,scroll_001-pulin-panel_03,1_1,3,0,1,0,/workspace/folk/all_data/s1_1/img/s3.jpg,"Shouting “Hail Durga,” Srimonto boarded the bo...",1_1_3
3,3,scroll_001-pulin-panel_04,1_1,4,0,0,0,/workspace/folk/all_data/s1_1/img/s4.jpg,"Finally, the boat docked at Ratnamala quay to ...",1_1_4
4,4,scroll_001-pulin-panel_05,1_1,5,0,0,0,/workspace/folk/all_data/s1_1/img/s5.jpg,The king Shalbahon was seated on a jewelled th...,1_1_5


Train size: 151  Test size: 38


In [4]:
import numpy as np
data = np.load("user_pref_profiles.npz", allow_pickle=True)
interacted_ids = data["interacted_ids"]
preferred_mat = data["preferred_mat"]


In [5]:
train_ids = set(train_df["id"].values)
missing = [x for x in interacted_ids.flatten() if x not in train_ids]
print("Missing interacted ids in train_df:", len(missing))


Missing interacted ids in train_df: 0


In [6]:
from pathlib import Path
REAL_PREFIX = str(Path("workspace/folk").resolve())
OLD_PREFIX  = "/workspace/folk"

def fix_image_path(p):
    p = str(p)
    if Path(p).exists():
        return p
    p2 = p.replace(OLD_PREFIX, REAL_PREFIX)
    if Path(p2).exists():
        return p2

    # 3) 
    p3 = p.lstrip("/")
    if Path(p3).exists():
        return str(Path(p3).resolve())

    return p2

train_df["image_path"] = train_df["image_path"].apply(fix_image_path)
test_df["image_path"]  = test_df["image_path"].apply(fix_image_path)

In [8]:
train_df.to_csv("train_df_fixed_paths.csv", index=False)
test_df.to_csv("test_df_fixed_paths.csv", index=False)
print("Saved fixed-path dfs.")

Saved fixed-path dfs.


In [9]:
import pandas as pd

train_df = pd.read_csv("train_df_fixed_paths.csv")
test_df  = pd.read_csv("test_df_fixed_paths.csv")

assert "id" in train_df.columns and "id" in test_df.columns

In [10]:
FEATURE_ROOT = "features_pgl5"
os.makedirs(FEATURE_ROOT, exist_ok=True)
np.save(f"{FEATURE_ROOT}/train_ids.npy", train_df["id"].values)
np.save(f"{FEATURE_ROOT}/test_ids.npy",  test_df["id"].values)

Save feature extractions

In [11]:
class FeatureExtractor:
    def __init__(self, name):
        self.name = name

class CLIPExtractor(FeatureExtractor):
    def __init__(self, model_dir=None, name="clip_features", device=DEVICE):
        super().__init__(name=name)
        self.device = device

        if model_dir is None:
            self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
            self.model.eval()
            print("Base CLIP ready.")
            return

        try:
            print(f"Loading processor from: {model_dir}")
            self.processor = CLIPProcessor.from_pretrained(model_dir)

            print("Instantiating base CLIP model 'openai/clip-vit-base-patch32'...")
            base_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

            safetensor_path = os.path.join(model_dir, "model.safetensors")
            print(f"Loading finetuned weights from: {safetensor_path}")
            state_dict = load_file(safetensor_path)

            missing, unexpected = base_model.load_state_dict(state_dict, strict=False)
            if missing: print("Missing keys:", len(missing))
            if unexpected: print("Unexpected keys:", len(unexpected))

            self.model = base_model.to(self.device)
            self.model.eval()
            print("SigCLIP finetuned model ready.")
        except Exception as e:
            print("Could not load finetuned model, falling back to base CLIP.")
            print("Reason:", repr(e))
            self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
            self.model.eval()

    @torch.no_grad()
    def extract(self, image_path, text):
        img = Image.open(image_path).convert("RGB")
        inputs = self.processor(
            text=[text],
            images=img,
            return_tensors="pt",
            padding=True,
            truncation=True,
        ).to(self.device)

        outputs = self.model(**inputs)
        image_embeds = F.normalize(outputs.image_embeds.squeeze(0), dim=-1)
        text_embeds  = F.normalize(outputs.text_embeds.squeeze(0),  dim=-1)

        combined = (image_embeds + text_embeds) / 2.0                                  
        return combined.detach().cpu().numpy()


SIGCLIP_FEAT_NAME = "sigclip_finetuned_image_text_embedding__avgchunk"
CLIP_FT_FEAT_NAME = "clip_finetuned_image_text_embedding__avgchunk"
CLIP_BASE_FEAT_NAME = "clip_base_image_text_embedding__avgchunk"

SIGCLIP_FEAT_KEY = "sigclip_ft"
CLIP_FT_FEAT_KEY = "clip_ft"
CLIP_BASE_FEAT_KEY = "clip_base"

sigclip_extractor = CLIPExtractor(
    model_dir=CONFIG["sigclip_finetuned_drive_dir"],     
    name=SIGCLIP_FEAT_NAME,
    device=DEVICE,
)

clip_ft_extractor = CLIPExtractor(
    model_dir=CONFIG["clip_finetuned_drive_dir"],        
    name=CLIP_FT_FEAT_NAME,
    device=DEVICE,
)

clip_base_extractor = CLIPExtractor(
    model_dir=None,                            
    name=CLIP_BASE_FEAT_NAME,
    device=DEVICE,
)


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading processor from: sigmoidclip_finetuned_bce
Instantiating base CLIP model 'openai/clip-vit-base-patch32'...
Loading finetuned weights from: sigmoidclip_finetuned_bce/model.safetensors
SigCLIP finetuned model ready.
Loading processor from: clip_finetuned_softmax
Instantiating base CLIP model 'openai/clip-vit-base-patch32'...
Loading finetuned weights from: clip_finetuned_softmax/model.safetensors
SigCLIP finetuned model ready.
Base CLIP ready.


In [12]:
def extract_features_chunked(df_in, extractor: CLIPExtractor, num_chunks=3, chunk_size=100, text_col="text", seed=171717): #, text_col="text", seed=171717
    df_local = df_in.copy()
    rng = np.random.default_rng(seed)
    results = []

    for _, row in tqdm(df_local.iterrows(), total=len(df_local), desc=f"Extracting {extractor.name}"):
        text = row.get(text_col, "")
        image_path = row["image_path"]

        if pd.isna(text) or not isinstance(text, str) or len(text.strip()) == 0:
            text_chunks = [""]
        else:
            text = text.strip()
            text_chunks = []
            for _ in range(num_chunks):
                if len(text) <= chunk_size:
                    text_chunks.append(text)
                else:
                    start_idx = rng.integers(0, len(text) - chunk_size + 1)
                    text_chunks.append(text[start_idx:start_idx + chunk_size])

        chunk_features = [extractor.extract(image_path, chunk) for chunk in text_chunks]
        avg_feature = np.mean(chunk_features, axis=0)
        results.append(avg_feature)

    df_local[extractor.name] = results
    return df_local


NUM_CHUNKS = 3
CHUNK_SIZE = 100

train_sig = extract_features_chunked(
    train_df, sigclip_extractor,
    num_chunks=NUM_CHUNKS, chunk_size=CHUNK_SIZE,
    text_col="text",
    seed=CONFIG["base_seed"]
)
test_sig  = extract_features_chunked(
    test_df,  sigclip_extractor,
    num_chunks=NUM_CHUNKS, chunk_size=CHUNK_SIZE,
    text_col="text",
    seed=CONFIG["base_seed"]
)

emb_train_sigclip = np.stack(train_sig[SIGCLIP_FEAT_NAME].values)
emb_test_sigclip  = np.stack(test_sig[SIGCLIP_FEAT_NAME].values)

print("SigCLIP train embedding shape:", emb_train_sigclip.shape)
print("SigCLIP test embedding shape :", emb_test_sigclip.shape)

train_cft = extract_features_chunked(train_df, clip_ft_extractor, num_chunks=NUM_CHUNKS, chunk_size=CHUNK_SIZE)
test_cft  = extract_features_chunked(test_df,  clip_ft_extractor, num_chunks=NUM_CHUNKS, chunk_size=CHUNK_SIZE)

emb_train_clipFT = np.stack(train_cft[CLIP_FT_FEAT_NAME].values)
emb_test_clipFT  = np.stack(test_cft[CLIP_FT_FEAT_NAME].values)

print("CLIP-FT train embedding shape:", emb_train_clipFT.shape)
print("CLIP-FT test embedding shape :", emb_test_clipFT.shape)

train_cb = extract_features_chunked(train_df, clip_base_extractor, num_chunks=NUM_CHUNKS, chunk_size=CHUNK_SIZE)
test_cb  = extract_features_chunked(test_df,  clip_base_extractor, num_chunks=NUM_CHUNKS, chunk_size=CHUNK_SIZE)

emb_train_clipBASE = np.stack(train_cb[CLIP_BASE_FEAT_NAME].values)
emb_test_clipBASE  = np.stack(test_cb[CLIP_BASE_FEAT_NAME].values)

print("CLIP-BASE train embedding shape:", emb_train_clipBASE.shape)
print("CLIP-BASE test embedding shape :", emb_test_clipBASE.shape)

Extracting sigclip_finetuned_image_text_embedding__avgchunk:   0%|          | 0/151 [00:00<?, ?it/s]

Extracting sigclip_finetuned_image_text_embedding__avgchunk:   0%|          | 0/38 [00:00<?, ?it/s]

SigCLIP train embedding shape: (151, 512)
SigCLIP test embedding shape : (38, 512)


Extracting clip_finetuned_image_text_embedding__avgchunk:   0%|          | 0/151 [00:00<?, ?it/s]

Extracting clip_finetuned_image_text_embedding__avgchunk:   0%|          | 0/38 [00:00<?, ?it/s]

CLIP-FT train embedding shape: (151, 512)
CLIP-FT test embedding shape : (38, 512)


Extracting clip_base_image_text_embedding__avgchunk:   0%|          | 0/151 [00:00<?, ?it/s]

Extracting clip_base_image_text_embedding__avgchunk:   0%|          | 0/38 [00:00<?, ?it/s]

CLIP-BASE train embedding shape: (151, 512)
CLIP-BASE test embedding shape : (38, 512)


In [13]:
from transformers import AutoTokenizer, AutoModelForCausalLM

def _safe_text(x):
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return ""
    if not isinstance(x, str):
        x = str(x)
    return x.strip()

def load_prompt_llm(model_name):
    try:
        tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token
        mdl = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None
        )
        mdl.eval()
        return tok, mdl
    except Exception as e:
        print(f"[WARN] Could not load Llama prompt model '{model_name}'. Falling back to original text.\n  Error: {e}")
        return None, None

LLAMA_TOK, LLAMA_MDL = load_prompt_llm(CONFIG["llama_prompt_model"])

@torch.no_grad()
def rewrite_text_llama(raw_text: str) -> str:
    raw_text = _safe_text(raw_text)
    if raw_text == "":
        return ""
    if LLAMA_TOK is None or LLAMA_MDL is None:
        return raw_text

    prompt = (
        "You are helping build a recommender system for folk art panels.\n"
        "Rewrite the given text into a short, descriptive list of keywords and entities.\n"
        "Keep it factual. Avoid extra sentences.\n\n"
        f"TEXT: {raw_text}\n"
        "KEYWORDS:"
    )
    inputs = LLAMA_TOK(prompt, return_tensors="pt", truncation=True, max_length=512).to(LLAMA_MDL.device)
    out = LLAMA_MDL.generate(
        **inputs,
        max_new_tokens=CONFIG["llama_prompt_max_new_tokens"],
        do_sample=False,
        temperature=0.0,
        pad_token_id=LLAMA_TOK.eos_token_id,
    )
    decoded = LLAMA_TOK.decode(out[0], skip_special_tokens=True)
    if "KEYWORDS:" in decoded:
        rewritten = decoded.split("KEYWORDS:", 1)[-1].strip()
        return rewritten if rewritten else raw_text
    return raw_text

def add_llama_text_column(df_in, src_col="text", new_col="llama_text"):
    df = df_in.copy()
    rewritten = []
    for t in tqdm(df[src_col].tolist(), desc="Llama rewrite text"):
        rewritten.append(rewrite_text_llama(t))
    df[new_col] = rewritten
    return df

train_df = add_llama_text_column(train_df, src_col="text", new_col="llama_text")
test_df  = add_llama_text_column(test_df,  src_col="text", new_col="llama_text")

print(train_df[["text", "llama_text"]].head(2))

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Llama rewrite text:   0%|          | 0/151 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Llama rewrite text:   0%|          | 0/38 [00:00<?, ?it/s]

  text llama_text
0  NaN           
1  NaN           


In [14]:
LLAMASIG_FEAT_NAME = "llamasigclip_image_text_embedding__avgchunk"
LLAMASIG_FEAT_KEY  = "llamasigclip"

train_lls = extract_features_chunked(
    train_df, sigclip_extractor,
    num_chunks=NUM_CHUNKS, chunk_size=140,
    text_col="llama_text",
    seed=CONFIG["base_seed"]
)
test_lls = extract_features_chunked(
    test_df, sigclip_extractor,
    num_chunks=NUM_CHUNKS, chunk_size=140,
    text_col="llama_text",
    seed=CONFIG["base_seed"]
)


emb_train_llamasig = np.stack(train_lls[SIGCLIP_FEAT_NAME].values)
emb_test_llamasig  = np.stack(test_lls[SIGCLIP_FEAT_NAME].values)

print("Llama-SigCLIP train shape:", emb_train_llamasig.shape)
print("Llama-SigCLIP test  shape:", emb_test_llamasig.shape)

Extracting sigclip_finetuned_image_text_embedding__avgchunk:   0%|          | 0/151 [00:00<?, ?it/s]

Extracting sigclip_finetuned_image_text_embedding__avgchunk:   0%|          | 0/38 [00:00<?, ?it/s]

Llama-SigCLIP train shape: (151, 512)
Llama-SigCLIP test  shape: (38, 512)


In [15]:
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5EncoderModel, T5ForConditionalGeneration
VAE_CONFIG = {
    "model_name": "t5-small",
    "seq_len": 64,
    "batch_size": CONFIG["batch_size"],
}

VAE_CKPT_PATH = os.path.join("Pruthvi", CONFIG["vae_ckpt_relpath"])  

tokenizer = T5Tokenizer.from_pretrained(VAE_CONFIG["model_name"])

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [18]:
image_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

class ScrollsDataset(Dataset):
    def __init__(self, dataframe, tokenizer, seq_len):
        self.df = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row["image_path"]).convert("RGB")
        image = image_transform(image)

        text = row.get("text", "")
        if not isinstance(text, str) or text.strip() == "":
            text = tokenizer.pad_token

        tokens = self.tokenizer(
            text, padding="max_length", truncation=True,
            max_length=self.seq_len, return_tensors="pt"
        )
        return image, tokens.input_ids.squeeze(0), tokens.attention_mask.squeeze(0), row["id"]

def product_of_experts(mus, logvars):
    TINY = 1e-8
    precisions = [1.0 / (torch.exp(lv) + TINY) for lv in logvars]
    mu_comb = sum(p * m for p, m in zip(precisions, mus)) / sum(precisions)
    logvar_comb = torch.log(1.0 / sum(precisions) + TINY)
    return mu_comb, logvar_comb

class ImageEncoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        resnet = models.resnet18(weights=None)
        self.cnn = nn.Sequential(*list(resnet.children())[:-2])
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(512 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(512 * 7 * 7, latent_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = self.cnn(x)
        x = self.flatten(x)
        return self.dropout(self.fc_mu(x)), self.fc_logvar(x)

class TextEncoder(nn.Module):
    def __init__(self, model_name, latent_dim):
        super().__init__()
        self.encoder = T5EncoderModel.from_pretrained(model_name)
        self.fc_mu = nn.Linear(self.encoder.config.d_model, latent_dim)
        self.fc_logvar = nn.Linear(self.encoder.config.d_model, latent_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, input_ids, attention_mask):
        if attention_mask.sum().item() == 0:
            b = input_ids.size(0)
            mu = torch.zeros(b, self.fc_mu.out_features, device=input_ids.device)
            logvar = torch.zeros_like(mu)
            return mu, logvar
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_repr = out.last_hidden_state[:, 0, :]
        return self.dropout(self.fc_mu(cls_repr)), self.fc_logvar(cls_repr)

class ImageDecoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 512 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (512, 7, 7)),
            nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),   nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),    nn.Sigmoid()
        )
        self.dropout = nn.Dropout(0.3)

    def forward(self, z):
        x = self.dropout(self.fc(z))
        return self.decoder(x)

class TextDecoder(nn.Module):
    def __init__(self, model_name, latent_dim):
        super().__init__()
        self.decoder = T5ForConditionalGeneration.from_pretrained(model_name)
        self.latent_to_prefix = nn.Linear(latent_dim, self.decoder.config.d_model)
        self.dropout = nn.Dropout(0.3)

    def forward(self, z, input_ids=None, attention_mask=None):
        prefix_emb = self.dropout(self.latent_to_prefix(z)).unsqueeze(1)
        input_embeds = self.decoder.encoder.embed_tokens(input_ids)
        input_embeds = torch.cat([prefix_emb, input_embeds], dim=1)
        if attention_mask is not None:
            prefix_mask = torch.ones((attention_mask.size(0), 1), device=attention_mask.device)
            attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
        return self.decoder(inputs_embeds=input_embeds, attention_mask=attention_mask, labels=input_ids)

class MultiModalVAE(nn.Module):
    def __init__(self, latent_dim, model_name):
        super().__init__()
        self.image_enc = ImageEncoder(latent_dim)
        self.text_enc  = TextEncoder(model_name, latent_dim)
        self.image_dec = ImageDecoder(latent_dim)
        self.text_dec  = TextDecoder(model_name, latent_dim)

if not os.path.exists(VAE_CKPT_PATH):
    raise FileNotFoundError(f"VAE checkpoint not found: {VAE_CKPT_PATH}")

ckpt = torch.load(VAE_CKPT_PATH, map_location=DEVICE)

vae_model = MultiModalVAE(
    latent_dim=ckpt["config"]["latent_dim"],
    model_name=ckpt["config"]["model_name"],
).to(DEVICE)

vae_model.load_state_dict(ckpt["model"], strict=True)
vae_model.eval()
print("Loaded VAE ckpt:", VAE_CKPT_PATH)

@torch.no_grad()
def extract_mu_multimodal_vae(df_in, batch_size):
    loader = DataLoader(
        ScrollsDataset(df_in, tokenizer, VAE_CONFIG["seq_len"]),
        batch_size=batch_size, shuffle=False,
        num_workers=0, pin_memory=torch.cuda.is_available()
    )

    out_rows = []
    for img, input_ids, attn_mask, ids in tqdm(loader, desc="Extracting VAE mu"):
        img = img.to(DEVICE, non_blocking=True)
        input_ids = input_ids.to(DEVICE, non_blocking=True)
        attn_mask = attn_mask.to(DEVICE, non_blocking=True)

        img_mu, img_logvar = vae_model.image_enc(img)
        txt_mu, txt_logvar = vae_model.text_enc(input_ids, attn_mask)
        mu, _ = product_of_experts([img_mu, txt_mu], [img_logvar, txt_logvar])

        mu = mu.detach().cpu().numpy()
        for id_, vec in zip(ids, mu):
            out_rows.append({"id": id_, "vae_mu": vec})

    out_df = pd.DataFrame(out_rows)
    return df_in.merge(out_df, on="id", how="left")

train_df = extract_mu_multimodal_vae(train_df, batch_size=VAE_CONFIG["batch_size"])
test_df  = extract_mu_multimodal_vae(test_df,  batch_size=VAE_CONFIG["batch_size"])

emb_train_vae = np.stack(train_df["vae_mu"].values).astype(np.float32)
emb_test_vae  = np.stack(test_df["vae_mu"].values).astype(np.float32)

print("VAE mu shapes:", emb_train_vae.shape, emb_test_vae.shape)

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Loaded VAE ckpt: Pruthvi/multimodal_vae_400.pth


Extracting VAE mu:   0%|          | 0/5 [00:00<?, ?it/s]

Extracting VAE mu:   0%|          | 0/2 [00:00<?, ?it/s]

VAE mu shapes: (151, 4096) (38, 4096)


In [19]:
LLAMAVAE_CKPT_PATH = os.path.join("Pruthvi", CONFIG["llamavae_ckpt_relpath"])  
class SentenceT5VAEEncoder(nn.Module):
    def __init__(self, st5_name="sentence-transformers/sentence-t5-base", latent_dim=768):
        super().__init__()
        self.st5 = T5EncoderModel.from_pretrained(st5_name)
        self.latent_dim = latent_dim
        self.proj_mu = nn.Linear(self.st5.config.d_model, latent_dim)
        self.proj_logvar = nn.Linear(self.st5.config.d_model, latent_dim)

    def forward(self, input_ids, attention_mask):
        out = self.st5(input_ids=input_ids, attention_mask=attention_mask)
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (out.last_hidden_state * mask).sum(dim=1) / torch.clamp(mask.sum(dim=1), min=1.0)
        mu = self.proj_mu(pooled)
        logvar = torch.clamp(self.proj_logvar(pooled), -10, 10)
        return mu, logvar

def load_llamavae_or_fallback(latent_dim=768):
    model = SentenceT5VAEEncoder(latent_dim=latent_dim).to(DEVICE)
    if os.path.exists(LLAMAVAE_CKPT_PATH):
        try:
            ck = torch.load(LLAMAVAE_CKPT_PATH, map_location=DEVICE)
            sd = ck["model"] if isinstance(ck, dict) and "model" in ck else ck
            model.load_state_dict(sd, strict=False)
            print("Loaded LlamaVAE encoder ckpt:", LLAMAVAE_CKPT_PATH)
            model.eval()
            return model, True
        except Exception as e:
            print(f"[WARN] Failed to load LlamaVAE ckpt, falling back to SentenceT5 embeddings.\n  Error: {e}")
    else:
        print("[INFO] LlamaVAE ckpt not found; using SentenceT5 embeddings as LlamaVAE features.")
    model.eval()
    return model, False

LLAMAVAE_ENC, LLAMAVAE_HAS_CKPT = load_llamavae_or_fallback(latent_dim=768)
LLAMAVAE_TOK = AutoTokenizer.from_pretrained("sentence-transformers/sentence-t5-base")

@torch.no_grad()
def extract_llamavae_features_text_only(df_in, batch_size=64, max_len=128, text_col="text"):
    texts = []
    for t in df_in[text_col].tolist():
        t = _safe_text(t)
        texts.append(t if t != "" else LLAMAVAE_TOK.pad_token)

    vecs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Extracting LlamaVAE (text)"):
        batch = texts[i:i+batch_size]
        tok = LLAMAVAE_TOK(batch, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
        input_ids = tok["input_ids"].to(DEVICE)
        attn = tok["attention_mask"].to(DEVICE)

        mu, _ = LLAMAVAE_ENC(input_ids, attn)

        if not LLAMAVAE_HAS_CKPT:
            out = LLAMAVAE_ENC.st5(input_ids=input_ids, attention_mask=attn)
            mask = attn.unsqueeze(-1).float()
            pooled = (out.last_hidden_state * mask).sum(dim=1) / torch.clamp(mask.sum(dim=1), min=1.0)
            vec = pooled
        else:
            vec = mu

        vecs.append(vec.detach().cpu().numpy())

    return np.concatenate(vecs, axis=0).astype(np.float32)

emb_train_llamavae = extract_llamavae_features_text_only(train_df, batch_size=64, max_len=128, text_col="text")
emb_test_llamavae  = extract_llamavae_features_text_only(test_df,  batch_size=64, max_len=128, text_col="text")

print("LlamaVAE shapes:", emb_train_llamavae.shape, emb_test_llamavae.shape)

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/219M [00:00<?, ?B/s]

[INFO] LlamaVAE ckpt not found; using SentenceT5 embeddings as LlamaVAE features.


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Extracting LlamaVAE (text):   0%|          | 0/3 [00:00<?, ?it/s]

Extracting LlamaVAE (text):   0%|          | 0/1 [00:00<?, ?it/s]

LlamaVAE shapes: (151, 768) (38, 768)


In [20]:
print("Shapes summary:")
print("SigCLIP       :", emb_train_sigclip.shape, emb_test_sigclip.shape)
print("LlamaSigCLIP  :", emb_train_llamasig.shape, emb_test_llamasig.shape)
print("VAE           :", emb_train_vae.shape, emb_test_vae.shape)
print("LlamaVAE      :", emb_train_llamavae.shape, emb_test_llamavae.shape)


Shapes summary:
SigCLIP       : (151, 512) (38, 512)
LlamaSigCLIP  : (151, 512) (38, 512)
VAE           : (151, 4096) (38, 4096)
LlamaVAE      : (151, 768) (38, 768)


In [21]:
import os, json, time
import numpy as np


FEATURE_ROOT = "features_pgl5"
os.makedirs(FEATURE_ROOT, exist_ok=True)

np.save(f"{FEATURE_ROOT}/train_ids.npy", train_df["id"].values)
np.save(f"{FEATURE_ROOT}/test_ids.npy",  test_df["id"].values)

def save_feature_block(feature_key, feature_name, X_train, X_test, meta_extra=None):
    out_dir = f"{FEATURE_ROOT}/{feature_key}"
    os.makedirs(out_dir, exist_ok=True)

    np.save(f"{out_dir}/X_train.npy", X_train.astype(np.float32))
    np.save(f"{out_dir}/X_test.npy",  X_test.astype(np.float32))

    meta = {
        "feature_key": feature_key,
        "feature_name": feature_name,
        "train_shape": list(X_train.shape),
        "test_shape": list(X_test.shape),
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    }
    if meta_extra:
        meta.update(meta_extra)

    with open(f"{out_dir}/meta.json", "w") as f:
        json.dump(meta, f, indent=2)

    print(f"[Saved] {feature_key} -> {out_dir} | train {X_train.shape} test {X_test.shape}")

BASE_SEED = int(CONFIG.get("base_seed", 0))

SIGCLIP_FEAT_KEY  = "sigclip_ft"
SIGCLIP_FEAT_NAME = "sigclip_finetuned_image_text_embedding__avgfusion__avgchunk"

save_feature_block(
    SIGCLIP_FEAT_KEY,
    SIGCLIP_FEAT_NAME,
    emb_train_sigclip,
    emb_test_sigclip,
    meta_extra={
        "type": "clip_like",
        "fusion": "(image_embeds + text_embeds)/2",
        "num_chunks": 3,
        "chunk_size": 100,
        "seed_for_chunk_sampling": BASE_SEED,
        "model_dir": str(CONFIG.get("sigclip_model_dir", CONFIG.get("sigclip_finetuned_drive_dir", ""))),
        "base_model": "openai/clip-vit-base-patch32",
    }
)

VAE_FEAT_KEY  = "vae_mu"
VAE_FEAT_NAME = "multimodal_vae_mu__poe__latent"

save_feature_block(
    VAE_FEAT_KEY,
    VAE_FEAT_NAME,
    emb_train_vae,
    emb_test_vae,
    meta_extra={
        "type": "vae_mu",
        "poe": True,
        "seq_len": int(VAE_CONFIG.get("seq_len", 0)) if "VAE_CONFIG" in globals() else None,
        "batch_size": int(VAE_CONFIG.get("batch_size", 0)) if "VAE_CONFIG" in globals() else None,
        "ckpt_relpath": str(CONFIG.get("vae_ckpt_relpath", "")),
    }
)

LLAMASIGCLIP_FEAT_KEY  = "llamasigclip"
LLAMASIGCLIP_FEAT_NAME = "llamasigclip_embedding__avgfusion__avgchunk"

save_feature_block(
    LLAMASIGCLIP_FEAT_KEY,
    LLAMASIGCLIP_FEAT_NAME,
    emb_train_llamasig,
    emb_test_llamasig,
    meta_extra={
        "type": "clip_like",
        "fusion": "(image_embeds + text_embeds)/2",
        "num_chunks": 3,
        "chunk_size": 140,
        "seed_for_chunk_sampling": BASE_SEED,
        "llama_prompt_model": str(CONFIG.get("llama_prompt_model", "")),
        "llama_prompt_max_new_tokens": int(CONFIG.get("llama_prompt_max_new_tokens", 0)),
        "sigclip_model_dir": str(CONFIG.get("sigclip_model_dir", CONFIG.get("sigclip_finetuned_drive_dir", ""))),
        "base_model": "openai/clip-vit-base-patch32",
    }
)

LLAMAVAE_FEAT_KEY  = "llamavae"
LLAMAVAE_FEAT_NAME = "llamavae_text_embedding__mu_or_pooled"

save_feature_block(
    LLAMAVAE_FEAT_KEY,
    LLAMAVAE_FEAT_NAME,
    emb_train_llamavae,
    emb_test_llamavae,
    meta_extra={
        "type": "llamavae_text",
        "ckpt_relpath": str(CONFIG.get("llamavae_ckpt_relpath", "")),
    }
)

CLIP_FT_FEAT_KEY  = "clip_ft"
CLIP_FT_FEAT_NAME = "clip_finetuned_image_text_embedding__avgfusion__avgchunk"

save_feature_block(
    CLIP_FT_FEAT_KEY,
    CLIP_FT_FEAT_NAME,
    emb_train_clipFT,
    emb_test_clipFT,
    meta_extra={
        "type": "clip_like",
        "fusion": "(image_embeds + text_embeds)/2",
        "num_chunks": 3,
        "chunk_size": 100,
        "seed_for_chunk_sampling": BASE_SEED,
        "model_dir": str(CONFIG.get("clip_model_dir", CONFIG.get("clip_finetuned_drive_dir", ""))),
        "base_model": "openai/clip-vit-base-patch32",
    }
)

CLIP_BASE_FEAT_KEY  = "clip_base"
CLIP_BASE_FEAT_NAME = "clip_base_image_text_embedding__avgfusion__avgchunk"

save_feature_block(
    CLIP_BASE_FEAT_KEY,
    CLIP_BASE_FEAT_NAME,
    emb_train_clipBASE,
    emb_test_clipBASE,
    meta_extra={
        "type": "clip_like",
        "fusion": "(image_embeds + text_embeds)/2",
        "num_chunks": 3,
        "chunk_size": 100,
        "seed_for_chunk_sampling": BASE_SEED,
        "model": "openai/clip-vit-base-patch32",
        "model_dir": None,
    }
)


print("All feature sets saved under:", FEATURE_ROOT)

[Saved] sigclip_ft -> features_pgl5/sigclip_ft | train (151, 512) test (38, 512)
[Saved] vae_mu -> features_pgl5/vae_mu | train (151, 4096) test (38, 4096)
[Saved] llamasigclip -> features_pgl5/llamasigclip | train (151, 512) test (38, 512)
[Saved] llamavae -> features_pgl5/llamavae | train (151, 768) test (38, 768)
[Saved] clip_ft -> features_pgl5/clip_ft | train (151, 512) test (38, 512)
[Saved] clip_base -> features_pgl5/clip_base | train (151, 512) test (38, 512)
✅ All 5 feature sets saved under: features_pgl5


# ===== 4) TF-IDF =====
TFIDF_FEAT_KEY  = "tfidf"
TFIDF_FEAT_NAME = f"tfidf_text_vector__max{int(CONFIG.get('tfidf_max_features', 0) or 0)}"

save_feature_block(
    TFIDF_FEAT_KEY,
    TFIDF_FEAT_NAME,
    emb_train_tfidf,
    emb_test_tfidf,
    meta_extra={
        "type": "tfidf",
        "tfidf_max_features": int(CONFIG.get("tfidf_max_features", 0)),
    }
)

# ===== 5) ResNet50 =====
RESNET_FEAT_KEY  = "resnet50"
RESNET_FEAT_NAME = "resnet50_imagenet_image_embedding__l2norm"

save_feature_block(
    RESNET_FEAT_KEY,
    RESNET_FEAT_NAME,
    emb_train_resnet,
    emb_test_resnet,
    meta_extra={
        "type": "resnet50",
        "weights": "IMAGENET1K_V1",
        "pooling": "avgpool",
        "l2_normalized": True,
    }
)