In [None]:
# =========================================================
# f_r FULL TRAINING SCRIPT (fixed)
# - ‚úÖ filter token_id < 3 everywhere (so no "genes" 1,2)
# - ‚úÖ NO [TARGET] token in input
# - ‚úÖ input prefix: [CLS][DRUG][CELL] + gene tokens
# - ‚úÖ load pretrained gene embeddings (gene_embeddings.npy)
# - ‚úÖ load pretrained cell-line embeddings (cell_embeddings.npy) + mapping (cell2id.csv)
# =========================================================

import os, glob, math, random
from collections import defaultdict
from datetime import datetime
from itertools import islice

import numpy as np
import pandas as pd
from torch.amp import autocast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from torch.cuda.amp import GradScaler

import pyarrow.parquet as pq
from tqdm import tqdm

import scanpy as sc
from scipy import sparse
from sklearn.model_selection import train_test_split


# =========================================================
# 0) PATHS / CONFIG
# =========================================================
GENE_META_PATH = "/data/aiffel/data/Tahoe-100M/metadata/gene_metadata.parquet"
DRUG_META_PATH = "/data/aiffel/data/Tahoe-100M/metadata/drug_metadata.parquet"
COUNTS_CSV     = "/data/aiffel/babayakga/making_data/aiffel/babayakga/making_data/tahoe_counts_per_drug_cell_line.csv"

PARQUET_DIR    = "/data/aiffel/data/Tahoe-100M/data"
DMSO_H5AD      = "/data/aiffel/babayakga/outputs/dmso.h5ad"

# pretrained embeddings
PRETRAINED_GENE_NPY = "/data/aiffel/babayakga/pretraining/checkpoints_with_cell/gene_embeddings.npy"

CELL_CKPT_DIR  = "/data/aiffel/babayakga/pretraining/checkpoints_with_cell"
CELL2ID_CSV    = os.path.join(CELL_CKPT_DIR, "cell2id.csv")
CELL_EMB_NPY   = os.path.join(CELL_CKPT_DIR, "cell_embeddings.npy")

# smiles embedding for drugs (must match drug_metadata row order)
SMILES_EMB_PATH = "/data/aiffel/babayakga/f_p module/f_r/drug_smiles_emb_all.pt"

# training
CONTROL_DRUG = "DMSO_TF"
SEED = 42

MIN_GENE_TOKEN_ID = 3   # ‚úÖ IMPORTANT: exclude 0/1/2 (not real genes)

TOP_K      = 1000
MAX_LEN    = 512        # gene tokens length (not counting prefix)
BATCH_SIZE = 16

TOTAL_EPOCHS     = 8
WARMUP_EPOCHS    = 2
lambda_rank_main = 0.2

STEPS_PER_EPOCH  = 10000
VAL_STEPS        = 900

GRAD_CLIP = 1.0
LR = 3e-4

CKPT_DIR   = "/data/aiffel/babayakga/checkpoints/f_r_withcellline"
SAVE_EVERY = 2

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


# =========================================================
# 1) VOCAB (special + ENSG) from gene_metadata
# =========================================================
def build_vocab_from_gene_metadata(gene_meta_path: str):
    """
    vocab-space: special tokens + all ensembl_id
    gene-space: Tahoe token_id (0..N_GENES-1)
    """
    SPECIAL_TOKENS = ["[PAD]", "[CLS]", "[DRUG]", "[TARGET]", "[CELL]", "[MASK]"]

    gene_md = pd.read_parquet(gene_meta_path).copy()
    gene_md["ensembl_id"] = gene_md["ensembl_id"].astype(str)
    gene_md["token_id"]   = gene_md["token_id"].astype(int)
    gene_md = gene_md.sort_values("token_id").reset_index(drop=True)

    N_GENES = int(gene_md["token_id"].max()) + 1

    local_token_to_id = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
    for ensg in gene_md["ensembl_id"].tolist():
        if ensg not in local_token_to_id:
            local_token_to_id[ensg] = len(local_token_to_id)

    token_id_to_vocab_id = {
        int(tid): int(local_token_to_id[str(ensg)])
        for tid, ensg in zip(gene_md["token_id"].values, gene_md["ensembl_id"].values)
    }

    ensg_to_token_id = {
        str(ensg): int(tid)
        for ensg, tid in zip(gene_md["ensembl_id"].values, gene_md["token_id"].values)
    }

    PAD_ID  = local_token_to_id["[PAD]"]
    CLS_ID  = local_token_to_id["[CLS]"]
    DRUG_ID = local_token_to_id["[DRUG]"]
    CELL_ID = local_token_to_id["[CELL]"]

    return local_token_to_id, token_id_to_vocab_id, ensg_to_token_id, N_GENES, SPECIAL_TOKENS, PAD_ID, CLS_ID, DRUG_ID, CELL_ID


local_token_to_id, token_id_to_vocab_id, ensg_to_token_id, N_GENES, SPECIAL_TOKENS, PAD_ID, CLS_ID, DRUGTOK_ID, CELLTOK_ID = \
    build_vocab_from_gene_metadata(GENE_META_PATH)

VOCAB_SIZE = len(local_token_to_id)
print("VOCAB_SIZE(vocab-space):", VOCAB_SIZE)
print("N_GENES(gene-space):", N_GENES)


# =========================================================
# 2) LOAD cell2id mapping (MUST match pretrained cell embeddings)
# =========================================================
if not os.path.exists(CELL2ID_CSV):
    raise FileNotFoundError(f"cell2id.csv not found: {CELL2ID_CSV}")
if not os.path.exists(CELL_EMB_NPY):
    raise FileNotFoundError(f"cell_embeddings.npy not found: {CELL_EMB_NPY}")

cell2id_df = pd.read_csv(CELL2ID_CSV)
cell2id_df["cell_line_id"] = cell2id_df["cell_line_id"].astype(str)
cell_line2id = {c: int(i) for c, i in zip(cell2id_df["cell_line_id"], cell2id_df["cell_id"])}

NUM_CELL_LINE = len(cell_line2id)
W_cell = np.load(CELL_EMB_NPY)
print("NUM_CELL_LINE(from cell2id.csv):", NUM_CELL_LINE, "| cell_emb rows:", W_cell.shape[0])

assert W_cell.shape[0] == NUM_CELL_LINE, f"cell2id size {NUM_CELL_LINE} != cell_emb rows {W_cell.shape[0]}"


# =========================================================
# 3) SPLIT PAIRS (drug, cell_line) from COUNTS_CSV
# =========================================================
DRUG_COL, CELL_COL, N_COL = "drug", "cell_line_id", "n_cells"

counts = pd.read_csv(COUNTS_CSV)
counts[DRUG_COL] = counts[DRUG_COL].astype(str)
counts[CELL_COL] = counts[CELL_COL].astype(str)

MIN_TRAIN = 1000
MIN_EVAL  = 1000

train_pool = counts[counts[N_COL] >= MIN_TRAIN].copy()
eval_pool  = counts[counts[N_COL] >= MIN_EVAL].copy()

pairs_df = train_pool[[DRUG_COL, CELL_COL]].drop_duplicates()

train_df, val_df = train_test_split(
    pairs_df,
    test_size=0.1,
    random_state=SEED,
    stratify=pairs_df[DRUG_COL] if len(pairs_df) else None,
)

train_df = train_df[train_df[DRUG_COL] != CONTROL_DRUG]
val_df   = val_df[val_df[DRUG_COL]   != CONTROL_DRUG]

train_pairs = list(zip(train_df[DRUG_COL], train_df[CELL_COL]))
val_pairs   = list(zip(val_df[DRUG_COL],   val_df[CELL_COL]))

eval_pairs_df = eval_pool[[DRUG_COL, CELL_COL]].drop_duplicates()
eval_pairs_df = eval_pairs_df[eval_pairs_df[DRUG_COL] != CONTROL_DRUG]
eval_pairs = list(zip(eval_pairs_df[DRUG_COL], eval_pairs_df[CELL_COL]))

print("train pairs:", len(train_pairs))
print("val pairs:", len(val_pairs))
print(f"eval pairs (>={MIN_EVAL}):", len(eval_pairs))


# =========================================================
# 4) INDEX PARQUET row-groups for valid pairs
# =========================================================
PARQUET_FILES = sorted(glob.glob(os.path.join(PARQUET_DIR, "**", "*.parquet"), recursive=True))
print("parquet files found:", len(PARQUET_FILES))

PARQUET_DRUG_COL = "drug"
PARQUET_CELL_COL = "cell_line_id"

def build_pair_to_locations(parquet_files, valid_pairs_set, drug_col, cell_col):
    out = defaultdict(list)
    for f in tqdm(parquet_files, desc="Index parquet row-groups"):
        try:
            pf = pq.ParquetFile(f)
        except Exception:
            continue
        for rg in range(pf.num_row_groups):
            try:
                tbl = pf.read_row_group(rg, columns=[drug_col, cell_col])
                df = tbl.to_pandas()
            except Exception:
                continue

            pairs_here = set(zip(df[drug_col].astype(str), df[cell_col].astype(str)))
            inter = pairs_here.intersection(valid_pairs_set)
            for p in inter:
                out[p].append((f, rg))
    return dict(out)

valid_pairs_set = set(train_pairs) | set(val_pairs)
pair_to_locations = build_pair_to_locations(
    parquet_files=PARQUET_FILES,
    valid_pairs_set=valid_pairs_set,
    drug_col=PARQUET_DRUG_COL,
    cell_col=PARQUET_CELL_COL
)
print("indexed pairs:", len(pair_to_locations))


# =========================================================
# 5) DMSO baselines + topK variance genes (gene-space token_id)
#    ‚úÖ filter token_id < 3 here too
# =========================================================
def build_dmso_baselines_gene_space(dmso_h5ad_path: str, control_drug: str, N_GENES: int, ensg_to_token_id: dict,
                                    drug_col="drug", cell_col="cell_line_id", dtype=np.float32):
    adata = sc.read_h5ad(dmso_h5ad_path)
    obs = adata.obs
    X = adata.X.tocsr() if sparse.issparse(adata.X) else sparse.csr_matrix(adata.X)

    m = (obs[drug_col].astype(str).values == control_drug)
    idx_ctrl = np.where(m)[0]
    if idx_ctrl.size == 0:
        raise ValueError(f"No control rows: {control_drug}")

    ensgs = adata.var_names.astype(str).tolist()
    token_ids, cols = [], []
    for j, ensg in enumerate(ensgs):
        tid = ensg_to_token_id.get(ensg, None)
        if tid is None:
            continue
        tid = int(tid)
        if tid < MIN_GENE_TOKEN_ID:        # ‚úÖ filter 0/1/2
            continue
        token_ids.append(tid)
        cols.append(j)

    token_ids = np.asarray(token_ids, dtype=np.int64)
    cols      = np.asarray(cols, dtype=np.int64)

    Xc = X[idx_ctrl][:, cols]
    mean_global_sub = np.asarray(Xc.mean(axis=0)).ravel().astype(dtype)

    baseline_global = np.zeros(N_GENES, dtype=dtype)
    baseline_global[token_ids] = mean_global_sub

    baseline_by_cl = {}
    cls_all = obs[cell_col].astype(str).values
    for cl in np.unique(cls_all):
        cl_idx = np.where(m & (cls_all == cl))[0]
        if cl_idx.size == 0:
            continue
        Xcl = X[cl_idx][:, cols]
        mean_cl_sub = np.asarray(Xcl.mean(axis=0)).ravel().astype(dtype)
        v = np.zeros(N_GENES, dtype=dtype)
        v[token_ids] = mean_cl_sub
        baseline_by_cl[str(cl)] = v

    return baseline_global, baseline_by_cl


def topk_by_variance_gene_space(dmso_h5ad_path: str, control_drug: str, N_GENES: int, ensg_to_token_id: dict,
                               drug_col="drug", top_k=1000):
    adata = sc.read_h5ad(dmso_h5ad_path)
    obs = adata.obs
    X = adata.X.tocsr() if sparse.issparse(adata.X) else sparse.csr_matrix(adata.X)

    m = (obs[drug_col].astype(str).values == control_drug)
    idx = np.where(m)[0]
    if idx.size == 0:
        raise ValueError(f"No control rows: {control_drug}")

    ensgs = adata.var_names.astype(str).tolist()
    token_ids, cols = [], []
    for j, ensg in enumerate(ensgs):
        tid = ensg_to_token_id.get(ensg, None)
        if tid is None:
            continue
        tid = int(tid)
        if tid < MIN_GENE_TOKEN_ID:        # ‚úÖ filter 0/1/2
            continue
        token_ids.append(tid)
        cols.append(j)

    token_ids = np.asarray(token_ids, dtype=np.int64)
    cols      = np.asarray(cols, dtype=np.int64)

    Xc = X[idx][:, cols]
    ex  = np.asarray(Xc.mean(axis=0)).ravel()
    ex2 = np.asarray(Xc.power(2).mean(axis=0)).ravel()
    var = ex2 - ex**2

    top_local = np.argsort(-var)[:top_k]
    top_gene_token_ids = token_ids[top_local]
    return top_gene_token_ids.astype(np.int64)


baseline_global, baseline_by_cl = build_dmso_baselines_gene_space(
    dmso_h5ad_path=DMSO_H5AD,
    control_drug=CONTROL_DRUG,
    N_GENES=N_GENES,
    ensg_to_token_id=ensg_to_token_id,
)

sorted_gene_token_ids = topk_by_variance_gene_space(
    dmso_h5ad_path=DMSO_H5AD,
    control_drug=CONTROL_DRUG,
    N_GENES=N_GENES,
    ensg_to_token_id=ensg_to_token_id,
    top_k=TOP_K,
)

assert (sorted_gene_token_ids >= MIN_GENE_TOKEN_ID).all(), "TOP_K contains token_id < 3 !"

print("baseline_global:", baseline_global.shape, "baseline_by_cl:", len(baseline_by_cl))
print("sorted_gene_token_ids:", sorted_gene_token_ids.shape, sorted_gene_token_ids[:10])


# =========================================================
# 6) DRUG -> id + SMILES embeddings
# =========================================================
drug_meta_df = pd.read_parquet(DRUG_META_PATH).copy()
drug_meta_df["drug"] = drug_meta_df["drug"].astype(str)
drugs = drug_meta_df["drug"].tolist()
drug2id = {d: i for i, d in enumerate(drugs)}
print("num drugs:", len(drug2id))

smiles_tensor = torch.load(SMILES_EMB_PATH, map_location="cpu").to(torch.float32)
assert smiles_tensor.shape[0] == len(drug_meta_df), "SMILES rows != drug_metadata rows"

drug_to_smiles_emb = {d: smiles_tensor[i] for i, d in enumerate(drugs)}
smiles_dim = int(smiles_tensor.shape[-1])
print("smiles_dim:", smiles_dim)


# =========================================================
# 7) pair weights
# =========================================================
def make_pair_weights_from_counts(counts_df, pairs, drug_col="drug", cell_col="cell_line_id", n_col="n_cells",
                                  mode="inv_sqrt", eps=1.0):
    tmp = counts_df[[drug_col, cell_col, n_col]].copy()
    tmp[drug_col] = tmp[drug_col].astype(str)
    tmp[cell_col] = tmp[cell_col].astype(str)

    pair2n = {(d, c): int(n) for d, c, n in tmp.values}

    w = []
    for p in pairs:
        n = pair2n.get(p, 0)
        if mode == "inv":
            ww = 1.0 / (n + eps)
        elif mode == "inv_log":
            ww = 1.0 / np.log1p(n + eps)
        else:
            ww = 1.0 / np.sqrt(n + eps)
        w.append(float(ww))

    w = np.asarray(w, dtype=np.float64)
    w = np.clip(w, 0.0, None)
    w = w / (w.sum() + 1e-12)
    pair2w = {p: float(wi) for p, wi in zip(pairs, w)}
    return w, pair2w


# =========================================================
# 8) DATASET (Iterable, aligned)  ‚úÖ filters token_id<3
# =========================================================
class FRSeqExpressionParquetDatasetAligned(IterableDataset):
    """
    input_ids: [CLS][DRUG][CELL] + gene tokens (vocab-space ids)
    values:    delta (val - baseline[cell or global]) aligned with tokens
    y_topk:    true expression on TOP_K genes (gene-space token_id list)
    """

    def __init__(
        self,
        pair_to_locations,
        pairs,
        token_id_to_vocab_id,
        sorted_gene_token_ids,
        baseline_global,
        baseline_by_cellline,
        cell_line2id,
        drug2id,
        drug_to_smiles_emb,
        pair_weights=None,
        seed=42,
        max_gene_len=512,
        top_k=1000,
        batch_size=16,
        pad_id=0,
        cls_id=1,
        drugtok_id=2,
        celltok_id=4,
        drug_col="drug",
        cell_col="cell_line_id",
        genes_col="genes",
        expr_col="expressions",
        cap_per_pair_in_rg=None,
        max_tries=30,
        shuffle=False,
    ):
        super().__init__()
        self.pair_to_locations = pair_to_locations
        self.pairs = list(pairs)

        self.token_id_to_vocab_id = token_id_to_vocab_id
        self.q = np.asarray(sorted_gene_token_ids, dtype=np.int64)

        self.baseline_global = np.asarray(baseline_global, dtype=np.float32)
        self.baseline_by_cellline = baseline_by_cellline or {}

        self.cell_line2id = cell_line2id
        self.drug2id = drug2id
        self.drug_to_smiles_emb = drug_to_smiles_emb

        self.max_gene_len = int(max_gene_len)
        self.top_k = int(top_k)
        self.batch_size = int(batch_size)

        self.pad_id = int(pad_id)
        self.cls_id = int(cls_id)
        self.drugtok_id = int(drugtok_id)
        self.celltok_id = int(celltok_id)

        self.drug_col = drug_col
        self.cell_col = cell_col
        self.genes_col = genes_col
        self.expr_col = expr_col

        self.cap_per_pair_in_rg = cap_per_pair_in_rg
        self.max_tries = int(max_tries)
        self.shuffle = bool(shuffle)

        self.num_prefix = 3  # ‚úÖ [CLS][DRUG][CELL]
        self.seq_len = self.num_prefix + self.max_gene_len

        any_vec = next(iter(self.drug_to_smiles_emb.values()))
        self.smiles_dim = int(any_vec.shape[-1])

        # weights
        self.seed = int(seed)
        if pair_weights is None:
            self.pair_weights = None
        elif isinstance(pair_weights, dict):
            w = np.asarray([pair_weights.get(p, 0.0) for p in self.pairs], dtype=np.float64)
            w = np.clip(w, 0.0, None)
            w = w / (w.sum() + 1e-12)
            self.pair_weights = w
        else:
            w = np.asarray(pair_weights, dtype=np.float64)
            assert len(w) == len(self.pairs), "pair_weights length must match pairs length"
            w = np.clip(w, 0.0, None)
            w = w / (w.sum() + 1e-12)
            self.pair_weights = w

    def _read_row_group_df(self, file_path, rg_id, columns):
        pf = pq.ParquetFile(file_path)
        return pf.read_row_group(rg_id, columns=columns).to_pandas()

    def _prepare_sparse(self, genes, expr):
        idx = np.asarray(genes, dtype=np.int64)
        val = np.asarray(expr, dtype=np.float32)
        if idx.size == 0:
            return idx, val

        # ‚úÖ FILTER: remove non-gene tokens 0/1/2 (and anything <3)
        keep = idx >= MIN_GENE_TOKEN_ID
        idx = idx[keep]
        val = val[keep]
        if idx.size == 0:
            return idx, val

        order = np.argsort(idx)
        return idx[order], val[order]

    def _make_y_true_topk(self, idx_sorted, val_sorted):
        q = self.q
        if idx_sorted.size == 0:
            return np.zeros(q.shape[0], dtype=np.float32)

        pos = np.searchsorted(idx_sorted, q)
        y = np.zeros(q.shape[0], dtype=np.float32)

        in_bounds = (pos < idx_sorted.size)
        match = np.zeros(q.shape[0], dtype=bool)
        match[in_bounds] = (idx_sorted[pos[in_bounds]] == q[in_bounds])
        ok = in_bounds & match
        y[ok] = val_sorted[pos[ok]]
        return y.astype(np.float32)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        rng = np.random.default_rng(self.seed if worker_info is None else (self.seed + worker_info.id))

        pairs = self.pairs
        weights = self.pair_weights
        cols = [self.drug_col, self.cell_col, self.genes_col, self.expr_col]

        while True:
            # weighted pair sampling
            if weights is None:
                drug_name, cell_line = pairs[rng.integers(0, len(pairs))]
            else:
                idxp = rng.choice(len(pairs), p=weights)
                drug_name, cell_line = pairs[idxp]

            locs = self.pair_to_locations.get((drug_name, cell_line), [])
            if not locs:
                continue

            for _ in range(self.max_tries):
                fpath, rg_id = locs[rng.integers(0, len(locs))]

                try:
                    df = self._read_row_group_df(fpath, rg_id, columns=cols)
                except Exception:
                    continue

                df = df[(df[self.drug_col].astype(str) == str(drug_name)) &
                        (df[self.cell_col].astype(str) == str(cell_line))]
                if len(df) < self.batch_size:
                    continue

                if self.cap_per_pair_in_rg is not None and len(df) > self.cap_per_pair_in_rg:
                    df = df.sample(self.cap_per_pair_in_rg, replace=False, random_state=int(rng.integers(0, 1e9)))

                df = df.sample(self.batch_size, replace=False, random_state=int(rng.integers(0, 1e9)))

                # baseline (cell-specific if exists)
                baseline = self.baseline_by_cellline.get(cell_line, self.baseline_global)

                # cell id must exist in pretrained mapping
                if cell_line not in self.cell_line2id:
                    continue
                cell_id = self.cell_line2id[cell_line]

                drug_id = self.drug2id.get(drug_name, 0)

                sm = self.drug_to_smiles_emb.get(drug_name, None)
                if sm is None:
                    smiles_emb = torch.zeros(self.smiles_dim, dtype=torch.float32)
                else:
                    smiles_emb = sm.detach().clone().to(torch.float32) if isinstance(sm, torch.Tensor) \
                        else torch.tensor(sm, dtype=torch.float32)

                input_ids = np.full((self.batch_size, self.seq_len), self.pad_id, dtype=np.int64)
                values    = np.zeros((self.batch_size, self.seq_len), dtype=np.float32)
                mask      = np.zeros((self.batch_size, self.seq_len), dtype=np.int64)
                y_topk    = np.zeros((self.batch_size, self.top_k), dtype=np.float32)

                cell_batch   = np.full((self.batch_size,), cell_id, dtype=np.int64)
                drug_batch   = np.full((self.batch_size,), drug_id, dtype=np.int64)
                smiles_batch = np.stack([smiles_emb.numpy()] * self.batch_size, axis=0).astype(np.float32)

                # prefix: [CLS][DRUG][CELL]
                input_ids[:, 0] = self.cls_id
                input_ids[:, 1] = self.drugtok_id
                input_ids[:, 2] = self.celltok_id
                mask[:, :self.num_prefix] = 1

                for b, (_, r) in enumerate(df.iterrows()):
                    idx, val = self._prepare_sparse(r[self.genes_col], r[self.expr_col])

                    # y_true on TOP_K
                    y_topk[b] = self._make_y_true_topk(idx, val)

                    if idx.size == 0:
                        continue

                    # delta
                    base_vals = baseline[idx]
                    delta = (val - base_vals)

                    # choose top genes by |delta|
                    k = min(self.max_gene_len, idx.size)
                    if k <= 0:
                        continue

                    if k == idx.size:
                        top_pos = np.argsort(-np.abs(delta))
                    else:
                        top_pos = np.argpartition(-np.abs(delta), k - 1)[:k]
                        top_pos = top_pos[np.argsort(-np.abs(delta[top_pos]))]

                    sel_gene_token_ids = idx[top_pos]
                    sel_delta = delta[top_pos]

                    # token_id -> vocab_id (drop missing)
                    sel_vocab_ids = np.asarray(
                        [self.token_id_to_vocab_id.get(int(t), -1) for t in sel_gene_token_ids],
                        dtype=np.int64
                    )
                    ok = sel_vocab_ids != -1
                    sel_vocab_ids = sel_vocab_ids[ok]
                    sel_delta = sel_delta[ok]

                    L = min(self.max_gene_len, sel_vocab_ids.size)
                    if L <= 0:
                        continue

                    start = self.num_prefix
                    input_ids[b, start:start+L] = sel_vocab_ids[:L]
                    values[b,    start:start+L] = sel_delta[:L]
                    mask[b,      start:start+L] = 1

                yield (
                    torch.tensor(input_ids, dtype=torch.long),
                    torch.tensor(values, dtype=torch.float32),
                    torch.tensor(mask, dtype=torch.long),
                    torch.tensor(y_topk, dtype=torch.float32),
                    torch.tensor(cell_batch, dtype=torch.long),
                    torch.tensor(drug_batch, dtype=torch.long),
                    torch.tensor(smiles_batch, dtype=torch.float32),
                )
                break


# =========================================================
# 9) DataLoaders
# =========================================================
train_w, _ = make_pair_weights_from_counts(counts, train_pairs, mode="inv_sqrt")

train_ds = FRSeqExpressionParquetDatasetAligned(
    pair_to_locations=pair_to_locations,
    pairs=train_pairs,
    pair_weights=train_w,
    token_id_to_vocab_id=token_id_to_vocab_id,
    sorted_gene_token_ids=sorted_gene_token_ids,
    baseline_global=baseline_global,
    baseline_by_cellline=baseline_by_cl,
    cell_line2id=cell_line2id,
    drug2id=drug2id,
    drug_to_smiles_emb=drug_to_smiles_emb,
    batch_size=BATCH_SIZE,
    max_gene_len=MAX_LEN,
    top_k=TOP_K,
    shuffle=False,
)

val_ds = FRSeqExpressionParquetDatasetAligned(
    pair_to_locations=pair_to_locations,
    pairs=val_pairs,
    pair_weights=None,
    token_id_to_vocab_id=token_id_to_vocab_id,
    sorted_gene_token_ids=sorted_gene_token_ids,
    baseline_global=baseline_global,
    baseline_by_cellline=baseline_by_cl,
    cell_line2id=cell_line2id,
    drug2id=drug2id,
    drug_to_smiles_emb=drug_to_smiles_emb,
    batch_size=BATCH_SIZE,
    max_gene_len=MAX_LEN,
    top_k=TOP_K,
    shuffle=False,
)

train_loader = DataLoader(train_ds, batch_size=None, num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=None, num_workers=0, pin_memory=True)


# =========================================================
# 10) MODEL (Cell2Sentence-like Encoder for f_r)
# =========================================================
class Cell2SentenceEncoderFR(nn.Module):
    """
    prefix: [CLS][DRUG][CELL] + gene tokens
    - token_emb: vocab-space ids
    - values: delta
    - inject smiles into position 1 ([DRUG])
    - inject cell_line embedding into position 2 ([CELL]) ‚úÖ pretrained
    """
    def __init__(self, vocab_size, d_model, n_heads, num_layers, max_len_with_prefix, smiles_dim, num_cell_lines, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=PAD_ID)
        self.value_proj = nn.Sequential(
            nn.Linear(1, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )
        self.pos_emb = nn.Embedding(max_len_with_prefix, d_model)

        self.cell_line_emb = nn.Embedding(num_cell_lines, d_model)
        self.smiles_proj = nn.Linear(smiles_dim, d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads,
            dim_feedforward=4*d_model,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

    def forward(self, input_ids, values, attention_mask, cell_line_id, smiles_emb):
        B, L = input_ids.shape
        device_ = input_ids.device

        x = self.token_emb(input_ids) + self.value_proj(values.unsqueeze(-1))

        pos = torch.arange(L, device=device_).unsqueeze(0).expand(B, L)
        x = x + self.pos_emb(pos)

        # inject drug / cell info
        x[:, 1, :] = x[:, 1, :] + self.smiles_proj(smiles_emb.to(device=device_, dtype=torch.float32)).to(x.dtype)
        x[:, 2, :] = x[:, 2, :] + self.cell_line_emb(cell_line_id.to(device=device_)).to(x.dtype)

        key_padding_mask = (attention_mask == 0)
        h = self.encoder(x, src_key_padding_mask=key_padding_mask)
        return h[:, 0, :]  # CLS


class FRModelExpression(nn.Module):
    def __init__(self, encoder, d_model, out_dim):
        super().__init__()
        self.encoder = encoder
        self.head = nn.Linear(d_model, out_dim)

    def forward(self, input_ids, values, mask, cell_line_id, smiles_emb):
        h = self.encoder(input_ids, values, mask, cell_line_id, smiles_emb)
        return self.head(h)


# =========================================================
# 11) Load pretrained gene + cell embeddings
# =========================================================
def load_pretrained_token_emb_from_gene_metadata(token_emb: nn.Embedding, npy_path: str, gene_meta_path: str, local_token_to_id: dict, device):
    W = np.load(npy_path)  # (N_genes, d_model)
    Wt = torch.tensor(W, dtype=torch.float32, device=device)

    gene_md = pd.read_parquet(gene_meta_path).copy()
    gene_md["ensembl_id"] = gene_md["ensembl_id"].astype(str)
    gene_md["token_id"] = gene_md["token_id"].astype(int)
    gene_md = gene_md.sort_values("token_id").reset_index(drop=True)

    if Wt.shape[1] != token_emb.weight.shape[1]:
        raise ValueError(f"d mismatch: npy d={Wt.shape[1]} vs token_emb d={token_emb.weight.shape[1]}")

    # NOTE: original training saved gene_embeddings.npy as "special removed" (index 0 corresponds to gene_md row 0)
    n = min(len(gene_md), Wt.shape[0])
    loaded = 0
    with torch.no_grad():
        for i in range(n):
            ensg = gene_md.loc[i, "ensembl_id"]
            vid = local_token_to_id.get(ensg, None)
            if vid is None:
                continue
            token_emb.weight[vid].copy_(Wt[i])
            loaded += 1
    print(f"‚úÖ Loaded pretrained gene token_emb: {loaded} genes")


def load_pretrained_cell_emb(cell_emb: nn.Embedding, cell_emb_npy: str, device):
    W = np.load(cell_emb_npy)  # (num_cell_lines, d_model)
    Wt = torch.tensor(W, dtype=torch.float32, device=device)

    if Wt.shape != cell_emb.weight.shape:
        raise ValueError(f"cell_emb shape mismatch: npy={tuple(Wt.shape)} vs emb={tuple(cell_emb.weight.shape)}")

    with torch.no_grad():
        cell_emb.weight.copy_(Wt)
    print(f"‚úÖ Loaded pretrained cell_line_emb: {tuple(Wt.shape)}")



def sanity_check_gene_emb_mapping(
    gene_meta_path,
    local_token_to_id,
    token_emb: torch.nn.Embedding,
    pretrained_gene_npy,
    n_check=20,
    seed=0,
):
    gene_md = pd.read_parquet(gene_meta_path).copy()
    gene_md["ensembl_id"] = gene_md["ensembl_id"].astype(str)
    gene_md["token_id"]   = gene_md["token_id"].astype(int)
    gene_md = gene_md.sort_values("token_id").reset_index(drop=True)

    W = np.load(pretrained_gene_npy)  # (N_genes, d_model)
    assert W.shape[1] == token_emb.weight.shape[1]

    rng = np.random.default_rng(seed)
    idxs = rng.integers(0, min(len(gene_md), W.shape[0]), size=n_check)

    max_abs = 0.0
    bad = 0

    with torch.no_grad():
        for i in idxs:
            ensg = gene_md.loc[i, "ensembl_id"]
            vid = local_token_to_id.get(ensg, None)
            if vid is None:
                continue

            a = token_emb.weight[vid].detach().cpu().numpy()
            b = W[i]

            diff = np.max(np.abs(a - b))
            max_abs = max(max_abs, float(diff))
            if diff > 1e-6:
                bad += 1
                print("Mismatch:", "i=", i, "ensg=", ensg, "vid=", vid, "max_abs_diff=", diff)

    print(f"[sanity] checked={n_check}, bad={bad}, max_abs_diff={max_abs}")


# =========================================================
# 12) Init model
# =========================================================
D_MODEL = 256
assert W_cell.shape[1] == D_MODEL, f"cell_emb dim {W_cell.shape[1]} != D_MODEL {D_MODEL}"

encoder = Cell2SentenceEncoderFR(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_heads=8,
    num_layers=4,
    max_len_with_prefix=(3 + MAX_LEN),   # prefix 3 + gene_len
    smiles_dim=smiles_dim,
    num_cell_lines=NUM_CELL_LINE,
    dropout=0.1,
).to(device)

# gene embeddings
load_pretrained_token_emb_from_gene_metadata(
    token_emb=encoder.token_emb,
    npy_path=PRETRAINED_GENE_NPY,
    gene_meta_path=GENE_META_PATH,
    local_token_to_id=local_token_to_id,
    device=device,
)

# ‚úÖ cell embeddings (pretrained)
load_pretrained_cell_emb(
    cell_emb=encoder.cell_line_emb,
    cell_emb_npy=CELL_EMB_NPY,
    device=device
)

fr_model = FRModelExpression(encoder=encoder, d_model=D_MODEL, out_dim=TOP_K).to(device)
optimizer = torch.optim.AdamW(fr_model.parameters(), lr=LR, weight_decay=0.01)
scaler = GradScaler(enabled=(device.type == "cuda"))

print("‚úÖ f_r model ready")

sanity_check_gene_emb_mapping(GENE_META_PATH, local_token_to_id, encoder.token_emb, PRETRAINED_GENE_NPY)


gene_md = pd.read_parquet(GENE_META_PATH).copy()
gene_md["token_id"] = gene_md["token_id"].astype(int)
bad = gene_md.index.values != gene_md["token_id"].values
print("‚ùå rows where index != token_id:", bad.sum())

if bad.any():
    print("‚ùó MISALIGNMENT: gene_md index ‚â† token_id")
else:
    print("‚úÖ OK: gene_md index == token_id")


with torch.no_grad():
    emb_table = encoder.token_emb.weight.detach().cpu().numpy()

# –ø—Ä–æ–≤–µ—Ä–∏–º –Ω–µ—Å–∫–æ–ª—å–∫–æ —Å–ª—É—á–∞–π–Ω—ã—Ö gene token_id
rng = np.random.default_rng(42)
test_ids = rng.choice(gene_md["token_id"].values, size=10, replace=False)

for tid in test_ids:
    ensg = gene_md.loc[gene_md["token_id"] == tid, "ensembl_id"].values[0]
    vocab_id = local_token_to_id[ensg]

    diff = np.linalg.norm(emb_table[vocab_id] - W[tid])
    print(f"token_id={tid} | vocab_id={vocab_id} | diff={diff:.6f}")

# =========================================================
# 13) Losses (MSE + ranking)
# =========================================================
mse_loss = nn.MSELoss()

baseline_vec = torch.tensor(baseline_global[sorted_gene_token_ids], dtype=torch.float32, device=device)  # (TOP_K,)

def expr_ranking_loss(y_pred, y_true, baseline_vec, top_pos=30, num_neg=80, margin=0.0):
    device_ = y_pred.device
    B, K = y_pred.shape

    base = baseline_vec.view(1, K).expand(B, K).to(device=device_, dtype=y_pred.dtype)
    dt = y_true - base
    dp = y_pred - base

    losses = []
    for b in range(B):
        order = torch.argsort(dt[b].abs(), descending=True)
        P = min(top_pos, K)
        pos_idx = order[:P]
        neg_candidates = order[P:]
        if neg_candidates.numel() == 0:
            continue

        if neg_candidates.numel() > num_neg:
            neg_idx = neg_candidates[torch.randperm(neg_candidates.numel(), device=device_)[:num_neg]]
        else:
            neg_idx = neg_candidates

        pos_scores = dp[b, pos_idx]   # (P,)
        neg_scores = dp[b, neg_idx]   # (N,)

        diff = pos_scores.view(-1, 1) - neg_scores.view(1, -1)
        loss_mat = F.relu(margin - diff)
        losses.append(loss_mat.mean())

    if len(losses) == 0:
        return torch.tensor(0.0, device=device_, dtype=y_pred.dtype)
    return torch.stack(losses).mean()


# =========================================================
# 14) Checkpoint utils
# =========================================================
def save_fr_checkpoint(save_dir, fr_model, optimizer, scaler, epoch, metrics=None, extra=None, prefix="fr"):
    os.makedirs(save_dir, exist_ok=True)
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")

    ckpt = {
        "epoch": int(epoch),
        "model_state": fr_model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scaler_state": scaler.state_dict() if scaler is not None else None,
        "metrics": metrics or {},
        "extra": extra or {},
    }

    path = os.path.join(save_dir, f"{prefix}_epoch{epoch}_{ts}.pt")
    torch.save(ckpt, path)
    print(f"üíæ saved checkpoint: {path}")
    return path


# =========================================================
# 15) Eval helpers
# =========================================================

@torch.no_grad()
def eval_mse(fr_model, val_loader, steps, device):
    fr_model.eval()
    total = 0.0
    n = 0
    for batch in islice(val_loader, steps):
        input_ids, values, mask, y_true, cell_id, drug_id, smiles = batch
        input_ids = input_ids.to(device)
        values    = values.to(device)
        mask      = mask.to(device)
        y_true    = y_true.to(device)
        cell_id   = cell_id.to(device)
        smiles    = smiles.to(device)

        with autocast(device_type="cuda", enabled=(device.type == "cuda")):
            y_pred = fr_model(input_ids, values, mask, cell_id, smiles)
            loss = mse_loss(y_pred, y_true)

        bs = y_true.size(0)
        total += loss.item() * bs
        n += bs
    return total / max(1, n)

@torch.no_grad()
def baseline_mse(val_loader, steps, baseline_vec, device):
    total = 0.0
    n = 0
    baseline_vec = baseline_vec.to(device)
    for batch in islice(val_loader, steps):
        _, _, _, y_true, _, _, _ = batch
        y_true = y_true.to(device)
        bs = y_true.size(0)
        pred = baseline_vec.view(1, -1).expand(bs, -1)
        loss = F.mse_loss(pred, y_true)
        total += loss.item() * bs
        n += bs
    return total / max(1, n)


# =========================================================
# 16) TRAIN
# =========================================================
base_mse = baseline_mse(val_loader, steps=VAL_STEPS, baseline_vec=baseline_vec, device=device)
print(f"Baseline Valid MSE (DMSO) = {base_mse:.6f}")

print("üöÄ f_r training start")

for epoch in range(1, TOTAL_EPOCHS + 1):
    lambda_rank = 0.0 if epoch <= WARMUP_EPOCHS else lambda_rank_main

    fr_model.train()
    run_mse = 0.0
    run_rank = 0.0
    run_total = 0.0
    n = 0

    pbar = tqdm(
        islice(train_loader, STEPS_PER_EPOCH),
        total=STEPS_PER_EPOCH,
        desc=f"[Epoch {epoch}] Train",
        leave=True,
        dynamic_ncols=True
    )

    for batch in pbar:
        input_ids, values, mask, y_true, cell_id, drug_id, smiles = batch

        input_ids = input_ids.to(device, non_blocking=True)
        values    = values.to(device, non_blocking=True)
        mask      = mask.to(device, non_blocking=True)
        y_true    = y_true.to(device, non_blocking=True)
        cell_id   = cell_id.to(device, non_blocking=True)
        smiles    = smiles.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", enabled=(device.type == "cuda")):
            y_pred = fr_model(input_ids, values, mask, cell_id, smiles)
            loss_m = mse_loss(y_pred, y_true)

            if lambda_rank > 0:
                loss_r = expr_ranking_loss(
                    y_pred, y_true, baseline_vec,
                    top_pos=30, num_neg=80, margin=0.0
                )
            else:
                loss_r = torch.tensor(0.0, device=device)

            loss = loss_m + lambda_rank * loss_r

        if not torch.isfinite(loss):
            continue

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(fr_model.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()

        bs = y_true.size(0)
        run_mse   += loss_m.item() * bs
        run_rank  += loss_r.item() * bs
        run_total += loss.item() * bs
        n += bs

        pbar.set_postfix({
            "mse": f"{loss_m.item():.4f}",
            "rank": f"{loss_r.item():.4f}",
            "Œª_rank": float(lambda_rank),
        })

    train_mse   = run_mse   / max(1, n)
    train_rank  = run_rank  / max(1, n)
    train_total = run_total / max(1, n)

    val_mse = eval_mse(fr_model, val_loader, steps=VAL_STEPS, device=device)

    print(
        f"[Epoch {epoch}] "
        f"Train total={train_total:.6f}, mse={train_mse:.6f}, rank={train_rank:.6f} (Œª_rank={lambda_rank}) | "
        f"Valid mse={val_mse:.6f} | Baseline(DMSO) mse={base_mse:.6f}"
    )

    if (epoch % SAVE_EVERY == 0) or (epoch == TOTAL_EPOCHS):
        save_fr_checkpoint(
            save_dir=CKPT_DIR,
            fr_model=fr_model,
            optimizer=optimizer,
            scaler=scaler,
            epoch=epoch,
            metrics={
                "train_total": float(train_total),
                "train_mse": float(train_mse),
                "train_rank": float(train_rank),
                "val_mse": float(val_mse),
                "baseline_mse": float(base_mse),
                "lambda_rank": float(lambda_rank),
            },
            extra={
                "TOP_K": int(baseline_vec.numel()),
                "STEPS_PER_EPOCH": int(STEPS_PER_EPOCH),
                "VAL_STEPS": int(VAL_STEPS),
                "WARMUP_EPOCHS": int(WARMUP_EPOCHS),
                "lambda_rank_main": float(lambda_rank_main),
                "baseline_vec": baseline_vec.detach().float().cpu(),
                "CELL2ID_CSV": CELL2ID_CSV,
                "CELL_EMB_NPY": CELL_EMB_NPY,
                "sorted_gene_token_ids": sorted_gene_token_ids.astype(np.int64)
            },
            prefix="fr",
        )

print("‚úÖ DONE")

gene_md = pd.read_parquet(GENE_META_PATH)[["token_id","ensembl_id"]].copy()
tid2ensg = dict(zip(gene_md["token_id"].astype(int), gene_md["ensembl_id"].astype(str)))
topk_ensg = np.array([tid2ensg[int(t)] for t in sorted_gene_token_ids], dtype=object)

np.save(os.path.join(CKPT_DIR, f"topk_ensg_k{TOP_K}.npy"), topk_ensg)


## fast option

In [1]:
# =========================================================
# f_r FULL TRAINING SCRIPT (SAFE-FAST DATALOADER)
# - ‚úÖ token_id < 3 filtered everywhere
# - ‚úÖ NO [TARGET] token in input
# - ‚úÖ input prefix: [CLS][DRUG][CELL] + gene tokens
# - ‚úÖ load pretrained gene embeddings (gene_embeddings.npy)
# - ‚úÖ load pretrained cell-line embeddings (cell_embeddings.npy) + mapping (cell2id.csv)
#
# SAFE-FAST changes (—á—Ç–æ–±—ã —É—Å–∫–æ—Ä–∏—Ç—å –∏ –Ω–µ —É–±–∏—Ç—å CPU / kernel):
# - ‚úÖ —á—Ç–µ–Ω–∏–µ row_group –ë–ï–ó pandas (pyarrow + numpy mask)
# - ‚úÖ –Ω–µ–±–æ–ª—å—à–æ–π num_workers=2 + prefetch_factor=1 + persistent_workers=True
# - ‚úÖ LRU-–∫—ç—à ParquetFile –≤–Ω—É—Ç—Ä–∏ worker
# - ‚úÖ cap_per_pair_in_rg=256 (–∫–æ–Ω—Ç—Ä–æ–ª—å CPU –Ω–∞–≥—Ä—É–∑–∫–∏)
# - ‚úÖ (–æ–ø—Ü–∏–æ–Ω–∞–ª—å–Ω–æ) –æ–≥—Ä–∞–Ω–∏—á–µ–Ω–∏–µ –ø–æ—Ç–æ–∫–æ–≤ BLAS/OMP
# =========================================================

import os, glob, math, random
from collections import defaultdict, OrderedDict
from datetime import datetime
from itertools import islice

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from torch.cuda.amp import GradScaler
from torch.amp import autocast

import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm

import scanpy as sc
from scipy import sparse
from sklearn.model_selection import train_test_split

In [2]:
# ---------------------------------------------------------
# (–û–ü–¶–ò–û–ù–ê–õ–¨–ù–û, –ù–û –†–ï–ö–û–ú–ï–ù–î–£–Æ) –æ–≥—Ä–∞–Ω–∏—á–∏–≤–∞–µ–º —á–∏—Å–ª–æ –ø–æ—Ç–æ–∫–æ–≤
# —á—Ç–æ–±—ã CPU –Ω–µ "–≤–∑–ª–µ—Ç–∞–ª" –∏–∑-–∑–∞ OpenMP/BLAS –∏ –Ω–µ —Ä–æ–Ω—è–ª kernel
# ---------------------------------------------------------
os.environ.setdefault("OMP_NUM_THREADS", "4")
os.environ.setdefault("MKL_NUM_THREADS", "4")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "4")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "4")
torch.set_num_threads(4)


# =========================================================
# 0) PATHS / CONFIG
# =========================================================
GENE_META_PATH = "/data/aiffel/data/Tahoe-100M/metadata/gene_metadata.parquet"
DRUG_META_PATH = "/data/aiffel/data/Tahoe-100M/metadata/drug_metadata.parquet"
COUNTS_CSV     = "/data/aiffel/babayakga/making_data/aiffel/babayakga/making_data/tahoe_counts_per_drug_cell_line.csv"

PARQUET_DIR    = "/data/aiffel/data/Tahoe-100M/data"
DMSO_H5AD      = "/data/aiffel/babayakga/outputs/dmso.h5ad"

# pretrained embeddings
PRETRAINED_GENE_NPY = "/data/aiffel/babayakga/pretraining/checkpoints_with_cell/gene_embeddings.npy"

CELL_CKPT_DIR  = "/data/aiffel/babayakga/pretraining/checkpoints_with_cell"
CELL2ID_CSV    = os.path.join(CELL_CKPT_DIR, "cell2id.csv")
CELL_EMB_NPY   = os.path.join(CELL_CKPT_DIR, "cell_embeddings.npy")

# smiles embedding for drugs (must match drug_metadata row order)
SMILES_EMB_PATH = "/data/aiffel/babayakga/f_p module/f_r/drug_smiles_emb_all.pt"

# training
CONTROL_DRUG = "DMSO_TF"
SEED = 42

MIN_GENE_TOKEN_ID = 3   # ‚úÖ exclude 0/1/2 (not real genes)

TOP_K      = 1000
MAX_LEN    = 512        # gene tokens length (not counting prefix)
BATCH_SIZE = 16

TOTAL_EPOCHS     = 8
WARMUP_EPOCHS    = 2
lambda_rank_main = 0.2

STEPS_PER_EPOCH  = 10000
VAL_STEPS        = 900

GRAD_CLIP = 1.0
LR = 3e-4

CKPT_DIR   = "/data/aiffel/babayakga/checkpoints/f_r_withcellline"
SAVE_EVERY = 2

# —Ç–≤–æ–π –≤—ã–±–æ—Ä GPU
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


# =========================================================
# 1) VOCAB (special + ENSG) from gene_metadata
# =========================================================
def build_vocab_from_gene_metadata(gene_meta_path: str):
    """
    vocab-space: special tokens + all ensembl_id
    gene-space: Tahoe token_id (0..N_GENES-1)
    """
    SPECIAL_TOKENS = ["[PAD]", "[CLS]", "[DRUG]", "[TARGET]", "[CELL]", "[MASK]"]

    gene_md = pd.read_parquet(gene_meta_path).copy()
    gene_md["ensembl_id"] = gene_md["ensembl_id"].astype(str)
    gene_md["token_id"]   = gene_md["token_id"].astype(int)
    gene_md = gene_md.sort_values("token_id").reset_index(drop=True)

    N_GENES = int(gene_md["token_id"].max()) + 1

    local_token_to_id = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
    for ensg in gene_md["ensembl_id"].tolist():
        if ensg not in local_token_to_id:
            local_token_to_id[ensg] = len(local_token_to_id)

    token_id_to_vocab_id = {
        int(tid): int(local_token_to_id[str(ensg)])
        for tid, ensg in zip(gene_md["token_id"].values, gene_md["ensembl_id"].values)
    }

    ensg_to_token_id = {
        str(ensg): int(tid)
        for ensg, tid in zip(gene_md["ensembl_id"].values, gene_md["token_id"].values)
    }

    PAD_ID  = local_token_to_id["[PAD]"]
    CLS_ID  = local_token_to_id["[CLS]"]
    DRUG_ID = local_token_to_id["[DRUG]"]
    CELL_ID = local_token_to_id["[CELL]"]

    return local_token_to_id, token_id_to_vocab_id, ensg_to_token_id, N_GENES, SPECIAL_TOKENS, PAD_ID, CLS_ID, DRUG_ID, CELL_ID


local_token_to_id, token_id_to_vocab_id, ensg_to_token_id, N_GENES, SPECIAL_TOKENS, PAD_ID, CLS_ID, DRUGTOK_ID, CELLTOK_ID = \
    build_vocab_from_gene_metadata(GENE_META_PATH)

VOCAB_SIZE = len(local_token_to_id)
print("VOCAB_SIZE(vocab-space):", VOCAB_SIZE)
print("N_GENES(gene-space):", N_GENES)


# =========================================================
# 2) LOAD cell2id mapping (MUST match pretrained cell embeddings)
# =========================================================
if not os.path.exists(CELL2ID_CSV):
    raise FileNotFoundError(f"cell2id.csv not found: {CELL2ID_CSV}")
if not os.path.exists(CELL_EMB_NPY):
    raise FileNotFoundError(f"cell_embeddings.npy not found: {CELL_EMB_NPY}")

cell2id_df = pd.read_csv(CELL2ID_CSV)
cell2id_df["cell_line_id"] = cell2id_df["cell_line_id"].astype(str)
cell_line2id = {c: int(i) for c, i in zip(cell2id_df["cell_line_id"], cell2id_df["cell_id"])}

NUM_CELL_LINE = len(cell_line2id)
W_cell = np.load(CELL_EMB_NPY)
print("NUM_CELL_LINE(from cell2id.csv):", NUM_CELL_LINE, "| cell_emb rows:", W_cell.shape[0])

assert W_cell.shape[0] == NUM_CELL_LINE, f"cell2id size {NUM_CELL_LINE} != cell_emb rows {W_cell.shape[0]}"


# =========================================================
# 3) SPLIT PAIRS (drug, cell_line) from COUNTS_CSV
# =========================================================
DRUG_COL, CELL_COL, N_COL = "drug", "cell_line_id", "n_cells"

counts = pd.read_csv(COUNTS_CSV)
counts[DRUG_COL] = counts[DRUG_COL].astype(str)
counts[CELL_COL] = counts[CELL_COL].astype(str)

MIN_TRAIN = 1000
MIN_EVAL  = 1000

train_pool = counts[counts[N_COL] >= MIN_TRAIN].copy()
eval_pool  = counts[counts[N_COL] >= MIN_EVAL].copy()

pairs_df = train_pool[[DRUG_COL, CELL_COL]].drop_duplicates()

train_df, val_df = train_test_split(
    pairs_df,
    test_size=0.1,
    random_state=SEED,
    stratify=pairs_df[DRUG_COL] if len(pairs_df) else None,
)

train_df = train_df[train_df[DRUG_COL] != CONTROL_DRUG]
val_df   = val_df[val_df[DRUG_COL]   != CONTROL_DRUG]

train_pairs = list(zip(train_df[DRUG_COL], train_df[CELL_COL]))
val_pairs   = list(zip(val_df[DRUG_COL],   val_df[CELL_COL]))

eval_pairs_df = eval_pool[[DRUG_COL, CELL_COL]].drop_duplicates()
eval_pairs_df = eval_pairs_df[eval_pairs_df[DRUG_COL] != CONTROL_DRUG]
eval_pairs = list(zip(eval_pairs_df[DRUG_COL], eval_pairs_df[CELL_COL]))

print("train pairs:", len(train_pairs))
print("val pairs:", len(val_pairs))
print(f"eval pairs (>={MIN_EVAL}):", len(eval_pairs))


# =========================================================
# 4) INDEX PARQUET row-groups for valid pairs
# =========================================================
PARQUET_FILES = sorted(glob.glob(os.path.join(PARQUET_DIR, "**", "*.parquet"), recursive=True))
print("parquet files found:", len(PARQUET_FILES))

PARQUET_DRUG_COL = "drug"
PARQUET_CELL_COL = "cell_line_id"

def build_pair_to_locations(parquet_files, valid_pairs_set, drug_col, cell_col):
    out = defaultdict(list)
    for f in tqdm(parquet_files, desc="Index parquet row-groups"):
        try:
            pf = pq.ParquetFile(f)
        except Exception:
            continue
        for rg in range(pf.num_row_groups):
            try:
                tbl = pf.read_row_group(rg, columns=[drug_col, cell_col])
                # –í–ù–ò–ú–ê–ù–ò–ï: —Ç—É—Ç pandas –û–ö, –ø–æ—Ç–æ–º—É —á—Ç–æ —ç—Ç–æ —Ä–∞–∑–æ–≤–∞—è –∏–Ω–¥–µ–∫—Å–∞—Ü–∏—è (–æ–¥–∏–Ω —Ä–∞–∑)
                df = tbl.to_pandas()
            except Exception:
                continue

            pairs_here = set(zip(df[drug_col].astype(str), df[cell_col].astype(str)))
            inter = pairs_here.intersection(valid_pairs_set)
            for p in inter:
                out[p].append((f, rg))
    return dict(out)

valid_pairs_set = set(train_pairs) | set(val_pairs)
pair_to_locations = build_pair_to_locations(
    parquet_files=PARQUET_FILES,
    valid_pairs_set=valid_pairs_set,
    drug_col=PARQUET_DRUG_COL,
    cell_col=PARQUET_CELL_COL
)
print("indexed pairs:", len(pair_to_locations))


# =========================================================
# 5) DMSO baselines + topK variance genes (gene-space token_id)
#    ‚úÖ filter token_id < 3 here too
# =========================================================
def build_dmso_baselines_gene_space(dmso_h5ad_path: str, control_drug: str, N_GENES: int, ensg_to_token_id: dict,
                                    drug_col="drug", cell_col="cell_line_id", dtype=np.float32):
    adata = sc.read_h5ad(dmso_h5ad_path)
    obs = adata.obs
    X = adata.X.tocsr() if sparse.issparse(adata.X) else sparse.csr_matrix(adata.X)

    m = (obs[drug_col].astype(str).values == control_drug)
    idx_ctrl = np.where(m)[0]
    if idx_ctrl.size == 0:
        raise ValueError(f"No control rows: {control_drug}")

    ensgs = adata.var_names.astype(str).tolist()
    token_ids, cols = [], []
    for j, ensg in enumerate(ensgs):
        tid = ensg_to_token_id.get(ensg, None)
        if tid is None:
            continue
        tid = int(tid)
        if tid < MIN_GENE_TOKEN_ID:
            continue
        token_ids.append(tid)
        cols.append(j)

    token_ids = np.asarray(token_ids, dtype=np.int64)
    cols      = np.asarray(cols, dtype=np.int64)

    Xc = X[idx_ctrl][:, cols]
    mean_global_sub = np.asarray(Xc.mean(axis=0)).ravel().astype(dtype)

    baseline_global = np.zeros(N_GENES, dtype=dtype)
    baseline_global[token_ids] = mean_global_sub

    baseline_by_cl = {}
    cls_all = obs[cell_col].astype(str).values
    for cl in np.unique(cls_all):
        cl_idx = np.where(m & (cls_all == cl))[0]
        if cl_idx.size == 0:
            continue
        Xcl = X[cl_idx][:, cols]
        mean_cl_sub = np.asarray(Xcl.mean(axis=0)).ravel().astype(dtype)
        v = np.zeros(N_GENES, dtype=dtype)
        v[token_ids] = mean_cl_sub
        baseline_by_cl[str(cl)] = v

    return baseline_global, baseline_by_cl


def topk_by_variance_gene_space(dmso_h5ad_path: str, control_drug: str, N_GENES: int, ensg_to_token_id: dict,
                               drug_col="drug", top_k=1000):
    adata = sc.read_h5ad(dmso_h5ad_path)
    obs = adata.obs
    X = adata.X.tocsr() if sparse.issparse(adata.X) else sparse.csr_matrix(adata.X)

    m = (obs[drug_col].astype(str).values == control_drug)
    idx = np.where(m)[0]
    if idx.size == 0:
        raise ValueError(f"No control rows: {control_drug}")

    ensgs = adata.var_names.astype(str).tolist()
    token_ids, cols = [], []
    for j, ensg in enumerate(ensgs):
        tid = ensg_to_token_id.get(ensg, None)
        if tid is None:
            continue
        tid = int(tid)
        if tid < MIN_GENE_TOKEN_ID:
            continue
        token_ids.append(tid)
        cols.append(j)

    token_ids = np.asarray(token_ids, dtype=np.int64)
    cols      = np.asarray(cols, dtype=np.int64)

    Xc = X[idx][:, cols]
    ex  = np.asarray(Xc.mean(axis=0)).ravel()
    ex2 = np.asarray(Xc.power(2).mean(axis=0)).ravel()
    var = ex2 - ex**2

    top_local = np.argsort(-var)[:top_k]
    top_gene_token_ids = token_ids[top_local]
    return top_gene_token_ids.astype(np.int64)


baseline_global, baseline_by_cl = build_dmso_baselines_gene_space(
    dmso_h5ad_path=DMSO_H5AD,
    control_drug=CONTROL_DRUG,
    N_GENES=N_GENES,
    ensg_to_token_id=ensg_to_token_id,
)

sorted_gene_token_ids = topk_by_variance_gene_space(
    dmso_h5ad_path=DMSO_H5AD,
    control_drug=CONTROL_DRUG,
    N_GENES=N_GENES,
    ensg_to_token_id=ensg_to_token_id,
    top_k=TOP_K,
)

assert (sorted_gene_token_ids >= MIN_GENE_TOKEN_ID).all(), "TOP_K contains token_id < 3 !"

print("baseline_global:", baseline_global.shape, "baseline_by_cl:", len(baseline_by_cl))
print("sorted_gene_token_ids:", sorted_gene_token_ids.shape, sorted_gene_token_ids[:10])


# =========================================================
# 6) DRUG -> id + SMILES embeddings
# =========================================================
drug_meta_df = pd.read_parquet(DRUG_META_PATH).copy()
drug_meta_df["drug"] = drug_meta_df["drug"].astype(str)
drugs = drug_meta_df["drug"].tolist()
drug2id = {d: i for i, d in enumerate(drugs)}
print("num drugs:", len(drug2id))

smiles_tensor = torch.load(SMILES_EMB_PATH, map_location="cpu").to(torch.float32)
assert smiles_tensor.shape[0] == len(drug_meta_df), "SMILES rows != drug_metadata rows"

drug_to_smiles_emb = {d: smiles_tensor[i] for i, d in enumerate(drugs)}
smiles_dim = int(smiles_tensor.shape[-1])
print("smiles_dim:", smiles_dim)


# =========================================================
# 7) pair weights
# =========================================================
def make_pair_weights_from_counts(counts_df, pairs, drug_col="drug", cell_col="cell_line_id", n_col="n_cells",
                                  mode="inv_sqrt", eps=1.0):
    tmp = counts_df[[drug_col, cell_col, n_col]].copy()
    tmp[drug_col] = tmp[drug_col].astype(str)
    tmp[cell_col] = tmp[cell_col].astype(str)

    pair2n = {(d, c): int(n) for d, c, n in tmp.values}

    w = []
    for p in pairs:
        n = pair2n.get(p, 0)
        if mode == "inv":
            ww = 1.0 / (n + eps)
        elif mode == "inv_log":
            ww = 1.0 / np.log1p(n + eps)
        else:
            ww = 1.0 / np.sqrt(n + eps)
        w.append(float(ww))

    w = np.asarray(w, dtype=np.float64)
    w = np.clip(w, 0.0, None)
    w = w / (w.sum() + 1e-12)
    pair2w = {p: float(wi) for p, wi in zip(pairs, w)}
    return w, pair2w


# =========================================================
# 8) DATASET (Iterable, aligned)  ‚úÖ SAFE-FAST (NO PANDAS)
# =========================================================
class _PFCache:
    """–ü—Ä–æ—Å—Ç–æ–π LRU-–∫—ç—à ParquetFile –≤–Ω—É—Ç—Ä–∏ worker (–±–µ–∑–æ–ø–∞—Å–Ω–æ –ø–æ –ø–∞–º—è—Ç–∏)."""
    def __init__(self, max_items=32):
        self.max_items = int(max_items)
        self.cache = OrderedDict()

    def get(self, path: str) -> pq.ParquetFile:
        pf = self.cache.get(path, None)
        if pf is not None:
            self.cache.move_to_end(path)
            return pf
        pf = pq.ParquetFile(path)
        self.cache[path] = pf
        if len(self.cache) > self.max_items:
            self.cache.popitem(last=False)
        return pf


class FRSeqExpressionParquetDatasetAligned(IterableDataset):
    """
    input_ids: [CLS][DRUG][CELL] + gene tokens (vocab-space ids)
    values:    delta (val - baseline[cell or global]) aligned with tokens
    y_topk:    true expression on TOP_K genes (gene-space token_id list)
    """

    def __init__(
        self,
        pair_to_locations,
        pairs,
        token_id_to_vocab_id,
        sorted_gene_token_ids,
        baseline_global,
        baseline_by_cellline,
        cell_line2id,
        drug2id,
        drug_to_smiles_emb,
        pair_weights=None,
        seed=42,
        max_gene_len=512,
        top_k=1000,
        batch_size=16,
        pad_id=0,
        cls_id=1,
        drugtok_id=2,
        celltok_id=4,
        drug_col="drug",
        cell_col="cell_line_id",
        genes_col="genes",
        expr_col="expressions",
        cap_per_pair_in_rg=256,       # ‚úÖ –≤–∞–∂–Ω—ã–π "–ø—Ä–µ–¥–æ—Ö—Ä–∞–Ω–∏—Ç–µ–ª—å"
        max_tries=30,
        shuffle=False,
        pf_cache_size=32,             # ‚úÖ —Ä–∞–∑–º–µ—Ä LRU –∫—ç—à–∞ —Ñ–∞–π–ª–æ–≤
    ):
        super().__init__()
        self.pair_to_locations = pair_to_locations
        self.pairs = list(pairs)

        self.token_id_to_vocab_id = token_id_to_vocab_id
        self.q = np.asarray(sorted_gene_token_ids, dtype=np.int64)

        self.baseline_global = np.asarray(baseline_global, dtype=np.float32)
        self.baseline_by_cellline = baseline_by_cellline or {}

        self.cell_line2id = cell_line2id
        self.drug2id = drug2id
        self.drug_to_smiles_emb = drug_to_smiles_emb

        self.max_gene_len = int(max_gene_len)
        self.top_k = int(top_k)
        self.batch_size = int(batch_size)

        self.pad_id = int(pad_id)
        self.cls_id = int(cls_id)
        self.drugtok_id = int(drugtok_id)
        self.celltok_id = int(celltok_id)

        self.drug_col = drug_col
        self.cell_col = cell_col
        self.genes_col = genes_col
        self.expr_col = expr_col

        self.cap_per_pair_in_rg = int(cap_per_pair_in_rg) if cap_per_pair_in_rg is not None else None
        self.max_tries = int(max_tries)
        self.shuffle = bool(shuffle)

        self.num_prefix = 3  # ‚úÖ [CLS][DRUG][CELL]
        self.seq_len = self.num_prefix + self.max_gene_len

        any_vec = next(iter(self.drug_to_smiles_emb.values()))
        self.smiles_dim = int(any_vec.shape[-1])

        self.seed = int(seed)
        self.pf_cache_size = int(pf_cache_size)

        # weights
        if pair_weights is None:
            self.pair_weights = None
        elif isinstance(pair_weights, dict):
            w = np.asarray([pair_weights.get(p, 0.0) for p in self.pairs], dtype=np.float64)
            w = np.clip(w, 0.0, None)
            w = w / (w.sum() + 1e-12)
            self.pair_weights = w
        else:
            w = np.asarray(pair_weights, dtype=np.float64)
            assert len(w) == len(self.pairs), "pair_weights length must match pairs length"
            w = np.clip(w, 0.0, None)
            w = w / (w.sum() + 1e-12)
            self.pair_weights = w

    @staticmethod
    def _to_numpy_str(chunked_arr):
        # zero_copy_only=False –±–µ–∑–æ–ø–∞—Å–Ω–µ–µ (–∏–Ω–∞—á–µ –∏–Ω–æ–≥–¥–∞ –ø–∞–¥–∞–µ—Ç –Ω–∞ chunked)
        return chunked_arr.combine_chunks().to_numpy(zero_copy_only=False).astype(str)

    def _prepare_sparse(self, genes, expr):
        idx = np.asarray(genes, dtype=np.int64)
        val = np.asarray(expr, dtype=np.float32)
        if idx.size == 0:
            return idx, val

        keep = idx >= MIN_GENE_TOKEN_ID
        idx = idx[keep]
        val = val[keep]
        if idx.size == 0:
            return idx, val

        order = np.argsort(idx)
        return idx[order], val[order]

    def _make_y_true_topk(self, idx_sorted, val_sorted):
        q = self.q
        if idx_sorted.size == 0:
            return np.zeros(q.shape[0], dtype=np.float32)

        pos = np.searchsorted(idx_sorted, q)
        y = np.zeros(q.shape[0], dtype=np.float32)

        in_bounds = (pos < idx_sorted.size)
        match = np.zeros(q.shape[0], dtype=bool)
        match[in_bounds] = (idx_sorted[pos[in_bounds]] == q[in_bounds])
        ok = in_bounds & match
        y[ok] = val_sorted[pos[ok]]
        return y.astype(np.float32)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        wid = 0 if worker_info is None else int(worker_info.id)
        rng = np.random.default_rng(self.seed + 1337 * wid)

        pairs = self.pairs
        weights = self.pair_weights

        cols = [self.drug_col, self.cell_col, self.genes_col, self.expr_col]

        # ‚úÖ –∫—ç—à ParquetFile –≤–Ω—É—Ç—Ä–∏ worker
        pf_cache = _PFCache(max_items=self.pf_cache_size)

        while True:
            # weighted pair sampling
            if weights is None:
                drug_name, cell_line = pairs[rng.integers(0, len(pairs))]
            else:
                idxp = rng.choice(len(pairs), p=weights)
                drug_name, cell_line = pairs[idxp]

            locs = self.pair_to_locations.get((drug_name, cell_line), [])
            if not locs:
                continue

            # baseline (cell-specific if exists)
            baseline = self.baseline_by_cellline.get(cell_line, self.baseline_global)

            # cell id must exist in pretrained mapping
            if cell_line not in self.cell_line2id:
                continue
            cell_id = self.cell_line2id[cell_line]

            drug_id = self.drug2id.get(drug_name, 0)

            sm = self.drug_to_smiles_emb.get(drug_name, None)
            if sm is None:
                smiles_emb = torch.zeros(self.smiles_dim, dtype=torch.float32)
            else:
                smiles_emb = sm.detach().to(torch.float32) if isinstance(sm, torch.Tensor) \
                    else torch.tensor(sm, dtype=torch.float32)

            for _ in range(self.max_tries):
                fpath, rg_id = locs[rng.integers(0, len(locs))]

                try:
                    pf = pf_cache.get(fpath)
                    table = pf.read_row_group(rg_id, columns=cols)
                except Exception:
                    continue

                # --- –§–∏–ª—å—Ç—Ä –ø–æ (drug, cell) –±–µ–∑ pandas ---
                try:
                    drug_arr = self._to_numpy_str(table[self.drug_col])
                    cell_arr = self._to_numpy_str(table[self.cell_col])
                except Exception:
                    continue

                mask_pair = (drug_arr == str(drug_name)) & (cell_arr == str(cell_line))
                idxs = np.where(mask_pair)[0]
                if idxs.size < self.batch_size:
                    continue

                # ‚úÖ cap –¥–ª—è –∫–æ–Ω—Ç—Ä–æ–ª—è CPU (–Ω–µ –¥–∞—ë–º –æ–≥—Ä–æ–º–Ω—ã–º row_group –≥—Ä—É–∑–∏—Ç—å –Ω–∞—Å)
                if self.cap_per_pair_in_rg is not None and idxs.size > self.cap_per_pair_in_rg:
                    idxs = rng.choice(idxs, size=self.cap_per_pair_in_rg, replace=False)

                if idxs.size < self.batch_size:
                    continue

                choose = rng.choice(idxs, size=self.batch_size, replace=False)

                # –∫–æ–ª–æ–Ω–∫–∏ genes/expr (ChunkedArray)
                genes_col = table[self.genes_col].combine_chunks()
                exprs_col = table[self.expr_col].combine_chunks()

                # allocate batch arrays
                input_ids = np.full((self.batch_size, self.seq_len), self.pad_id, dtype=np.int64)
                values    = np.zeros((self.batch_size, self.seq_len), dtype=np.float32)
                mask      = np.zeros((self.batch_size, self.seq_len), dtype=np.int64)
                y_topk    = np.zeros((self.batch_size, self.top_k), dtype=np.float32)

                cell_batch   = np.full((self.batch_size,), cell_id, dtype=np.int64)
                drug_batch   = np.full((self.batch_size,), drug_id, dtype=np.int64)

                # smiles -> numpy (–æ–¥–∏–Ω —Ä–∞–∑)
                sm_np = smiles_emb.detach().cpu().numpy() if isinstance(smiles_emb, torch.Tensor) else np.asarray(smiles_emb, dtype=np.float32)
                smiles_batch = np.repeat(sm_np[None, :], repeats=self.batch_size, axis=0).astype(np.float32)

                # prefix: [CLS][DRUG][CELL]
                input_ids[:, 0] = self.cls_id
                input_ids[:, 1] = self.drugtok_id
                input_ids[:, 2] = self.celltok_id
                mask[:, :self.num_prefix] = 1

                # build each sample
                for b, j in enumerate(choose):
                    try:
                        genes = genes_col[int(j)].as_py()
                        expr  = exprs_col[int(j)].as_py()
                    except Exception:
                        continue

                    idx, val = self._prepare_sparse(genes, expr)

                    # y_true on TOP_K
                    y_topk[b] = self._make_y_true_topk(idx, val)

                    if idx.size == 0:
                        continue

                    # delta
                    base_vals = baseline[idx]
                    delta = (val - base_vals)

                    # choose top genes by |delta|
                    k = min(self.max_gene_len, idx.size)
                    if k <= 0:
                        continue

                    if k == idx.size:
                        top_pos = np.argsort(-np.abs(delta))
                    else:
                        top_pos = np.argpartition(-np.abs(delta), k - 1)[:k]
                        top_pos = top_pos[np.argsort(-np.abs(delta[top_pos]))]

                    sel_gene_token_ids = idx[top_pos]
                    sel_delta = delta[top_pos]

                    # token_id -> vocab_id (drop missing)
                    sel_vocab_ids = np.asarray(
                        [self.token_id_to_vocab_id.get(int(t), -1) for t in sel_gene_token_ids],
                        dtype=np.int64
                    )
                    ok = sel_vocab_ids != -1
                    sel_vocab_ids = sel_vocab_ids[ok]
                    sel_delta = sel_delta[ok]

                    L = min(self.max_gene_len, sel_vocab_ids.size)
                    if L <= 0:
                        continue

                    start = self.num_prefix
                    input_ids[b, start:start+L] = sel_vocab_ids[:L]
                    values[b,    start:start+L] = sel_delta[:L]
                    mask[b,      start:start+L] = 1

                yield (
                    torch.tensor(input_ids, dtype=torch.long),
                    torch.tensor(values, dtype=torch.float32),
                    torch.tensor(mask, dtype=torch.long),
                    torch.tensor(y_topk, dtype=torch.float32),
                    torch.tensor(cell_batch, dtype=torch.long),
                    torch.tensor(drug_batch, dtype=torch.long),
                    torch.tensor(smiles_batch, dtype=torch.float32),
                )
                break


# =========================================================
# 9) DataLoaders (SAFE)
# =========================================================
train_w, _ = make_pair_weights_from_counts(counts, train_pairs, mode="inv_sqrt")

train_ds = FRSeqExpressionParquetDatasetAligned(
    pair_to_locations=pair_to_locations,
    pairs=train_pairs,
    pair_weights=train_w,
    token_id_to_vocab_id=token_id_to_vocab_id,
    sorted_gene_token_ids=sorted_gene_token_ids,
    baseline_global=baseline_global,
    baseline_by_cellline=baseline_by_cl,
    cell_line2id=cell_line2id,
    drug2id=drug2id,
    drug_to_smiles_emb=drug_to_smiles_emb,
    batch_size=BATCH_SIZE,
    max_gene_len=MAX_LEN,
    top_k=TOP_K,
    shuffle=False,
    cap_per_pair_in_rg=256,     # ‚úÖ safety knob
    pf_cache_size=32,
)

val_ds = FRSeqExpressionParquetDatasetAligned(
    pair_to_locations=pair_to_locations,
    pairs=val_pairs,
    pair_weights=None,
    token_id_to_vocab_id=token_id_to_vocab_id,
    sorted_gene_token_ids=sorted_gene_token_ids,
    baseline_global=baseline_global,
    baseline_by_cellline=baseline_by_cl,
    cell_line2id=cell_line2id,
    drug2id=drug2id,
    drug_to_smiles_emb=drug_to_smiles_emb,
    batch_size=BATCH_SIZE,
    max_gene_len=MAX_LEN,
    top_k=TOP_K,
    shuffle=False,
    cap_per_pair_in_rg=256,
    pf_cache_size=32,
)

# ‚úÖ –±–µ–∑–æ–ø–∞—Å–Ω—ã–µ –ø–∞—Ä–∞–º–µ—Ç—Ä—ã: –Ω–µ–º–Ω–æ–≥–æ –≤–æ—Ä–∫–µ—Ä–æ–≤ + –º–∞–ª–µ–Ω—å–∫–∏–π prefetch
NUM_WORKERS = 2
PREFETCH = 1

train_loader = DataLoader(
    train_ds,
    batch_size=None,
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH,
    persistent_workers=True,
    pin_memory=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=None,
    num_workers=1,
    prefetch_factor=1,
    persistent_workers=True,
    pin_memory=True,
)


# =========================================================
# 10) MODEL (Cell2Sentence-like Encoder for f_r)
# =========================================================
class Cell2SentenceEncoderFR(nn.Module):
    """
    prefix: [CLS][DRUG][CELL] + gene tokens
    - token_emb: vocab-space ids
    - values: delta
    - inject smiles into position 1 ([DRUG])
    - inject cell_line embedding into position 2 ([CELL]) ‚úÖ pretrained
    """
    def __init__(self, vocab_size, d_model, n_heads, num_layers, max_len_with_prefix, smiles_dim, num_cell_lines, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=PAD_ID)
        self.value_proj = nn.Sequential(
            nn.Linear(1, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )
        self.pos_emb = nn.Embedding(max_len_with_prefix, d_model)

        self.cell_line_emb = nn.Embedding(num_cell_lines, d_model)
        self.smiles_proj = nn.Linear(smiles_dim, d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads,
            dim_feedforward=4*d_model,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

    def forward(self, input_ids, values, attention_mask, cell_line_id, smiles_emb):
        B, L = input_ids.shape
        device_ = input_ids.device

        x = self.token_emb(input_ids) + self.value_proj(values.unsqueeze(-1))

        pos = torch.arange(L, device=device_).unsqueeze(0).expand(B, L)
        x = x + self.pos_emb(pos)

        # inject drug / cell info
        x[:, 1, :] = x[:, 1, :] + self.smiles_proj(smiles_emb.to(device=device_, dtype=torch.float32)).to(x.dtype)
        x[:, 2, :] = x[:, 2, :] + self.cell_line_emb(cell_line_id.to(device=device_)).to(x.dtype)

        key_padding_mask = (attention_mask == 0)
        h = self.encoder(x, src_key_padding_mask=key_padding_mask)
        return h[:, 0, :]  # CLS


class FRModelExpression(nn.Module):
    def __init__(self, encoder, d_model, out_dim):
        super().__init__()
        self.encoder = encoder
        self.head = nn.Linear(d_model, out_dim)

    def forward(self, input_ids, values, mask, cell_line_id, smiles_emb):
        h = self.encoder(input_ids, values, mask, cell_line_id, smiles_emb)
        return self.head(h)


# =========================================================
# 11) Load pretrained gene + cell embeddings
# =========================================================
def load_pretrained_token_emb_from_gene_metadata(token_emb: nn.Embedding, npy_path: str, gene_meta_path: str, local_token_to_id: dict, device):
    W = np.load(npy_path)  # (N_genes, d_model)
    Wt = torch.tensor(W, dtype=torch.float32, device=device)

    gene_md = pd.read_parquet(gene_meta_path).copy()
    gene_md["ensembl_id"] = gene_md["ensembl_id"].astype(str)
    gene_md["token_id"] = gene_md["token_id"].astype(int)
    gene_md = gene_md.sort_values("token_id").reset_index(drop=True)

    if Wt.shape[1] != token_emb.weight.shape[1]:
        raise ValueError(f"d mismatch: npy d={Wt.shape[1]} vs token_emb d={token_emb.weight.shape[1]}")

    n = min(len(gene_md), Wt.shape[0])
    loaded = 0
    with torch.no_grad():
        for i in range(n):
            ensg = gene_md.loc[i, "ensembl_id"]
            vid = local_token_to_id.get(ensg, None)
            if vid is None:
                continue
            token_emb.weight[vid].copy_(Wt[i])
            loaded += 1
    print(f"‚úÖ Loaded pretrained gene token_emb: {loaded} genes")


def load_pretrained_cell_emb(cell_emb: nn.Embedding, cell_emb_npy: str, device):
    W = np.load(cell_emb_npy)  # (num_cell_lines, d_model)
    Wt = torch.tensor(W, dtype=torch.float32, device=device)

    if Wt.shape != cell_emb.weight.shape:
        raise ValueError(f"cell_emb shape mismatch: npy={tuple(Wt.shape)} vs emb={tuple(cell_emb.weight.shape)}")

    with torch.no_grad():
        cell_emb.weight.copy_(Wt)
    print(f"‚úÖ Loaded pretrained cell_line_emb: {tuple(Wt.shape)}")


def sanity_check_gene_emb_mapping(
    gene_meta_path,
    local_token_to_id,
    token_emb: torch.nn.Embedding,
    pretrained_gene_npy,
    n_check=20,
    seed=0,
):
    gene_md = pd.read_parquet(gene_meta_path).copy()
    gene_md["ensembl_id"] = gene_md["ensembl_id"].astype(str)
    gene_md["token_id"]   = gene_md["token_id"].astype(int)
    gene_md = gene_md.sort_values("token_id").reset_index(drop=True)

    W = np.load(pretrained_gene_npy)  # (N_genes, d_model)
    assert W.shape[1] == token_emb.weight.shape[1]

    rng = np.random.default_rng(seed)
    idxs = rng.integers(0, min(len(gene_md), W.shape[0]), size=n_check)

    max_abs = 0.0
    bad = 0

    with torch.no_grad():
        for i in idxs:
            ensg = gene_md.loc[i, "ensembl_id"]
            vid = local_token_to_id.get(ensg, None)
            if vid is None:
                continue

            a = token_emb.weight[vid].detach().cpu().numpy()
            b = W[i]

            diff = np.max(np.abs(a - b))
            max_abs = max(max_abs, float(diff))
            if diff > 1e-6:
                bad += 1
                print("Mismatch:", "i=", i, "ensg=", ensg, "vid=", vid, "max_abs_diff=", diff)

    print(f"[sanity] checked={n_check}, bad={bad}, max_abs_diff={max_abs}")


# =========================================================
# 12) Init model
# =========================================================
D_MODEL = 256
assert W_cell.shape[1] == D_MODEL, f"cell_emb dim {W_cell.shape[1]} != D_MODEL {D_MODEL}"

encoder = Cell2SentenceEncoderFR(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_heads=8,
    num_layers=4,
    max_len_with_prefix=(3 + MAX_LEN),
    smiles_dim=smiles_dim,
    num_cell_lines=NUM_CELL_LINE,
    dropout=0.1,
).to(device)

load_pretrained_token_emb_from_gene_metadata(
    token_emb=encoder.token_emb,
    npy_path=PRETRAINED_GENE_NPY,
    gene_meta_path=GENE_META_PATH,
    local_token_to_id=local_token_to_id,
    device=device,
)

load_pretrained_cell_emb(
    cell_emb=encoder.cell_line_emb,
    cell_emb_npy=CELL_EMB_NPY,
    device=device
)

fr_model = FRModelExpression(encoder=encoder, d_model=D_MODEL, out_dim=TOP_K).to(device)
optimizer = torch.optim.AdamW(fr_model.parameters(), lr=LR, weight_decay=0.01)
scaler = GradScaler(enabled=(device.type == "cuda"))

print("‚úÖ f_r model ready")
sanity_check_gene_emb_mapping(GENE_META_PATH, local_token_to_id, encoder.token_emb, PRETRAINED_GENE_NPY)


# =========================================================
# 13) Losses (MSE + ranking)
# =========================================================
mse_loss = nn.MSELoss()
baseline_vec = torch.tensor(baseline_global[sorted_gene_token_ids], dtype=torch.float32, device=device)  # (TOP_K,)

def expr_ranking_loss(y_pred, y_true, baseline_vec, top_pos=30, num_neg=80, margin=0.0):
    device_ = y_pred.device
    B, K = y_pred.shape

    base = baseline_vec.view(1, K).expand(B, K).to(device=device_, dtype=y_pred.dtype)
    dt = y_true - base
    dp = y_pred - base

    losses = []
    for b in range(B):
        order = torch.argsort(dt[b].abs(), descending=True)
        P = min(top_pos, K)
        pos_idx = order[:P]
        neg_candidates = order[P:]
        if neg_candidates.numel() == 0:
            continue

        if neg_candidates.numel() > num_neg:
            neg_idx = neg_candidates[torch.randperm(neg_candidates.numel(), device=device_)[:num_neg]]
        else:
            neg_idx = neg_candidates

        pos_scores = dp[b, pos_idx]
        neg_scores = dp[b, neg_idx]

        diff = pos_scores.view(-1, 1) - neg_scores.view(1, -1)
        loss_mat = F.relu(margin - diff)
        losses.append(loss_mat.mean())

    if len(losses) == 0:
        return torch.tensor(0.0, device=device_, dtype=y_pred.dtype)
    return torch.stack(losses).mean()


# =========================================================
# 14) Checkpoint utils
# =========================================================
def save_fr_checkpoint(save_dir, fr_model, optimizer, scaler, epoch, metrics=None, extra=None, prefix="fr"):
    os.makedirs(save_dir, exist_ok=True)
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")

    ckpt = {
        "epoch": int(epoch),
        "model_state": fr_model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scaler_state": scaler.state_dict() if scaler is not None else None,
        "metrics": metrics or {},
        "extra": extra or {},
    }

    path = os.path.join(save_dir, f"{prefix}_epoch{epoch}_{ts}.pt")
    torch.save(ckpt, path)
    print(f"üíæ saved checkpoint: {path}")
    return path


# =========================================================
# 15) Eval helpers
# =========================================================
@torch.no_grad()
def eval_mse(fr_model, val_loader, steps, device):
    fr_model.eval()
    total = 0.0
    n = 0
    for batch in islice(val_loader, steps):
        input_ids, values, mask, y_true, cell_id, drug_id, smiles = batch
        input_ids = input_ids.to(device, non_blocking=True)
        values    = values.to(device, non_blocking=True)
        mask      = mask.to(device, non_blocking=True)
        y_true    = y_true.to(device, non_blocking=True)
        cell_id   = cell_id.to(device, non_blocking=True)
        smiles    = smiles.to(device, non_blocking=True)

        with autocast(device_type="cuda", enabled=(device.type == "cuda")):
            y_pred = fr_model(input_ids, values, mask, cell_id, smiles)
            loss = mse_loss(y_pred, y_true)

        bs = y_true.size(0)
        total += loss.item() * bs
        n += bs
    return total / max(1, n)

@torch.no_grad()
def baseline_mse(val_loader, steps, baseline_vec, device):
    total = 0.0
    n = 0
    baseline_vec = baseline_vec.to(device)
    for batch in islice(val_loader, steps):
        _, _, _, y_true, _, _, _ = batch
        y_true = y_true.to(device, non_blocking=True)
        bs = y_true.size(0)
        pred = baseline_vec.view(1, -1).expand(bs, -1)
        loss = F.mse_loss(pred, y_true)
        total += loss.item() * bs
        n += bs
    return total / max(1, n)


# =========================================================
# 16) TRAIN
# =========================================================
base_mse = baseline_mse(val_loader, steps=VAL_STEPS, baseline_vec=baseline_vec, device=device)
print(f"Baseline Valid MSE (DMSO) = {base_mse:.6f}")

print("üöÄ f_r training start")

for epoch in range(1, TOTAL_EPOCHS + 1):
    lambda_rank = 0.0 if epoch <= WARMUP_EPOCHS else lambda_rank_main

    fr_model.train()
    run_mse = 0.0
    run_rank = 0.0
    run_total = 0.0
    n = 0

    pbar = tqdm(
        islice(train_loader, STEPS_PER_EPOCH),
        total=STEPS_PER_EPOCH,
        desc=f"[Epoch {epoch}] Train",
        leave=True,
        dynamic_ncols=True
    )

    for batch in pbar:
        input_ids, values, mask, y_true, cell_id, drug_id, smiles = batch

        input_ids = input_ids.to(device, non_blocking=True)
        values    = values.to(device, non_blocking=True)
        mask      = mask.to(device, non_blocking=True)
        y_true    = y_true.to(device, non_blocking=True)
        cell_id   = cell_id.to(device, non_blocking=True)
        smiles    = smiles.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", enabled=(device.type == "cuda")):
            y_pred = fr_model(input_ids, values, mask, cell_id, smiles)
            loss_m = mse_loss(y_pred, y_true)

            if lambda_rank > 0:
                loss_r = expr_ranking_loss(
                    y_pred, y_true, baseline_vec,
                    top_pos=30, num_neg=80, margin=0.0
                )
            else:
                loss_r = torch.tensor(0.0, device=device)

            loss = loss_m + lambda_rank * loss_r

        if not torch.isfinite(loss):
            continue

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(fr_model.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()

        bs = y_true.size(0)
        run_mse   += loss_m.item() * bs
        run_rank  += loss_r.item() * bs
        run_total += loss.item() * bs
        n += bs

        pbar.set_postfix({
            "mse": f"{loss_m.item():.4f}",
            "rank": f"{loss_r.item():.4f}",
            "Œª_rank": float(lambda_rank),
        })

    train_mse   = run_mse   / max(1, n)
    train_rank  = run_rank  / max(1, n)
    train_total = run_total / max(1, n)

    val_mse = eval_mse(fr_model, val_loader, steps=VAL_STEPS, device=device)

    print(
        f"[Epoch {epoch}] "
        f"Train total={train_total:.6f}, mse={train_mse:.6f}, rank={train_rank:.6f} (Œª_rank={lambda_rank}) | "
        f"Valid mse={val_mse:.6f} | Baseline(DMSO) mse={base_mse:.6f}"
    )

    if (epoch % SAVE_EVERY == 0) or (epoch == TOTAL_EPOCHS):
        save_fr_checkpoint(
            save_dir=CKPT_DIR,
            fr_model=fr_model,
            optimizer=optimizer,
            scaler=scaler,
            epoch=epoch,
            metrics={
                "train_total": float(train_total),
                "train_mse": float(train_mse),
                "train_rank": float(train_rank),
                "val_mse": float(val_mse),
                "baseline_mse": float(base_mse),
                "lambda_rank": float(lambda_rank),
            },
            extra={
                "TOP_K": int(baseline_vec.numel()),
                "STEPS_PER_EPOCH": int(STEPS_PER_EPOCH),
                "VAL_STEPS": int(VAL_STEPS),
                "WARMUP_EPOCHS": int(WARMUP_EPOCHS),
                "lambda_rank_main": float(lambda_rank_main),
                "baseline_vec": baseline_vec.detach().float().cpu(),
                "CELL2ID_CSV": CELL2ID_CSV,
                "CELL_EMB_NPY": CELL_EMB_NPY,
                "sorted_gene_token_ids": sorted_gene_token_ids.astype(np.int64)
            },
            prefix="fr",
        )

print("‚úÖ DONE")

gene_md = pd.read_parquet(GENE_META_PATH)[["token_id","ensembl_id"]].copy()
tid2ensg = dict(zip(gene_md["token_id"].astype(int), gene_md["ensembl_id"].astype(str)))
topk_ensg = np.array([tid2ensg[int(t)] for t in sorted_gene_token_ids], dtype=object)

np.save(os.path.join(CKPT_DIR, f"topk_ensg_k{TOP_K}.npy"), topk_ensg)


VOCAB_SIZE(vocab-space): 62716
N_GENES(gene-space): 62713
NUM_CELL_LINE(from cell2id.csv): 50 | cell_emb rows: 50
train pairs: 15118
val pairs: 1680
eval pairs (>=1000): 16798
parquet files found: 3388


Index parquet row-groups: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3388/3388 [07:54<00:00,  7.14it/s]


indexed pairs: 16798
baseline_global: (62713,) baseline_by_cl: 50
sorted_gene_token_ids: (1000,) [39721 21437 21401 37295  4423  3916   455 17902  6378  4185]
num drugs: 379
smiles_dim: 768
‚úÖ Loaded pretrained gene token_emb: 62710 genes
‚úÖ Loaded pretrained cell_line_emb: (50, 256)
‚úÖ f_r model ready
[sanity] checked=20, bad=0, max_abs_diff=0.0


  scaler = GradScaler(enabled=(device.type == "cuda"))


Baseline Valid MSE (DMSO) = 20.909212
üöÄ f_r training start


[Epoch 1] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [2:22:11<00:00,  1.17it/s, mse=6.0742, rank=0.0000, Œª_rank=0]   


[Epoch 1] Train total=11.203143, mse=11.203143, rank=0.000000 (Œª_rank=0.0) | Valid mse=7.016501 | Baseline(DMSO) mse=20.909212


[Epoch 2] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [2:14:59<00:00,  1.23it/s, mse=4.0167, rank=0.0000, Œª_rank=0]   


[Epoch 2] Train total=4.970266, mse=4.970266, rank=0.000000 (Œª_rank=0.0) | Valid mse=4.112769 | Baseline(DMSO) mse=20.909212
üíæ saved checkpoint: /data/aiffel/babayakga/checkpoints/f_r_withcellline/fr_epoch2_20251224_234816.pt


[Epoch 3] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [2:11:19<00:00,  1.27it/s, mse=3.5937, rank=2.4922, Œª_rank=0.2]  


[Epoch 3] Train total=4.101466, mse=3.625382, rank=2.380414 (Œª_rank=0.2) | Valid mse=3.341106 | Baseline(DMSO) mse=20.909212


[Epoch 4] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [2:05:18<00:00,  1.33it/s, mse=3.9715, rank=2.3105, Œª_rank=0.2]  


[Epoch 4] Train total=3.505763, mse=3.028709, rank=2.385272 (Œª_rank=0.2) | Valid mse=2.866567 | Baseline(DMSO) mse=20.909212
üíæ saved checkpoint: /data/aiffel/babayakga/checkpoints/f_r_withcellline/fr_epoch4_20251225_042620.pt


[Epoch 5] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [2:04:22<00:00,  1.34it/s, mse=4.6376, rank=2.0918, Œª_rank=0.2]  


[Epoch 5] Train total=3.175252, mse=2.697655, rank=2.387986 (Œª_rank=0.2) | Valid mse=2.604213 | Baseline(DMSO) mse=20.909212


[Epoch 6] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [1:58:09<00:00,  1.41it/s, mse=2.4020, rank=2.6074, Œª_rank=0.2]  


[Epoch 6] Train total=2.787942, mse=2.308350, rank=2.397962 (Œª_rank=0.2) | Valid mse=2.180845 | Baseline(DMSO) mse=20.909212
üíæ saved checkpoint: /data/aiffel/babayakga/checkpoints/f_r_withcellline/fr_epoch6_20251225_085916.pt


[Epoch 7] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [2:33:13<00:00,  1.09it/s, mse=2.3121, rank=2.6582, Œª_rank=0.2]  


[Epoch 7] Train total=2.566124, mse=2.085309, rank=2.404073 (Œª_rank=0.2) | Valid mse=1.992714 | Baseline(DMSO) mse=20.909212


[Epoch 8] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [1:35:49<00:00,  1.74it/s, mse=2.1831, rank=2.8555, Œª_rank=0.2]  


[Epoch 8] Train total=2.440751, mse=1.959138, rank=2.408071 (Œª_rank=0.2) | Valid mse=2.055772 | Baseline(DMSO) mse=20.909212
üíæ saved checkpoint: /data/aiffel/babayakga/checkpoints/f_r_withcellline/fr_epoch8_20251225_132613.pt
‚úÖ DONE


In [3]:
import os, math
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Tuple, List, Any

import torch
import torch.nn.functional as F
from torch.amp import autocast

try:
    import pandas as pd
except Exception:
    pd = None


# -----------------------------
# Correlations
# -----------------------------
def pearson_corr(x: torch.Tensor, y: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    x = x - x.mean(dim=1, keepdim=True)
    y = y - y.mean(dim=1, keepdim=True)
    num = (x * y).sum(dim=1)
    den = torch.sqrt((x * x).sum(dim=1).clamp_min(eps)) * torch.sqrt((y * y).sum(dim=1).clamp_min(eps))
    return num / den.clamp_min(eps)

def _rankdata(x: torch.Tensor) -> torch.Tensor:
    order = torch.argsort(x, dim=1, descending=False)
    ranks = torch.empty_like(order, dtype=torch.float32)
    idx = torch.arange(x.size(1), device=x.device).view(1, -1).expand_as(order)
    ranks.scatter_(1, order, idx.to(torch.float32))
    return ranks

def spearman_corr(x: torch.Tensor, y: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    rx = _rankdata(x)
    ry = _rankdata(y)
    return pearson_corr(rx, ry, eps=eps)


# -----------------------------
# Top-|Œî| ranking metrics
# -----------------------------
def topk_precision_recall_ndcg(
    pred_scores: torch.Tensor,   # (B, K), e.g. |d_pred|
    true_scores: torch.Tensor,   # (B, K), e.g. |d_true|
    k: int,
    p_pos: int,
    eps: float = 1e-8,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    B, Kdim = pred_scores.shape
    k = min(k, Kdim)
    p_pos = min(p_pos, Kdim)

    gt_pos_idx = torch.topk(true_scores, k=p_pos, dim=1, largest=True).indices  # (B, p_pos)
    pred_topk_idx = torch.topk(pred_scores, k=k, dim=1, largest=True).indices  # (B, k)

    gt_mask = torch.zeros((B, Kdim), device=pred_scores.device, dtype=torch.bool)
    gt_mask.scatter_(1, gt_pos_idx, True)

    hits = gt_mask.gather(1, pred_topk_idx)  # (B, k)
    hit_count = hits.sum(dim=1).to(torch.float32)

    precision = hit_count / float(k)
    recall = hit_count / float(p_pos)

    pos = torch.arange(k, device=pred_scores.device, dtype=torch.float32)
    denom = torch.log2(pos + 2.0)
    dcg = (hits.to(torch.float32) / denom.view(1, -1)).sum(dim=1)

    ideal_ones = min(p_pos, k)
    idcg = (torch.ones((ideal_ones,), device=pred_scores.device, dtype=torch.float32) /
            torch.log2(torch.arange(ideal_ones, device=pred_scores.device, dtype=torch.float32) + 2.0)).sum()

    ndcg = dcg / idcg.clamp_min(eps)
    return precision, recall, ndcg

def sign_accuracy_on_top_pos(d_pred: torch.Tensor, d_true: torch.Tensor, p_pos: int = 30, eps: float = 1e-8) -> torch.Tensor:
    B, Kdim = d_true.shape
    p_pos = min(p_pos, Kdim)
    idx = torch.topk(d_true.abs(), k=p_pos, dim=1, largest=True).indices  # (B, p_pos)

    tp = d_true.gather(1, idx)
    pp = d_pred.gather(1, idx)

    valid = tp.abs() > eps
    match = (torch.sign(tp) == torch.sign(pp)) & valid
    denom = valid.sum(dim=1).clamp_min(1)
    return match.sum(dim=1).to(torch.float32) / denom.to(torch.float32)


# -----------------------------
# Config
# -----------------------------
@dataclass
class FREvalConfig:
    steps: Optional[int] = None
    amp: bool = True
    top_pos: int = 30
    eval_ks: Tuple[int, ...] = (10, 30, 50, 100)
    per_stratum: bool = True
    max_groups_report: int = 30
    min_group_size: int = 50


# -----------------------------
# Accumulator
# -----------------------------
class _MetricAccum:
    def __init__(self, eval_ks: Tuple[int, ...], top_pos: int):
        self.eval_ks = eval_ks
        self.top_pos = top_pos
        self.n = 0
        self.sum_mse = 0.0
        self.sum_mae = 0.0
        self.sum_cos = 0.0
        self.sum_pear = 0.0
        self.sum_spear = 0.0
        self.sum_sign = 0.0
        self.sum_prec = {k: 0.0 for k in eval_ks}
        self.sum_rec  = {k: 0.0 for k in eval_ks}
        self.sum_ndcg = {k: 0.0 for k in eval_ks}

    def add_batch(self, y_pred: torch.Tensor, y_true: torch.Tensor, base: torch.Tensor):
        B = y_true.size(0)

        mse_each = F.mse_loss(y_pred, y_true, reduction="none").mean(dim=1)
        mae_each = (y_pred - y_true).abs().mean(dim=1)

        d_true = y_true - base
        d_pred = y_pred - base

        cos_each  = F.cosine_similarity(d_pred, d_true, dim=1)
        pear_each = pearson_corr(d_pred, d_true)
        spear_each = spearman_corr(d_pred, d_true)

        abs_dt = d_true.abs()
        abs_dp = d_pred.abs()

        sign_each = sign_accuracy_on_top_pos(d_pred, d_true, p_pos=self.top_pos)

        self.n += B
        self.sum_mse += float(mse_each.sum().item())
        self.sum_mae += float(mae_each.sum().item())
        self.sum_cos += float(cos_each.sum().item())
        self.sum_pear += float(pear_each.sum().item())
        self.sum_spear += float(spear_each.sum().item())
        self.sum_sign += float(sign_each.sum().item())

        for k in self.eval_ks:
            p, r, nd = topk_precision_recall_ndcg(abs_dp, abs_dt, k=k, p_pos=self.top_pos)
            self.sum_prec[k] += float(p.sum().item())
            self.sum_rec[k]  += float(r.sum().item())
            self.sum_ndcg[k] += float(nd.sum().item())

    def to_dict(self) -> Dict[str, float]:
        n = max(1, self.n)
        out = {
            "n_samples": int(self.n),
            "mse": self.sum_mse / n,
            "rmse": math.sqrt(self.sum_mse / n),
            "mae": self.sum_mae / n,
            "cosine_d": self.sum_cos / n,
            "pearson_d": self.sum_pear / n,
            "spearman_d": self.sum_spear / n,
            f"signacc_top{self.top_pos}": self.sum_sign / n,
        }
        for k in self.eval_ks:
            out[f"precision@{k}"] = self.sum_prec[k] / n
            out[f"recall@{k}"]    = self.sum_rec[k] / n
            out[f"ndcg@{k}"]      = self.sum_ndcg[k] / n
        return out


def _pretty_line(name: str, m: Dict[str, float], cfg: FREvalConfig) -> str:
    parts = [
        f"{name:>18}",
        f"n={m['n_samples']}",
        f"RMSE={m['rmse']:.4f}",
        f"MAE={m['mae']:.4f}",
        f"Cos(d)={m['cosine_d']:.4f}",
        f"Pear(d)={m['pearson_d']:.4f}",
        f"Spear(d)={m['spearman_d']:.4f}",
        f"Sign@{cfg.top_pos}={m[f'signacc_top{cfg.top_pos}']:.4f}",
    ]
    for k in cfg.eval_ks:
        parts += [f"R@{k}={m[f'recall@{k}']:.4f}", f"NDCG@{k}={m[f'ndcg@{k}']:.4f}"]
    return " | ".join(parts)


def _group_report(group_name: str, group_dict: Dict[int, _MetricAccum], cfg: FREvalConfig, sort_by: str = "n_samples"):
    rows = []
    for gid, acc in group_dict.items():
        d = acc.to_dict()
        d[group_name] = int(gid)
        rows.append(d)

    rows_print = [r for r in rows if r["n_samples"] >= cfg.min_group_size]
    if len(rows_print) > 0 and sort_by in rows_print[0]:
        rows_print.sort(key=lambda r: r[sort_by], reverse=True)
    else:
        rows_print.sort(key=lambda r: r["n_samples"], reverse=True)

    return rows_print[: cfg.max_groups_report], rows


@torch.no_grad()
def eval_fr_with_strata(fr_model, loader: Iterable, baseline_vec: torch.Tensor, device: torch.device, cfg: FREvalConfig):
    fr_model.eval()
    baseline_vec = baseline_vec.to(device=device)

    global_acc = _MetricAccum(cfg.eval_ks, cfg.top_pos)
    cell_acc: Dict[int, _MetricAccum] = {}
    drug_acc: Dict[int, _MetricAccum] = {}

    for step, batch in enumerate(loader):
        if cfg.steps is not None and step >= cfg.steps:
            break

        input_ids, values, mask, y_true, cell_id, drug_id, smiles = batch

        input_ids = input_ids.to(device, non_blocking=True)
        values    = values.to(device, non_blocking=True)
        mask      = mask.to(device, non_blocking=True)
        y_true    = y_true.to(device, non_blocking=True)
        cell_id   = cell_id.to(device, non_blocking=True)
        drug_id   = drug_id.to(device, non_blocking=True)
        smiles    = smiles.to(device, non_blocking=True)

        B, Kdim = y_true.shape
        base = baseline_vec.view(1, Kdim).expand(B, Kdim)

        with autocast(device_type=device.type, enabled=(cfg.amp and device.type == "cuda")):
            # ‚úÖ –Ω–æ–≤–∞—è –º–æ–¥–µ–ª—å –∂–¥—ë—Ç cell_line_id –∏ smiles_emb -> —Å–æ–≤–ø–∞–¥–∞–µ—Ç
            y_pred = fr_model(input_ids, values, mask, cell_id, smiles)

        global_acc.add_batch(y_pred, y_true, base)

        if cfg.per_stratum:
            # per cell
            for cid in torch.unique(cell_id).tolist():
                cid = int(cid)
                m = (cell_id == cid)
                if cid not in cell_acc:
                    cell_acc[cid] = _MetricAccum(cfg.eval_ks, cfg.top_pos)
                cell_acc[cid].add_batch(y_pred[m], y_true[m], base[m])

            # per drug
            for did in torch.unique(drug_id).tolist():
                did = int(did)
                m = (drug_id == did)
                if did not in drug_acc:
                    drug_acc[did] = _MetricAccum(cfg.eval_ks, cfg.top_pos)
                drug_acc[did].add_batch(y_pred[m], y_true[m], base[m])

    result = {"global": global_acc.to_dict(), "per_cell": None, "per_drug": None}
    if cfg.per_stratum:
        cell_top, cell_all = _group_report("cell_id", cell_acc, cfg, sort_by="n_samples")
        drug_top, drug_all = _group_report("drug_id", drug_acc, cfg, sort_by="n_samples")
        result["per_cell"] = {"top_report": cell_top, "all_rows": cell_all}
        result["per_drug"] = {"top_report": drug_top, "all_rows": drug_all}
    return result


@torch.no_grad()
def eval_baseline_dmso_with_strata(loader: Iterable, baseline_vec: torch.Tensor, device: torch.device, cfg: FREvalConfig):
    baseline_vec = baseline_vec.to(device=device)

    global_acc = _MetricAccum(cfg.eval_ks, cfg.top_pos)
    cell_acc: Dict[int, _MetricAccum] = {}
    drug_acc: Dict[int, _MetricAccum] = {}

    for step, batch in enumerate(loader):
        if cfg.steps is not None and step >= cfg.steps:
            break

        _, _, _, y_true, cell_id, drug_id, _ = batch
        y_true  = y_true.to(device, non_blocking=True)
        cell_id = cell_id.to(device, non_blocking=True)
        drug_id = drug_id.to(device, non_blocking=True)

        B, Kdim = y_true.shape
        base = baseline_vec.view(1, Kdim).expand(B, Kdim)
        y_pred = base

        global_acc.add_batch(y_pred, y_true, base)

        if cfg.per_stratum:
            for cid in torch.unique(cell_id).tolist():
                cid = int(cid)
                m = (cell_id == cid)
                if cid not in cell_acc:
                    cell_acc[cid] = _MetricAccum(cfg.eval_ks, cfg.top_pos)
                cell_acc[cid].add_batch(y_pred[m], y_true[m], base[m])

            for did in torch.unique(drug_id).tolist():
                did = int(did)
                m = (drug_id == did)
                if did not in drug_acc:
                    drug_acc[did] = _MetricAccum(cfg.eval_ks, cfg.top_pos)
                drug_acc[did].add_batch(y_pred[m], y_true[m], base[m])

    result = {"global": global_acc.to_dict(), "per_cell": None, "per_drug": None}
    if cfg.per_stratum:
        cell_top, cell_all = _group_report("cell_id", cell_acc, cfg, sort_by="n_samples")
        drug_top, drug_all = _group_report("drug_id", drug_acc, cfg, sort_by="n_samples")
        result["per_cell"] = {"top_report": cell_top, "all_rows": cell_all}
        result["per_drug"] = {"top_report": drug_top, "all_rows": drug_all}
    return result


# ============================
# RUN
# ============================
eval_cfg = FREvalConfig(
    steps=VAL_STEPS,     # –∏–ª–∏ None
    amp=True,
    top_pos=30,
    eval_ks=(10, 30, 50, 100),
    per_stratum=True,
    max_groups_report=20,
    min_group_size=50,
)

base_res = eval_baseline_dmso_with_strata(val_loader, baseline_vec, device, eval_cfg)
print("=== GLOBAL BASELINE ===")
print(_pretty_line("Baseline(DMSO)", base_res["global"], eval_cfg))

fr_res = eval_fr_with_strata(fr_model, val_loader, baseline_vec, device, eval_cfg)
print("=== GLOBAL f_r ===")
print(_pretty_line("f_r", fr_res["global"], eval_cfg))

print("\n=== TOP CELL LINES (by n) ‚Äî f_r ===")
for row in fr_res["per_cell"]["top_report"]:
    print(f"cell_id={row['cell_id']:>6} n={row['n_samples']:>6} RMSE={row['rmse']:.4f} Cos={row['cosine_d']:.4f} R@30={row.get('recall@30', float('nan')):.4f}")

print("\n=== TOP DRUGS (by n) ‚Äî f_r ===")
for row in fr_res["per_drug"]["top_report"]:
    print(f"drug_id={row['drug_id']:>6} n={row['n_samples']:>6} RMSE={row['rmse']:.4f} Cos={row['cosine_d']:.4f} R@30={row.get('recall@30', float('nan')):.4f}")

# optional save
if pd is not None:
    out_dir = "/data/aiffel/babayakga/eval_outputs/fr_withcell"
    os.makedirs(out_dir, exist_ok=True)

    df_cell = pd.DataFrame(fr_res["per_cell"]["all_rows"])
    df_drug = pd.DataFrame(fr_res["per_drug"]["all_rows"])

    df_cell.to_csv(os.path.join(out_dir, "per_cell_metrics.csv"), index=False)
    df_drug.to_csv(os.path.join(out_dir, "per_drug_metrics.csv"), index=False)
    print(f"‚úÖ saved CSVs to: {out_dir}")

=== GLOBAL BASELINE ===
    Baseline(DMSO) | n=14400 | RMSE=4.5727 | MAE=1.0536 | Cos(d)=0.0000 | Pear(d)=0.0000 | Spear(d)=0.1993 | Sign@30=0.0000 | R@10=0.2143 | NDCG@10=0.6216 | R@30=0.3374 | NDCG@30=0.2937 | R@50=0.3945 | NDCG@50=0.4885 | R@100=0.4989 | NDCG@100=0.5440
=== GLOBAL f_r ===
               f_r | n=14400 | RMSE=1.4338 | MAE=0.7684 | Cos(d)=0.8992 | Pear(d)=0.8972 | Spear(d)=0.3740 | Sign@30=0.8374 | R@10=0.2097 | NDCG@10=0.7251 | R@30=0.3539 | NDCG@30=0.4712 | R@50=0.4326 | NDCG@50=0.5195 | R@100=0.5443 | NDCG@100=0.5789

=== TOP CELL LINES (by n) ‚Äî f_r ===
cell_id=    21 n=   608 RMSE=0.9280 Cos=0.9205 R@30=0.3249
cell_id=    22 n=   576 RMSE=1.2060 Cos=0.8987 R@30=0.2634
cell_id=    28 n=   560 RMSE=1.4188 Cos=0.8706 R@30=0.3307
cell_id=    19 n=   560 RMSE=1.3351 Cos=0.8901 R@30=0.3883
cell_id=    31 n=   560 RMSE=1.3043 Cos=0.9092 R@30=0.3905
cell_id=    46 n=   544 RMSE=1.3893 Cos=0.9046 R@30=0.3049
cell_id=    13 n=   528 RMSE=1.3350 Cos=0.9050 R@30=0.3628
cell_

In [4]:
print("üöÄ f_r training start")
TOTAL_EPOCHS=6
for epoch in range(1, TOTAL_EPOCHS + 1):
    lambda_rank = 0.0 if epoch <= WARMUP_EPOCHS else lambda_rank_main

    fr_model.train()
    run_mse = 0.0
    run_rank = 0.0
    run_total = 0.0
    n = 0

    pbar = tqdm(
        islice(train_loader, STEPS_PER_EPOCH),
        total=STEPS_PER_EPOCH,
        desc=f"[Epoch {epoch}] Train",
        leave=True,
        dynamic_ncols=True
    )

    for batch in pbar:
        input_ids, values, mask, y_true, cell_id, drug_id, smiles = batch

        input_ids = input_ids.to(device, non_blocking=True)
        values    = values.to(device, non_blocking=True)
        mask      = mask.to(device, non_blocking=True)
        y_true    = y_true.to(device, non_blocking=True)
        cell_id   = cell_id.to(device, non_blocking=True)
        smiles    = smiles.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", enabled=(device.type == "cuda")):
            y_pred = fr_model(input_ids, values, mask, cell_id, smiles)
            loss_m = mse_loss(y_pred, y_true)

            if lambda_rank > 0:
                loss_r = expr_ranking_loss(
                    y_pred, y_true, baseline_vec,
                    top_pos=30, num_neg=80, margin=0.0
                )
            else:
                loss_r = torch.tensor(0.0, device=device)

            loss = loss_m + lambda_rank * loss_r

        if not torch.isfinite(loss):
            continue

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(fr_model.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()

        bs = y_true.size(0)
        run_mse   += loss_m.item() * bs
        run_rank  += loss_r.item() * bs
        run_total += loss.item() * bs
        n += bs

        pbar.set_postfix({
            "mse": f"{loss_m.item():.4f}",
            "rank": f"{loss_r.item():.4f}",
            "Œª_rank": float(lambda_rank),
        })

    train_mse   = run_mse   / max(1, n)
    train_rank  = run_rank  / max(1, n)
    train_total = run_total / max(1, n)

    val_mse = eval_mse(fr_model, val_loader, steps=VAL_STEPS, device=device)

    print(
        f"[Epoch {epoch}] "
        f"Train total={train_total:.6f}, mse={train_mse:.6f}, rank={train_rank:.6f} (Œª_rank={lambda_rank}) | "
        f"Valid mse={val_mse:.6f} | Baseline(DMSO) mse={base_mse:.6f}"
    )

    if (epoch % SAVE_EVERY == 0) or (epoch == TOTAL_EPOCHS):
        save_fr_checkpoint(
            save_dir=CKPT_DIR,
            fr_model=fr_model,
            optimizer=optimizer,
            scaler=scaler,
            epoch=epoch,
            metrics={
                "train_total": float(train_total),
                "train_mse": float(train_mse),
                "train_rank": float(train_rank),
                "val_mse": float(val_mse),
                "baseline_mse": float(base_mse),
                "lambda_rank": float(lambda_rank),
            },
            extra={
                "TOP_K": int(baseline_vec.numel()),
                "STEPS_PER_EPOCH": int(STEPS_PER_EPOCH),
                "VAL_STEPS": int(VAL_STEPS),
                "WARMUP_EPOCHS": int(WARMUP_EPOCHS),
                "lambda_rank_main": float(lambda_rank_main),
                "baseline_vec": baseline_vec.detach().float().cpu(),
                "CELL2ID_CSV": CELL2ID_CSV,
                "CELL_EMB_NPY": CELL_EMB_NPY,
                "sorted_gene_token_ids": sorted_gene_token_ids.astype(np.int64)
            },
            prefix="fr",
        )

print("‚úÖ DONE")

üöÄ f_r training start


[Epoch 1] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [3:01:43<00:00,  1.09s/it, mse=2.3042, rank=0.0000, Œª_rank=0]  


[Epoch 1] Train total=1.847853, mse=1.847853, rank=0.000000 (Œª_rank=0.0) | Valid mse=1.970388 | Baseline(DMSO) mse=20.909212


[Epoch 2] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [2:52:08<00:00,  1.03s/it, mse=2.1046, rank=0.0000, Œª_rank=0]  


[Epoch 2] Train total=1.791824, mse=1.791824, rank=0.000000 (Œª_rank=0.0) | Valid mse=1.850703 | Baseline(DMSO) mse=20.909212
üíæ saved checkpoint: /data/aiffel/babayakga/checkpoints/f_r_withcellline/fr_epoch2_20251226_185142.pt


[Epoch 3] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [2:13:33<00:00,  1.25it/s, mse=2.1844, rank=2.8223, Œª_rank=0.2]  


[Epoch 3] Train total=2.278205, mse=1.793413, rank=2.423961 (Œª_rank=0.2) | Valid mse=1.829237 | Baseline(DMSO) mse=20.909212


[Epoch 4] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [2:40:48<00:00,  1.04it/s, mse=2.1488, rank=2.6523, Œª_rank=0.2]  


[Epoch 4] Train total=2.230828, mse=1.747259, rank=2.417844 (Œª_rank=0.2) | Valid mse=1.768221 | Baseline(DMSO) mse=20.909212
üíæ saved checkpoint: /data/aiffel/babayakga/checkpoints/f_r_withcellline/fr_epoch4_20251227_001233.pt


[Epoch 5] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [2:17:48<00:00,  1.21it/s, mse=2.1317, rank=2.6191, Œª_rank=0.2]  


[Epoch 5] Train total=2.206838, mse=1.723612, rank=2.416120 (Œª_rank=0.2) | Valid mse=1.716062 | Baseline(DMSO) mse=20.909212


[Epoch 6] Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [2:28:52<00:00,  1.12it/s, mse=2.2265, rank=2.4688, Œª_rank=0.2]  


[Epoch 6] Train total=2.179475, mse=1.696342, rank=2.415661 (Œª_rank=0.2) | Valid mse=1.749337 | Baseline(DMSO) mse=20.909212
üíæ saved checkpoint: /data/aiffel/babayakga/checkpoints/f_r_withcellline/fr_epoch6_20251227_052053.pt
‚úÖ DONE
