In [1]:
!pip install open-clip-torch

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import pandas as pd
pd.set_option("display.max_columns", None) 

df = pd.read_csv("jl_fs/train.csv")
df["image_path"] = df["sample_id"].apply(lambda x : f"jl_fs/images/train/{x}.jpg")
df.to_csv("train_updated.csv", index = False)

In [3]:
df.head()

Unnamed: 0,sample_id,catalog_content,image_link,price,image_path
0,33127,"Item Name: La Victoria Green Taco Sauce Mild, ...",https://m.media-amazon.com/images/I/51mo8htwTH...,4.89,jl_fs/images/train/33127.jpg
1,198967,"Item Name: Salerno Cookies, The Original Butte...",https://m.media-amazon.com/images/I/71YtriIHAA...,13.12,jl_fs/images/train/198967.jpg
2,261251,"Item Name: Bear Creek Hearty Soup Bowl, Creamy...",https://m.media-amazon.com/images/I/51+PFEe-w-...,1.97,jl_fs/images/train/261251.jpg
3,55858,Item Name: Judee’s Blue Cheese Powder 11.25 oz...,https://m.media-amazon.com/images/I/41mu0HAToD...,30.34,jl_fs/images/train/55858.jpg
4,292686,"Item Name: kedem Sherry Cooking Wine, 12.7 Oun...",https://m.media-amazon.com/images/I/41sA037+Qv...,66.49,jl_fs/images/train/292686.jpg


In [None]:
# %%
import os
import math
import random
import json
from dataclasses import dataclass
from typing import Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image

from transformers import (
    CLIPModel,
    AutoProcessor,
    get_linear_schedule_with_warmup,
)

# --------------------------- Config ---------------------------
CSV_PATH        = os.environ.get("TRAIN_CSV", "train_updated.csv")   # must contain text + price + image path
TEXT_COL        = os.environ.get("TEXT_COL", "catalog_content")
PRICE_COL       = os.environ.get("PRICE_COL", "price")
IMG_COL         = os.environ.get("IMG_COL",  "image_path")         # <-- new: local jpg path column

MODEL_ID        = os.environ.get("MODEL_ID", "openai/clip-vit-large-patch14-336")
OUTPUT_DIR      = os.environ.get("OUTPUT_DIR", "price_clip_vit_99_1_split")

SEED            = int(os.environ.get("SEED", "42"))
VAL_FRAC        = float(os.environ.get("VAL_FRAC", "0.01"))         # set 0.5 for 50/50
MAX_LEN         = int(os.environ.get("MAX_LEN", "64"))             # CLIP text context is shorter
BATCH_SIZE      = int(os.environ.get("BATCH_SIZE", "16"))
LR              = float(os.environ.get("LR", "2e-5"))
WEIGHT_DECAY    = float(os.environ.get("WEIGHT_DECAY", "0.01"))
EPOCHS          = int(os.environ.get("EPOCHS", "10"))
WARMUP_RATIO    = float(os.environ.get("WARMUP_RATIO", "0.06"))
GRAD_ACCUM      = int(os.environ.get("GRAD_ACCUM", "2"))
MAX_GRAD_NORM   = float(os.environ.get("MAX_GRAD_NORM", "1.0"))
FP16            = os.environ.get("FP16", "true").lower() == "true"

ALPHA_CONTRAST  = float(os.environ.get("ALPHA_CONTRAST", "0.25"))  # weight for contrastive loss
TAU             = float(os.environ.get("TAU", "0.07"))             # temperature
HUBER_DELTA     = float(os.environ.get("HUBER_DELTA", "1.0"))

EARLY_STOP_ROUNDS = int(os.environ.get("EARLY_STOP_ROUNDS", "3"))
MIN_PRICE       = float(os.environ.get("MIN_PRICE", "1e-6"))

IMG_MISSING_POLICY = os.environ.get("IMG_MISSING_POLICY", "zero").lower()  # zero | text_only | drop
assert IMG_MISSING_POLICY in {"zero", "text_only", "drop"}

os.makedirs(OUTPUT_DIR, exist_ok=True)

# --------------------------- Utils ---------------------------
def set_seed(seed: int = SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

def smape_np(y_true, y_pred, eps=1e-8):
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    denom = (np.abs(y_true) + np.abs(y_pred) + eps) / 2.0
    return 100.0 * np.mean(np.abs(y_pred - y_true) / denom)

def log2_price(p: np.ndarray) -> np.ndarray:
    return np.log2(np.clip(p, MIN_PRICE, None))

def delog2(x: np.ndarray) -> np.ndarray:
    return np.power(2.0, x)

def split_train_val(df: pd.DataFrame, frac_val: float = VAL_FRAC, seed: int = SEED):
    df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
    n_val = int(len(df) * frac_val)
    df_val = df.iloc[:n_val].reset_index(drop=True)
    df_tr  = df.iloc[n_val:].reset_index(drop=True)
    return df_tr, df_val

# --------------------------- Dataset & Collate ---------------------------
class ClipPriceDataset(Dataset):
    """
    Policies:
      - 'zero':      returns dummy pixel_values for missing images; vision forward done; features zeroed later.
      - 'text_only': returns pixel_values=None for missing images; vision forward skipped; text features only.
      - 'drop':      rows with missing images removed at dataset build time.
    Yields dict with: input_ids, attention_mask, (pixel_values), img_missing, (target)
    """
    def __init__(self, df: pd.DataFrame, text_col: str, img_col: str, prices_log2: Optional[np.ndarray],
                 processor: AutoProcessor, max_len: int, policy: str):
        self.processor = processor
        self.max_len = max_len
        self.policy = policy

        df = df.reset_index(drop=True).copy()
        df[text_col] = df[text_col].fillna("").astype(str)

        if policy == "drop":
            before = len(df)
            df = df[df[img_col].apply(lambda p: isinstance(p, str) and len(p) > 0 and os.path.exists(p))]
            self.dropped_missing = before - len(df)
        else:
            self.dropped_missing = 0

        self.texts = df[text_col].tolist()
        self.img_paths = df[img_col].fillna("").astype(str).tolist()
        self.prices_log2 = prices_log2
        self.missing_img_count = 0

        # Dummy pixel for consistent shapes (use processor to derive size)
        dummy = self.processor(images=Image.new("RGB", (224, 224)), return_tensors="pt")
        self._dummy_pixel = dummy["pixel_values"].squeeze(0)  # (3,H,W)

    def __len__(self): return len(self.texts)

    def _load_image(self, path: str):
        if isinstance(path, str) and path and os.path.exists(path):
            try:
                return Image.open(path).convert("RGB")
            except Exception:
                pass
        self.missing_img_count += 1
        return None

    def __getitem__(self, idx):
        text = self.texts[idx]
        img  = self._load_image(self.img_paths[idx])

        enc_text = self.processor(text=[text], padding=False, truncation=True,
                                  max_length=self.max_len, return_tensors="pt")

        img_missing = 0
        pixel_values = None

        if img is None:
            img_missing = 1
            if self.policy == "zero":
                pixel_values = self._dummy_pixel.clone()
            elif self.policy == "text_only":
                pixel_values = None
        else:
            enc_img = self.processor(images=img, return_tensors="pt")
            pixel_values = enc_img["pixel_values"].squeeze(0)

        item = {
            "input_ids": enc_text["input_ids"].squeeze(0),
            "attention_mask": enc_text["attention_mask"].squeeze(0),
            "img_missing": torch.tensor(img_missing, dtype=torch.uint8),
        }
        if pixel_values is not None:
            item["pixel_values"] = pixel_values
        if self.prices_log2 is not None:
            item["target"] = torch.tensor(self.prices_log2[idx], dtype=torch.float32)
        return item

@dataclass
class CollateClip:
    processor: AutoProcessor
    def __call__(self, batch):
        # pad text
        input_ids = [b["input_ids"] for b in batch]
        attention = [b["attention_mask"] for b in batch]
        text_padded = self.processor.tokenizer.pad(
            {"input_ids": input_ids, "attention_mask": attention},
            padding=True, return_tensors="pt"
        )
        # images: some may be absent (text_only policy)
        has_pix = [("pixel_values" in b) for b in batch]
        pixel_values = None
        if any(has_pix):
            shapes = [b["pixel_values"].shape for b in batch if "pixel_values" in b]
            C,H,W = shapes[0]
            stacked = []
            for b in batch:
                if "pixel_values" in b:
                    stacked.append(b["pixel_values"])
                else:
                    stacked.append(torch.zeros((C,H,W), dtype=torch.float32))
            pixel_values = torch.stack(stacked, dim=0)

        res = {
            "input_ids": text_padded["input_ids"],
            "attention_mask": text_padded["attention_mask"],
            "img_missing": torch.stack([b["img_missing"] for b in batch], dim=0),
        }
        if pixel_values is not None:
            res["pixel_values"] = pixel_values
        if "target" in batch[0]:
            res["target"] = torch.stack([b["target"] for b in batch], dim=0)
        return res

# --------------------------- Model & Loss ---------------------------
class ClipRegressionHead(nn.Module):
    def __init__(self, embed_dim: int, dropout: float = 0.1):
        super().__init__()
        in_dim = 2 * embed_dim  # concat image+text
        self.net = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(in_dim, in_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(in_dim, 1),
        )
    def forward(self, x): return self.net(x).squeeze(-1)

def info_nce(z_img: torch.Tensor, z_txt: torch.Tensor, tau: float = TAU) -> torch.Tensor:
    z_img = F.normalize(z_img, dim=-1)
    z_txt = F.normalize(z_txt, dim=-1)
    logits = torch.matmul(z_img, z_txt.t()) / tau  # (B,B)
    labels = torch.arange(z_img.size(0), device=z_img.device)
    loss_i = F.cross_entropy(logits, labels)
    loss_t = F.cross_entropy(logits.t(), labels)
    return 0.5 * (loss_i + loss_t)

def huber_loss(pred, target, delta=HUBER_DELTA):
    return F.huber_loss(pred, target, delta=delta)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# --------------------------- Load data ---------------------------
print(f"🔧 Loading CSV: {CSV_PATH}")
df = pd.read_csv(CSV_PATH)

for col, name in [(TEXT_COL, "TEXT_COL"), (PRICE_COL, "PRICE_COL"), (IMG_COL, "IMG_COL")]:
    if col not in df.columns:
        raise ValueError(f"{name} '{col}' not in CSV columns={df.columns.tolist()}")

# clean / guard
df[TEXT_COL] = df[TEXT_COL].fillna("").astype(str).str.strip()
df = df.loc[pd.to_numeric(df[PRICE_COL], errors="coerce").notnull()].copy()
df[PRICE_COL] = df[PRICE_COL].astype(float)
df = df.loc[df[PRICE_COL] >= 0.0].reset_index(drop=True)

# split
df_tr, df_va = split_train_val(df, frac_val=VAL_FRAC, seed=SEED)
print(f"📊 Split: train={len(df_tr)} | valid={len(df_va)}")

y_tr_log = log2_price(df_tr[PRICE_COL].values)
y_va_log = log2_price(df_va[PRICE_COL].values)

# --------------------------- CLIP backbone ---------------------------
processor = AutoProcessor.from_pretrained(MODEL_ID)
clip_model = CLIPModel.from_pretrained(MODEL_ID)

train_ds = ClipPriceDataset(df_tr, TEXT_COL, IMG_COL, y_tr_log, processor, MAX_LEN, IMG_MISSING_POLICY)
val_ds   = ClipPriceDataset(df_va, TEXT_COL, IMG_COL, y_va_log, processor, MAX_LEN, IMG_MISSING_POLICY)
collate  = CollateClip(processor)

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model = clip_model.to(device)
price_head = ClipRegressionHead(embed_dim=clip_model.config.projection_dim, dropout=0.1).to(device)

print(f"🖥️ Device: {device}")
print(f"🧮 Trainable params clip={count_parameters(clip_model):,} | head={count_parameters(price_head):,}")
for p in clip_model.parameters():
    p.requires_grad = True  # end-to-end; set False to freeze

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate)

# Optimizer & scheduler
no_decay = ["bias", "LayerNorm.weight"]
params = list(clip_model.named_parameters()) + [(f"head.{n}", p) for n, p in price_head.named_parameters()]
grouped = [
    {"params": [p for n, p in params if not any(nd in n for nd in no_decay)], "weight_decay": WEIGHT_DECAY},
    {"params": [p for n, p in params if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(grouped, lr=LR)
num_training_steps = EPOCHS * max(1, math.ceil(len(train_loader) / max(1, GRAD_ACCUM)))
num_warmup = int(num_training_steps * WARMUP_RATIO)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=num_warmup, num_training_steps=num_training_steps
)
scaler = torch.cuda.amp.GradScaler(enabled=FP16)

# --------------------------- Train ---------------------------
best_smape = float("inf")
best_path = os.path.join(OUTPUT_DIR, "best_clip.pt")
patience = 0

print("🔎 Warmup batch to tally missing images…")
if len(train_loader) > 0:
    _ = next(iter(train_loader))
if len(val_loader) > 0:
    _ = next(iter(val_loader))
print(f"⚠️ Missing images counted (train/val): {train_ds.missing_img_count}/{val_ds.missing_img_count}")
print(f"🗑️ Dropped due to policy=drop (train/val): {getattr(train_ds,'dropped_missing',0)}/{getattr(val_ds,'dropped_missing',0)}")

for epoch in range(1, EPOCHS + 1):
    clip_model.train(); price_head.train()
    train_loss_running = 0.0
    reg_loss_running = 0.0
    con_loss_running = 0.0

    optimizer.zero_grad(set_to_none=True)

    for step, batch in enumerate(train_loader, 1):
        input_ids      = batch["input_ids"].to(device, non_blocking=True)
        attention_mask = batch["attention_mask"].to(device, non_blocking=True)
        targets        = batch["target"].to(device, non_blocking=True).float()
        img_missing    = batch["img_missing"].to(device)

        with torch.cuda.amp.autocast(enabled=FP16):
            # Text features
            txt_feat = clip_model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)

            # Image features (guarded by policy/batch content)
            do_vision = ("pixel_values" in batch) and (IMG_MISSING_POLICY != "text_only")
            if do_vision:
                pixel_values = batch["pixel_values"].to(
                    device,
                    dtype=next(clip_model.vision_model.parameters()).dtype,
                    non_blocking=True
                )
                img_feat = clip_model.get_image_features(pixel_values=pixel_values)
                if img_missing.any():
                    img_feat = img_feat * (1.0 - img_missing.unsqueeze(1).float())
            else:
                img_feat = torch.zeros_like(txt_feat)

            # Normalize and fuse
            txt_n = F.normalize(txt_feat, dim=-1)
            img_n = F.normalize(img_feat, dim=-1)
            fused = torch.cat([img_n, txt_n], dim=-1)

            # Losses
            pred_log = price_head(fused)
            reg_loss = huber_loss(pred_log, targets, delta=HUBER_DELTA)

            con_loss = torch.tensor(0.0, device=device, dtype=txt_n.dtype)
            valid_idx = (img_missing == 0).nonzero(as_tuple=False).squeeze(-1)
            if do_vision and valid_idx.numel() > 1:
                con_loss = info_nce(img_n[valid_idx], txt_n[valid_idx], tau=TAU)

            loss = (1.0 - ALPHA_CONTRAST) * reg_loss + ALPHA_CONTRAST * con_loss

        scaler.scale(loss).backward()

        if step % GRAD_ACCUM == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(list(clip_model.parameters()) + list(price_head.parameters()), MAX_GRAD_NORM)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()

        train_loss_running += float(loss.item())
        reg_loss_running   += float(reg_loss.item())
        con_loss_running   += float(con_loss.item()) if isinstance(con_loss, torch.Tensor) else float(con_loss)

        if step % 200 == 0:
            print(f"epoch {epoch} step {step}/{len(train_loader)} "
                  f"loss={train_loss_running/step:.4f} reg={reg_loss_running/step:.4f} con={con_loss_running/step:.4f}")

    # ------------------ Validation ------------------
    clip_model.eval(); price_head.eval()
    preds_log = []
    with torch.no_grad():
        for batch in val_loader:
            input_ids      = batch["input_ids"].to(device, non_blocking=True)
            attention_mask = batch["attention_mask"].to(device, non_blocking=True)
            img_missing    = batch["img_missing"].to(device)

            txt_feat = clip_model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)

            do_vision = ("pixel_values" in batch) and (IMG_MISSING_POLICY != "text_only")
            if do_vision:
                pixel_values = batch["pixel_values"].to(
                    device,
                    dtype=next(clip_model.vision_model.parameters()).dtype,
                    non_blocking=True
                )
                img_feat = clip_model.get_image_features(pixel_values=pixel_values)
                if img_missing.any():
                    img_feat = img_feat * (1.0 - img_missing.unsqueeze(1).float())
            else:
                img_feat = torch.zeros_like(txt_feat)

            txt_n = F.normalize(txt_feat, dim=-1)
            img_n = F.normalize(img_feat, dim=-1)
            fused = torch.cat([img_n, txt_n], dim=-1)
            pred  = price_head(fused)
            preds_log.append(pred.detach().float().cpu().numpy())

    preds_log = np.concatenate(preds_log, axis=0) if len(preds_log) else np.array([])
    if len(preds_log):
        va_preds  = delog2(preds_log)
        va_true   = delog2(y_va_log)
        smape = smape_np(va_true, va_preds)
    else:
        smape = float("inf")
    print(f"✅ Epoch {epoch}: VAL SMAPE = {smape:.3f}% | missing_imgs (train/val) = {train_ds.missing_img_count}/{val_ds.missing_img_count}")

    # Save best
    if smape < best_smape - 1e-6:
        best_smape = smape
        patience = 0
        torch.save(
            {
                "clip_state": clip_model.state_dict(),
                "head_state": price_head.state_dict(),
                "model_id": MODEL_ID,
                "config": {
                    "ALPHA_CONTRAST": ALPHA_CONTRAST,
                    "TAU": TAU,
                    "MAX_LEN": MAX_LEN,
                    "projection_dim": clip_model.config.projection_dim,
                    "IMG_MISSING_POLICY": IMG_MISSING_POLICY,
                },
                "columns": {"text": TEXT_COL, "image": IMG_COL, "price": PRICE_COL},
                "val_frac": VAL_FRAC,
            },
            best_path
        )
        print(f"💾 Saved new best to {best_path}")
    else:
        patience += 1
        print(f"⏸️ No improvement. Patience {patience}/{EARLY_STOP_ROUNDS}")
        if patience >= EARLY_STOP_ROUNDS:
            print("🛑 Early stopping triggered.")
            break

print(f"🏁 Best VAL SMAPE: {best_smape:.3f}% | Checkpoint: {best_path}")

# Save final artifacts
with open(os.path.join(OUTPUT_DIR, "metrics_clip.json"), "w") as f:
    json.dump({
        "best_val_smape": float(best_smape),
        "train_missing_images": int(train_ds.missing_img_count),
        "valid_missing_images": int(val_ds.missing_img_count),
        "dropped_train": int(getattr(train_ds, "dropped_missing", 0)),
        "dropped_valid": int(getattr(val_ds, "dropped_missing", 0)),
        "missing_policy": IMG_MISSING_POLICY,
        "val_frac": VAL_FRAC,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "lr": LR,
        "weight_decay": WEIGHT_DECAY
    }, f, indent=2)
# %%

🔧 Loading CSV: train_updated.csv
📊 Split: train=74250 | valid=750


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


🖥️ Device: cuda
🧮 Trainable params clip=427,944,193 | head=2,362,369
🔎 Warmup batch to tally missing images…


  scaler = torch.cuda.amp.GradScaler(enabled=FP16)
You're using a CLIPTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a CLIPTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a CLIPTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a CLIPTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


⚠️ Missing images counted (train/val): 0/0
🗑️ Dropped due to policy=drop (train/val): 0/0


You're using a CLIPTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a CLIPTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  with torch.cuda.amp.autocast(enabled=FP16):


epoch 1 step 200/4641 loss=2.4999 reg=3.2356 con=0.2928


In [None]:
print('gelllo')

In [None]:
print('melllo')

In [None]:
print('melllo')

In [8]:
print('jello')

jello


In [10]:
best_smape

np.float64(42.69014756236535)

In [11]:
best_path

'price_clip_vit_99_1_split/best_clip.pt'

In [12]:
import pandas as pd
pd.set_option("display.max_columns", None) 

df = pd.read_csv("jl_fs/test.csv")
df["image_path"] = df["sample_id"].apply(lambda x : f"jl_fs/images/train/{x}.jpg")
df.to_csv("test_updated.csv", index = False)

In [13]:
# %% [markdown]
# --- Inference: load best checkpoint and predict on TEST_CSV ---

# %%
import os, json, math
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import CLIPModel, AutoProcessor

# ---- Config / paths ----
TEST_CSV        = os.environ.get("TEST_CSV", "test_updated.csv")   # must contain ID + text + image path
ID_COL          = os.environ.get("ID_COL", "sample_id")
TEXT_COL        = os.environ.get("TEXT_COL", "catalog_content")
IMG_COL         = os.environ.get("IMG_COL",  "image_path")

OUTPUT_DIR      = os.environ.get("OUTPUT_DIR", "price_clip_vit_99_1_split")
CKPT_PATH       = os.environ.get("CKPT_PATH", os.path.join(OUTPUT_DIR, "best_clip.pt"))

BATCH_SIZE      = int(os.environ.get("INF_BATCH_SIZE", "64"))
MAX_LEN_ENV     = os.environ.get("MAX_LEN", None)  # if you want to override tokenizer max len
FP16            = os.environ.get("FP16", "true").lower() == "true"

assert os.path.exists(CKPT_PATH), f"Checkpoint not found at {CKPT_PATH}"
assert os.path.exists(TEST_CSV),  f"Test CSV not found at {TEST_CSV}"

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---- Load checkpoint ----
ckpt = torch.load(CKPT_PATH, map_location="cpu")
model_id = ckpt.get("model_id", "openai/clip-vit-large-patch14-336")
cfg = ckpt.get("config", {})
projection_dim = cfg.get("projection_dim")
img_missing_policy = cfg.get("IMG_MISSING_POLICY", "zero")
max_len = int(cfg.get("MAX_LEN", 64)) if MAX_LEN_ENV is None else int(MAX_LEN_ENV)

print(f"📦 Loaded checkpoint from: {CKPT_PATH}")
print(f"🔤 MODEL_ID={model_id} | projection_dim={projection_dim} | IMG_MISSING_POLICY={img_missing_policy} | MAX_LEN={max_len}")

# ---- Recreate processor & models ----
processor = AutoProcessor.from_pretrained(model_id)
clip_model = CLIPModel.from_pretrained(model_id)
clip_model.load_state_dict(ckpt["clip_state"], strict=True)
clip_model.to(device).eval()

# Recreate and load regression head (same class as training cell)
price_head = ClipRegressionHead(embed_dim=projection_dim, dropout=0.0)
price_head.load_state_dict(ckpt["head_state"], strict=True)
price_head.to(device).eval()

# ---- Load test data ----
dft = pd.read_csv(TEST_CSV)
dft["image_path"] = dft["sample_id"].apply(lambda x : f"jl_fs/images/test/{x}.jpg")
for col, name in [(ID_COL, "ID_COL"), (TEXT_COL, "TEXT_COL"), (IMG_COL, "IMG_COL")]:
    if col not in dft.columns:
        raise ValueError(f"{name} '{col}' missing from test CSV. Columns={dft.columns.tolist()}")

# Basic clean
dft[TEXT_COL] = dft[TEXT_COL].fillna("").astype(str).str.strip()
dft[IMG_COL]  = dft[IMG_COL].fillna("").astype(str)

# Build dataset/dataloader with no targets
test_ds = ClipPriceDataset(
    df=dft[[ID_COL, TEXT_COL, IMG_COL]].copy(),
    text_col=TEXT_COL,
    img_col=IMG_COL,
    prices_log2=None,
    processor=processor,
    max_len=max_len,
    policy=img_missing_policy
)
collate = CollateClip(processor)

dl_te = DataLoader(
    test_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=2, pin_memory=True, collate_fn=collate
)

print(f"🖥 Device: {device}")
print(f"🧪 Test rows: {len(test_ds)} | Missing images encountered (during getitem): {test_ds.missing_img_count}")
print(f"🗑 Dropped due to policy=drop: {getattr(test_ds, 'dropped_missing', 0)}")

# ---- Inference loop ----
clip_model_dtype = next(clip_model.vision_model.parameters()).dtype
preds_log2 = []

with torch.no_grad():
    for batch in tqdm(dl_te, total = len(dl_te)):
        input_ids      = batch["input_ids"].to(device, non_blocking=True)
        attention_mask = batch["attention_mask"].to(device, non_blocking=True)
        img_missing    = batch["img_missing"].to(device)

        # Text features
        with torch.cuda.amp.autocast(enabled=FP16):
            txt_feat = clip_model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)

            # Image features depending on policy
            do_vision = ("pixel_values" in batch) and (img_missing_policy != "text_only")
            if do_vision:
                pixel_values = batch["pixel_values"].to(device, dtype=clip_model_dtype, non_blocking=True)
                img_feat = clip_model.get_image_features(pixel_values=pixel_values)
                if img_missing.any():
                    img_feat = img_feat * (1.0 - img_missing.unsqueeze(1).float())
            else:
                img_feat = torch.zeros_like(txt_feat)

            # Normalize + fuse
            txt_n = F.normalize(txt_feat, dim=-1)
            img_n = F.normalize(img_feat, dim=-1)
            fused = torch.cat([img_n, txt_n], dim=-1)

            # Predict log2(price)
            pred_log = price_head(fused)
            preds_log2.append(pred_log.detach().float().cpu().numpy())

📦 Loaded checkpoint from: price_clip_vit_99_1_split/best_clip.pt
🔤 MODEL_ID=openai/clip-vit-large-patch14-336 | projection_dim=768 | IMG_MISSING_POLICY=zero | MAX_LEN=64
🖥 Device: cuda
🧪 Test rows: 75000 | Missing images encountered (during getitem): 0
🗑 Dropped due to policy=drop: 0


  0%|          | 0/1172 [00:00<?, ?it/s]

You're using a CLIPTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a CLIPTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  with torch.cuda.amp.autocast(enabled=FP16):
