In [None]:
%%capture
!pip -q install --upgrade pip
!pip -q install transformers sentencepiece protobuf safetensors scikit-learn pandas tqdm numpy scipy nltk tabulate torchmetrics matplotlib opencv-python accelerate
!pip -q install torch-geometric -f https://data.pyg.org/whl/torch-2.4.0+cu121.html

import os, random, zipfile, shutil
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from tqdm.notebook import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity

from google.colab import drive
drive.mount("/content/drive")

CONFIG = {
    "random_seed": 171717,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),

    "sigclip_drive_dir": "sigmoidclip_finetuned_bce",
    "clip_drive_dir": "clip_finetuned_softmax",
    "llama_sigclip_drive_dir": "llamasigclip_assets",

    "drive_root": "/content/drive/MyDrive/FolkArt",
    "csv_name": "all_labels.csv",
    "zip_name": "images.zip",
    "pref_name": "user_pref_profiles.npz",

    "vae_ckpt_relpath": "vae_checkpoints/multimodal_vae_400.pth",
    "llamavae_ckpt_relpath": "llamavae_checkpoints/llamavae_text_vae.pth",

    "top_k": 5,
    "pref_thresh": 0.2,
    "gcn_epochs": 500,
    "gcn_percentile": 95,

    "graph_use_knn": True,
    "graph_knn_k": 25,
    "graph_sim_floor": 0.0,

    "vgae_hidden_dim": 256,
    "vgae_latent_dim": 128,
    "vgae_dropout": 0.20,

    "vgae_lr": 3e-3,
    "vgae_weight_decay": 5e-4,
    "vgae_kl_beta_max": 1.0,
    "vgae_kl_warmup_frac": 0.35,
    "vgae_grad_clip": 2.0,

    "link_val_frac": 0.05,
    "link_test_frac": 0.05,

    "llama_prompt_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "llama_prompt_max_new_tokens": 96,
    "tree_edge_aug": True,
    "tree_edge_aug_k": 25,
    "tree_edge_aug_max_nodes": 2500,

    "ssl_lambda": 0.60,
    "use_pos_weight": True,
    "use_focal": True,
    "focal_gamma": 2.0,
    "focal_alpha": 0.25,
}

DEVICE = CONFIG["device"]
print("Using device:", DEVICE)

random.seed(CONFIG["random_seed"])
np.random.seed(CONFIG["random_seed"])
torch.manual_seed(CONFIG["random_seed"])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG["random_seed"])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DRIVE_ROOT = CONFIG["drive_root"]
CSV_PATH   = os.path.join(DRIVE_ROOT, CONFIG["csv_name"])
ZIP_PATH   = os.path.join(DRIVE_ROOT, CONFIG["zip_name"])
PREF_PATH  = os.path.join(DRIVE_ROOT, CONFIG["pref_name"])

SIGCLIP_DRIVE_DIR = os.path.join(DRIVE_ROOT, CONFIG["sigclip_drive_dir"])
CLIP_DRIVE_DIR    = os.path.join(DRIVE_ROOT, CONFIG["clip_drive_dir"])
LLAMA_SIGCLIP_DRIVE_DIR = os.path.join(DRIVE_ROOT, CONFIG["llama_sigclip_drive_dir"])

VAE_CKPT_PATH      = os.path.join(DRIVE_ROOT, CONFIG["vae_ckpt_relpath"])
LLAMAVAE_CKPT_PATH = os.path.join(DRIVE_ROOT, CONFIG["llamavae_ckpt_relpath"])

WORKDIR = "/content/work"
IMG_DIR = os.path.join(WORKDIR, "images")
os.makedirs(WORKDIR, exist_ok=True)
os.makedirs(IMG_DIR, exist_ok=True)

if not os.listdir(IMG_DIR):
    print("Unzipping images.zip...")
    with zipfile.ZipFile(ZIP_PATH, "r") as z:
        z.extractall(WORKDIR)
    print("Done.")

SIGCLIP_LOCAL_DIR = os.path.join(WORKDIR, CONFIG["sigclip_drive_dir"])
if (not os.path.exists(SIGCLIP_LOCAL_DIR)) or (len(os.listdir(SIGCLIP_LOCAL_DIR)) == 0):
    if os.path.exists(SIGCLIP_LOCAL_DIR): shutil.rmtree(SIGCLIP_LOCAL_DIR)
    shutil.copytree(SIGCLIP_DRIVE_DIR, SIGCLIP_LOCAL_DIR)

CLIP_LOCAL_DIR = os.path.join(WORKDIR, CONFIG["clip_drive_dir"])
if (not os.path.exists(CLIP_LOCAL_DIR)) or (len(os.listdir(CLIP_LOCAL_DIR)) == 0):
    if os.path.exists(CLIP_LOCAL_DIR): shutil.rmtree(CLIP_LOCAL_DIR)
    shutil.copytree(CLIP_DRIVE_DIR, CLIP_LOCAL_DIR)

LLAMA_SIGCLIP_LOCAL_DIR = os.path.join(WORKDIR, CONFIG["llama_sigclip_drive_dir"])
if os.path.exists(LLAMA_SIGCLIP_DRIVE_DIR):
    if (not os.path.exists(LLAMA_SIGCLIP_LOCAL_DIR)) or (len(os.listdir(LLAMA_SIGCLIP_LOCAL_DIR)) == 0):
        if os.path.exists(LLAMA_SIGCLIP_LOCAL_DIR): shutil.rmtree(LLAMA_SIGCLIP_LOCAL_DIR)
        shutil.copytree(LLAMA_SIGCLIP_DRIVE_DIR, LLAMA_SIGCLIP_LOCAL_DIR)

df = pd.read_csv(CSV_PATH)
df["id"] = df["scroll_id"].astype(str) + "_" + df["panel_id"].astype(str)
label_cols = ["animal_label", "myth_label", "tree_label"]
print("Data:", df.shape)
display(df.head())

def remap_path(old_path: str) -> str:
    p = str(old_path)
    if os.path.exists(p):
        return p
    parts = p.split("/")
    scroll_idx = None
    for i, token in enumerate(parts):
        if token.startswith("s1_") or token.startswith("s2_"):
            scroll_idx = i
            break
    if scroll_idx is not None:
        folder = parts[scroll_idx]
        filename = parts[-1]
        return os.path.join(WORKDIR, "images", folder, "img", filename)
    if p.startswith("images/"):
        rel = p[len("images/"):]
        rel_parts = rel.split("/")
        if len(rel_parts) >= 2:
            folder = rel_parts[0]
            return os.path.join(WORKDIR, "images", folder, "img", rel_parts[-1])
    return p

train_df, test_df = train_test_split(df, test_size=0.2, random_state=CONFIG["random_seed"])
train_df = train_df.reset_index(drop=True)
test_df  = test_df.reset_index(drop=True)
train_df["image_path"] = train_df["image_path"].apply(remap_path)
test_df["image_path"]  = test_df["image_path"].apply(remap_path)

_ = Image.open(train_df["image_path"].iloc[0]).convert("RGB")
print("Paths OK | Train:", len(train_df), "Test:", len(test_df))

pref_data = np.load(PREF_PATH, allow_pickle=True)
interacted_ids = pref_data["interacted_ids"]
preferred_mat  = pref_data["preferred_mat"]
preferred_mat_bin = (preferred_mat > CONFIG["pref_thresh"]).astype(int)

print("Loaded FIXED profiles:")
print("  interacted_ids:", interacted_ids.shape)
print("  preferred_mat :", preferred_mat.shape)
print("  preferred_bin :", preferred_mat_bin.shape, "| avg prefs/user:", preferred_mat_bin.sum(axis=1).mean())

id_to_train_idx = {pid: i for i, pid in enumerate(train_df["id"].values)}


In [None]:
from torch.utils.data import Dataset, DataLoader
from transformers import (
    T5Tokenizer, T5EncoderModel, T5ForConditionalGeneration,
    AutoTokenizer, AutoModelForCausalLM
)
from torchvision import models, transforms
VAE_CONFIG = {"model_name": "t5-small", "seq_len": 64, "batch_size": 32}
tokenizer = T5Tokenizer.from_pretrained(VAE_CONFIG["model_name"])

image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

class ScrollsDataset(Dataset):
    def __init__(self, dataframe, tokenizer, seq_len):
        self.df = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row["image_path"]).convert("RGB")
        image = image_transform(image)

        text = row.get("text", "")
        if not isinstance(text, str) or text.strip() == "":
            text = tokenizer.pad_token

        tokens = self.tokenizer(
            text, padding="max_length", truncation=True, max_length=self.seq_len, return_tensors="pt"
        )
        return image, tokens.input_ids.squeeze(0), tokens.attention_mask.squeeze(0), row["id"]

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

def product_of_experts(mus, logvars):
    T = 1e-8
    precisions = [1.0 / (torch.exp(lv) + T) for lv in logvars]
    mu_comb = sum(p * m for p, m in zip(precisions, mus)) / sum(precisions)
    logvar_comb = torch.log(1.0 / sum(precisions) + T)
    return mu_comb, logvar_comb

class ImageEncoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        resnet = models.resnet18(weights=None)
        self.cnn = nn.Sequential(*list(resnet.children())[:-2])
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(512 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(512 * 7 * 7, latent_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = self.cnn(x)
        x = self.flatten(x)
        return self.dropout(self.fc_mu(x)), self.fc_logvar(x)

class TextEncoder(nn.Module):
    def __init__(self, model_name, latent_dim):
        super().__init__()
        self.encoder = T5EncoderModel.from_pretrained(model_name)
        self.fc_mu = nn.Linear(self.encoder.config.d_model, latent_dim)
        self.fc_logvar = nn.Linear(self.encoder.config.d_model, latent_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, input_ids, attention_mask):
        if attention_mask.sum().item() == 0:
            b = input_ids.size(0)
            mu = torch.zeros(b, self.fc_mu.out_features, device=input_ids.device)
            logvar = torch.zeros_like(mu)
            return mu, logvar
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_repr = out.last_hidden_state[:, 0, :]
        return self.dropout(self.fc_mu(cls_repr)), self.fc_logvar(cls_repr)

class ImageDecoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 512 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (512, 7, 7)),
            nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),   nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),    nn.Sigmoid()
        )
        self.dropout = nn.Dropout(0.3)

    def forward(self, z):
        x = self.dropout(self.fc(z))
        return self.decoder(x)

class TextDecoder(nn.Module):
    def __init__(self, model_name, latent_dim):
        super().__init__()
        self.decoder = T5ForConditionalGeneration.from_pretrained(model_name)
        self.latent_to_prefix = nn.Linear(latent_dim, self.decoder.config.d_model)
        self.dropout = nn.Dropout(0.3)

    def forward(self, z, input_ids=None, attention_mask=None):
        prefix_emb = self.dropout(self.latent_to_prefix(z)).unsqueeze(1)
        input_embeds = self.decoder.encoder.embed_tokens(input_ids)
        input_embeds = torch.cat([prefix_emb, input_embeds], dim=1)
        if attention_mask is not None:
            prefix_mask = torch.ones((attention_mask.size(0), 1), device=attention_mask.device)
            attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
        return self.decoder(inputs_embeds=input_embeds, attention_mask=attention_mask, labels=input_ids)

class MultiModalVAE(nn.Module):
    def __init__(self, latent_dim, model_name):
        super().__init__()
        self.image_enc = ImageEncoder(latent_dim)
        self.text_enc  = TextEncoder(model_name, latent_dim)
        self.image_dec = ImageDecoder(latent_dim)
        self.text_dec  = TextDecoder(model_name, latent_dim)

if not os.path.exists(VAE_CKPT_PATH):
    raise FileNotFoundError(f"VAE checkpoint not found at: {VAE_CKPT_PATH}")

ckpt = torch.load(VAE_CKPT_PATH, map_location=DEVICE)
vae_model = MultiModalVAE(latent_dim=ckpt["config"]["latent_dim"], model_name=ckpt["config"]["model_name"]).to(DEVICE)
vae_model.load_state_dict(ckpt["model"])
vae_model.eval()
print("Loaded VAE from Drive:", VAE_CKPT_PATH)

@torch.no_grad()
def extract_mu_multimodal_vae(df_in, batch_size):
    loader = DataLoader(
        ScrollsDataset(df_in, tokenizer, VAE_CONFIG["seq_len"]),
        batch_size=batch_size, shuffle=False, num_workers=2,
        pin_memory=torch.cuda.is_available()
    )
    out = []
    for img, input_ids, attn_mask, ids in tqdm(loader, desc="Extracting VAE mu"):
        img = img.to(DEVICE, non_blocking=True)
        input_ids = input_ids.to(DEVICE, non_blocking=True)
        attn_mask = attn_mask.to(DEVICE, non_blocking=True)

        img_mu, img_logvar = vae_model.image_enc(img)
        txt_mu, txt_logvar = vae_model.text_enc(input_ids, attn_mask)
        mu, _ = product_of_experts([img_mu, txt_mu], [img_logvar, txt_logvar])

        mu = mu.detach().cpu().numpy()
        for id_, vec in zip(ids, mu):
            out.append({"id": id_, "vae_mu": vec})
    out = pd.DataFrame(out)
    return df_in.merge(out, on="id", how="left")

train_df = extract_mu_multimodal_vae(train_df, batch_size=VAE_CONFIG["batch_size"])
test_df  = extract_mu_multimodal_vae(test_df,  batch_size=VAE_CONFIG["batch_size"])

emb_train_vae = np.stack(train_df["vae_mu"].values)
emb_test_vae  = np.stack(test_df["vae_mu"].values)
print("VAE mu shapes:", emb_train_vae.shape, emb_test_vae.shape)

from sklearn.feature_extraction.text import TfidfVectorizer

def extract_tfidf_features(train_df, test_df, max_features=512):
    corpus_tr = train_df["text"].fillna("").astype(str).tolist()
    corpus_te = test_df["text"].fillna("").astype(str).tolist()
    vec = TfidfVectorizer(max_features=max_features)
    X_tr = vec.fit_transform(corpus_tr).toarray()
    X_te = vec.transform(corpus_te).toarray()
    train_df = train_df.copy(); test_df = test_df.copy()
    train_df["tfidf_features"] = list(X_tr)
    test_df["tfidf_features"]  = list(X_te)
    return train_df, test_df

train_df, test_df = extract_tfidf_features(train_df, test_df, max_features=512)

from torchvision import models as tv_models
from torchvision import transforms as tv_transforms
import cv2

resnet_transform = tv_transforms.Compose([
    tv_transforms.ToPILImage(),
    tv_transforms.Resize((224, 224)),
    tv_transforms.ToTensor(),
    tv_transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
])

@torch.no_grad()
def extract_resnet_features(df_in):
    df = df_in.copy()
    resnet = tv_models.resnet50(weights=tv_models.ResNet50_Weights.DEFAULT)
    resnet.fc = torch.nn.Identity()
    resnet = resnet.to(DEVICE).eval()

    feats = []
    for path in tqdm(df["image_path"].tolist(), desc="ResNet50 feats"):
        img = cv2.imread(path)
        if img is None:
            raise FileNotFoundError(f"cv2.imread failed for: {path}")
        img = img[:, :, ::-1]
        x = resnet_transform(img).unsqueeze(0).to(DEVICE)
        feat = resnet(x).squeeze(0).detach().cpu().numpy()
        feats.append(feat)
    df["resnet_features"] = feats
    return df

train_df = extract_resnet_features(train_df)
test_df  = extract_resnet_features(test_df)

from transformers import CLIPProcessor, CLIPModel
from safetensors.torch import load_file

class FeatureExtractor:
    def __init__(self, name): self.name = name

class FinetunedCLIPExtractor(FeatureExtractor):
    def __init__(self, model_dir, name, device=DEVICE):
        super().__init__(name=name)
        self.device = device
        self.processor = CLIPProcessor.from_pretrained(model_dir)
        base = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        sd_path = os.path.join(model_dir, "model.safetensors")
        if not os.path.exists(sd_path):
            raise FileNotFoundError(f"Missing model.safetensors in: {model_dir}")
        sd = load_file(sd_path)
        base.load_state_dict(sd, strict=False)
        self.model = base.to(self.device).eval()

    @torch.no_grad()
    def extract(self, image_path, text):
        img = Image.open(image_path).convert("RGB")
        inputs = self.processor(text=[text], images=img, return_tensors="pt", padding=True, truncation=True).to(self.device)
        out = self.model(**inputs)
        im = F.normalize(out.image_embeds.squeeze(0), dim=-1)
        tx = F.normalize(out.text_embeds.squeeze(0), dim=-1)
        return ((im + tx) / 2.0).detach().cpu().numpy()

def extract_features_chunked(df_in, extractor, num_chunks=3, chunk_size=100, text_col="text"):
    df = df_in.copy()
    rng = np.random.default_rng(CONFIG["random_seed"])
    feats = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Extracting {extractor.name}"):
        text = row.get(text_col, "")
        if pd.isna(text) or not isinstance(text, str) or len(text.strip()) == 0:
            chunks = [""]
        else:
            text = text.strip()
            chunks = []
            for _ in range(num_chunks):
                if len(text) <= chunk_size:
                    chunks.append(text)
                else:
                    s = rng.integers(0, len(text) - chunk_size + 1)
                    chunks.append(text[s:s+chunk_size])
        ch = [extractor.extract(row["image_path"], c) for c in chunks]
        feats.append(np.mean(ch, axis=0))
    df[extractor.name] = feats
    return df

sigclip_extractor = FinetunedCLIPExtractor(SIGCLIP_LOCAL_DIR, "sigclip_img_text_finetune_features", device=DEVICE)
train_df = extract_features_chunked(train_df, sigclip_extractor, num_chunks=3, chunk_size=100, text_col="text")
test_df  = extract_features_chunked(test_df,  sigclip_extractor, num_chunks=3, chunk_size=100, text_col="text")

def _safe_text(x):
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return ""
    if not isinstance(x, str):
        x = str(x)
    return x.strip()

def load_prompt_llm(model_name):
    try:
        tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token
        mdl = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None
        )
        mdl.eval()
        return tok, mdl
    except Exception as e:
        print(f"[WARN] Could not load Llama prompt model '{model_name}'. Falling back to original text.\n  Error: {e}")
        return None, None

LLAMA_TOK, LLAMA_MDL = load_prompt_llm(CONFIG["llama_prompt_model"])

@torch.no_grad()
def rewrite_text_llama(raw_text: str) -> str:
    raw_text = _safe_text(raw_text)
    if raw_text == "":
        return ""
    if LLAMA_TOK is None or LLAMA_MDL is None:
        return raw_text

    prompt = (
        "You are helping build a recommender system for folk art panels.\n"
        "Rewrite the given text into a short, descriptive list of keywords and entities.\n"
        "Keep it factual. Avoid extra sentences.\n\n"
        f"TEXT: {raw_text}\n"
        "KEYWORDS:"
    )
    inputs = LLAMA_TOK(prompt, return_tensors="pt", truncation=True, max_length=512).to(LLAMA_MDL.device)
    out = LLAMA_MDL.generate(
        **inputs,
        max_new_tokens=CONFIG["llama_prompt_max_new_tokens"],
        do_sample=False,
        temperature=0.0,
        pad_token_id=LLAMA_TOK.eos_token_id,
    )
    decoded = LLAMA_TOK.decode(out[0], skip_special_tokens=True)
    if "KEYWORDS:" in decoded:
        rewritten = decoded.split("KEYWORDS:", 1)[-1].strip()
        return rewritten if rewritten else raw_text
    return raw_text

def add_llamasigclip_text_column(df_in, new_col="llama_text"):
    df = df_in.copy()
    rewritten = []
    for t in tqdm(df["text"].tolist(), desc="Llama rewrite text"):
        rewritten.append(rewrite_text_llama(t))
    df[new_col] = rewritten
    return df

train_df = add_llamasigclip_text_column(train_df, new_col="llama_text")
test_df  = add_llamasigclip_text_column(test_df,  new_col="llama_text")

llamasigclip_extractor = FinetunedCLIPExtractor(SIGCLIP_LOCAL_DIR, "llamasigclip_features", device=DEVICE)
train_df = extract_features_chunked(train_df, llamasigclip_extractor, num_chunks=3, chunk_size=140, text_col="llama_text")
test_df  = extract_features_chunked(test_df,  llamasigclip_extractor, num_chunks=3, chunk_size=140, text_col="llama_text")

class SentenceT5VAEEncoder(nn.Module):
    def __init__(self, st5_name="sentence-transformers/sentence-t5-base", latent_dim=768):
        super().__init__()
        self.st5 = T5EncoderModel.from_pretrained(st5_name)
        self.proj_mu = nn.Linear(self.st5.config.d_model, latent_dim)
        self.proj_logvar = nn.Linear(self.st5.config.d_model, latent_dim)

    def forward(self, input_ids, attention_mask):
        out = self.st5(input_ids=input_ids, attention_mask=attention_mask)
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (out.last_hidden_state * mask).sum(dim=1) / torch.clamp(mask.sum(dim=1), min=1.0)
        mu = self.proj_mu(pooled)
        logvar = torch.clamp(self.proj_logvar(pooled), -10, 10)
        return mu, logvar

def load_llamavae_or_fallback(latent_dim=768):
    model = SentenceT5VAEEncoder(latent_dim=latent_dim).to(DEVICE)
    if os.path.exists(LLAMAVAE_CKPT_PATH):
        try:
            ck = torch.load(LLAMAVAE_CKPT_PATH, map_location=DEVICE)
            sd = ck["model"] if isinstance(ck, dict) and "model" in ck else ck
            model.load_state_dict(sd, strict=False)
            print("Loaded LlamaVAE encoder ckpt:", LLAMAVAE_CKPT_PATH)
            model.eval()
            return model, True
        except Exception as e:
            print(f"[WARN] Failed to load LlamaVAE ckpt, falling back to SentenceT5 embeddings.\n  Error: {e}")
    else:
        print("[INFO] LlamaVAE ckpt not found; using SentenceT5 embeddings as LlamaVAE features.")
    model.eval()
    return model, False

LLAMAVAE_ENC, LLAMAVAE_HAS_CKPT = load_llamavae_or_fallback(latent_dim=768)
LLAMAVAE_TOK = AutoTokenizer.from_pretrained("sentence-transformers/sentence-t5-base")

@torch.no_grad()
def extract_llamavae_features_text_only(df_in, batch_size=64, max_len=128):
    texts = [_safe_text(t) if _safe_text(t) != "" else LLAMAVAE_TOK.pad_token for t in df_in["text"].tolist()]
    vecs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Extracting LlamaVAE (text)"):
        batch = texts[i:i+batch_size]
        tok = LLAMAVAE_TOK(batch, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
        input_ids = tok["input_ids"].to(DEVICE)
        attn = tok["attention_mask"].to(DEVICE)

        mu, _ = LLAMAVAE_ENC(input_ids, attn)
        if not LLAMAVAE_HAS_CKPT:
            out = LLAMAVAE_ENC.st5(input_ids=input_ids, attention_mask=attn)
            mask = attn.unsqueeze(-1).float()
            pooled = (out.last_hidden_state * mask).sum(dim=1) / torch.clamp(mask.sum(dim=1), min=1.0)
            vec = pooled
        else:
            vec = mu
        vecs.append(vec.detach().cpu().numpy())
    return np.concatenate(vecs, axis=0)

llamavae_train = extract_llamavae_features_text_only(train_df, batch_size=64, max_len=128)
llamavae_test  = extract_llamavae_features_text_only(test_df,  batch_size=64, max_len=128)

emb_train_sigclip  = np.stack(train_df["sigclip_img_text_finetune_features"].values)
emb_test_sigclip   = np.stack(test_df["sigclip_img_text_finetune_features"].values)

emb_train_llamasig = np.stack(train_df["llamasigclip_features"].values)
emb_test_llamasig  = np.stack(test_df["llamasigclip_features"].values)

emb_train_resnet = np.stack(train_df["resnet_features"].values)
emb_test_resnet  = np.stack(test_df["resnet_features"].values)

emb_train_tfidf  = np.stack(train_df["tfidf_features"].values)
emb_test_tfidf   = np.stack(test_df["tfidf_features"].values)

emb_train_llamavae = llamavae_train
emb_test_llamavae  = llamavae_test

labels_all = pd.concat([train_df[label_cols], test_df[label_cols]], ignore_index=True)
labels_all = torch.tensor(labels_all.values.astype(np.float32))

print("Shapes:")
print("VAE            :", emb_train_vae.shape, emb_test_vae.shape)
print("SigCLIP        :", emb_train_sigclip.shape, emb_test_sigclip.shape)
print("LlamaSigCLIP   :", emb_train_llamasig.shape, emb_test_llamasig.shape)
print("ResNet         :", emb_train_resnet.shape, emb_test_resnet.shape)
print("TF-IDF         :", emb_train_tfidf.shape, emb_test_tfidf.shape)
print("LlamaVAE(text) :", emb_train_llamavae.shape, emb_test_llamavae.shape)
print("Labels all     :", labels_all.shape)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Loaded VAE from Drive: /content/drive/MyDrive/FolkArt/vae_checkpoints/multimodal_vae_400.pth


Extracting VAE mu:   0%|          | 0/5 [00:00<?, ?it/s]

Extracting VAE mu:   0%|          | 0/2 [00:00<?, ?it/s]

VAE mu shapes: (151, 4096) (38, 4096)
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 251MB/s]


ResNet50 feats:   0%|          | 0/151 [00:00<?, ?it/s]

ResNet50 feats:   0%|          | 0/38 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

Extracting sigclip_img_text_finetune_features:   0%|          | 0/151 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

Extracting sigclip_img_text_finetune_features:   0%|          | 0/38 [00:00<?, ?it/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Llama rewrite text:   0%|          | 0/151 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Llama rewrite text:   0%|          | 0/38 [00:00<?, ?it/s]

Extracting llamasigclip_features:   0%|          | 0/151 [00:00<?, ?it/s]

Extracting llamasigclip_features:   0%|          | 0/38 [00:00<?, ?it/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/219M [00:00<?, ?B/s]

[INFO] LlamaVAE ckpt not found; using SentenceT5 embeddings as LlamaVAE features.


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Extracting LlamaVAE (text):   0%|          | 0/3 [00:00<?, ?it/s]

Extracting LlamaVAE (text):   0%|          | 0/1 [00:00<?, ?it/s]

Shapes:
VAE            : (151, 4096) (38, 4096)
SigCLIP        : (151, 512) (38, 512)
LlamaSigCLIP   : (151, 512) (38, 512)
ResNet         : (151, 2048) (38, 2048)
TF-IDF         : (151, 512) (38, 512)
LlamaVAE(text) : (151, 768) (38, 768)
Labels all     : torch.Size([189, 3])


In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.nn.models import VGAE
from torch_geometric.transforms import RandomLinkSplit
from sklearn.metrics.pairwise import cosine_similarity
from tqdm.notebook import tqdm

def _l2_normalize_np(x, eps=1e-12):
    n = np.linalg.norm(x, axis=1, keepdims=True)
    return x / np.clip(n, eps, None)

def build_graph_knn(features_np, k=25, sim_floor=0.0):
    X = _l2_normalize_np(features_np.astype(np.float32))
    sim = cosine_similarity(X)
    sim[sim < sim_floor] = 0.0
    n = sim.shape[0]
    np.fill_diagonal(sim, 0.0)

    kk = min(k, n - 1)
    idx = np.argpartition(-sim, kth=kk, axis=1)[:, :kk]
    rows = np.repeat(np.arange(n), kk)
    cols = idx.reshape(-1)

    src = np.concatenate([rows, cols], axis=0)
    dst = np.concatenate([cols, rows], axis=0)
    edge_index = torch.from_numpy(np.stack([src, dst], axis=0)).long()
    x = torch.from_numpy(X).float()
    return Data(x=x, edge_index=edge_index)

def build_graph_percentile(features_np, percentile=95):
    X = _l2_normalize_np(features_np.astype(np.float32))
    sim = cosine_similarity(X)
    sim[sim < 0] = 0.0
    mask = ~np.eye(len(sim), dtype=bool)
    threshold = np.percentile(sim[mask], percentile)
    src, dst = np.where((sim >= threshold) & (~np.eye(len(sim), dtype=bool)))
    edge_index = torch.from_numpy(np.stack([src, dst], axis=0)).long()
    x = torch.from_numpy(X).float()
    return Data(x=x, edge_index=edge_index)

def build_graph(features_np):
    if CONFIG["graph_use_knn"]:
        return build_graph_knn(features_np, k=CONFIG["graph_knn_k"], sim_floor=CONFIG["graph_sim_floor"])
    return build_graph_percentile(features_np, percentile=CONFIG["gcn_percentile"])

def augment_tree_edges(graph: Data, labels_all: torch.Tensor, n_train: int, tree_col: int):
    """
    Add extra edges among TREE-positive TRAIN nodes only.
    This makes the sparse Tree signal much easier to preserve in message passing.
    """
    if not CONFIG["tree_edge_aug"]:
        return graph

    y_tr = labels_all[:n_train]
    tree = y_tr[:, tree_col].cpu().numpy().astype(int)
    pos_idx = np.where(tree > 0)[0]
    if len(pos_idx) < 3:
        return graph

    if len(pos_idx) > CONFIG["tree_edge_aug_max_nodes"]:
        rng = np.random.default_rng(CONFIG["random_seed"])
        pos_idx = rng.choice(pos_idx, size=CONFIG["tree_edge_aug_max_nodes"], replace=False)

    X = graph.x.cpu().numpy()
    Xpos = X[pos_idx]
    S = cosine_similarity(Xpos, Xpos)
    np.fill_diagonal(S, 0.0)

    kk = min(CONFIG["tree_edge_aug_k"], len(pos_idx) - 1)
    nbr = np.argpartition(-S, kth=kk, axis=1)[:, :kk]

    src_local = np.repeat(np.arange(len(pos_idx)), kk)
    dst_local = nbr.reshape(-1)

    src = pos_idx[src_local]
    dst = pos_idx[dst_local]
    src2 = np.concatenate([src, dst], axis=0)
    dst2 = np.concatenate([dst, src], axis=0)

    aug_edges = torch.from_numpy(np.stack([src2, dst2], axis=0)).long()
    graph.edge_index = torch.cat([graph.edge_index, aug_edges], dim=1)
    return graph

def compute_pos_weight(labels_train: torch.Tensor):
    pos = labels_train.sum(dim=0).clamp(min=1.0)
    neg = (labels_train.size(0) - pos).clamp(min=1.0)
    return (neg / pos)

def focal_bce_with_logits(logits, targets, pos_weight=None, gamma=2.0, alpha=0.25):
    bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none", pos_weight=pos_weight)
    p = torch.sigmoid(logits)
    p_t = p * targets + (1 - p) * (1 - targets)
    focal = (1 - p_t).clamp(min=0.0) ** gamma
    alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
    return (alpha_t * focal * bce).mean()

class VGAEEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels=256, latent_channels=128, dropout=0.2):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv_mu = GCNConv(hidden_channels, latent_channels)
        self.conv_logstd = GCNConv(hidden_channels, latent_channels)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        mu = self.conv_mu(x, edge_index)
        logstd = self.conv_logstd(x, edge_index)
        logstd = torch.clamp(logstd, -10, 10)
        return mu, logstd

class LabelHead(nn.Module):
    def __init__(self, z_dim, out_dim, dropout=0.15):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, z_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(z_dim, out_dim)
        )
    def forward(self, z):
        return self.net(z)

def _kl_beta(epoch, epochs, beta_max=1.0, warmup_frac=0.35):
    w = int(max(1, warmup_frac * epochs))
    if epoch >= w:
        return beta_max
    return beta_max * (epoch / w)

def train_balanced_ssl_vgae(vgae, label_head, data, labels_all, n_train, epochs=500):
    splitter = RandomLinkSplit(
        num_val=CONFIG["link_val_frac"],
        num_test=CONFIG["link_test_frac"],
        is_undirected=True,
        add_negative_train_samples=False,
        split_labels=False
    )
    train_data, _, _ = splitter(data)

    vgae = vgae.to(DEVICE)
    label_head = label_head.to(DEVICE)
    train_data = train_data.to(DEVICE)
    labels_all = labels_all.to(DEVICE)

    train_mask = torch.zeros(labels_all.size(0), dtype=torch.bool, device=DEVICE)
    train_mask[:n_train] = True

    pos_weight = None
    if CONFIG["use_pos_weight"]:
        pos_weight = compute_pos_weight(labels_all[train_mask]).to(DEVICE)

    opt = torch.optim.Adam(
        list(vgae.parameters()) + list(label_head.parameters()),
        lr=CONFIG["vgae_lr"],
        weight_decay=CONFIG["vgae_weight_decay"]
    )

    best = float("inf")
    best_state = None

    for ep in tqdm(range(epochs), desc="Training Balanced SSL-VGAE"):
        vgae.train()
        label_head.train()
        opt.zero_grad(set_to_none=True)

        z = vgae.encode(train_data.x, train_data.edge_index)
        recon = vgae.recon_loss(z, train_data.edge_index)
        beta = _kl_beta(ep, epochs, beta_max=CONFIG["vgae_kl_beta_max"], warmup_frac=CONFIG["vgae_kl_warmup_frac"])
        kl = vgae.kl_loss()
        vgae_loss = recon + beta * kl
        logits = label_head(z)
        y = labels_all

        if CONFIG["use_focal"]:
            sup_loss = focal_bce_with_logits(
                logits[train_mask],
                y[train_mask],
                pos_weight=pos_weight,
                gamma=CONFIG["focal_gamma"],
                alpha=CONFIG["focal_alpha"]
            )
        else:
            sup_loss = F.binary_cross_entropy_with_logits(logits[train_mask], y[train_mask], pos_weight=pos_weight)

        loss = vgae_loss + CONFIG["ssl_lambda"] * sup_loss
        loss.backward()

        if CONFIG["vgae_grad_clip"] is not None and CONFIG["vgae_grad_clip"] > 0:
            torch.nn.utils.clip_grad_norm_(list(vgae.parameters()) + list(label_head.parameters()), CONFIG["vgae_grad_clip"])

        opt.step()

        if loss.item() < best:
            best = loss.item()
            best_state = {
                "vgae": {k: v.detach().cpu() for k, v in vgae.state_dict().items()},
                "head": {k: v.detach().cpu() for k, v in label_head.state_dict().items()},
            }

    if best_state is not None:
        vgae.load_state_dict(best_state["vgae"])
        label_head.load_state_dict(best_state["head"])

    vgae.eval()
    label_head.eval()
    return vgae, label_head

@torch.no_grad()
def embed_vgae(vgae, data):
    z = vgae.encode(data.x.to(DEVICE), data.edge_index.to(DEVICE))
    return z.detach().cpu().numpy()

def get_balanced_vgae_features(train_feats, test_feats, labels_all, epochs):
    feats_all = np.concatenate([train_feats, test_feats], axis=0)
    graph_all = build_graph(feats_all)

    n_tr = len(train_feats)
    tree_col = label_cols.index("tree_label")
    graph_all = augment_tree_edges(graph_all, labels_all, n_train=n_tr, tree_col=tree_col)

    encoder = VGAEEncoder(
        in_channels=graph_all.x.shape[1],
        hidden_channels=CONFIG["vgae_hidden_dim"],
        latent_channels=CONFIG["vgae_latent_dim"],
        dropout=CONFIG["vgae_dropout"]
    )
    vgae = VGAE(encoder)
    head = LabelHead(z_dim=CONFIG["vgae_latent_dim"], out_dim=labels_all.shape[1])

    vgae, head = train_balanced_ssl_vgae(vgae, head, graph_all, labels_all, n_train=n_tr, epochs=epochs)
    z_all = embed_vgae(vgae, graph_all)

    return z_all[:n_tr], z_all[n_tr:]

sigclip_vgae_train, sigclip_vgae_test = get_balanced_vgae_features(
    emb_train_sigclip, emb_test_sigclip, labels_all, epochs=CONFIG["gcn_epochs"]
)
vae_vgae_train, vae_vgae_test = get_balanced_vgae_features(
    emb_train_vae, emb_test_vae, labels_all, epochs=CONFIG["gcn_epochs"]
)
resnet_vgae_train, resnet_vgae_test = get_balanced_vgae_features(
    emb_train_resnet, emb_test_resnet, labels_all, epochs=CONFIG["gcn_epochs"]
)
llamasig_vgae_train, llamasig_vgae_test = get_balanced_vgae_features(
    emb_train_llamasig, emb_test_llamasig, labels_all, epochs=CONFIG["gcn_epochs"]
)
llamavae_vgae_train, llamavae_vgae_test = get_balanced_vgae_features(
    emb_train_llamavae, emb_test_llamavae, labels_all, epochs=CONFIG["gcn_epochs"]
)

print("Balanced SSL-VGAE embedding shapes:")
print("SigCLIP+VGAE     :", sigclip_vgae_train.shape, sigclip_vgae_test.shape)
print("VAE+VGAE         :", vae_vgae_train.shape, vae_vgae_test.shape)
print("ResNet+VGAE      :", resnet_vgae_train.shape, resnet_vgae_test.shape)
print("LlamaSigCLIP+VGAE:", llamasig_vgae_train.shape, llamasig_vgae_test.shape)
print("LlamaVAE+VGAE    :", llamavae_vgae_train.shape, llamavae_vgae_test.shape)

Training Balanced SSL-VGAE:   0%|          | 0/500 [00:00<?, ?it/s]

Training Balanced SSL-VGAE:   0%|          | 0/500 [00:00<?, ?it/s]

Training Balanced SSL-VGAE:   0%|          | 0/500 [00:00<?, ?it/s]

Training Balanced SSL-VGAE:   0%|          | 0/500 [00:00<?, ?it/s]

Training Balanced SSL-VGAE:   0%|          | 0/500 [00:00<?, ?it/s]

Balanced SSL-VGAE embedding shapes:
SigCLIP+VGAE     : (151, 128) (38, 128)
VAE+VGAE         : (151, 128) (38, 128)
ResNet+VGAE      : (151, 128) (38, 128)
LlamaSigCLIP+VGAE: (151, 128) (38, 128)
LlamaVAE+VGAE    : (151, 128) (38, 128)


In [None]:
from sklearn.metrics.pairwise import cosine_similarity

def evaluate_recommendation_fixed_profiles(
    emb_train_dict,
    emb_test_dict,
    df_train,
    df_test,
    label_cols,
    interacted_ids,
    preferred_mat,
    id_to_train_idx,
    top_k=5,
):
    all_model_precisions = {model: [] for model in emb_train_dict}

    for u in tqdm(range(len(interacted_ids)), desc="Fixed-Profile Eval"):
        pref = preferred_mat[u].astype(int)
        if pref.sum() == 0:
            continue

        ids_u = interacted_ids[u]
        idx_u = [id_to_train_idx[x] for x in ids_u if x in id_to_train_idx]
        if len(idx_u) == 0:
            continue
        idx_u = np.array(idx_u, dtype=int)

        for model_name, emb_train in emb_train_dict.items():
            emb_test = emb_test_dict[model_name]
            user_emb = emb_train[idx_u].mean(axis=0, keepdims=True)
            sims = cosine_similarity(user_emb, emb_test)[0]
            top_idx = np.argsort(sims)[-top_k:][::-1]

            rel_counts = []
            for idx in top_idx:
                test_labels = df_test.iloc[idx][label_cols].values.astype(int)
                rel = ((test_labels > 0) & (pref > 0)).astype(int)
                rel_counts.append(rel)

            rel_counts = np.array(rel_counts)
            precision = rel_counts.sum(0) / top_k
            all_model_precisions[model_name].append(precision)

    return all_model_precisions

emb_train_dict = {
    "vae_features": emb_train_vae,
    "sigclip_features": emb_train_sigclip,
    "llamasigclip_features": emb_train_llamasig,
    "llamavae_features": emb_train_llamavae,
    "resnet_features": emb_train_resnet,
    "tfidf_features": emb_train_tfidf,

    "sigclip+vgae_balanced": sigclip_vgae_train,
    "vae+vgae_balanced": vae_vgae_train,
    "resnet+vgae_balanced": resnet_vgae_train,
    "llamasigclip+vgae_balanced": llamasig_vgae_train,
    "llamavae+vgae_balanced": llamavae_vgae_train,
}

emb_test_dict = {
    "vae_features": emb_test_vae,
    "sigclip_features": emb_test_sigclip,
    "llamasigclip_features": emb_test_llamasig,
    "llamavae_features": emb_test_llamavae,
    "resnet_features": emb_test_resnet,
    "tfidf_features": emb_test_tfidf,

    "sigclip+vgae_balanced": sigclip_vgae_test,
    "vae+vgae_balanced": vae_vgae_test,
    "resnet+vgae_balanced": resnet_vgae_test,
    "llamasigclip+vgae_balanced": llamasig_vgae_test,
    "llamavae+vgae_balanced": llamavae_vgae_test,
}

results = evaluate_recommendation_fixed_profiles(
    emb_train_dict=emb_train_dict,
    emb_test_dict=emb_test_dict,
    df_train=train_df,
    df_test=test_df,
    label_cols=label_cols,
    interacted_ids=interacted_ids,
    preferred_mat=preferred_mat_bin,
    id_to_train_idx=id_to_train_idx,
    top_k=CONFIG["top_k"],
)

for k, v in results.items():
    arr = np.array(v)
    if len(arr) == 0:
        print(k, ": no valid users")
        continue
    m = arr.mean(axis=0)
    s = arr.std(axis=0)
    print(k, ":", "  ".join([f"{label_cols[i]}={m[i]:.2f} ± {s[i]:.2f}" for i in range(len(label_cols))]))

rows = []
label_map = {"animal_label": "Animal", "myth_label": "Mythology", "tree_label": "Tree"}

for model_name, vals in results.items():
    arr = np.array(vals)
    if len(arr) == 0:
        continue
    mean = arr.mean(axis=0)
    std  = arr.std(axis=0)

    row = {"Model": model_name}
    for i, col in enumerate(label_cols):
        row[label_map[col]] = f"{mean[i]:.2f} ± {std[i]:.2f}"
    rows.append(row)

results_df = pd.DataFrame(rows)[["Model", "Animal", "Mythology", "Tree"]]
pd.set_option("display.max_colwidth", None)
pd.set_option("display.width", 140)
pd.set_option("display.colheader_justify", "center")

print("\nPrecision@5 Recommendation Results (mean ± std)\n")
display(results_df)

Fixed-Profile Eval:   0%|          | 0/1500 [00:00<?, ?it/s]

vae_features : animal_label=0.44 ± 0.23  myth_label=0.60 ± 0.21  tree_label=0.27 ± 0.21
sigclip_features : animal_label=0.56 ± 0.26  myth_label=0.59 ± 0.16  tree_label=0.13 ± 0.12
llamasigclip_features : animal_label=0.52 ± 0.25  myth_label=0.44 ± 0.17  tree_label=0.14 ± 0.12
llamavae_features : animal_label=0.56 ± 0.27  myth_label=0.67 ± 0.13  tree_label=0.11 ± 0.11
resnet_features : animal_label=0.38 ± 0.20  myth_label=0.40 ± 0.19  tree_label=0.37 ± 0.20
tfidf_features : animal_label=0.49 ± 0.22  myth_label=0.76 ± 0.17  tree_label=0.32 ± 0.18
sigclip+vgae_balanced : animal_label=0.50 ± 0.26  myth_label=0.66 ± 0.22  tree_label=0.60 ± 0.27
vae+vgae_balanced : animal_label=0.50 ± 0.25  myth_label=0.54 ± 0.15  tree_label=0.55 ± 0.33
resnet+vgae_balanced : animal_label=0.49 ± 0.25  myth_label=0.46 ± 0.18  tree_label=0.47 ± 0.22
llamasigclip+vgae_balanced : animal_label=0.51 ± 0.22  myth_label=0.69 ± 0.16  tree_label=0.48 ± 0.21
llamavae+vgae_balanced : animal_label=0.62 ± 0.25  myth_label

Unnamed: 0,Model,Animal,Mythology,Tree
0,vae_features,0.44 ± 0.23,0.60 ± 0.21,0.27 ± 0.21
1,sigclip_features,0.56 ± 0.26,0.59 ± 0.16,0.13 ± 0.12
2,llamasigclip_features,0.52 ± 0.25,0.44 ± 0.17,0.14 ± 0.12
3,llamavae_features,0.56 ± 0.27,0.67 ± 0.13,0.11 ± 0.11
4,resnet_features,0.38 ± 0.20,0.40 ± 0.19,0.37 ± 0.20
5,tfidf_features,0.49 ± 0.22,0.76 ± 0.17,0.32 ± 0.18
6,sigclip+vgae_balanced,0.50 ± 0.26,0.66 ± 0.22,0.60 ± 0.27
7,vae+vgae_balanced,0.50 ± 0.25,0.54 ± 0.15,0.55 ± 0.33
8,resnet+vgae_balanced,0.49 ± 0.25,0.46 ± 0.18,0.47 ± 0.22
9,llamasigclip+vgae_balanced,0.51 ± 0.22,0.69 ± 0.16,0.48 ± 0.21
