In [None]:
# =========================================================
# FP (TARGET-ONLY) + ORGAN TOKEN + SMILES CLIP  
# Variant A scaling: raw counts -> log1p -> delta -> clip (+ optional asinh)
# Drop first "service" element from genes/expr
# Stable gene order after selection: sort by gene_id (deterministic)
# ORGAN token injected at fixed position [CLS][ORGAN] (NO cell-line embedding)
# Positional embeddings REMOVED
# Targets: cosine vector loss + cosine-BCE (neg sampling) + InfoNCE rank
# SMILES: CLIP loss (+ optional cosine align), NO SupCon, NO MSE
# Batch is CLIP-safe: unique drug in batch (1 cell per drug)
# Skip missing/zero SMILES vectors
# OneCycleLR fixed for grad accumulation (total_steps = real updates)
# AMP + grad accumulation
# =========================================================

import os, glob, ast, random
from collections import defaultdict

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from tqdm import tqdm
import pyarrow.parquet as pq

import scanpy as sc
from scipy import sparse
from sklearn.model_selection import train_test_split


# =========================================================
# 0) PATHS / HYPERPARAMS
# =========================================================
PARQUET_DIR    = "/data/aiffel/data/Tahoe-100M/data"
GENE_META_PATH = "/data/aiffel/data/Tahoe-100M/metadata/gene_metadata.parquet"
DRUG_META_PATH = "/data/aiffel/data/Tahoe-100M/metadata/drug_metadata.parquet"
COUNTS_CSV     = "/data/aiffel/babayakga/making_data/aiffel/babayakga/making_data/tahoe_counts_per_drug_cell_line.csv"
DMSO_PATH      = "/data/aiffel/babayakga/outputs/dmso.h5ad"
CELL_LINE_META_PATH = "/data/aiffel/data/Tahoe-100M/metadata/cell_line_metadata.parquet"

SMILES_EMB_PATH       = "/data/aiffel/babayakga/smiles_emb/drug_smiles_emb_all1.pt"
PRETRAINED_GENE_NPY   = "/data/aiffel/babayakga/pretraining/checkpoints_with_cell/gene_embeddings.npy"  # optional

CONTROL_DRUG = "DMSO_TF"
SEED = 42

# sequence
MAX_SEQ_LEN = 256
HVG_K = 4000

# training
BATCH_SIZE  = 128
ACCUM_STEPS = 4
STEPS_PER_EPOCH = 7000
VAL_STEPS       = 300
EPOCHS          = 20

LR           = 1e-4
WEIGHT_DECAY = 0.01
MAX_GRAD_NORM = 1.0

# targets loss weights 
lambda_cos  = 1.0
lambda_bce  = 0.05
lambda_rank = 1.0

# bce/rank knobs
bce_num_neg  = 2048
bce_pos_cap  = None
tau_bce      = 0.15

rank_num_neg = 1024
rank_num_pos = 4
tau_rank     = 0.15

# SMILES (CLIP)
lambda_smiles = 0.05
alpha_align   = 0.5
TAU_INIT      = 0.10  # initial tau (learnable temperature)

# overall mixing
lambda_targets = 1.0  

# data sampling
NUM_WORKERS = 4
MIN_TRAIN_CELLS_PER_PAIR = 1000
TEST_SIZE = 0.1

# Variant A scaling
USE_LOG1P_EXPR   = True
USE_ASINH_DELTA  = False
DELTA_CLIP_ABS   = 5.0

# misc
DROP_FIRST_GENE_TOKEN = True

# device
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass


# =========================================================
# 1) gene_metadata
# =========================================================
gene_md = pd.read_parquet(GENE_META_PATH).copy()
gene_md["gene_symbol"] = gene_md["gene_symbol"].astype(str)
gene_md["ensembl_id"]  = gene_md["ensembl_id"].astype(str)
gene_md["token_id"]    = gene_md["token_id"].astype(int)
gene_md = gene_md.sort_values("token_id").reset_index(drop=True)

N_GENES = int(gene_md["token_id"].max()) + 1
symbol_to_ensg_lower = dict(zip(gene_md["gene_symbol"].str.lower(), gene_md["ensembl_id"]))
ensg_to_token_id = dict(zip(gene_md["ensembl_id"].values, gene_md["token_id"].values))


# =========================================================
# 2) drug_metadata -> targets
# =========================================================
def parse_targets(x):
    if x is None:
        return []
    if isinstance(x, float) and np.isnan(x):
        return []
    if isinstance(x, (list, tuple)):
        return [str(t).strip() for t in x if str(t).strip()]
    if isinstance(x, str):
        s = x.strip()
        if (s.startswith("[") and s.endswith("]")) or (s.startswith("(") and s.endswith(")")):
            try:
                out = ast.literal_eval(s)
                if isinstance(out, (list, tuple)):
                    return [str(t).strip() for t in out if str(t).strip()]
            except Exception:
                pass
        for sep in [";", ","]:
            if sep in s:
                return [t.strip() for t in s.split(sep) if t.strip()]
        return [s] if s else []
    return [str(x).strip()]

drug_meta_df = pd.read_parquet(DRUG_META_PATH).copy()
drug_meta_df["drug"] = drug_meta_df["drug"].astype(str)

drug_to_target_tokenids = {}
all_target_tokenids = set()

for _, row in drug_meta_df.iterrows():
    drug = str(row["drug"])
    targets = parse_targets(row.get("targets", None))

    tids = []
    for t in targets:
        t = str(t).strip()
        if not t:
            continue
        if t.startswith("ENSG"):
            ensg = t
        else:
            ensg = symbol_to_ensg_lower.get(t.lower(), None)
        if ensg is None:
            continue
        tid = ensg_to_token_id.get(ensg, None)
        if tid is None:
            continue
        tid = int(tid)
        if 0 <= tid < N_GENES:
            tids.append(tid)

    tids = sorted(set(tids))
    drug_to_target_tokenids[drug] = tids
    all_target_tokenids.update(tids)

drug_has_targets = {d: (len(tids) > 0) for d, tids in drug_to_target_tokenids.items()}
print(f"[targets] drugs total={len(drug_to_target_tokenids)}, with>=1 target={sum(drug_has_targets.values())}")


# =========================================================
# 3) HVG from DMSO
# =========================================================
def compute_hvg_token_ids_from_dmso(dmso_h5ad_path: str, control_drug: str, HVG_K: int, ensg_to_token_id: dict):
    ad = sc.read_h5ad(dmso_h5ad_path)
    obs = ad.obs
    m = (obs["drug"].astype(str).values == str(control_drug))
    idx = np.where(m)[0]
    if idx.size == 0:
        raise ValueError(f"No DMSO cells found: control_drug={control_drug}")

    X = ad.X.tocsr() if sparse.issparse(ad.X) else sparse.csr_matrix(ad.X)
    Xc = X[idx]

    mean = np.asarray(Xc.mean(axis=0)).ravel()
    mean2 = np.asarray(Xc.multiply(Xc).mean(axis=0)).ravel()
    var = (mean2 - mean**2).astype(np.float32)

    ensgs = ad.var_names.astype(str).tolist()
    token_ids, vars_ = [], []
    for j, ensg in enumerate(ensgs):
        tid = ensg_to_token_id.get(ensg, None)
        if tid is None:
            continue
        token_ids.append(int(tid))
        vars_.append(float(var[j]))

    token_ids = np.asarray(token_ids, dtype=np.int64)
    vars_ = np.asarray(vars_, dtype=np.float32)
    if token_ids.size == 0:
        raise ValueError("ENSG mapping failed.")

    k = min(int(HVG_K), token_ids.size)
    top = np.argpartition(-vars_, k-1)[:k]
    return set(token_ids[top].tolist())

hvg_token_ids = compute_hvg_token_ids_from_dmso(DMSO_PATH, CONTROL_DRUG, HVG_K, ensg_to_token_id)
print("[HVG] token_ids:", len(hvg_token_ids))

# INPUT subset = HVG ∪ TARGETS
subset_token_ids = sorted(set(hvg_token_ids) | set(all_target_tokenids))
M_SUB = len(subset_token_ids)
print("[subset] HVG ∪ TARGETS size:", M_SUB)

# OUTPUT target-only = TARGETS only
target_token_ids = sorted(set(all_target_tokenids))
M_TGT = len(target_token_ids)
print("[target-only] size:", M_TGT)

old_tid_to_subid = {tid: i for i, tid in enumerate(subset_token_ids)}
old_tid_to_tgtid = {tid: i for i, tid in enumerate(target_token_ids)}


# =========================================================
# 4) vocab + LUT (include ORGAN token)
# =========================================================
SPECIAL_TOKENS = ["[PAD]", "[CLS]", "[ORGAN]", "[MASK]"]
local_token_to_id = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
N_SPECIAL = len(SPECIAL_TOKENS)

VOCAB_SIZE = N_SPECIAL + M_SUB
PAD_ID   = local_token_to_id["[PAD]"]
CLS_ID   = local_token_to_id["[CLS]"]
ORGAN_TOK_ID = local_token_to_id["[ORGAN]"]

old_tid_to_vocab_lut = np.full((N_GENES,), -1, dtype=np.int64)
for sid, old_tid in enumerate(subset_token_ids):
    if 0 <= old_tid < N_GENES:
        old_tid_to_vocab_lut[old_tid] = N_SPECIAL + sid

subset_token_ids_np = np.asarray(subset_token_ids, dtype=np.int64)
print("[vocab] VOCAB_SIZE:", VOCAB_SIZE, "| N_SPECIAL:", N_SPECIAL)


# =========================================================
# 5) y_targets (drug -> TARGET-ONLY multi-hot)
# =========================================================
drug_to_target_vec_tgt = {}
for d, tids in drug_to_target_tokenids.items():
    vec = np.zeros(M_TGT, dtype=np.float32)
    for tid in tids:
        j = old_tid_to_tgtid.get(int(tid), None)
        if j is not None:
            vec[j] = 1.0
    drug_to_target_vec_tgt[d] = vec

print("[targets] drugs with>=1 target vec:", sum(float(v.sum()) > 0 for v in drug_to_target_vec_tgt.values()))


# =========================================================
# 6) organ mapping: cell_line_id -> organ_id   (FIX: UNK=0, organs start at 1)
# =========================================================
cl_meta = pd.read_parquet(CELL_LINE_META_PATH).copy()
cl_meta["Cell_ID_Cellosaur"] = cl_meta["Cell_ID_Cellosaur"].astype(str)
cl_meta["Organ"] = cl_meta["Organ"].astype(str)

cl_meta_small = cl_meta[["Cell_ID_Cellosaur", "Organ"]].dropna().drop_duplicates()
organs = sorted(cl_meta_small["Organ"].unique().tolist())

UNK_ORGAN_ID = 0
organ2id = {o: i+1 for i, o in enumerate(organs)}  # shift by 1
NUM_ORGANS = len(organs) + 1

cellline2organid = {
    str(cvcl): int(organ2id.get(str(org), UNK_ORGAN_ID))
    for cvcl, org in cl_meta_small.values
}
print("[organ] NUM_ORGANS:", NUM_ORGANS, "| mapped cell_lines:", len(cellline2organid), "| UNK_ORGAN_ID:", UNK_ORGAN_ID)


# =========================================================
# 7) SMILES embeddings + bank (for retrieval)
# =========================================================
obj = torch.load(SMILES_EMB_PATH, map_location="cpu")
assert isinstance(obj, dict) and "drug" in obj and "emb" in obj

drug_list_saved = [str(d) for d in obj["drug"]]
emb_matrix = obj["emb"].to(dtype=torch.float32).cpu().numpy()
SMILES_DIM = int(emb_matrix.shape[1])
drug_to_smiles_np_raw = {d: emb_matrix[i].astype(np.float32, copy=False) for i, d in enumerate(drug_list_saved)}

drug_names_all = sorted(set(drug_meta_df["drug"].astype(str).tolist()))
drug2id = {d: i for i, d in enumerate(drug_names_all)}

drug_to_smiles_np = {}
missing = 0
zeroed  = 0
for d in drug_names_all:
    v = drug_to_smiles_np_raw.get(d, None)
    if v is None:
        drug_to_smiles_np[d] = np.zeros((SMILES_DIM,), dtype=np.float32)
        missing += 1
    else:
        vv = v.astype(np.float32, copy=False)
        if np.abs(vv).sum() == 0.0:
            zeroed += 1
        drug_to_smiles_np[d] = vv
print(f"[SMILES] missing={missing}/{len(drug_names_all)} | zero_vec={zeroed}")

smiles_bank_np = np.stack([drug_to_smiles_np[d] for d in drug_names_all], axis=0).astype(np.float32)
print("[SMILES] bank:", smiles_bank_np.shape)


# =========================================================
# 8) DMSO baselines (Variant A: log1p BEFORE mean)
# =========================================================
def build_dmso_baselines_gene_space(dmso_h5ad_path: str, control_drug: str, N_GENES: int, ensg_to_token_id: dict, use_log1p: bool):
    adata = sc.read_h5ad(dmso_h5ad_path)
    obs = adata.obs
    X = adata.X.tocsr() if sparse.issparse(adata.X) else sparse.csr_matrix(adata.X)

    m = (obs["drug"].astype(str).values == str(control_drug))
    idx = np.where(m)[0]
    if idx.size == 0:
        raise ValueError("No DMSO cells.")

    ensgs = adata.var_names.astype(str).tolist()
    token_ids, cols = [], []
    for j, ensg in enumerate(ensgs):
        tid = ensg_to_token_id.get(ensg, None)
        if tid is None:
            continue
        token_ids.append(int(tid)); cols.append(j)

    token_ids = np.asarray(token_ids, dtype=np.int64)
    cols = np.asarray(cols, dtype=np.int64)

    Xc = X[idx][:, cols]
    if use_log1p:
        Xc = Xc.copy()
        Xc.data = np.log1p(np.clip(Xc.data, a_min=0.0, a_max=None))

    mean_global_sub = np.asarray(Xc.mean(axis=0)).ravel().astype(np.float32)
    baseline_global = np.zeros(N_GENES, dtype=np.float32)
    baseline_global[token_ids] = mean_global_sub

    baseline_by_cl = {}
    cl_values = obs["cell_line_id"].astype(str).values
    for cl in np.unique(cl_values):
        cl_idx = np.where(m & (cl_values == cl))[0]
        if cl_idx.size == 0:
            continue
        Xcl = X[cl_idx][:, cols]
        if use_log1p:
            Xcl = Xcl.copy()
            Xcl.data = np.log1p(np.clip(Xcl.data, a_min=0.0, a_max=None))
        mean_cl_sub = np.asarray(Xcl.mean(axis=0)).ravel().astype(np.float32)

        v = np.zeros(N_GENES, dtype=np.float32)
        v[token_ids] = mean_cl_sub
        baseline_by_cl[str(cl)] = v

    return baseline_global, baseline_by_cl

baseline_global, baseline_by_cl = build_dmso_baselines_gene_space(
    DMSO_PATH, CONTROL_DRUG, N_GENES, ensg_to_token_id, use_log1p=USE_LOG1P_EXPR
)
print("[baseline] global:", baseline_global.shape, "| by_cl:", len(baseline_by_cl))


# =========================================================
# 9) split (drug, cell_line) pairs + weights (filter: has targets)
# =========================================================
DRUG_COL, CELL_COL, N_COL = "drug", "cell_line_id", "n_cells"

counts = pd.read_csv(COUNTS_CSV)
counts[DRUG_COL] = counts[DRUG_COL].astype(str)
counts[CELL_COL] = counts[CELL_COL].astype(str)
counts[N_COL]    = counts[N_COL].astype(int)

pairs_df = counts[counts[N_COL] >= MIN_TRAIN_CELLS_PER_PAIR][[DRUG_COL, CELL_COL]].drop_duplicates().copy()
pairs_df = pairs_df[pairs_df[DRUG_COL] != str(CONTROL_DRUG)].copy()
pairs_df = pairs_df[pairs_df[DRUG_COL].map(lambda d: drug_has_targets.get(str(d), False))].copy()
pairs_df = pairs_df[pairs_df[DRUG_COL].isin(set(drug2id.keys()))].copy()

train_df, val_df = train_test_split(
    pairs_df,
    test_size=TEST_SIZE,
    random_state=SEED,
    stratify=pairs_df[DRUG_COL],
)

train_pairs = list(zip(train_df[DRUG_COL], train_df[CELL_COL]))
val_pairs   = list(zip(val_df[DRUG_COL],   val_df[CELL_COL]))
print("[split] train pairs:", len(train_pairs), "| val pairs:", len(val_pairs))

def make_pair_weights_from_counts(counts_df, pairs, mode="inv_sqrt", eps=1.0):
    pair2n = {(str(d), str(c)): int(n) for d, c, n in counts_df[[DRUG_COL, CELL_COL, N_COL]].values}
    w = []
    for p in pairs:
        n = pair2n.get((str(p[0]), str(p[1])), 0)
        if mode == "inv":
            ww = 1.0 / (n + eps)
        elif mode == "inv_log":
            ww = 1.0 / np.log1p(n + eps)
        else:
            ww = 1.0 / np.sqrt(n + eps)
        w.append(float(ww))
    w = np.asarray(w, dtype=np.float64)
    w = np.clip(w, 0.0, None)
    w = w / (w.sum() + 1e-12)
    return w

w_train = make_pair_weights_from_counts(counts, train_pairs, mode="inv_sqrt")
w_val   = make_pair_weights_from_counts(counts, val_pairs,   mode="inv_sqrt")


# =========================================================
# 10) parquet row-group indexing
# =========================================================
PARQUET_FILES = sorted(glob.glob(os.path.join(PARQUET_DIR, "**", "*.parquet"), recursive=True))
print("[parquet] files:", len(PARQUET_FILES))

def build_pair_to_locations(parquet_files, valid_pairs_set, drug_col="drug", cell_col="cell_line_id"):
    out = defaultdict(list)
    for f in tqdm(parquet_files, desc="Index parquet row-groups", dynamic_ncols=True):
        pf = pq.ParquetFile(f)
        for rg in range(pf.num_row_groups):
            tbl = pf.read_row_group(rg, columns=[drug_col, cell_col])
            df = tbl.to_pandas()
            pairs_here = set(zip(df[drug_col].astype(str), df[cell_col].astype(str)))
            inter = pairs_here.intersection(valid_pairs_set)
            for p in inter:
                out[p].append((f, rg))
    return out

valid_pairs_set = set(train_pairs) | set(val_pairs)
pair_to_locations = build_pair_to_locations(PARQUET_FILES, valid_pairs_set, drug_col=DRUG_COL, cell_col=CELL_COL)
print("[parquet] indexed pairs:", len(pair_to_locations))


# =========================================================
# 11) Dataset (unique drug batch + Variant A + stable ordering + organ_id)
# =========================================================
class TahoeFPParquetDataset_UniqueDrug(torch.utils.data.IterableDataset):
    def __init__(
        self,
        pair_to_locations,
        pairs,
        baseline_global,
        baseline_by_cellline,
        drug_to_target_vec_target_only,   # (M_TGT,)
        drug2id,
        drug_to_smiles_np,
        cellline2organid,
        unk_organ_id,
        n_genes_full,
        steps,
        max_seq_len=256,
        batch_size=128,
        control_drug="DMSO_TF",
        pad_id=0,
        cls_id=1,
        organtok_id=2,
        pair_weights=None,
        seed=42,
        drug_col="drug",
        cell_col="cell_line_id",
        genes_col="genes",
        expr_col="expressions",
        cap_per_pair_in_rg=None,
        max_tries_per_pair=20,
        invalid_global_gene_tids=(1, 2),
        subset_token_ids_np=None,
        old_tid_to_vocab_lut=None,
        m_tgt: int = 0,
        drop_first_gene_token: bool = True,
        use_log1p_expr: bool = True,
        use_asinh_delta: bool = False,
        delta_clip_abs: float = 5.0,
        stable_sort_selected_by_gene_id: bool = True,
    ):
        super().__init__()
        self.pair_to_locations = pair_to_locations
        self.pairs = list(pairs)
        self.baseline_global = np.asarray(baseline_global, dtype=np.float32)
        self.baseline_by_cellline = baseline_by_cellline or {}
        self.drug_to_target_vec_target_only = drug_to_target_vec_target_only
        self.drug2id = drug2id
        self.drug_to_smiles_np = drug_to_smiles_np
        self.cellline2organid = cellline2organid or {}
        self.unk_organ_id = int(unk_organ_id)

        self.n_genes_full = int(n_genes_full)
        self.steps = int(steps)
        self.max_seq_len = int(max_seq_len)
        self.batch_size = int(batch_size)

        self.control_drug = str(control_drug)
        self.pad_id = int(pad_id)
        self.cls_id = int(cls_id)
        self.organtok_id = int(organtok_id)

        self.drug_col = drug_col
        self.cell_col = cell_col
        self.genes_col = genes_col
        self.expr_col = expr_col

        self.cap_per_pair_in_rg = cap_per_pair_in_rg
        self.max_tries_per_pair = int(max_tries_per_pair)
        self.seed = int(seed)

        any_vec = next(iter(self.drug_to_smiles_np.values()))
        self.smiles_dim = int(any_vec.shape[-1])

        self.invalid_global_gene_tids = np.asarray(list(set(int(x) for x in invalid_global_gene_tids)), dtype=np.int64)

        self.m_tgt = int(m_tgt); assert self.m_tgt > 0
        self.drop_first_gene_token = bool(drop_first_gene_token)

        self.subset_token_ids_np = subset_token_ids_np
        self.old_tid_to_vocab_lut = old_tid_to_vocab_lut

        self.use_log1p_expr = bool(use_log1p_expr)
        self.use_asinh_delta = bool(use_asinh_delta)
        self.delta_clip_abs = float(delta_clip_abs)
        self.stable_sort_selected_by_gene_id = bool(stable_sort_selected_by_gene_id)

        if pair_weights is None:
            self.pair_weights = None
        else:
            w = np.asarray(pair_weights, dtype=np.float64)
            assert len(w) == len(self.pairs)
            w = np.clip(w, 0.0, None)
            w = w / (w.sum() + 1e-12)
            self.pair_weights = w

        self._pf_cache = {}

    def _get_pf(self, file_path):
        pf = self._pf_cache.get(file_path, None)
        if pf is None:
            pf = pq.ParquetFile(file_path)
            self._pf_cache[file_path] = pf
        return pf

    def _read_row_group_df(self, file_path, rg_id, columns):
        pf = self._get_pf(file_path)
        return pf.read_row_group(rg_id, columns=columns).to_pandas()

    def _scale_delta(self, delta: np.ndarray) -> np.ndarray:
        if self.delta_clip_abs and self.delta_clip_abs > 0:
            delta = np.clip(delta, -self.delta_clip_abs, self.delta_clip_abs)
        if self.use_asinh_delta:
            delta = np.arcsinh(delta)
        return delta.astype(np.float32, copy=False)

    def _prepare_sparse_sorted_drop0(self, genes, expr):
        if genes is None or expr is None:
            return np.asarray([], dtype=np.int64), np.asarray([], dtype=np.float32)

        idx = np.asarray(genes, dtype=np.int64)
        val = np.asarray(expr, dtype=np.float32)
        L = min(idx.size, val.size)
        idx = idx[:L]; val = val[:L]

        if self.drop_first_gene_token and L >= 1:
            idx = idx[1:]
            val = val[1:]

        if idx.size == 0:
            return idx, val

        # Variant A: log1p before delta
        if self.use_log1p_expr:
            val = np.log1p(np.clip(val, a_min=0.0, a_max=None))

        if self.invalid_global_gene_tids.size > 0:
            m_bad = np.isin(idx, self.invalid_global_gene_tids, assume_unique=False)
            if m_bad.any():
                keep = ~m_bad
                idx = idx[keep]; val = val[keep]
                if idx.size == 0:
                    return idx, val

        m = (idx >= 0) & (idx < self.n_genes_full)
        idx = idx[m]; val = val[m]
        if idx.size == 0:
            return idx, val

        order = np.argsort(idx)
        return idx[order], val[order]

    def _fill_one_row(self, row_genes, row_expr, baseline_vec, input_ids_row, values_row, attn_row):
        idx_sorted, val_sorted = self._prepare_sparse_sorted_drop0(row_genes, row_expr)
        if idx_sorted.size == 0:
            return False

        delta = val_sorted - baseline_vec[idx_sorted]
        delta = self._scale_delta(delta)

        mask_sub = np.isin(idx_sorted, self.subset_token_ids_np, assume_unique=False)
        if not mask_sub.any():
            return False

        idx_sub = idx_sorted[mask_sub]
        del_sub = delta[mask_sub]
        if idx_sub.size == 0:
            return False

        k = min(self.max_seq_len, idx_sub.size)
        top = np.argpartition(-np.abs(del_sub), k - 1)[:k]
        sel_tid = idx_sub[top]
        sel_del = del_sub[top]

        # stable order: sort by gene_id (deterministic)
        if self.stable_sort_selected_by_gene_id:
            o2 = np.argsort(sel_tid)
            sel_tid = sel_tid[o2]
            sel_del = sel_del[o2]

        sel_vid = self.old_tid_to_vocab_lut[sel_tid]
        ok = sel_vid != -1
        if not ok.any():
            return False

        sel_vid = sel_vid[ok]
        sel_del = sel_del[ok]

        L = min(self.max_seq_len, sel_vid.size)
        if L <= 0:
            return False

        # layout: [CLS][ORGAN] + genes...
        input_ids_row[2:2+L] = sel_vid[:L]
        values_row[2:2+L]    = sel_del[:L]
        attn_row[2:2+L]      = 1
        return True

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        base_seed = self.seed if worker_info is None else (self.seed + worker_info.id)
        rng = np.random.default_rng(base_seed)

        pairs = self.pairs
        weights = self.pair_weights
        n_pairs = len(pairs)

        cols = [self.drug_col, self.cell_col, self.genes_col, self.expr_col]
        seq_len = 2 + self.max_seq_len

        cnt = 0
        while True:
            chosen = []
            seen_drugs = set()

            tries = 0
            while len(chosen) < self.batch_size and tries < 80:
                tries += 1
                draw = min(max(self.batch_size * 8, 512), max(n_pairs, 1))
                if weights is None:
                    cand_idx = rng.integers(0, n_pairs, size=draw)
                else:
                    cand_idx = rng.choice(n_pairs, size=draw, replace=True, p=weights)

                for ii in cand_idx:
                    drug_name, cell_line = pairs[int(ii)]
                    drug_name = str(drug_name); cell_line = str(cell_line)

                    if drug_name == self.control_drug:
                        continue
                    if drug_name in seen_drugs:
                        continue

                    y_vec = self.drug_to_target_vec_target_only.get(drug_name, None)
                    if y_vec is None or float(y_vec.sum()) <= 0.0:
                        continue

                    sm = self.drug_to_smiles_np.get(drug_name, None)
                    if sm is None or (not np.isfinite(sm).all()) or (np.abs(sm).sum() == 0.0):
                        continue

                    if not self.pair_to_locations.get((drug_name, cell_line), []):
                        continue

                    chosen.append((drug_name, cell_line))
                    seen_drugs.add(drug_name)
                    if len(chosen) >= self.batch_size:
                        break

            if len(chosen) < self.batch_size:
                continue

            input_ids = np.full((self.batch_size, seq_len), self.pad_id, dtype=np.int64)
            values    = np.zeros((self.batch_size, seq_len), dtype=np.float32)
            attn      = np.zeros((self.batch_size, seq_len), dtype=np.int64)

            input_ids[:, 0] = self.cls_id
            input_ids[:, 1] = self.organtok_id
            attn[:, 0:2] = 1

            y_batch       = np.zeros((self.batch_size, self.m_tgt), dtype=np.float32)
            smiles_batch  = np.zeros((self.batch_size, self.smiles_dim), dtype=np.float32)
            drug_id_batch = np.zeros((self.batch_size,), dtype=np.int64)
            organ_id_batch = np.zeros((self.batch_size,), dtype=np.int64)

            row_ptr = 0
            built_any = False

            for (drug_name, cell_line) in chosen:
                locs = self.pair_to_locations.get((drug_name, cell_line), [])
                if not locs:
                    continue

                baseline = self.baseline_by_cellline.get(cell_line, self.baseline_global)
                did = int(self.drug2id.get(drug_name, 0))
                oid = int(self.cellline2organid.get(cell_line, self.unk_organ_id))

                y_vec = self.drug_to_target_vec_target_only[drug_name]
                sm_vec = self.drug_to_smiles_np[drug_name]

                ok_row = False
                for _ in range(self.max_tries_per_pair):
                    fpath, rg_id = locs[rng.integers(0, len(locs))]
                    df = self._read_row_group_df(fpath, rg_id, columns=cols)

                    df = df[(df[self.drug_col].astype(str) == drug_name) &
                            (df[self.cell_col].astype(str) == cell_line)]
                    if len(df) == 0:
                        continue

                    r = df.sample(1, random_state=None).itertuples(index=False).__next__()

                    ok_row = self._fill_one_row(
                        getattr(r, self.genes_col),
                        getattr(r, self.expr_col),
                        baseline,
                        input_ids[row_ptr], values[row_ptr], attn[row_ptr]
                    )
                    if ok_row:
                        y_batch[row_ptr] = y_vec
                        smiles_batch[row_ptr] = sm_vec
                        drug_id_batch[row_ptr] = did
                        organ_id_batch[row_ptr] = oid
                        row_ptr += 1
                        built_any = True
                        break  # ✅ break only if ok_row

                if row_ptr >= self.batch_size:
                    break

            if not built_any:
                continue

            if row_ptr < self.batch_size:
                fill = self.batch_size - row_ptr
                input_ids[row_ptr:]      = input_ids[:fill]
                values[row_ptr:]         = values[:fill]
                attn[row_ptr:]           = attn[:fill]
                y_batch[row_ptr:]        = y_batch[:fill]
                smiles_batch[row_ptr:]   = smiles_batch[:fill]
                drug_id_batch[row_ptr:]  = drug_id_batch[:fill]
                organ_id_batch[row_ptr:] = organ_id_batch[:fill]

            yield {
                "input_ids": torch.tensor(input_ids, dtype=torch.long),
                "values": torch.tensor(values, dtype=torch.float32),
                "attention_mask": torch.tensor(attn, dtype=torch.long),
                "y_targets": torch.tensor(y_batch, dtype=torch.float32),
                "smiles_emb": torch.tensor(smiles_batch, dtype=torch.float32),
                "drug_id": torch.tensor(drug_id_batch, dtype=torch.long),
                "organ_id": torch.tensor(organ_id_batch, dtype=torch.long),
            }

            cnt += 1
            if cnt >= self.steps:
                return


train_ds = TahoeFPParquetDataset_UniqueDrug(
    pair_to_locations=pair_to_locations,
    pairs=train_pairs,
    baseline_global=baseline_global,
    baseline_by_cellline=baseline_by_cl,
    drug_to_target_vec_target_only=drug_to_target_vec_tgt,
    drug2id=drug2id,
    drug_to_smiles_np=drug_to_smiles_np,
    cellline2organid=cellline2organid,
    unk_organ_id=UNK_ORGAN_ID,
    n_genes_full=N_GENES,
    steps=STEPS_PER_EPOCH,
    max_seq_len=MAX_SEQ_LEN,
    batch_size=BATCH_SIZE,
    control_drug=CONTROL_DRUG,
    pad_id=PAD_ID,
    cls_id=CLS_ID,
    organtok_id=ORGAN_TOK_ID,
    pair_weights=w_train,
    seed=SEED,
    subset_token_ids_np=subset_token_ids_np,
    old_tid_to_vocab_lut=old_tid_to_vocab_lut,
    m_tgt=M_TGT,
    drop_first_gene_token=DROP_FIRST_GENE_TOKEN,
    use_log1p_expr=USE_LOG1P_EXPR,
    use_asinh_delta=USE_ASINH_DELTA,
    delta_clip_abs=DELTA_CLIP_ABS,
    stable_sort_selected_by_gene_id=True,
)

val_ds = TahoeFPParquetDataset_UniqueDrug(
    pair_to_locations=pair_to_locations,
    pairs=val_pairs,
    baseline_global=baseline_global,
    baseline_by_cellline=baseline_by_cl,
    drug_to_target_vec_target_only=drug_to_target_vec_tgt,
    drug2id=drug2id,
    drug_to_smiles_np=drug_to_smiles_np,
    cellline2organid=cellline2organid,
    unk_organ_id=UNK_ORGAN_ID,
    n_genes_full=N_GENES,
    steps=VAL_STEPS,
    max_seq_len=MAX_SEQ_LEN,
    batch_size=BATCH_SIZE,
    control_drug=CONTROL_DRUG,
    pad_id=PAD_ID,
    cls_id=CLS_ID,
    organtok_id=ORGAN_TOK_ID,
    pair_weights=w_val,
    seed=SEED + 123,
    subset_token_ids_np=subset_token_ids_np,
    old_tid_to_vocab_lut=old_tid_to_vocab_lut,
    m_tgt=M_TGT,
    drop_first_gene_token=DROP_FIRST_GENE_TOKEN,
    use_log1p_expr=USE_LOG1P_EXPR,
    use_asinh_delta=USE_ASINH_DELTA,
    delta_clip_abs=DELTA_CLIP_ABS,
    stable_sort_selected_by_gene_id=True,
)

train_loader = DataLoader(
    train_ds,
    batch_size=None,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=(NUM_WORKERS > 0),
    prefetch_factor=2 if NUM_WORKERS > 0 else None,
)
val_loader = DataLoader(
    val_ds,
    batch_size=None,
    num_workers=0,
    pin_memory=True,
)


# =========================================================
# 12) Model: ORGAN embedding + NO positional emb + learnable CLIP temperature
# =========================================================
class FPEncoderWithOrgan(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, num_layers, pad_id,
                 max_len: int, num_organs: int, organ_pos: int = 1, use_pos_emb: bool = False):
        super().__init__()
        self.token_emb  = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.value_proj = nn.Linear(1, d_model)

        self.use_pos_emb = bool(use_pos_emb)
        if self.use_pos_emb:
            self.pos_emb = nn.Embedding(max_len, d_model)

        self.organ_emb = nn.Embedding(num_organs, d_model)
        self.organ_pos = int(organ_pos)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=4*d_model,
            dropout=0.1, batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

    def forward(self, input_ids, values, attention_mask, organ_id):
        B, L = input_ids.shape
        dev = input_ids.device

        x = self.token_emb(input_ids) + self.value_proj(values.unsqueeze(-1))

        if self.use_pos_emb:
            pos = torch.arange(L, device=dev).unsqueeze(0).expand(B, L)
            x = x + self.pos_emb(pos)

        if organ_id is not None:
            x[:, self.organ_pos, :] = x[:, self.organ_pos, :] + self.organ_emb(organ_id.to(dev)).to(x.dtype)

        key_padding_mask = (attention_mask == 0)
        h = self.encoder(x, src_key_padding_mask=key_padding_mask)
        return h[:, 0, :]


class FPModelTied_OrganCLIP(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, num_layers, pad_id, smiles_dim,
                 max_len: int, num_organs: int, n_special: int, tau_init: float = 0.10):
        super().__init__()
        self.n_special = int(n_special)

        self.encoder = FPEncoderWithOrgan(
            vocab_size=vocab_size,
            d_model=d_model,
            n_heads=n_heads,
            num_layers=num_layers,
            pad_id=pad_id,
            max_len=max_len,
            num_organs=num_organs,
            organ_pos=1,          # [CLS][ORGAN]
            use_pos_emb=False,   
        )
        self.proj = nn.Linear(d_model, d_model)

        self.smiles_head = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.BatchNorm1d(4*d_model),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(4*d_model, smiles_dim),
        )

        # CLIP logit scale: logits = (z1 @ z2.T) * exp(logit_scale)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1.0 / float(tau_init)))

    def gene_emb_subset(self):
        return self.encoder.token_emb.weight[self.n_special:, :] 

    def get_tau(self):
        # tau = 1/exp(scale)
        return (1.0 / self.logit_scale.exp()).clamp(0.01, 0.5)

    def forward(self, input_ids, values, attention_mask, organ_id, return_smiles=False):
        h_cls = self.encoder(input_ids, values, attention_mask, organ_id=organ_id)
        v_pred = self.proj(h_cls)
        z_pred = self.smiles_head(h_cls)
        if return_smiles:
            return v_pred, z_pred
        return v_pred


D_MODEL = 256
N_HEADS = 8
N_LAYERS = 4

model = FPModelTied_OrganCLIP(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    num_layers=N_LAYERS,
    pad_id=PAD_ID,
    smiles_dim=SMILES_DIM,
    max_len=(2 + MAX_SEQ_LEN),
    num_organs=NUM_ORGANS,
    n_special=N_SPECIAL,
    tau_init=TAU_INIT,
).to(device)


# =========================================================
# 13) Load pretrained gene embeddings into subset token emb
# =========================================================
def load_pretrained_subset_into_token_emb(token_emb: nn.Embedding, npy_path: str, device):
    if (npy_path is None) or (not os.path.exists(npy_path)):
        print("⚠️ PRETRAINED_GENE_NPY not found. Skip loading.")
        return
    W = np.load(npy_path)  # (N_GENES, d)
    Wt = torch.tensor(W, dtype=torch.float32, device=device)
    d = token_emb.weight.shape[1]
    if Wt.shape[1] != d:
        raise ValueError(f"d mismatch: npy={Wt.shape[1]} vs token_emb={d}")

    loaded = 0
    with torch.no_grad():
        for sid, old_tid in enumerate(subset_token_ids):
            vid = N_SPECIAL + sid
            if 0 <= old_tid < Wt.shape[0]:
                token_emb.weight[vid].copy_(Wt[int(old_tid)])
                loaded += 1
    print(f"✅ token_emb loaded: {loaded}/{len(subset_token_ids)}")

load_pretrained_subset_into_token_emb(model.encoder.token_emb, PRETRAINED_GENE_NPY, device=device)


# =========================================================
# 14) target_sub_ids
# =========================================================
target_sub_ids = torch.tensor([old_tid_to_subid[tid] for tid in target_token_ids],
                              dtype=torch.long, device=device)
print("[target_sub_ids]:", tuple(target_sub_ids.shape))


# =========================================================
# 15) pos_weight (TARGET-ONLY) 
# =========================================================
pos_weight = torch.ones((M_TGT,), dtype=torch.float32, device=device)


# =========================================================
# 16) Losses (Targets + SMILES CLIP)
# =========================================================
def info_nce_ranking_loss_multi_pos(
    v_pred: torch.Tensor,
    gene_emb: torch.Tensor,
    y_targets: torch.Tensor,
    num_neg: int = 256,
    num_pos: int = 8,
    tau: float = 0.1,
):
    device_ = v_pred.device
    B, _ = v_pred.shape
    losses = []

    v_pred = F.normalize(v_pred, dim=1)
    gene_emb = F.normalize(gene_emb, dim=1)

    for i in range(B):
        pos_idx = (y_targets[i] > 0.5).nonzero(as_tuple=True)[0]
        if pos_idx.numel() == 0:
            continue

        neg_idx_all = (y_targets[i] < 0.5).nonzero(as_tuple=True)[0]
        if neg_idx_all.numel() == 0:
            continue

        if num_pos and pos_idx.numel() > num_pos:
            pos_idx = pos_idx[torch.randperm(pos_idx.numel(), device=device_)[:num_pos]]

        if neg_idx_all.numel() > num_neg:
            neg_idx = neg_idx_all[torch.randperm(neg_idx_all.numel(), device=device_)[:num_neg]]
        else:
            neg_idx = neg_idx_all

        pos_emb = gene_emb[pos_idx]
        neg_emb = gene_emb[neg_idx]
        cand_emb = torch.cat([pos_emb, neg_emb], dim=0)

        v = v_pred[i].unsqueeze(0)
        scores = (v @ cand_emb.T).squeeze(0) / tau

        P = pos_emb.size(0)
        logits = scores.unsqueeze(0).repeat(P, 1)
        targets = torch.arange(P, device=device_, dtype=torch.long)
        losses.append(F.cross_entropy(logits, targets))

    if len(losses) == 0:
        return torch.tensor(0.0, device=device_)
    return torch.stack(losses).mean()


def bce_with_neg_sampling_cosine(
    pred_vec: torch.Tensor,
    y_targets: torch.Tensor,
    gene_emb: torch.Tensor,
    pos_weight_full: torch.Tensor,
    num_neg: int = 2048,
    pos_cap: int | None = None,
    tau_bce: float = 0.10,
):
    device_ = pred_vec.device
    B, _ = pred_vec.shape
    losses = []

    gene_emb = F.normalize(gene_emb, dim=1)

    for i in range(B):
        yi = y_targets[i]
        pos_idx = (yi > 0.5).nonzero(as_tuple=True)[0]
        if pos_idx.numel() == 0:
            continue

        if (pos_cap is not None) and (pos_idx.numel() > pos_cap):
            pos_idx = pos_idx[torch.randperm(pos_idx.numel(), device=device_)[:pos_cap]]

        neg_idx_all = (yi < 0.5).nonzero(as_tuple=True)[0]
        if neg_idx_all.numel() == 0:
            continue

        k = min(int(num_neg), neg_idx_all.numel())
        neg_idx = neg_idx_all[torch.randperm(neg_idx_all.numel(), device=device_)[:k]]

        idx = torch.cat([pos_idx, neg_idx], dim=0)

        v = F.normalize(pred_vec[i], dim=0)
        logits = (v @ gene_emb[idx].T) / tau_bce

        y_sub = yi[idx]
        pw_sub = pos_weight_full[idx]

        losses.append(F.binary_cross_entropy_with_logits(logits, y_sub, pos_weight=pw_sub, reduction="mean"))

    if len(losses) == 0:
        return torch.tensor(0.0, device=device_)
    return torch.stack(losses).mean()


def combined_target_loss_neg_sampling_tied(
    pred_vec: torch.Tensor,
    y_targets: torch.Tensor,
    gene_emb: torch.Tensor,
    pos_weight: torch.Tensor,
    lambda_cos: float = 1.0,
    lambda_bce: float = 0.1,
    lambda_rank: float = 0.5,
    bce_num_neg: int = 2048,
    bce_pos_cap: int | None = None,
    rank_num_neg: int = 256,
    rank_num_pos: int = 8,
    tau_rank: float = 0.1,
    tau_bce: float = 0.10,
):
    device_ = pred_vec.device

    gene_emb_norm = F.normalize(gene_emb, dim=1)
    pred_norm = F.normalize(pred_vec, dim=1)

    true_vec = y_targets @ gene_emb_norm
    num_t = y_targets.sum(dim=1, keepdim=True)
    mask = (num_t > 0).squeeze(1)

    if mask.any():
        true_vec_pos = true_vec[mask] / (num_t[mask] + 1e-6)
        true_vec_pos = F.normalize(true_vec_pos, dim=1)
        pred_pos = pred_norm[mask]
        loss_cos = 1.0 - (pred_pos * true_vec_pos).sum(dim=1).mean()
    else:
        loss_cos = torch.tensor(0.0, device=device_)

    loss_bce = bce_with_neg_sampling_cosine(
        pred_vec=pred_vec,
        y_targets=y_targets,
        gene_emb=gene_emb,
        pos_weight_full=pos_weight,
        num_neg=bce_num_neg,
        pos_cap=bce_pos_cap,
        tau_bce=tau_bce,
    )

    loss_rank = info_nce_ranking_loss_multi_pos(
        v_pred=pred_vec,
        gene_emb=gene_emb,
        y_targets=y_targets,
        num_neg=rank_num_neg,
        num_pos=rank_num_pos,
        tau=tau_rank,
    )

    loss = lambda_cos * loss_cos + lambda_bce * loss_bce + lambda_rank * loss_rank
    return loss, loss_cos.detach(), loss_bce.detach(), loss_rank.detach()


# --- SMILES CLIP loss (+ optional cosine align) ---
def clip_loss(z_pred: torch.Tensor, z_true: torch.Tensor, tau: torch.Tensor):
    z1 = F.normalize(z_pred, dim=1)
    z2 = F.normalize(z_true, dim=1)
    logits = (z1 @ z2.T) / tau
    labels = torch.arange(z_pred.size(0), device=z_pred.device, dtype=torch.long)
    return 0.5 * (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels))

def smiles_align_loss_cosine(z_pred: torch.Tensor, z_true: torch.Tensor):
    z1 = F.normalize(z_pred, dim=1)
    z2 = F.normalize(z_true, dim=1)
    return 1.0 - (z1 * z2).sum(dim=1).mean()


# =========================================================
# 17) Eval
# =========================================================
def compute_recall_precision_at_k(scores: torch.Tensor, y_true: torch.Tensor, k: int = 20):
    B, M = scores.shape
    kk = min(k, M)
    _, topk_idx = torch.topk(scores, k=kk, dim=1)

    recalls, precisions = [], []
    for i in range(B):
        true_labels = y_true[i]
        num_pos_ = true_labels.sum().item()
        if num_pos_ == 0:
            continue
        topk = topk_idx[i]
        num_pos_in_topk = true_labels[topk].sum().item()
        recalls.append(num_pos_in_topk / max(num_pos_, 1e-6))
        precisions.append(num_pos_in_topk / max(kk, 1))

    if len(recalls) == 0:
        return 0.0, 0.0
    return float(sum(recalls) / len(recalls)), float(sum(precisions) / len(precisions))

@torch.no_grad()
def smiles_retrieval_hitk(z_pred: torch.Tensor, drug_id: torch.Tensor, smiles_bank_t: torch.Tensor, k_list=(1,5,10)):
    z = F.normalize(z_pred.float(), dim=1)
    b = F.normalize(smiles_bank_t.float(), dim=1)
    logits = z @ b.T
    out = {}
    for k in k_list:
        topk = torch.topk(logits, k=min(k, logits.size(1)), dim=1).indices
        hit = (topk == drug_id.view(-1,1)).any(dim=1).float().mean().item()
        out[f"Hit@{k}"] = float(hit)
    true_vec = b[drug_id]
    out["TrueCos"] = float((z * true_vec).sum(dim=1).mean().item())
    return out

@torch.no_grad()
def evaluate_fp(model, loader, device, target_sub_ids, smiles_bank_t, k_list=(5,10), hitk=(1,5,10)):
    model.eval()

    gene_emb = model.gene_emb_subset()[target_sub_ids].to(device)
    g_norm = F.normalize(gene_emb, dim=1)

    recall_sums = {k: 0.0 for k in k_list}
    prec_sums   = {k: 0.0 for k in k_list}
    counts_     = {k: 0   for k in k_list}

    hit_sums = {f"Hit@{k}": 0.0 for k in hitk}
    hit_sums["TrueCos"] = 0.0
    clip_sum = 0.0
    tau_sum  = 0.0
    n = 0

    for batch in loader:
        input_ids = batch["input_ids"].to(device, non_blocking=True)
        values    = batch["values"].to(device, non_blocking=True)
        attn      = batch["attention_mask"].to(device, non_blocking=True)
        y_targets = batch["y_targets"].to(device, non_blocking=True)
        z_true    = batch["smiles_emb"].to(device, non_blocking=True)
        drug_id   = batch["drug_id"].to(device, non_blocking=True)
        organ_id  = batch["organ_id"].to(device, non_blocking=True)

        v_pred, z_pred = model(input_ids, values, attn, organ_id=organ_id, return_smiles=True)
        v_norm = F.normalize(v_pred, dim=1)
        scores = v_norm @ g_norm.T

        for k in k_list:
            r, p = compute_recall_precision_at_k(scores, y_targets, k=k)
            recall_sums[k] += r
            prec_sums[k]   += p
            counts_[k]     += 1

        bs = input_ids.size(0)
        m = smiles_retrieval_hitk(z_pred, drug_id, smiles_bank_t, k_list=hitk)
        for k in hitk:
            hit_sums[f"Hit@{k}"] += m[f"Hit@{k}"] * bs
        hit_sums["TrueCos"] += m["TrueCos"] * bs

        tau = model.get_tau()
        clip_sum += float(clip_loss(z_pred, z_true, tau=tau).item()) * bs
        tau_sum  += float(tau.item()) * bs

        n += bs

    out = {}
    for k in k_list:
        out[f"Recall@{k}"] = recall_sums[k] / max(counts_[k], 1)
        out[f"Precision@{k}"] = prec_sums[k] / max(counts_[k], 1)

    for k in hitk:
        out[f"SMILES_Hit@{k}"] = hit_sums[f"Hit@{k}"] / max(n, 1)
    out["SMILES_TrueCos"] = hit_sums["TrueCos"] / max(n, 1)
    out["SMILES_CLIP"] = clip_sum / max(n, 1)
    out["tau"] = tau_sum / max(n, 1)
    return out


# =========================================================
# 18) Train loop (OneCycleLR fixed + grad accumulation + AMP)
# =========================================================
def infinite_loader(loader):
    while True:
        for b in loader:
            yield b

USE_AMP = (device.type == "cuda")
scaler = torch.amp.GradScaler("cuda", enabled=USE_AMP)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

updates_per_epoch = max(1, STEPS_PER_EPOCH // max(1, ACCUM_STEPS))
total_updates = EPOCHS * updates_per_epoch

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=LR,
    total_steps=total_updates,
    pct_start=0.05,
    anneal_strategy="cos",
    div_factor=10.0,
    final_div_factor=100.0,
)

smiles_bank_t = torch.tensor(smiles_bank_np, dtype=torch.float32, device=device)

def train_one_epoch_fixed_steps(
    model,
    train_loader,
    device,
    steps_per_epoch,
    optimizer,
    scheduler,
    scaler,
    target_sub_ids,
    pos_weight,
    smiles_bank_t,
    log_every=50,
    grad_clip=1.0,
    accum_steps=1,
):
    model.train()
    it = infinite_loader(train_loader)

    run_total = 0.0
    run_tgt   = 0.0
    run_clip  = 0.0
    run_align = 0.0
    run_rank_last = 0.0

    run_hit5 = 0.0
    run_truecos = 0.0
    n = 0

    optimizer.zero_grad(set_to_none=True)
    update_count = 0

    pbar = tqdm(range(1, steps_per_epoch + 1), desc="Train", leave=True, dynamic_ncols=True)

    for step in pbar:
        batch = next(it)

        input_ids = batch["input_ids"].to(device, non_blocking=True)
        values    = batch["values"].to(device, non_blocking=True)
        attn      = batch["attention_mask"].to(device, non_blocking=True)
        y_targets = batch["y_targets"].to(device, non_blocking=True)
        z_true    = batch["smiles_emb"].to(device, non_blocking=True)
        drug_id   = batch["drug_id"].to(device, non_blocking=True)
        organ_id  = batch["organ_id"].to(device, non_blocking=True)

        bs = input_ids.size(0)
        n += bs

        if USE_AMP:
            with torch.amp.autocast("cuda", enabled=True):
                v_pred, z_pred = model(input_ids, values, attn, organ_id=organ_id, return_smiles=True)

                # --- Targets ---
                gene_emb = model.gene_emb_subset()[target_sub_ids]  # (M_TGT, d)
                loss_targets, loss_cos_t, loss_bce_t, loss_rank_t = combined_target_loss_neg_sampling_tied(
                    pred_vec=v_pred,
                    y_targets=y_targets,
                    gene_emb=gene_emb,
                    pos_weight=pos_weight,
                    lambda_cos=lambda_cos,
                    lambda_bce=lambda_bce,
                    lambda_rank=lambda_rank,
                    bce_num_neg=bce_num_neg,
                    bce_pos_cap=bce_pos_cap,
                    rank_num_neg=rank_num_neg,
                    rank_num_pos=rank_num_pos,
                    tau_rank=tau_rank,
                    tau_bce=tau_bce,
                )

                # --- SMILES CLIP ---
                tau = model.get_tau()
                loss_c = clip_loss(z_pred, z_true, tau=tau)
                loss_a = smiles_align_loss_cosine(z_pred, z_true)
                loss_smiles = loss_c + alpha_align * loss_a

                loss = (lambda_targets * loss_targets + lambda_smiles * loss_smiles) / float(accum_steps)

            if not torch.isfinite(loss).all():
                optimizer.zero_grad(set_to_none=True)
                continue

            scaler.scale(loss).backward()

            do_update = (step % accum_steps) == 0
            if do_update:
                scaler.unscale_(optimizer)
                if grad_clip is not None and grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

                scheduler.step()
                update_count += 1

        else:
            v_pred, z_pred = model(input_ids, values, attn, organ_id=organ_id, return_smiles=True)

            gene_emb = model.gene_emb_subset()[target_sub_ids]
            loss_targets, loss_cos_t, loss_bce_t, loss_rank_t = combined_target_loss_neg_sampling_tied(
                pred_vec=v_pred,
                y_targets=y_targets,
                gene_emb=gene_emb,
                pos_weight=pos_weight,
                lambda_cos=lambda_cos,
                lambda_bce=lambda_bce,
                lambda_rank=lambda_rank,
                bce_num_neg=bce_num_neg,
                bce_pos_cap=bce_pos_cap,
                rank_num_neg=rank_num_neg,
                rank_num_pos=rank_num_pos,
                tau_rank=tau_rank,
                tau_bce=tau_bce,
            )

            tau = model.get_tau()
            loss_c = clip_loss(z_pred, z_true, tau=tau)
            loss_a = smiles_align_loss_cosine(z_pred, z_true)
            loss_smiles = loss_c + alpha_align * loss_a

            loss = (lambda_targets * loss_targets + lambda_smiles * loss_smiles) / float(accum_steps)

            if not torch.isfinite(loss).all():
                optimizer.zero_grad(set_to_none=True)
                continue

            loss.backward()

            do_update = (step % accum_steps) == 0
            if do_update:
                if grad_clip is not None and grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

                scheduler.step()
                update_count += 1

        # logging metrics (cheap)
        with torch.no_grad():
            m = smiles_retrieval_hitk(z_pred, drug_id, smiles_bank_t, k_list=(5,))
            run_hit5 += float(m["Hit@5"]) * bs
            run_truecos += float(m["TrueCos"]) * bs

        run_total += float((lambda_targets * loss_targets + lambda_smiles * loss_smiles).item()) * bs
        run_tgt   += float(loss_targets.item()) * bs
        run_clip  += float(loss_c.item()) * bs
        run_align += float(loss_a.item()) * bs
        run_rank_last = float(loss_rank_t.item())

        if step % log_every == 0:
            lr_now = optimizer.param_groups[0]["lr"]
            pbar.set_postfix({
                "lr": f"{lr_now:.2e}",
                "tau": f"{float(model.get_tau().item()):.3f}",
                "tot": f"{run_total/max(n,1):.4f}",
                "tgt": f"{run_tgt/max(n,1):.4f}",
                "clip": f"{run_clip/max(n,1):.4f}",
                "align": f"{run_align/max(n,1):.4f}",
                "rank(last)": f"{run_rank_last:.4f}",
                "Hit@5": f"{run_hit5/max(n,1):.3f}",
                "TrueCos": f"{run_truecos/max(n,1):.3f}",
                "upd": f"{update_count}/{updates_per_epoch}",
            })

    return {
        "train_total": run_total / max(n,1),
        "train_tgt": run_tgt / max(n,1),
        "train_clip": run_clip / max(n,1),
        "train_align": run_align / max(n,1),
        "train_hit5": run_hit5 / max(n,1),
        "train_truecos": run_truecos / max(n,1),
        "rank_last": run_rank_last,
        "tau": float(model.get_tau().item()),
        "lr": optimizer.param_groups[0]["lr"],
        "updates": update_count,
    }


# =========================================================
# 19) TRAIN
# =========================================================
print(">>> TRAIN START: FP(TARGET-ONLY) + ORGAN + SMILES CLIP | Variant A(log1p->delta->clip) | NO pos_emb")

for epoch in range(1, EPOCHS + 1):
    logs = train_one_epoch_fixed_steps(
        model=model,
        train_loader=train_loader,
        device=device,
        steps_per_epoch=STEPS_PER_EPOCH,
        optimizer=optimizer,
        scheduler=scheduler,
        scaler=scaler,
        target_sub_ids=target_sub_ids,
        pos_weight=pos_weight,
        smiles_bank_t=smiles_bank_t,
        log_every=50,
        grad_clip=MAX_GRAD_NORM,
        accum_steps=max(1, int(ACCUM_STEPS)),
    )

    print(
        f"\n[Epoch {epoch}/{EPOCHS}] "
        f"lr={logs['lr']:.2e} | tau={logs['tau']:.3f} | "
        f"train_total={logs['train_total']:.4f} | "
        f"train_tgt={logs['train_tgt']:.4f} | "
        f"train_clip={logs['train_clip']:.4f} | "
        f"train_align={logs['train_align']:.4f} | "
        f"rank_last={logs['rank_last']:.4f} | "
        f"Hit@5={logs['train_hit5']:.3f} | TrueCos={logs['train_truecos']:.3f}"
    )

    valid = evaluate_fp(
        model=model,
        loader=val_loader,
        device=device,
        target_sub_ids=target_sub_ids,
        smiles_bank_t=smiles_bank_t,
        k_list=(5,10),
        hitk=(1,5,10),
    )
    print("✅ VALID:", valid)

print(">>> DONE")


[targets] drugs total=379, with>=1 target=264
[HVG] token_ids: 4000
[subset] HVG ∪ TARGETS size: 4184
[target-only] size: 278
[vocab] VOCAB_SIZE: 4188 | N_SPECIAL: 4
[targets] drugs with>=1 target vec: 264
[organ] NUM_ORGANS: 16 | mapped cell_lines: 102 | UNK_ORGAN_ID: 0
[SMILES] missing=0/379 | zero_vec=2
[SMILES] bank: (379, 768)
[baseline] global: (62713,) | by_cl: 50
[split] train pairs: 10505 | val pairs: 1168
[parquet] files: 3388


Index parquet row-groups: 100%|██████████| 3388/3388 [13:31<00:00,  4.18it/s]


[parquet] indexed pairs: 11673
✅ token_emb loaded: 4184/4184
[target_sub_ids]: (278,)
>>> TRAIN START: FP(TARGET-ONLY) + ORGAN + SMILES CLIP | Variant A(log1p->delta->clip) | NO pos_emb


Train: 100%|██████████| 7000/7000 [2:47:20<00:00,  1.43s/it, lr=1.00e-04, tau=0.106, tot=6.4425, tgt=6.1915, clip=4.8591, align=0.3241, rank(last)=5.1235, Hit@5=0.018, TrueCos=0.676, upd=1750/1750]  



[Epoch 1/20] lr=1.00e-04 | tau=0.106 | train_total=6.4425 | train_tgt=6.1915 | train_clip=4.8591 | train_align=0.3241 | rank_last=5.1235 | Hit@5=0.018 | TrueCos=0.676


  output = torch._nested_tensor_from_mask(


✅ VALID: {'Recall@5': 0.13453268006588323, 'Precision@5': 0.03774479166666671, 'Recall@10': 0.21400091495599294, 'Precision@10': 0.03110416666666666, 'SMILES_Hit@1': 0.011770833333333333, 'SMILES_Hit@5': 0.03763020833333333, 'SMILES_Hit@10': 0.055546875, 'SMILES_TrueCos': 0.761170000632604, 'SMILES_CLIP': 4.818466658592224, 'tau': 0.10604571551084518}


Train: 100%|██████████| 7000/7000 [3:02:26<00:00,  1.56s/it, lr=9.93e-05, tau=0.091, tot=6.0259, tgt=5.7825, clip=4.7293, align=0.2790, rank(last)=4.9658, Hit@5=0.062, TrueCos=0.721, upd=1750/1750]  



[Epoch 2/20] lr=9.93e-05 | tau=0.091 | train_total=6.0259 | train_tgt=5.7825 | train_clip=4.7293 | train_align=0.2790 | rank_last=4.9658 | Hit@5=0.062 | TrueCos=0.721
✅ VALID: {'Recall@5': 0.1812032657013125, 'Precision@5': 0.05100000000000001, 'Recall@10': 0.2784314728963165, 'Precision@10': 0.04054687500000003, 'SMILES_Hit@1': 0.029817708333333335, 'SMILES_Hit@5': 0.07591145833333333, 'SMILES_Hit@10': 0.11114583333333333, 'SMILES_TrueCos': 0.6834968763589859, 'SMILES_CLIP': 4.665763538678487, 'tau': 0.0910659059882164}


Train: 100%|██████████| 7000/7000 [3:06:25<00:00,  1.60s/it, lr=9.73e-05, tau=0.077, tot=5.7821, tgt=5.5438, clip=4.5923, align=0.3492, rank(last)=4.8268, Hit@5=0.104, TrueCos=0.651, upd=1750/1750]  



[Epoch 3/20] lr=9.73e-05 | tau=0.077 | train_total=5.7821 | train_tgt=5.5438 | train_clip=4.5923 | train_align=0.3492 | rank_last=4.8268 | Hit@5=0.104 | TrueCos=0.651
✅ VALID: {'Recall@5': 0.2047391628192409, 'Precision@5': 0.05968229166666662, 'Recall@10': 0.31234690584299996, 'Precision@10': 0.046921874999999995, 'SMILES_Hit@1': 0.051744791666666665, 'SMILES_Hit@5': 0.115703125, 'SMILES_Hit@10': 0.15895833333333334, 'SMILES_TrueCos': 0.620230129758517, 'SMILES_CLIP': 4.5414337539672855, 'tau': 0.07650987803936005}


Train: 100%|██████████| 7000/7000 [2:58:05<00:00,  1.53s/it, lr=9.40e-05, tau=0.066, tot=5.6183, tgt=5.3850, clip=4.4654, align=0.4042, rank(last)=4.6868, Hit@5=0.138, TrueCos=0.596, upd=1750/1750]  



[Epoch 4/20] lr=9.40e-05 | tau=0.066 | train_total=5.6183 | train_tgt=5.3850 | train_clip=4.4654 | train_align=0.4042 | rank_last=4.6868 | Hit@5=0.138 | TrueCos=0.596
✅ VALID: {'Recall@5': 0.22050381562881563, 'Precision@5': 0.06540104166666663, 'Recall@10': 0.3277380745701059, 'Precision@10': 0.050192708333333357, 'SMILES_Hit@1': 0.06661458333333334, 'SMILES_Hit@5': 0.14067708333333334, 'SMILES_Hit@10': 0.18861979166666668, 'SMILES_TrueCos': 0.5752571612596512, 'SMILES_CLIP': 4.431879811286926, 'tau': 0.06598414480686188}


Train: 100%|██████████| 7000/7000 [3:07:13<00:00,  1.60s/it, lr=8.95e-05, tau=0.057, tot=5.4982, tgt=5.2693, clip=4.3565, align=0.4428, rank(last)=4.6511, Hit@5=0.164, TrueCos=0.557, upd=1750/1750]  



[Epoch 5/20] lr=8.95e-05 | tau=0.057 | train_total=5.4982 | train_tgt=5.2693 | train_clip=4.3565 | train_align=0.4428 | rank_last=4.6511 | Hit@5=0.164 | TrueCos=0.557
✅ VALID: {'Recall@5': 0.23745490244709003, 'Precision@5': 0.07045312500000005, 'Recall@10': 0.3467131688479345, 'Precision@10': 0.05305468750000002, 'SMILES_Hit@1': 0.07565104166666667, 'SMILES_Hit@5': 0.16158854166666667, 'SMILES_Hit@10': 0.21513020833333332, 'SMILES_TrueCos': 0.5383901741107305, 'SMILES_CLIP': 4.329053503672282, 'tau': 0.05719948932528496}


Train: 100%|██████████| 7000/7000 [2:23:35<00:00,  1.23s/it, lr=8.39e-05, tau=0.050, tot=5.3934, tgt=5.1692, clip=4.2465, align=0.4753, rank(last)=4.5532, Hit@5=0.189, TrueCos=0.525, upd=1750/1750]  



[Epoch 6/20] lr=8.39e-05 | tau=0.050 | train_total=5.3934 | train_tgt=5.1692 | train_clip=4.2465 | train_align=0.4753 | rank_last=4.5532 | Hit@5=0.189 | TrueCos=0.525
✅ VALID: {'Recall@5': 0.24949122723341488, 'Precision@5': 0.07468229166666664, 'Recall@10': 0.36050634030321543, 'Precision@10': 0.055804687499999964, 'SMILES_Hit@1': 0.08911458333333333, 'SMILES_Hit@5': 0.19122395833333333, 'SMILES_Hit@10': 0.250546875, 'SMILES_TrueCos': 0.5103278501828512, 'SMILES_CLIP': 4.231426575183868, 'tau': 0.05004861205816269}


Train: 100%|██████████| 7000/7000 [2:14:16<00:00,  1.15s/it, lr=7.74e-05, tau=0.044, tot=5.3107, tgt=5.0908, clip=4.1492, align=0.4995, rank(last)=4.5271, Hit@5=0.210, TrueCos=0.501, upd=1750/1750]  



[Epoch 7/20] lr=7.74e-05 | tau=0.044 | train_total=5.3107 | train_tgt=5.0908 | train_clip=4.1492 | train_align=0.4995 | rank_last=4.5271 | Hit@5=0.210 | TrueCos=0.501
✅ VALID: {'Recall@5': 0.25910566350605413, 'Precision@5': 0.07683333333333335, 'Recall@10': 0.3705589896214895, 'Precision@10': 0.056846354166666654, 'SMILES_Hit@1': 0.09348958333333333, 'SMILES_Hit@5': 0.19731770833333334, 'SMILES_Hit@10': 0.2579427083333333, 'SMILES_TrueCos': 0.49005644301573437, 'SMILES_CLIP': 4.156260814666748, 'tau': 0.04421854019165039}


Train: 100%|██████████| 7000/7000 [3:03:08<00:00,  1.57s/it, lr=7.01e-05, tau=0.040, tot=5.2454, tgt=5.0293, clip=4.0621, align=0.5202, rank(last)=4.4614, Hit@5=0.227, TrueCos=0.480, upd=1750/1750]  



[Epoch 8/20] lr=7.01e-05 | tau=0.040 | train_total=5.2454 | train_tgt=5.0293 | train_clip=4.0621 | train_align=0.5202 | rank_last=4.4614 | Hit@5=0.227 | TrueCos=0.480
✅ VALID: {'Recall@5': 0.2628372523020961, 'Precision@5': 0.07845312499999996, 'Recall@10': 0.37603251551689054, 'Precision@10': 0.058479166666666686, 'SMILES_Hit@1': 0.10135416666666666, 'SMILES_Hit@5': 0.21856770833333333, 'SMILES_Hit@10': 0.28325520833333334, 'SMILES_TrueCos': 0.47121061543623605, 'SMILES_CLIP': 4.096646081606547, 'tau': 0.03951994702219963}


Train: 100%|██████████| 7000/7000 [3:37:14<00:00,  1.86s/it, lr=6.23e-05, tau=0.036, tot=5.1898, tgt=4.9772, clip=3.9836, align=0.5375, rank(last)=4.5038, Hit@5=0.242, TrueCos=0.463, upd=1750/1750]  



[Epoch 9/20] lr=6.23e-05 | tau=0.036 | train_total=5.1898 | train_tgt=4.9772 | train_clip=3.9836 | train_align=0.5375 | rank_last=4.5038 | Hit@5=0.242 | TrueCos=0.463
✅ VALID: {'Recall@5': 0.2748946131270349, 'Precision@5': 0.0824583333333334, 'Recall@10': 0.3857195520769744, 'Precision@10': 0.059971354166666636, 'SMILES_Hit@1': 0.11109375, 'SMILES_Hit@5': 0.23278645833333333, 'SMILES_Hit@10': 0.29932291666666666, 'SMILES_TrueCos': 0.4531615019838015, 'SMILES_CLIP': 4.0104965694745385, 'tau': 0.03574628755450249}


Train: 100%|██████████| 7000/7000 [3:58:35<00:00,  2.05s/it, lr=5.42e-05, tau=0.033, tot=5.1428, tgt=4.9334, clip=3.9123, align=0.5507, rank(last)=4.4189, Hit@5=0.256, TrueCos=0.449, upd=1750/1750]  



[Epoch 10/20] lr=5.42e-05 | tau=0.033 | train_total=5.1428 | train_tgt=4.9334 | train_clip=3.9123 | train_align=0.5507 | rank_last=4.4189 | Hit@5=0.256 | TrueCos=0.449
✅ VALID: {'Recall@5': 0.27848910081527267, 'Precision@5': 0.08344791666666675, 'Recall@10': 0.3921030132338726, 'Precision@10': 0.061320312499999995, 'SMILES_Hit@1': 0.11471354166666667, 'SMILES_Hit@5': 0.243046875, 'SMILES_Hit@10': 0.311875, 'SMILES_TrueCos': 0.4423800575733185, 'SMILES_CLIP': 3.955232696533203, 'tau': 0.032634906470775604}


Train: 100%|██████████| 7000/7000 [3:29:00<00:00,  1.79s/it, lr=4.59e-05, tau=0.030, tot=5.1065, tgt=4.8998, clip=3.8539, align=0.5616, rank(last)=4.4987, Hit@5=0.266, TrueCos=0.438, upd=1750/1750]  



[Epoch 11/20] lr=4.59e-05 | tau=0.030 | train_total=5.1065 | train_tgt=4.8998 | train_clip=3.8539 | train_align=0.5616 | rank_last=4.4987 | Hit@5=0.266 | TrueCos=0.438
✅ VALID: {'Recall@5': 0.2823159555288459, 'Precision@5': 0.08485937500000001, 'Recall@10': 0.3932350697624135, 'Precision@10': 0.06134374999999996, 'SMILES_Hit@1': 0.11958333333333333, 'SMILES_Hit@5': 0.24731770833333333, 'SMILES_Hit@10': 0.316015625, 'SMILES_TrueCos': 0.4323416962226232, 'SMILES_CLIP': 3.9188380829493203, 'tau': 0.03016114979982376}


Train: 100%|██████████| 7000/7000 [3:23:54<00:00,  1.75s/it, lr=3.78e-05, tau=0.028, tot=5.0765, tgt=4.8721, clip=3.8035, align=0.5707, rank(last)=4.3034, Hit@5=0.274, TrueCos=0.429, upd=1750/1750]  



[Epoch 12/20] lr=3.78e-05 | tau=0.028 | train_total=5.0765 | train_tgt=4.8721 | train_clip=3.8035 | train_align=0.5707 | rank_last=4.3034 | Hit@5=0.274 | TrueCos=0.429
✅ VALID: {'Recall@5': 0.28495215519434286, 'Precision@5': 0.08613541666666664, 'Recall@10': 0.4008907776251527, 'Precision@10': 0.06268229166666665, 'SMILES_Hit@1': 0.12125, 'SMILES_Hit@5': 0.2575, 'SMILES_Hit@10': 0.325234375, 'SMILES_TrueCos': 0.42219115207592645, 'SMILES_CLIP': 3.8668323413530987, 'tau': 0.02826393023133278}


Train: 100%|██████████| 7000/7000 [2:16:37<00:00,  1.17s/it, lr=3.00e-05, tau=0.027, tot=5.0510, tgt=4.8485, clip=3.7606, align=0.5776, rank(last)=4.3273, Hit@5=0.281, TrueCos=0.422, upd=1750/1750]  



[Epoch 13/20] lr=3.00e-05 | tau=0.027 | train_total=5.0510 | train_tgt=4.8485 | train_clip=3.7606 | train_align=0.5776 | rank_last=4.3273 | Hit@5=0.281 | TrueCos=0.422
✅ VALID: {'Recall@5': 0.28949838233236663, 'Precision@5': 0.08839062500000006, 'Recall@10': 0.40374855960012207, 'Precision@10': 0.0638854166666667, 'SMILES_Hit@1': 0.12619791666666666, 'SMILES_Hit@5': 0.2640104166666667, 'SMILES_Hit@10': 0.33171875, 'SMILES_TrueCos': 0.4184157766898473, 'SMILES_CLIP': 3.8417517272631327, 'tau': 0.026800617575645447}


Train: 100%|██████████| 7000/7000 [2:07:05<00:00,  1.09s/it, lr=2.27e-05, tau=0.026, tot=5.0314, tgt=4.8305, clip=3.7266, align=0.5827, rank(last)=4.2995, Hit@5=0.287, TrueCos=0.417, upd=1750/1750]  



[Epoch 14/20] lr=2.27e-05 | tau=0.026 | train_total=5.0314 | train_tgt=4.8305 | train_clip=3.7266 | train_align=0.5827 | rank_last=4.2995 | Hit@5=0.287 | TrueCos=0.417
✅ VALID: {'Recall@5': 0.2882152030550469, 'Precision@5': 0.0881875, 'Recall@10': 0.4061875874414938, 'Precision@10': 0.06410677083333335, 'SMILES_Hit@1': 0.12630208333333334, 'SMILES_Hit@5': 0.26114583333333335, 'SMILES_Hit@10': 0.33466145833333333, 'SMILES_TrueCos': 0.41290168126424154, 'SMILES_CLIP': 3.8126182357470193, 'tau': 0.025704631581902504}


Train: 100%|██████████| 7000/7000 [2:20:48<00:00,  1.21s/it, lr=1.62e-05, tau=0.025, tot=5.0146, tgt=4.8151, clip=3.6979, align=0.5870, rank(last)=4.3114, Hit@5=0.291, TrueCos=0.413, upd=1750/1750]  



[Epoch 15/20] lr=1.62e-05 | tau=0.025 | train_total=5.0146 | train_tgt=4.8151 | train_clip=3.6979 | train_align=0.5870 | rank_last=4.3114 | Hit@5=0.291 | TrueCos=0.413
✅ VALID: {'Recall@5': 0.2933345646685491, 'Precision@5': 0.08872916666666665, 'Recall@10': 0.40821422371031785, 'Precision@10': 0.06409375000000002, 'SMILES_Hit@1': 0.127265625, 'SMILES_Hit@5': 0.267734375, 'SMILES_Hit@10': 0.340703125, 'SMILES_TrueCos': 0.4091666708389918, 'SMILES_CLIP': 3.799344880580902, 'tau': 0.024924729019403458}


Train: 100%|██████████| 7000/7000 [2:27:37<00:00,  1.27s/it, lr=1.06e-05, tau=0.024, tot=5.0031, tgt=4.8044, clip=3.6800, align=0.5901, rank(last)=4.2728, Hit@5=0.295, TrueCos=0.410, upd=1750/1750]  



[Epoch 16/20] lr=1.06e-05 | tau=0.024 | train_total=5.0031 | train_tgt=4.8044 | train_clip=3.6800 | train_align=0.5901 | rank_last=4.2728 | Hit@5=0.295 | TrueCos=0.410
✅ VALID: {'Recall@5': 0.29402793040293046, 'Precision@5': 0.0888958333333333, 'Recall@10': 0.41049326382529516, 'Precision@10': 0.06436718749999999, 'SMILES_Hit@1': 0.12734375, 'SMILES_Hit@5': 0.2702864583333333, 'SMILES_Hit@10': 0.34197916666666667, 'SMILES_TrueCos': 0.406905122200648, 'SMILES_CLIP': 3.781051870981852, 'tau': 0.02439727634191513}


Train:  29%|██▉       | 2032/7000 [1:01:00<2:29:09,  1.80s/it, lr=9.22e-06, tau=0.024, tot=4.9950, tgt=4.7968, clip=3.6670, align=0.5911, rank(last)=4.3602, Hit@5=0.297, TrueCos=0.409, upd=500/1750]
Traceback (most recent call last):
  File "/data/aiffel/miniconda3/envs/babayakga/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
Traceback (most recent call last):
  File "/data/aiffel/miniconda3/envs/babayakga/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
Traceback (most recent call last):
  File "/data/aiffel/miniconda3/envs/babayakga/lib/python3.10/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/data/aiffel/miniconda3/envs/babayakga/lib/python3.10/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/data/aiffel/miniconda3/envs/babayakga/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finaliz

KeyboardInterrupt: 

## 저장

In [None]:
CKPT_DIR  = "/data/aiffel/babayakga/checkpoints/f_p_final"        
CKPT_NAME = "fp_smalltargets.pt" 
ckpt_path = os.path.join(CKPT_DIR, CKPT_NAME)
os.makedirs(CKPT_DIR, exist_ok=True)

# =========================
# 1) RNG states 
# =========================
def get_rng_state_bundle():
    out = {}
    # python random
    try:
        out["python_random_state"] = random.getstate()
    except Exception as e:
        out["python_random_state"] = None
        out["python_random_state_err"] = repr(e)

    # numpy
    try:
        out["numpy_random_state"] = np.random.get_state()
    except Exception as e:
        out["numpy_random_state"] = None
        out["numpy_random_state_err"] = repr(e)

    # torch cpu
    try:
        out["torch_rng_state"] = torch.get_rng_state()
    except Exception as e:
        out["torch_rng_state"] = None
        out["torch_rng_state_err"] = repr(e)

    # torch cuda 
    try:
        if torch.cuda.is_available():
            out["torch_cuda_rng_state_all"] = torch.cuda.get_rng_state_all()
        else:
            out["torch_cuda_rng_state_all"] = None
    except Exception as e:
        out["torch_cuda_rng_state_all"] = None
        out["torch_cuda_rng_state_all_err"] = repr(e)

    return out

rng_bundle = get_rng_state_bundle()

# =========================
# 2) EXTRA
# =========================
EXTRA = {
    "SPECIAL_TOKENS": SPECIAL_TOKENS,
    "N_SPECIAL": int(N_SPECIAL),
    "VOCAB_SIZE": int(VOCAB_SIZE),
    "PAD_ID": int(PAD_ID),
    "CLS_ID": int(CLS_ID),
    "ORGAN_TOK_ID": int(ORGAN_TOK_ID),

    "subset_token_ids": list(map(int, subset_token_ids)),
    "target_token_ids": list(map(int, target_token_ids)),

    "UNK_ORGAN_ID": int(UNK_ORGAN_ID),
    "organ2id": organ2id,
    "NUM_ORGANS": int(NUM_ORGANS),

    "D_MODEL": int(D_MODEL),
    "N_HEADS": int(N_HEADS),
    "N_LAYERS": int(N_LAYERS),
    "MAX_SEQ_LEN": int(MAX_SEQ_LEN),
    "SMILES_DIM": int(SMILES_DIM),

    "SEED": int(SEED),
    "CONTROL_DRUG": str(CONTROL_DRUG),
    "HVG_K": int(HVG_K),

    "USE_LOG1P_EXPR": bool(USE_LOG1P_EXPR),
    "USE_ASINH_DELTA": bool(USE_ASINH_DELTA),
    "DELTA_CLIP_ABS": float(DELTA_CLIP_ABS),
    "DROP_FIRST_GENE_TOKEN": bool(DROP_FIRST_GENE_TOKEN),
}

if "old_tid_to_vocab_lut" in globals() and isinstance(old_tid_to_vocab_lut, np.ndarray):
    lut_tensor = torch.from_numpy(old_tid_to_vocab_lut.astype(np.int64, copy=False)).cpu()
elif "old_tid_to_vocab_lut" in globals() and torch.is_tensor(old_tid_to_vocab_lut):
    lut_tensor = old_tid_to_vocab_lut.detach().to(dtype=torch.int64, device="cpu")
else:
    lut_tensor = None

# =========================
# 3) PAYLOAD
# =========================
payload = {
    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    "model_class": model.__class__.__name__,
    "model_state": model.state_dict(),

    "optimizer_state": optimizer.state_dict() if "optimizer" in globals() and optimizer is not None else None,
    "scheduler_state": scheduler.state_dict() if "scheduler" in globals() and scheduler is not None else None,
    "scaler_state": scaler.state_dict() if "scaler" in globals() and scaler is not None else None,

    "metrics": {"valid": valid} if "valid" in globals() else {},

    "extra": EXTRA,

    # ✅ LUT inside checkpoint
    "old_tid_to_vocab_lut": lut_tensor,

    # ✅ RNG states
    "rng_state": rng_bundle,
}

# =========================
# 4) ATOMIC SAVE
# =========================
tmp_path = ckpt_path + ".tmp"
torch.save(payload, tmp_path)
os.replace(tmp_path, ckpt_path)

print(f"✅ Saved FULL checkpoint (model+opt+sched+scaler+RNG+LUT): {ckpt_path}")
print(f"   - LUT saved: {lut_tensor is not None}")
print(f"   - optimizer saved: {payload['optimizer_state'] is not None}")
print(f"   - scheduler saved: {payload['scheduler_state'] is not None}")
print(f"   - scaler saved: {payload['scaler_state'] is not None}")
print(f"   - cuda rng saved: {payload['rng_state'].get('torch_cuda_rng_state_all') is not None}")


✅ Saved FULL checkpoint (model+opt+sched+scaler+RNG+LUT): /data/aiffel/babayakga/checkpoints/f_p_final/fp_smalltargets.pt
   - LUT saved: True
   - optimizer saved: True
   - scheduler saved: True
   - scaler saved: True
   - cuda rng saved: True


In [6]:
from typing import Dict, Tuple, Iterable, List
# =========================================================
# Ranking metrics for targets: mAP@K, NDCG@K, Coverage ceiling
# =========================================================

@torch.no_grad()
def average_precision_at_k(scores_1d: torch.Tensor, y_true_1d: torch.Tensor, k: int) -> float:
    """
    AP@K for ONE sample (binary relevance).
    scores_1d: (M,)
    y_true_1d: (M,) in {0,1}
    """
    M = scores_1d.numel()
    kk = min(int(k), int(M))
    if kk <= 0:
        return 0.0

    pos_total = float(y_true_1d.sum().item())
    if pos_total <= 0:
        return 0.0

    topk = torch.topk(scores_1d, k=kk, dim=0).indices
    rel = y_true_1d[topk].float()  # (kk,)

    # precision@i only when rel[i]=1
    cumsum_rel = torch.cumsum(rel, dim=0)
    ranks = torch.arange(1, kk + 1, device=scores_1d.device, dtype=torch.float32)
    precision_i = cumsum_rel / ranks

    ap = (precision_i * rel).sum() / max(1.0, min(pos_total, float(kk)))
    return float(ap.item())


@torch.no_grad()
def mean_average_precision_at_k(scores: torch.Tensor, y_true: torch.Tensor, k: int) -> float:
    """
    mAP@K over batch (binary relevance).
    scores: (B, M)
    y_true: (B, M) in {0,1}
    """
    B = scores.size(0)
    aps = []
    for i in range(B):
        aps.append(average_precision_at_k(scores[i], y_true[i], k))
    return float(sum(aps) / max(1, len(aps)))


@torch.no_grad()
def ndcg_at_k(scores_1d: torch.Tensor, y_true_1d: torch.Tensor, k: int) -> float:
    """
    NDCG@K for ONE sample (binary relevance).
    DCG = sum_{i=1..K} rel_i / log2(i+1)
    IDCG computed from sorted relevances (all ones first).
    """
    M = scores_1d.numel()
    kk = min(int(k), int(M))
    if kk <= 0:
        return 0.0

    pos_total = int(y_true_1d.sum().item())
    if pos_total <= 0:
        return 0.0

    topk = torch.topk(scores_1d, k=kk, dim=0).indices
    rel = y_true_1d[topk].float()  # (kk,)

    denom = torch.log2(torch.arange(2, kk + 2, device=scores_1d.device, dtype=torch.float32))
    dcg = (rel / denom).sum()

    ideal_k = min(pos_total, kk)
    ideal_rel = torch.ones((ideal_k,), device=scores_1d.device, dtype=torch.float32)
    idcg = (ideal_rel / denom[:ideal_k]).sum()

    return float((dcg / (idcg + 1e-12)).item())


@torch.no_grad()
def mean_ndcg_at_k(scores: torch.Tensor, y_true: torch.Tensor, k: int) -> float:
    B = scores.size(0)
    vals = []
    for i in range(B):
        vals.append(ndcg_at_k(scores[i], y_true[i], k))
    return float(sum(vals) / max(1, len(vals)))


@torch.no_grad()
def recall_precision_at_k(scores: torch.Tensor, y_true: torch.Tensor, k: int) -> Tuple[float, float]:
    """
    Same spirit as your compute_recall_precision_at_k, but vectorized-ish and robust.
    """
    B, M = scores.shape
    kk = min(int(k), int(M))
    if kk <= 0:
        return 0.0, 0.0

    topk = torch.topk(scores, k=kk, dim=1).indices  # (B, kk)
    rel = torch.gather(y_true, 1, topk).float()     # (B, kk)
    pos_total = y_true.sum(dim=1).float()           # (B,)

    mask = pos_total > 0
    if not mask.any():
        return 0.0, 0.0

    num_pos_in_topk = rel.sum(dim=1)  # (B,)
    recall = (num_pos_in_topk[mask] / (pos_total[mask] + 1e-12)).mean()
    precision = (num_pos_in_topk[mask] / float(kk)).mean()
    return float(recall.item()), float(precision.item())


@torch.no_grad()
def coverage_ceiling_recall_at_k(y_true: torch.Tensor, k: int) -> Dict[str, float]:
    """
    Coverage vs #targets:
    - avg #targets
    - fraction of samples with <=k targets
    - avg ceiling recall@k = min(k, #targets) / #targets
    - median #targets
    """
    k = int(k)
    num_t = y_true.sum(dim=1).float()  # (B,)
    mask = num_t > 0
    if not mask.any():
        return {
            "avg_targets": 0.0,
            "median_targets": 0.0,
            "frac_targets_le_k": 0.0,
            "avg_recall_ceiling": 0.0,
        }

    nt = num_t[mask]
    ceiling = torch.minimum(nt, torch.tensor(float(k), device=y_true.device)) / (nt + 1e-12)
    frac_le = (nt <= float(k)).float().mean()

    # median (torch median)
    median = nt.median()

    return {
        "avg_targets": float(nt.mean().item()),
        "median_targets": float(median.item()),
        "frac_targets_le_k": float(frac_le.item()),
        "avg_recall_ceiling": float(ceiling.mean().item()),
    }


# =========================================================
# Retrieval metrics for SMILES: MRR, median rank, plus Hit@K, TrueCos
# =========================================================

@torch.no_grad()
def smiles_retrieval_metrics(
    z_pred: torch.Tensor,
    drug_id: torch.Tensor,
    smiles_bank_t: torch.Tensor,
    k_list: Iterable[int] = (1, 5, 10),
) -> Dict[str, float]:
    """
    z_pred: (B, D) predicted SMILES vector
    drug_id: (B,) true drug index in bank
    smiles_bank_t: (N, D)
    Returns: Hit@K, TrueCos, MRR, median_rank, mean_rank
    """
    z = F.normalize(z_pred.float(), dim=1)
    b = F.normalize(smiles_bank_t.float(), dim=1)

    logits = z @ b.T  # (B, N)
    B, N = logits.shape

    # ranks: higher logits => better
    # rank = 1 + number of items with score > true_score (ties -> worst-ish; acceptable)
    true_scores = logits.gather(1, drug_id.view(-1, 1))  # (B,1)
    better = (logits > true_scores).sum(dim=1)           # (B,)
    rank = better + 1                                   # (B,) 1..N

    out = {}

    # Hit@K
    for k in k_list:
        k = min(int(k), N)
        topk = torch.topk(logits, k=k, dim=1).indices
        hit = (topk == drug_id.view(-1, 1)).any(dim=1).float().mean()
        out[f"Hit@{k}"] = float(hit.item())

    # TrueCos
    true_vec = b[drug_id]
    out["TrueCos"] = float((z * true_vec).sum(dim=1).mean().item())

    # MRR / ranks
    out["MRR"] = float((1.0 / rank.float()).mean().item())
    out["median_rank"] = float(rank.float().median().item())
    out["mean_rank"] = float(rank.float().mean().item())

    return out


# =========================================================
# Drop-in replacement: evaluate_fp + new metrics
# =========================================================

@torch.no_grad()
def evaluate_fp_with_ranking_and_retrieval(
    model,
    loader,
    device,
    target_sub_ids,
    smiles_bank_t,
    k_list_targets: Iterable[int] = (5, 10, 20),
    k_list_smiles: Iterable[int] = (1, 5, 10),
) -> Dict[str, float]:
    """
    Adds:
      - Targets: mAP@K, NDCG@K, Recall@K, Precision@K + Coverage ceiling stats
      - SMILES: Hit@K, TrueCos, MRR, median rank, mean rank, CLIP loss, tau
    """
    model.eval()

    gene_emb = model.gene_emb_subset()[target_sub_ids].to(device)  # (M_TGT, d)
    g_norm = F.normalize(gene_emb, dim=1)

    # accumulators
    out_sum = defaultdict(float)
    n_batches = 0
    n_samples = 0

    clip_sum = 0.0
    tau_sum = 0.0

    # coverage stats accum (per K)
    cov_sums = {k: defaultdict(float) for k in k_list_targets}
    cov_counts = {k: 0 for k in k_list_targets}

    for batch in loader:
        input_ids = batch["input_ids"].to(device, non_blocking=True)
        values    = batch["values"].to(device, non_blocking=True)
        attn      = batch["attention_mask"].to(device, non_blocking=True)
        y_targets = batch["y_targets"].to(device, non_blocking=True)  # (B, M_TGT)
        z_true    = batch["smiles_emb"].to(device, non_blocking=True)
        drug_id   = batch["drug_id"].to(device, non_blocking=True)
        organ_id  = batch["organ_id"].to(device, non_blocking=True)

        v_pred, z_pred = model(input_ids, values, attn, organ_id=organ_id, return_smiles=True)
        v_norm = F.normalize(v_pred, dim=1)
        scores = v_norm @ g_norm.T  # (B, M_TGT)

        B = scores.size(0)
        n_batches += 1
        n_samples += B

        # ---- Targets metrics ----
        for k in k_list_targets:
            r, p = recall_precision_at_k(scores, y_targets, k=k)
            ap = mean_average_precision_at_k(scores, y_targets, k=k)
            nd = mean_ndcg_at_k(scores, y_targets, k=k)

            out_sum[f"Recall@{k}"] += r
            out_sum[f"Precision@{k}"] += p
            out_sum[f"mAP@{k}"] += ap
            out_sum[f"NDCG@{k}"] += nd

            cov = coverage_ceiling_recall_at_k(y_targets, k=k)
            for kk, vv in cov.items():
                cov_sums[k][kk] += float(vv)
            cov_counts[k] += 1

        # ---- SMILES retrieval metrics ----
        m = smiles_retrieval_metrics(z_pred, drug_id, smiles_bank_t, k_list=k_list_smiles)
        for key, val in m.items():
            out_sum[f"SMILES_{key}"] += float(val) * B  # weight by batch size

        # ---- CLIP loss / tau ----
        tau = model.get_tau()
        clip_sum += float(clip_loss(z_pred, z_true, tau=tau).item()) * B
        tau_sum  += float(tau.item()) * B

    out = {}

    # Average batch-averaged target metrics
    for k in k_list_targets:
        out[f"Recall@{k}"] = out_sum[f"Recall@{k}"] / max(1, n_batches)
        out[f"Precision@{k}"] = out_sum[f"Precision@{k}"] / max(1, n_batches)
        out[f"mAP@{k}"] = out_sum[f"mAP@{k}"] / max(1, n_batches)
        out[f"NDCG@{k}"] = out_sum[f"NDCG@{k}"] / max(1, n_batches)

        # Coverage ceiling stats (averaged over batches)
        cc = cov_counts[k]
        if cc > 0:
            out[f"Coverage@{k}_avg_targets"] = cov_sums[k]["avg_targets"] / cc
            out[f"Coverage@{k}_median_targets"] = cov_sums[k]["median_targets"] / cc
            out[f"Coverage@{k}_frac_targets_le_k"] = cov_sums[k]["frac_targets_le_k"] / cc
            out[f"Coverage@{k}_avg_recall_ceiling"] = cov_sums[k]["avg_recall_ceiling"] / cc
            # optional: "normalized recall" = Recall@K / ceiling (if you want)
            ceil = out[f"Coverage@{k}_avg_recall_ceiling"]
            out[f"Recall@{k}_over_ceiling"] = out[f"Recall@{k}"] / max(ceil, 1e-9)

    # SMILES metrics averaged over samples (we weighted by B already)
    for key in ["Hit@1", "Hit@5", "Hit@10", "TrueCos", "MRR", "median_rank", "mean_rank"]:
        sk = f"SMILES_{key}"
        if sk in out_sum:
            out[sk] = out_sum[sk] / max(1, n_samples)

    out["SMILES_CLIP"] = clip_sum / max(1, n_samples)
    out["tau"] = tau_sum / max(1, n_samples)
    out["n_samples"] = float(n_samples)

    return out


valid = evaluate_fp_with_ranking_and_retrieval(
    model=model,
    loader=val_loader,
    device=device,
    target_sub_ids=target_sub_ids,
    smiles_bank_t=smiles_bank_t,
    k_list_targets=(5, 10, 20),
    k_list_smiles=(1, 5, 10),
)
print("✅ VALID:", valid)


✅ VALID: {'Recall@5': 0.2924789224068324, 'Precision@5': 0.08897396000723044, 'mAP@5': 0.19173023119471813, 'NDCG@5': 0.22381376816503082, 'Coverage@5_avg_targets': 2.095833333333333, 'Coverage@5_median_targets': 1.14, 'Coverage@5_frac_targets_le_k': 0.9517708333333333, 'Coverage@5_avg_recall_ceiling': 0.9826236150662104, 'Recall@5_over_ceiling': 0.2976510211258507, 'Recall@10': 0.4075360565384229, 'Precision@10': 0.06435156346609196, 'mAP@10': 0.21013024507517306, 'NDCG@10': 0.2646626206972481, 'Coverage@10_avg_targets': 2.095833333333333, 'Coverage@10_median_targets': 1.14, 'Coverage@10_frac_targets_le_k': 0.9886458333333333, 'Coverage@10_avg_recall_ceiling': 0.9976251810789108, 'Recall@10_over_ceiling': 0.4085061847553619, 'Recall@20': 0.5443653134504954, 'Precision@20': 0.0442695319528381, 'mAP@20': 0.2221575258799324, 'NDCG@20': 0.30403822676472675, 'Coverage@20_avg_targets': 2.095833333333333, 'Coverage@20_median_targets': 1.14, 'Coverage@20_frac_targets_le_k': 1.0, 'Coverage@20_