In [9]:
import os, math, random, json
from dataclasses import dataclass
from typing import List, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup

In [10]:
# Files (change here or via env)
TRAIN_CSV       = os.environ.get("TRAIN_CSV", "jl_fs/train.csv")
TEST_CSV        = os.environ.get("TEST_CSV", "jl_fs/test.csv")          # set if you have a test file
ID_COL          = os.environ.get("ID_COL", "sample_id")
TEXT_COL        = os.environ.get("TEXT_COL", "catalog_content")
PRICE_COL       = os.environ.get("PRICE_COL", "price")

# Model/output
MODEL_ID        = os.environ.get("MODEL_ID", "google-bert/bert-large-uncased")
OUTPUT_DIR      = os.environ.get("OUTPUT_DIR", "price_large_bert_contrastive")
BEST_CKPT_NAME  = "best_val.pt"
FULL_CKPT_NAME  = "final_full.pt"
PRED_CSV_NAME   = "predictions.csv"

# Training knobs
SEED            = 42
VAL_FRAC        = 0.0          # set to 0.0 to skip validation and just train on 100%
MAX_LEN         = 192
BATCH_SIZE      = 16
LR              = 3e-5
WEIGHT_DECAY    = 0.01
EPOCHS          = 5            # epochs for Stage A (with validation)
EPOCHS_FULL     = 2            # extra epochs when training on 100% (Stage B)
WARMUP_RATIO    = 0.06
GRAD_ACCUM      = 1
MAX_GRAD_NORM   = 1.0
FP16            = True

# Loss mixing
ALPHA_CONTRAST  = 0.25         # 0..1
TAU             = 0.05         # InfoNCE temperature

# Light data augmentation for the 2nd view
WORD_MASK_P     = 0.08
DROPOUT_PROB    = 0.1

EARLY_STOP_ROUNDS = 3          # only used if VAL_FRAC > 0.0
MIN_PRICE       = 1e-6         # for log2 transform
# --------------------------------------------------------------

In [11]:
ALPHA_CONTRAST  = 0.25              # weight for contrastive loss (0..1). 0.25 is a good start.
TAU             = 0.05                  # contrastive temperature

# Augmentations
WORD_MASK_P     = 0.08                  # probability to randomly mask an input token (2nd view only)
DROPOUT_PROB    = 0.1                   # Dropout already inside the transformer; can adjust here in heads.

EARLY_STOP_ROUNDS = 3                   # stop if val SMAPE hasn't improved for these epochs

# Price clipping for log transform
MIN_PRICE       = 1e-6

In [12]:
def set_seed(seed: int = SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

def smape_np(y_true, y_pred, eps=1e-8):
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    denom = (np.abs(y_true) + np.abs(y_pred) + eps) / 2.0
    return 100.0 * np.mean(np.abs(y_pred - y_true) / denom)

def log2_price(p: np.ndarray) -> np.ndarray:
    return np.log2(np.clip(p, MIN_PRICE, None))

def delog2(x: np.ndarray) -> np.ndarray:
    return np.power(2.0, x)

def split_train_val(df: pd.DataFrame, frac_val: float = 0.1, seed: int = SEED):
    if frac_val <= 0.0:
        return df.copy().reset_index(drop=True), pd.DataFrame(columns=df.columns)
    df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
    n_val = int(len(df) * frac_val)
    df_val = df.iloc[:n_val].reset_index(drop=True)
    df_tr  = df.iloc[n_val:].reset_index(drop=True)
    return df_tr, df_val

class PriceTextDataset(Dataset):
    def __init__(self, texts: List[str], prices_log2: Optional[np.ndarray], tokenizer, max_len: int):
        self.texts = texts
        self.prices_log2 = prices_log2
        self.tok = tokenizer
        self.max_len = max_len

    def _tokenize(self, text: str):
        return self.tok(
            text if isinstance(text, str) else "",
            truncation=True,
            max_length=self.max_len,
            padding=False,
            return_tensors="pt"
        )

    def _random_word_mask(self, input_ids: torch.Tensor, mask_token_id: int, prob: float) -> torch.Tensor:
        if prob <= 0.0:
            return input_ids
        ids = input_ids.clone()
        special = set(self.tok.all_special_ids)
        for i in range(ids.size(0)):
            for j in range(ids.size(1)):
                if ids[i, j].item() in special:
                    continue
                if random.random() < prob:
                    ids[i, j] = mask_token_id
        return ids

    def __len__(self): return len(self.texts)

    def __getitem__(self, idx):
        enc1 = self._tokenize(self.texts[idx])
        enc2 = {k: v.clone() for k, v in enc1.items()}
        enc2["input_ids"] = self._random_word_mask(enc2["input_ids"], self.tok.mask_token_id, WORD_MASK_P)
        item = {
            "input_ids_1": enc1["input_ids"].squeeze(0),
            "attention_mask_1": enc1["attention_mask"].squeeze(0),
            "input_ids_2": enc2["input_ids"].squeeze(0),
            "attention_mask_2": enc2["attention_mask"].squeeze(0),
        }
        if self.prices_log2 is not None:
            item["target"] = torch.tensor(self.prices_log2[idx], dtype=torch.float32)
        return item


In [15]:
@dataclass
class Collate:
    pad_id: int
    def __call__(self, batch):
        keys1 = ["input_ids_1", "attention_mask_1"]
        keys2 = ["input_ids_2", "attention_mask_2"]

        def pad_stack(keylist):
            maxlen = max(x[keylist[0]].size(0) for x in batch)
            out = {}
            for k in keylist:
                pad_val = self.pad_id if "input_ids" in k else 0
                tensors = []
                for x in batch:
                    v = x[k]
                    if v.size(0) < maxlen:
                        pad = torch.full((maxlen - v.size(0),), pad_val, dtype=v.dtype)
                        v = torch.cat([v, pad], dim=0)
                    tensors.append(v.unsqueeze(0))
                out[k] = torch.cat(tensors, dim=0)
            return out

        out1 = pad_stack(keys1)
        out2 = pad_stack(keys2)
        res = {**out1, **out2}
        if "target" in batch[0]:
            res["target"] = torch.stack([x["target"] for x in batch], dim=0)
        return res

In [19]:
class BertPriceModel(nn.Module):
    def __init__(self, model_id: str, proj_dim: int = 256, dropout: float = DROPOUT_PROB):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_id)
        hidden = self.backbone.config.hidden_size
        self.regressor = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)
        )
        self.proj = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.GELU(),
            nn.Linear(hidden, proj_dim)
        )

    def forward_once(self, input_ids, attention_mask):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        cls = out.last_hidden_state[:, 0, :]
        yhat = self.regressor(cls).squeeze(-1)
        z = self.proj(cls)
        return yhat, z

    def forward(self, input_ids_1, attention_mask_1, input_ids_2, attention_mask_2):
        y1, z1 = self.forward_once(input_ids_1, attention_mask_1)
        y2, z2 = self.forward_once(input_ids_2, attention_mask_2)
        return (y1 + y2) / 2.0, z1, z2

def info_nce(z1: torch.Tensor, z2: torch.Tensor, tau: float = TAU) -> torch.Tensor:
    z1 = F.normalize(z1, dim=-1)
    z2 = F.normalize(z2, dim=-1)
    logits = torch.matmul(z1, z2.t()) / tau
    labels = torch.arange(z1.size(0), device=z1.device)
    loss1 = F.cross_entropy(logits, labels)
    loss2 = F.cross_entropy(logits.t(), labels)
    return 0.5 * (loss1 + loss2)

def huber_loss(pred, target, delta=1.0):
    return F.huber_loss(pred, target, delta=delta)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def make_loader(texts, y_log2, tokenizer, batch_size, shuffle):
    ds = PriceTextDataset(texts=texts, prices_log2=y_log2, tokenizer=tokenizer, max_len=MAX_LEN)
    collate = Collate(pad_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=2, pin_memory=True, collate_fn=collate)

In [20]:
def train_one_stage(model, loader_tr, loader_val, tokenizer, device, out_dir, track_val=True, epochs=EPOCHS):
    no_decay = ["bias", "LayerNorm.weight"]
    grouped = [
        {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": WEIGHT_DECAY},
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = torch.optim.AdamW(grouped, lr=LR)
    steps_per_epoch = max(1, math.ceil(len(loader_tr) / GRAD_ACCUM))
    num_training_steps = epochs * steps_per_epoch
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=int(num_training_steps * WARMUP_RATIO),
                                                num_training_steps=num_training_steps)
    scaler = torch.cuda.amp.GradScaler(enabled=FP16)

    best_smape = float("inf")
    patience = 0
    metrics_history = []  # epoch-wise SMAPE logging

    for epoch in range(1, epochs + 1):
        # ---- Train ----
        model.train()
        run_loss = run_reg = run_con = 0.0
        for step, batch in enumerate(loader_tr, 1):
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.cuda.amp.autocast(enabled=FP16):
                yhat, z1, z2 = model(
                    input_ids_1=batch["input_ids_1"],
                    attention_mask_1=batch["attention_mask_1"],
                    input_ids_2=batch["input_ids_2"],
                    attention_mask_2=batch["attention_mask_2"],
                )
                loss_reg = huber_loss(yhat, batch["target"])
                loss_con = info_nce(z1, z2, tau=TAU)
                loss = (1.0 - ALPHA_CONTRAST) * loss_reg + ALPHA_CONTRAST * loss_con
            scaler.scale(loss).backward()

            if step % GRAD_ACCUM == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()

            run_loss += loss.item(); run_reg += loss_reg.item(); run_con += loss_con.item()
            if step % 300 == 0:
                print(f"epoch {epoch} step {step}/{len(loader_tr)} loss={run_loss/step:.4f} reg={run_reg/step:.4f} con={run_con/step:.4f}")

        # ---- Evaluate (val or train) ----
        model.eval()
        preds_log = []
        tgts_log  = []
        with torch.no_grad():
            loader_eval = loader_val if (track_val and loader_val is not None and len(loader_val) > 0) else loader_tr
            for batch in loader_eval:
                batch = {k: v.to(device) for k, v in batch.items()}
                yhat, _, _ = model(
                    input_ids_1=batch["input_ids_1"],
                    attention_mask_1=batch["attention_mask_1"],
                    input_ids_2=batch["input_ids_2"],
                    attention_mask_2=batch["attention_mask_2"],
                )
                preds_log.append(yhat.detach().float().cpu().numpy())
                tgts_log.append(batch["target"].detach().float().cpu().numpy())
        preds_log = np.concatenate(preds_log, axis=0)
        tgts_log  = np.concatenate(tgts_log, axis=0)
        preds = delog2(preds_log); tgts = delog2(tgts_log)
        smape = smape_np(tgts, preds)
        print(f"✅ Epoch {epoch}: {'VAL' if (track_val and loader_val is not None and len(loader_val)>0) else 'TRAIN'} SMAPE = {smape:.3f}%")

        # ---- Log epoch-wise SMAPE to history JSON ----
        metrics_history.append({"epoch": int(epoch), "val_smape" if (track_val and loader_val is not None and len(loader_val)>0) else "train_smape": float(smape)})
        with open(os.path.join(out_dir, "metrics_history.json"), "w") as f:
            json.dump(metrics_history, f, indent=2)

        # ---- Save best (only when tracking val) ----
        if track_val and loader_val is not None and len(loader_val) > 0:
            if smape < best_smape - 1e-6:
                best_smape = smape
                torch.save({"model_state": model.state_dict(), "tokenizer": MODEL_ID}, os.path.join(out_dir, BEST_CKPT_NAME))
                patience = 0
                print(f"💾 Saved new best to {os.path.join(out_dir, BEST_CKPT_NAME)}")
            else:
                patience += 1
                print(f"⏸️ No improvement. Patience {patience}/{EARLY_STOP_ROUNDS}")
                if patience >= EARLY_STOP_ROUNDS:
                    print("🛑 Early stopping.")
                    break

    # Final metrics file (best val when available, else last epoch train SMAPE)
    with open(os.path.join(out_dir, "metrics.json"), "w") as f:
        payload = {"history": metrics_history}
        if track_val and loader_val is not None and len(loader_val) > 0:
            payload["best_val_smape"] = float(min([h["val_smape"] for h in metrics_history]))
        else:
            payload["last_train_smape"] = float(metrics_history[-1]["train_smape"])
        json.dump(payload, f, indent=2)

In [21]:
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Device: {device}")
print(f"🔧 Loading train CSV: {TRAIN_CSV}")
df = pd.read_csv(TRAIN_CSV)

# Checks & cleaning
for col in [ID_COL, TEXT_COL, PRICE_COL]:
    if col not in df.columns:
        raise ValueError(f"Column '{col}' missing in {TRAIN_CSV}. Found: {df.columns.tolist()}")
df[TEXT_COL] = df[TEXT_COL].fillna("").astype(str).str.strip()
df = df.loc[pd.to_numeric(df[PRICE_COL], errors="coerce").notnull()].copy()
df[PRICE_COL] = df[PRICE_COL].astype(float)
df = df.loc[df[PRICE_COL] >= 0.0].reset_index(drop=True)

# Optional validation split (Stage A)
df_tr, df_va = split_train_val(df, frac_val=VAL_FRAC, seed=SEED)
print(f"📊 Split -> train={len(df_tr)} | valid={len(df_va)} (VAL_FRAC={VAL_FRAC})")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.mask_token is None:
    tokenizer.add_special_tokens({"mask_token": "[MASK]"})

# Stage A: train with validation (if VAL_FRAC>0)
model = BertPriceModel(MODEL_ID).to(device)
if tokenizer.vocab_size != model.backbone.get_input_embeddings().weight.size(0):
    model.backbone.resize_token_embeddings(len(tokenizer))

print(f"🧮 Trainable params: {count_parameters(model):,}")

y_tr_log = log2_price(df_tr[PRICE_COL].values)
loader_tr = make_loader(df_tr[TEXT_COL].tolist(), y_tr_log, tokenizer, BATCH_SIZE, shuffle=True)

if len(df_va) > 0:
    y_va_log = log2_price(df_va[PRICE_COL].values)
    loader_va = make_loader(df_va[TEXT_COL].tolist(), y_va_log, tokenizer, BATCH_SIZE, shuffle=False)
else:
    loader_va = None

train_one_stage(model, loader_tr, loader_va, tokenizer, device, OUTPUT_DIR,
                track_val=(len(df_va) > 0), epochs=EPOCHS)

# If we had a val set, reload best val checkpoint before Stage B
if len(df_va) > 0 and os.path.exists(os.path.join(OUTPUT_DIR, BEST_CKPT_NAME)):
    ckpt = torch.load(os.path.join(OUTPUT_DIR, BEST_CKPT_NAME), map_location=device)
    model.load_state_dict(ckpt["model_state"])
    print("✅ Loaded best validation checkpoint for full-data training.")

# Stage B: train on 100% of data (no validation) for a couple of epochs
print("🔁 Retraining on 100% of training data...")
y_all_log = log2_price(df[PRICE_COL].values)
loader_all = make_loader(df[TEXT_COL].tolist(), y_all_log, tokenizer, BATCH_SIZE, shuffle=True)
train_one_stage(model, loader_all, loader_val=None, tokenizer=tokenizer, device=device,
                out_dir=OUTPUT_DIR, track_val=False, epochs=EPOCHS_FULL)

# Save final full-data checkpoint
torch.save({"model_state": model.state_dict(), "tokenizer": MODEL_ID},
           os.path.join(OUTPUT_DIR, FULL_CKPT_NAME))
print(f"💾 Saved final full-data model to {os.path.join(OUTPUT_DIR, FULL_CKPT_NAME)}")

# -------------------- Stage C: Predict on test file --------------------
if TEST_CSV and os.path.exists(TEST_CSV):
    print(f"🔮 Loading test CSV: {TEST_CSV}")
    dft = pd.read_csv(TEST_CSV)
    if ID_COL not in dft.columns or TEXT_COL not in dft.columns:
        raise ValueError(f"Test file must contain '{ID_COL}' and '{TEXT_COL}'. Found: {dft.columns.tolist()}")
    dft[TEXT_COL] = dft[TEXT_COL].fillna("").astype(str).str.strip()

    # Inference loader (no targets)
    class InferDataset(Dataset):
        def __init__(self, texts, tokenizer, max_len):
            self.texts = texts; self.tok = tokenizer; self.max_len = max_len
        def __len__(self): return len(self.texts)
        def __getitem__(self, idx):
            enc = self.tok(self.texts[idx], truncation=True, max_length=self.max_len, padding=False, return_tensors="pt")
            return {
                "input_ids": enc["input_ids"].squeeze(0),
                "attention_mask": enc["attention_mask"].squeeze(0),
            }
    @dataclass
    class InferCollate:
        pad_id: int
        def __call__(self, batch):
            maxlen = max(x["input_ids"].size(0) for x in batch)
            def pad(key, pad_val):
                arr = []
                for x in batch:
                    v = x[key]
                    if v.size(0) < maxlen:
                        pad = torch.full((maxlen - v.size(0),), pad_val, dtype=v.dtype)
                        v = torch.cat([v, pad], dim=0)
                    arr.append(v.unsqueeze(0))
                return torch.cat(arr, dim=0)
            return {
                "input_ids": pad("input_ids", self.pad_id),
                "attention_mask": pad("attention_mask", 0),
            }

    infer_ds = InferDataset(dft[TEXT_COL].tolist(), tokenizer, MAX_LEN)
    infer_loader = DataLoader(
        infer_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2,
        pin_memory=True, collate_fn=InferCollate(pad_id=tokenizer.pad_token_id or tokenizer.eos_token_id)
    )

    # Build a small head-only forward for inference
    model.eval()
    preds = []
    with torch.no_grad():
        for batch in infer_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # two passes for stability (reusing model.forward_once)
            y1, _ = model.forward_once(input_ids, attention_mask)
            y2, _ = model.forward_once(input_ids, attention_mask)
            yhat_log = ((y1 + y2) / 2.0).detach().float().cpu().numpy()
            preds.append(yhat_log)
    preds = np.concatenate(preds, axis=0)
    price_pred = delog2(preds)

    out_df = pd.DataFrame({
        ID_COL: dft[ID_COL].values,
        "price_pred": price_pred
    })
    out_path = os.path.join(OUTPUT_DIR, PRED_CSV_NAME)
    out_df.to_csv(out_path, index=False)
    print(f"📤 Wrote predictions to: {out_path}")
else:
    print("ℹ️ TEST_CSV not set or file not found — skipping prediction.")

🖥️ Device: cuda
🔧 Loading train CSV: jl_fs/train.csv
📊 Split -> train=75000 | valid=0 (VAL_FRAC=0.0)
🧮 Trainable params: 337,504,513


  scaler = torch.cuda.amp.GradScaler(enabled=FP16)
  with torch.cuda.amp.autocast(enabled=FP16):


epoch 1 step 300/4688 loss=1.3099 reg=1.4714 con=0.8252
epoch 1 step 600/4688 loss=0.9119 reg=1.0753 con=0.4217
epoch 1 step 900/4688 loss=0.7697 reg=0.9312 con=0.2851
epoch 1 step 1200/4688 loss=0.6848 reg=0.8408 con=0.2168
epoch 1 step 1500/4688 loss=0.6321 reg=0.7843 con=0.1754
epoch 1 step 1800/4688 loss=0.5938 reg=0.7423 con=0.1482
epoch 1 step 2100/4688 loss=0.5659 reg=0.7114 con=0.1294
epoch 1 step 2400/4688 loss=0.5424 reg=0.6852 con=0.1141
epoch 1 step 2700/4688 loss=0.5235 reg=0.6639 con=0.1020
epoch 1 step 3000/4688 loss=0.5086 reg=0.6473 con=0.0925
epoch 1 step 3300/4688 loss=0.4960 reg=0.6330 con=0.0849
epoch 1 step 3600/4688 loss=0.4852 reg=0.6207 con=0.0786
epoch 1 step 3900/4688 loss=0.4750 reg=0.6090 con=0.0731
epoch 1 step 4200/4688 loss=0.4661 reg=0.5986 con=0.0684
epoch 1 step 4500/4688 loss=0.4578 reg=0.5890 con=0.0642
✅ Epoch 1: TRAIN SMAPE = 46.763%
epoch 2 step 300/4688 loss=0.2995 reg=0.3976 con=0.0053
epoch 2 step 600/4688 loss=0.2969 reg=0.3941 con=0.0055
epo

KeyboardInterrupt: 

In [22]:
torch.save({"model_state": model.state_dict(), "tokenizer": MODEL_ID},
           os.path.join(OUTPUT_DIR, FULL_CKPT_NAME))

In [24]:
# -------------------- Stage C: Load checkpoint & predict on test file --------------------
import glob

# Resolve checkpoint path
if "FULL_CKPT_NAME" in globals() and FULL_CKPT_NAME:
    CKPT_PATH = os.path.join(OUTPUT_DIR, FULL_CKPT_NAME)
else:
    # fallback: pick most recent .pt in OUTPUT_DIR
    cands = sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.pt")), key=os.path.getmtime)
    if not cands:
        raise FileNotFoundError(f"No .pt checkpoints found in {OUTPUT_DIR}")
    CKPT_PATH = cands[-1]

print(f"🧩 Loading checkpoint: {CKPT_PATH}")
ckpt = torch.load(CKPT_PATH, map_location="cpu")

# Rebuild tokenizer/model from checkpoint metadata
saved_model_id = ckpt.get("tokenizer", MODEL_ID)  # we stored MODEL_ID under 'tokenizer'
print(f"🔤 Using tokenizer/model id: {saved_model_id}")
tokenizer = AutoTokenizer.from_pretrained(saved_model_id, use_fast=True)

# If pad token is missing for any reason, set one
if tokenizer.pad_token is None:
    if tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token
    else:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})

# Recreate model and align embedding matrix size to tokenizer vocab BEFORE loading state_dict
model = BertPriceModel(saved_model_id).to(device)
if tokenizer.vocab_size != model.backbone.get_input_embeddings().weight.size(0):
    print(f"ℹ️ Resizing token embeddings to match tokenizer: {tokenizer.vocab_size}")
    model.backbone.resize_token_embeddings(len(tokenizer))

# Load weights (strict=False in case resize introduced size diffs in embeddings)
missing, unexpected = model.load_state_dict(ckpt["model_state"], strict=False)
if missing or unexpected:
    print(f"⚠️ load_state_dict: missing={len(missing)} unexpected={len(unexpected)}")
    if missing:   print("  missing:", missing[:8], "..." if len(missing) > 8 else "")
    if unexpected:print("  unexpected:", unexpected[:8], "..." if len(unexpected) > 8 else "")

model.eval()

# Names/defaults
if "PRED_CSV_NAME" not in globals() or not PRED_CSV_NAME:
    PRED_CSV_NAME = "predictions.csv"
if "ID_COL" not in globals() or not ID_COL:
    ID_COL = "sample_id"

# Inference
if "TEST_CSV" in globals() and TEST_CSV and os.path.exists(TEST_CSV):
    print(f"🔮 Loading test CSV: {TEST_CSV}")
    dft = pd.read_csv(TEST_CSV)
    if ID_COL not in dft.columns or TEXT_COL not in dft.columns:
        raise ValueError(f"Test file must contain '{ID_COL}' and '{TEXT_COL}'. Found: {dft.columns.tolist()}")
    dft[TEXT_COL] = dft[TEXT_COL].fillna("").astype(str).str.strip()

    class InferDataset(Dataset):
        def __init__(self, texts, tokenizer, max_len):
            self.texts = texts; self.tok = tokenizer; self.max_len = max_len
        def __len__(self): return len(self.texts)
        def __getitem__(self, idx):
            enc = self.tok(self.texts[idx], truncation=True, max_length=self.max_len, padding=False, return_tensors="pt")
            return {
                "input_ids": enc["input_ids"].squeeze(0),
                "attention_mask": enc["attention_mask"].squeeze(0),
            }

    @dataclass
    class InferCollate:
        pad_id: int
        def __call__(self, batch):
            maxlen = max(x["input_ids"].size(0) for x in batch)
            def pad(key, pad_val):
                arr = []
                for x in batch:
                    v = x[key]
                    if v.size(0) < maxlen:
                        pad = torch.full((maxlen - v.size(0),), pad_val, dtype=v.dtype)
                        v = torch.cat([v, pad], dim=0)
                    arr.append(v.unsqueeze(0))
                return torch.cat(arr, dim=0)
            return {
                "input_ids": pad("input_ids", self.pad_id),
                "attention_mask": pad("attention_mask", 0),
            }

    infer_ds = InferDataset(dft[TEXT_COL].tolist(), tokenizer, MAX_LEN)
    infer_loader = DataLoader(
        infer_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        collate_fn=InferCollate(pad_id=tokenizer.pad_token_id),
    )

    preds_log2 = []
    with torch.no_grad():
        for batch in infer_loader:
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attention_mask = batch["attention_mask"].to(device, non_blocking=True)
            # Two stochastic passes (dropout) and average for stability
            y1, _ = model.forward_once(input_ids, attention_mask)
            y2, _ = model.forward_once(input_ids, attention_mask)
            yhat = ((y1 + y2) / 2.0).detach().float().cpu().numpy()
            preds_log2.append(yhat)

    preds_log2 = np.concatenate(preds_log2, axis=0)
    price_pred = delog2(preds_log2)

    out_df = pd.DataFrame({ID_COL: dft[ID_COL].values, "price": price_pred})
    out_path = os.path.join(OUTPUT_DIR, PRED_CSV_NAME)
    out_df.to_csv("prediction-bert-large.csv", index=False)
    print(f"📤 Wrote predictions to: {out_path}")
else:
    print("ℹ️ TEST_CSV not set or file not found — skipping prediction.")


🧩 Loading checkpoint: price_large_bert_contrastive/final_full.pt
🔤 Using tokenizer/model id: google-bert/bert-large-uncased
🔮 Loading test CSV: jl_fs/test.csv
📤 Wrote predictions to: price_large_bert_contrastive/predictions.csv
