In [50]:
LABELS_PATH = "/kaggle/input/dataset-1/data.info.labelled"   # columns: gene_id,transcript_id,transcript_position,label
JSON_PATH   = "/kaggle/input/dataset-1/dataset0.json"    # line-delimited JSON (see assignment format)
OUTDIR      = "/kaggle/working"

In [51]:
import os, json, math, random
from collections import defaultdict

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split

from collections import defaultdict
from scipy.stats import entropy

In [52]:
def set_seed(seed: int = 4262):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def ensure_dir(p): os.makedirs(p, exist_ok=True)

In [53]:
def load_labels(labels_csv):
    df = pd.read_csv(labels_csv)
    need = ["gene_id","transcript_id","transcript_position","label"]
    miss = [c for c in need if c not in df.columns]
    if miss: raise ValueError(f"Missing cols in labels: {miss}")
    df["gene_id"] = df["gene_id"].astype(str)
    df["transcript_id"] = df["transcript_id"].astype(str)
    df["transcript_position"] = df["transcript_position"].astype(str)
    df["label"] = df["label"].astype(int)
    return df

def build_label_index(df_labels):
    idx = {}
    for r in df_labels.itertuples(index=False):
        idx[(r.transcript_id, r.transcript_position)] = int(r.label)
    return idx

In [54]:
def parse_data_json(json_path, label_keys):
    """
    bags[(transcript_id, position)] = np.array(n_inst, 9)
    Only keep keys in label_keys to save time/memory.
    """
    bags = defaultdict(list)
    with open(json_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip(): continue
            obj = json.loads(line)
            for tid, pos_dict in obj.items():
                if not isinstance(pos_dict, dict): continue
                for pos, kmer_dict in pos_dict.items():
                    key = (str(tid), str(pos))
                    if key not in label_keys: 
                        continue
                    if not isinstance(kmer_dict, dict): continue
                    for _kmer, reads in kmer_dict.items():
                        if not isinstance(reads, list): continue
                        for vec in reads:
                            v = np.asarray(vec, dtype=float)
                            if v.shape[0] == 9:
                                bags[key].append(v)
    # stack
    to_del = []
    for k, lst in bags.items():
        if len(lst)==0: to_del.append(k)
        else: bags[k] = np.stack(lst, axis=0)
    for k in to_del: del bags[k]
    return bags

In [55]:
### Preprocess data
# 1) 统计函数：对一个 (num_reads, 9) 的矩阵，返回每维的统计量
def stats_for_matrix(X):
    # X: (R, D=9)
    # 返回 shape (D, 8): [mean, std, median, q25, q75, min, max, iqr]
    if X.ndim != 2: X = np.asarray(X).reshape(-1, X.shape[-1])
    q25 = np.percentile(X, 25, axis=0)
    q75 = np.percentile(X, 75, axis=0)
    iqr = q75 - q25
    out = np.stack([
        np.mean(X, axis=0),
        np.std(X, axis=0, ddof=0),
        np.median(X, axis=0),
        q25,
        q75,
        np.min(X, axis=0),
        np.max(X, axis=0),
        iqr
    ], axis=1)  # (D,8)
    return out  # (9,8)

# 2) k-mer 序列特征：GC、熵、ACGT计数
def extract_kmer_context_features_one(kmer: str):
    L = max(1, len(kmer))
    A = kmer.count('A'); C = kmer.count('C'); G = kmer.count('G'); T = kmer.count('T')
    gc = (G + C) / L
    p = np.array([A, C, G, T], dtype=float) / L
    seq_entropy = entropy(p + 1e-12)  # 避免log(0)
    return np.array([gc, seq_entropy, A, C, G, T], dtype=float)  # 6 维

# 3) 将原始 json 结构聚合到 k-mer 级实例
# 输入：json_line_parsed dict 形如 {tid: {pos: {kmer: [[...9], ...], ...}}}
# 输出：bags_dict: (tid,pos)-> np.array(num_kmers, F)，其中 F = 9*8(统计) + 1(n_reads) + 6(kmer序列特征) = 79
def build_kmer_level_bags(json_path, label_keys_set):
    bags_kmer = {}
    with open(json_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip(): continue
            obj = json.loads(line)
            for tid, pos_dict in obj.items():
                for pos, kmer_dict in pos_dict.items():
                    key = (str(tid), str(pos))
                    if key not in label_keys_set:  # 只保留有标签的 bag
                        continue
                    row_list = []
                    for kmer, reads in kmer_dict.items():
                        X = np.asarray(reads, dtype=float)
                        if X.size == 0 or X.ndim != 2 or X.shape[1] != 9:
                            continue
                        S = stats_for_matrix(X)             # (9,8)
                        S_flat = S.reshape(-1)              # 9*8 = 72
                        n_reads = np.array([X.shape[0]], dtype=float)
                        kfeat = extract_kmer_context_features_one(kmer)  # 6
                        row = np.concatenate([S_flat, n_reads, kfeat], axis=0)  # 72+1+6=79
                        row_list.append(row)
                    if len(row_list) > 0:
                        bags_kmer[key] = np.vstack(row_list)  # (num_kmers, 79)
    return bags_kmer

# 4) 对新特征做数值变换（log1p + 标准化）——仅在训练集拟合

def fit_scaler_on_train_kmer(keys_train, bags_dict, log1p_indices=None):
    Xs = [bags_dict[k] for k in keys_train if k in bags_dict]
    X = np.concatenate(Xs, axis=0) if len(Xs)>0 else np.empty((0, bags_dict[next(iter(bags_dict))].shape[1]))
    Z = X.copy()
    if log1p_indices is not None and len(log1p_indices)>0 and Z.shape[0]>0:
        Z[:, log1p_indices] = np.log1p(np.clip(Z[:, log1p_indices], a_min=0, a_max=None))
    scaler = StandardScaler().fit(Z) if Z.shape[0]>0 else StandardScaler()
    return scaler

def transform_all_bags_kmer(bags_dict, scaler, log1p_indices=None):
    out = {}
    for k, X in bags_dict.items():
        Z = X.copy()
        if log1p_indices is not None and len(log1p_indices)>0 and Z.size>0:
            Z[:, log1p_indices] = np.log1p(np.clip(Z[:, log1p_indices], a_min=0, a_max=None))
        out[k] = scaler.transform(Z) if Z.shape[0]>0 else Z
    return out

In [56]:
def grouped_stratified_split(df, test_ratio, valid_ratio, seed, bags_raw):
    """
    Keep all bags of the same gene_id in the same split.
    Stratify at gene-level by whether a gene has ANY positive bag.
    """
    g = df.groupby("gene_id")["label"]
    gene_has_pos = (g.max() > 0).astype(int)
    genes = gene_has_pos.index.tolist(); strata = gene_has_pos.values

    # first split: train vs (valid+test)
    genes_train, genes_tmp, y_train_g, y_tmp_g = train_test_split(
        genes, strata, test_size=(valid_ratio+test_ratio), stratify=strata, random_state=seed
    )
    # second: valid vs test
    valid_size = valid_ratio / (valid_ratio + test_ratio) if (valid_ratio+test_ratio) > 0 else 0.5
    genes_valid, genes_test, y_valid_g, y_test_g = train_test_split(
        genes_tmp, y_tmp_g, test_size=1-valid_size, stratify=y_tmp_g, random_state=seed
    )

    def genes_to_keys(G):
        sub = df[df["gene_id"].isin(G)]
        keys = [
            (str(r.transcript_id), str(r.transcript_position))
            for r in sub.itertuples(index=False)
            if (str(r.transcript_id), str(r.transcript_position)) in bags_raw
        ]
        return sorted(list(set(keys)))

    kt, kv, kx = genes_to_keys(genes_train), genes_to_keys(genes_valid), genes_to_keys(genes_test)

    # fallback if any empty
    if len(kt)==0 or len(kv)==0 or len(kx)==0:
        rng = np.random.RandomState(seed)
        rng.shuffle(genes); n = len(genes)
        n_test = max(1, int(round(n*test_ratio)))
        n_valid = max(1, int(round(n*valid_ratio)))
        n_train = max(1, n-n_test-n_valid)
        Gt, Gv, Gx = genes[:n_train], genes[n_train:n_train+n_valid], genes[n_train+n_valid:]
        kt, kv, kx = genes_to_keys(Gt), genes_to_keys(Gv), genes_to_keys(Gx)

    return kt, kv, kx

In [57]:
class Cfg:
    seed=42
    batch_size=32; epochs=15
    lr=1e-3; weight_decay=1e-4
    dropout=0.1; enc_hid=64; enc_out=128; att_dim=128
    valid_ratio=0.15; test_ratio=0.15
    num_workers=2
CFG = Cfg()

set_seed(CFG.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ensure_dir(OUTDIR)

print(">> Loading labels...")
df_labels = load_labels(LABELS_PATH)
label_index_all = build_label_index(df_labels)

print(">> Parsing data.json (only labeled keys)...")
bags_raw = parse_data_json(JSON_PATH, set(label_index_all.keys()))

df_labels = df_labels[df_labels.apply(lambda r: (r.transcript_id, r.transcript_position) in bags_raw, axis=1)].reset_index(drop=True)

>> Loading labels...
>> Parsing data.json (only labeled keys)...


In [58]:
def grouped_split(df, test_ratio=0.15, valid_ratio=0.15, seed=42):
    g = df.groupby("gene_id")["label"]
    gene_has_pos = (g.max() > 0).astype(int)
    genes = gene_has_pos.index.tolist()
    strata = gene_has_pos.values
    gt, gtmp, yt, ytmp = train_test_split(genes, strata, test_size=(test_ratio+valid_ratio),
                                          stratify=strata, random_state=seed)
    vsize = valid_ratio / (test_ratio+valid_ratio) if (test_ratio+valid_ratio)>0 else 0.5
    gv, gx, yv, yx = train_test_split(gtmp, ytmp, test_size=1-vsize,
                                      stratify=ytmp, random_state=seed)
    def to_keys(G):
        sub = df[df["gene_id"].isin(G)]
        ks = []
        for r in sub.itertuples(index=False):
            k = (str(getattr(r,"transcript_id")), str(getattr(r,"transcript_position")))
            if k in bags_raw: ks.append(k)
        return sorted(list(set(ks)))
    kt, kv, kx = to_keys(gt), to_keys(gv), to_keys(gx)
    if len(kt)==0 or len(kv)==0 or len(kx)==0:
        # fallback
        rng = np.random.RandomState(seed); rng.shuffle(genes)
        n=len(genes); nte=max(1,int(round(n*test_ratio))); nva=max(1,int(round(n*valid_ratio)))
        gt, gv, gx = genes[:n-nte-nva], genes[n-nte-nva:n-nte], genes[n-nte:]
        kt, kv, kx = to_keys(gt), to_keys(gv), to_keys(gx)
    return kt, kv, kx


In [59]:
from sklearn.feature_selection import RFE
from sklearn.linear_model import LogisticRegression

def rfe_select_columns(keys_train, bags_dict, label_index,
                       num_features_keep=40, random_state=42,
                       max_iter=2000,  # ↑ 提高上限
                       C=0.5,          # ↑ 稍强正则（原默认1.0）
                       solver="lbfgs"  # 二分类可先用 lbfgs；不行再试 "liblinear"
                       ):
    X_list, y_list = [], []
    for k in keys_train:
        if k not in bags_dict: 
            continue
        Xk = bags_dict[k]
        yk = np.full((Xk.shape[0],), label_index[k], dtype=int)
        X_list.append(Xk); y_list.append(yk)
    if len(X_list) == 0:
        return np.arange(bags_dict[next(iter(bags_dict))].shape[1])

    X = np.vstack(X_list); y = np.concatenate(y_list)

    base = LogisticRegression(
        penalty="l2",
        solver=solver,
        max_iter=max_iter,
        random_state=random_state,
        class_weight="balanced",   # 处理不平衡
        C=C,
        tol=1e-4
    )

    # 用整数步长降低拟合次数（更稳更快）
    step = max(1, int(0.1 * X.shape[1]))
    rfe = RFE(base, n_features_to_select=num_features_keep, step=step)
    rfe.fit(X, y)
    return np.where(rfe.support_)[0]

def apply_column_selection(bags_dict, cols_keep):
    out = {}
    for k, X in bags_dict.items():
        out[k] = X[:, cols_keep] if X.shape[1] >= np.max(cols_keep)+1 else X
    return out

label_index = {(str(r.transcript_id), str(r.transcript_position)): int(r.label) for r in df_labels.itertuples(index=False)}
label_keys = set(label_index.keys())

print(">> Building k-mer level bags (this is fast)...")
bags_kmer_raw = build_kmer_level_bags(JSON_PATH, label_keys)  

df_labels = df_labels[df_labels.apply(lambda r: (str(r.transcript_id), str(r.transcript_position)) in bags_kmer_raw, axis=1)].reset_index(drop=True)
label_index = {(str(r.transcript_id), str(r.transcript_position)): int(r.label) for r in df_labels.itertuples(index=False)}

keys_train, keys_valid, keys_test = grouped_split(df_labels, test_ratio=0.15, valid_ratio=0.15, seed=4262)
print(f"Split(k-mer bags): train={len(keys_train)}, valid={len(keys_valid)}, test={len(keys_test)}")
F = bags_kmer_raw[next(iter(bags_kmer_raw))].shape[1]  
idx_all = np.arange(F)


log1p_cols = list(range(72)) + [72, 76, 77, 78, 79] if F >= 80 else list(range(min(73, F)))

print(">> Fitting scaler on TRAIN (k-mer instances)...")
scaler_kmer = fit_scaler_on_train_kmer(keys_train, bags_kmer_raw, log1p_indices=log1p_cols)
bags_kmer = transform_all_bags_kmer(bags_kmer_raw, scaler_kmer, log1p_indices=log1p_cols)


USE_RFE = True
TOP_K = 40
if USE_RFE:
    print(f">> Running RFE to select top {TOP_K} features (instance-level, proxy labels)...")
    cols_keep = rfe_select_columns(keys_train, bags_kmer, label_index, num_features_keep=TOP_K, random_state=42)
    print(" kept cols:", cols_keep.shape[0])
    bags_kmer = apply_column_selection(bags_kmer, cols_keep)
    IN_DIM = cols_keep.shape[0]
else:
    IN_DIM = F

print("Final instance feature dim:", IN_DIM)


>> Building k-mer level bags (this is fast)...
Split(k-mer bags): train=85011, valid=17874, test=18953
>> Fitting scaler on TRAIN (k-mer instances)...
>> Running RFE to select top 40 features (instance-level, proxy labels)...
 kept cols: 40
Final instance feature dim: 40


In [60]:
class BagsDataset(Dataset):
    def __init__(self, keys, bags, label_index, training=False, seed=42):
        self.keys = keys; self.bags = bags; self.lbl = label_index
        self.training = training; self.rng = np.random.RandomState(seed)
    def __len__(self): return len(self.keys)
    def __getitem__(self, i):
        k = self.keys[i]
        X = torch.tensor(self.bags[k], dtype=torch.float32)  # (L,F) — k-mer 实例矩阵
        y = torch.tensor(self.lbl[k], dtype=torch.float32)   # bag label
        return k, X, torch.ones(X.shape[0], dtype=torch.bool), y  # mask 全 True

def pad_collate(batch):
    keys, Xs, Ms, ys = zip(*batch)
    lens = [x.shape[0] for x in Xs]
    B, Lmax, F = len(Xs), max(lens), Xs[0].shape[1]
    X = torch.zeros(B, Lmax, F, dtype=torch.float32)
    M = torch.zeros(B, Lmax, dtype=torch.bool)
    y = torch.stack(ys)
    for i, (x, m) in enumerate(zip(Xs, Ms)):
        L = x.shape[0]; X[i, :L] = x; M[i, :L] = m
    return list(keys), X, M, y

In [61]:
### Model structure
class EncoderMLP(nn.Module):
    def __init__(self, in_dim, hid=64, out_dim=128, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid), nn.ReLU(True), nn.LayerNorm(hid), nn.Dropout(dropout),
            nn.Linear(hid, out_dim), nn.ReLU(True), nn.LayerNorm(out_dim),
        )
    def forward(self, x): return self.net(x)

class GatedAttentionMIL(nn.Module):
    def __init__(self, in_dim, att_dim=128):
        super().__init__()
        self.V = nn.Linear(in_dim, att_dim, bias=True)
        self.U = nn.Linear(in_dim, att_dim, bias=True)
        self.w = nn.Linear(att_dim, 1, bias=False)
    def forward(self, H, mask):
        Vh = torch.tanh(self.V(H)); Uh = torch.sigmoid(self.U(H))
        logits = self.w(Vh * Uh).squeeze(-1)                # (B,L)
        logits = logits.masked_fill(~mask, float("-inf"))
        att = torch.softmax(logits, dim=1)                  # (B,L)
        z = torch.bmm(att.unsqueeze(1), H).squeeze(1)       # (B,D)
        return z, att

class ABMILModel(nn.Module):
    def __init__(self, in_dim, enc_hid=64, enc_out=128, att_dim=128, dropout=0.1):
        super().__init__()
        self.encoder = EncoderMLP(in_dim, enc_hid, enc_out, dropout)
        self.att = GatedAttentionMIL(enc_out, att_dim)
        self.cls = nn.Linear(enc_out, 1)
    def forward(self, X, mask):
        H = self.encoder(X)             # (B,L,D)
        z, att = self.att(H, mask)      # (B,D), (B,L)
        logit = self.cls(z).squeeze(-1) # (B,)
        return logit, att

In [62]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, pos_weight=None, reduction="mean"):
        super().__init__()
        self.gamma = gamma; self.alpha = alpha
        self.pos_weight = pos_weight; self.reduction = reduction
    def forward(self, logits, targets):
        bce = nn.functional.binary_cross_entropy_with_logits(
            logits, targets, reduction="none", pos_weight=self.pos_weight)
        p = torch.sigmoid(logits)
        p_t = p*targets + (1-p)*(1-targets)
        if self.alpha is None:
            if self.pos_weight is not None and isinstance(self.pos_weight, torch.Tensor):
                alpha_pos = 1.0/(1.0 + self.pos_weight.item())
                alpha = alpha_pos*targets + (1-alpha_pos)*(1-targets)
            else:
                alpha = 1.0
        else:
            alpha = self.alpha*targets + (1-self.alpha)*(1-targets)
        loss = alpha * (1 - p_t).pow(self.gamma) * bce
        return loss.mean() if self.reduction=="mean" else loss.sum()

def compute_pos_weight(y_arr):
    y = np.asarray(y_arr, dtype=int); pos = (y==1).sum(); neg = (y==0).sum()
    w = 1.0 if pos==0 else max(1.0, neg/max(1,pos))
    return torch.tensor(w, dtype=torch.float32)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval(); logits_all=[]; labels_all=[]; keys_all=[]
    for keys, X, M, y in loader:
        X, M, y = X.to(device), M.to(device), y.to(device)
        logits, _ = model(X, M)
        logits_all.append(logits.cpu()); labels_all.append(y.cpu()); keys_all += keys
    logits = torch.cat(logits_all).numpy()
    labels = torch.cat(labels_all).numpy().astype(int)
    probs = 1/(1+np.exp(-logits))
    preds = (probs>=0.5).astype(int)
    try: auc = roc_auc_score(labels, probs)
    except: auc = float("nan")
    try: pr  = average_precision_score(labels, probs)
    except: pr = float("nan")
    acc = accuracy_score(labels, preds)
    cm  = confusion_matrix(labels, preds, labels=[0,1])
    return {"ROC_AUC":auc, "PR_AUC":pr, "ACC":acc, "CM":cm, "keys":keys_all, "probs":probs, "labels":labels}


In [72]:
BATCH = 64; EPOCHS = 15; LR = 1e-3; WD = 1e-4
NUM_WORKERS = 2; USE_AMP = (device.type == "cuda")  # 可改 False

ds_train = BagsDataset(keys_train, bags_kmer, label_index, training=True, seed=42)
ds_valid = BagsDataset(keys_valid, bags_kmer, label_index, training=False)
ds_test  = BagsDataset(keys_test,  bags_kmer, label_index, training=False)

dl_train = DataLoader(ds_train, batch_size=BATCH, shuffle=True,
                      num_workers=NUM_WORKERS, collate_fn=pad_collate, pin_memory=True,
                      persistent_workers=True, prefetch_factor=2)
dl_valid = DataLoader(ds_valid, batch_size=BATCH*2, shuffle=False,
                      num_workers=NUM_WORKERS, collate_fn=pad_collate, pin_memory=True)
dl_test  = DataLoader(ds_test,  batch_size=BATCH*2, shuffle=False,
                      num_workers=NUM_WORKERS, collate_fn=pad_collate, pin_memory=True)


In [64]:
model = ABMILModel(in_dim=IN_DIM, enc_hid=64, enc_out=128, att_dim=128, dropout=0.1).to(device)

y_train_arr = np.array([label_index[k] for k in keys_train], dtype=int)
pos_weight  = compute_pos_weight(y_train_arr).to(device)
criterion   = FocalLoss(gamma=2.0, alpha=None, pos_weight=pos_weight, reduction="mean").to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
scaler    = torch.cuda.amp.GradScaler(enabled=USE_AMP)

best_score = -1.0
best_path  = os.path.join(OUTDIR, "abmil_kmer_best.pt")
print(">>> Training on:", device)

>>> Training on: cuda


  scaler    = torch.cuda.amp.GradScaler(enabled=USE_AMP)


In [67]:
for epoch in range(1, EPOCHS+1):
    model.train(); total_loss=0.0; n=0
    for keys, X, M, y in dl_train:
        X = X.to(device, non_blocking=True); M = M.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=USE_AMP):
            logits, _ = model(X, M)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item() * X.size(0); n += X.size(0)

    train_loss = total_loss / max(1,n)
    metrics_v  = evaluate(model, dl_valid, device)
    valid_pr   = metrics_v["PR_AUC"]; valid_auc = metrics_v["ROC_AUC"]

    scheduler.step(valid_pr if not math.isnan(valid_pr) else 0.0)
    print(f"[{epoch}/{EPOCHS}] loss={train_loss:.4f} | valid PR-AUC={valid_pr:.4f} AUC={valid_auc:.4f} ACC={metrics_v['ACC']:.4f}")

    score = valid_pr if not math.isnan(valid_pr) else valid_auc
    if not math.isnan(score) and score > best_score:
        best_score = score
        torch.save({"model": model.state_dict(), "in_dim": IN_DIM}, best_path)
        print(f"  >> saved best to {best_path}")


[1/15] loss=0.0275 | valid PR-AUC=0.5001 AUC=0.9154 ACC=0.9615
  >> saved best to /kaggle/working/abmil_kmer_best.pt
[2/15] loss=0.0274 | valid PR-AUC=0.4748 AUC=0.9120 ACC=0.9592
[3/15] loss=0.0272 | valid PR-AUC=0.4901 AUC=0.9146 ACC=0.9604
[4/15] loss=0.0271 | valid PR-AUC=0.4856 AUC=0.9136 ACC=0.9603
[5/15] loss=0.0264 | valid PR-AUC=0.4916 AUC=0.9136 ACC=0.9604
[6/15] loss=0.0262 | valid PR-AUC=0.4843 AUC=0.9114 ACC=0.9598
[7/15] loss=0.0261 | valid PR-AUC=0.4787 AUC=0.9132 ACC=0.9596
[8/15] loss=0.0258 | valid PR-AUC=0.4860 AUC=0.9124 ACC=0.9599
[9/15] loss=0.0255 | valid PR-AUC=0.4830 AUC=0.9123 ACC=0.9599
[10/15] loss=0.0256 | valid PR-AUC=0.4830 AUC=0.9128 ACC=0.9606
[11/15] loss=0.0253 | valid PR-AUC=0.4828 AUC=0.9118 ACC=0.9598
[12/15] loss=0.0252 | valid PR-AUC=0.4822 AUC=0.9120 ACC=0.9602
[13/15] loss=0.0253 | valid PR-AUC=0.4812 AUC=0.9119 ACC=0.9599
[14/15] loss=0.0251 | valid PR-AUC=0.4829 AUC=0.9121 ACC=0.9599
[15/15] loss=0.0251 | valid PR-AUC=0.4825 AUC=0.9124 ACC=0.

In [68]:
if os.path.exists(best_path):
    ckpt = torch.load(best_path, map_location="cpu")
    model.load_state_dict(ckpt["model"])
    print(">>> Loaded best checkpoint")

metrics_t = evaluate(model, dl_test, device)
print("\n===== TEST =====")
print(f"AUC={metrics_t['ROC_AUC']:.4f} | PR-AUC={metrics_t['PR_AUC']:.4f} | ACC={metrics_t['ACC']:.4f}")
print("Confusion Matrix [[TN,FP],[FN,TP]]:\n", metrics_t["CM"])

import pandas as pd
rows=[]
for key, prob, lab in zip(metrics_t["keys"], metrics_t["probs"], metrics_t["labels"]):
    tid, pos = key
    rows.append({"transcript_id":tid, "transcript_position":pos, "score":float(prob), "label":int(lab)})
pd.DataFrame(rows).to_csv(os.path.join(OUTDIR, "kmer_test_predictions.csv"), index=False)
print("Saved:", os.path.join(OUTDIR, "kmer_test_predictions.csv"))


>>> Loaded best checkpoint

===== TEST =====
AUC=0.9192 | PR-AUC=0.4477 | ACC=0.9657
Confusion Matrix [[TN,FP],[FN,TP]]:
 [[18063   158]
 [  493   239]]
Saved: /kaggle/working/kmer_test_predictions.csv


### Fine Tuning

In [73]:
import os, copy, math, time
from itertools import product
import numpy as np
import torch
import torch.nn as nn

# ====== 1) 单次训练-验证（小回合） ======
@torch.no_grad()
def evaluate_model(model, dl_valid, device):
    model.eval()
    return evaluate(model, dl_valid, device)  # 复用你已有的 evaluate()

def train_one_config(
    IN_DIM, device, dl_train, dl_valid, label_index, keys_train,
    LR=1e-3, WD=1e-4, DROPOUT=0.1, ENC_HID=64, ENC_OUT=128, ATT_DIM=128,
    GAMMA=2.0, MAX_EPOCHS=6, USE_AMP=True, CLIP_NORM=5.0, seed=42, verbose=False
):
    torch.cuda.empty_cache()
    torch.manual_seed(seed); np.random.seed(seed)

    # --- model & loss & opt ---
    model = ABMILModel(in_dim=IN_DIM, enc_hid=ENC_HID, enc_out=ENC_OUT, att_dim=ATT_DIM, dropout=DROPOUT).to(device)

    y_train_arr = np.array([label_index[k] for k in keys_train], dtype=int)
    pos_weight  = compute_pos_weight(y_train_arr).to(device)
    criterion   = FocalLoss(gamma=GAMMA, alpha=None, pos_weight=pos_weight, reduction="mean").to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    scaler    = torch.amp.GradScaler("cuda", enabled=USE_AMP)

    best_score = -1.0
    best_metrics = None
    best_state = None

    for epoch in range(1, MAX_EPOCHS+1):
        model.train(); total_loss=0.0; n=0

        for keys, X, M, y in dl_train:
            X = X.to(device, non_blocking=True)
            M = M.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast("cuda", enabled=USE_AMP):
                logits, _ = model(X, M)
                loss = criterion(logits, y)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item() * X.size(0); n += X.size(0)

        train_loss = total_loss / max(1, n)

        # --- validate ---
        metrics_v  = evaluate_model(model, dl_valid, device)
        valid_pr   = metrics_v.get("PR_AUC", float("nan"))
        valid_auc  = metrics_v.get("ROC_AUC", float("nan"))

        # 评分：PR_AUC 优先，NaN 时用 AUC
        score = valid_pr if not math.isnan(valid_pr) else valid_auc
        if not math.isnan(score) and score > best_score:
            best_score  = score
            best_metrics = dict(metrics_v)
            best_state  = copy.deepcopy(model.state_dict())

        if verbose:
            cur_lr = optimizer.param_groups[0]['lr']
            print(f"[cfg LR={LR:.2e} WD={WD:.1e} DO={DROPOUT} ATT={ATT_DIM} GAMMA={GAMMA}] "
                  f"Epoch {epoch}/{MAX_EPOCHS}  loss={train_loss:.4f}  "
                  f"PR={valid_pr:.4f}  AUC={valid_auc:.4f}  LR={cur_lr:.2e}")

    return best_score, best_metrics, best_state

# ====== 2) 网格搜索入口 ======
def grid_search_abmil(
    IN_DIM, device, dl_train, dl_valid, label_index, keys_train,
    space=None, MAX_EPOCHS=6, USE_AMP=True, seed=42, save_best_to=None, verbose=False
):
    if space is None:
        space = dict(
            LR=[3e-4, 5e-4, 1e-3],
            WD=[0.0, 1e-5, 1e-4],
            DROPOUT=[0.0, 0.1, 0.2],
            ENC_HID=[64],          # 如需搜可加 [64, 128]
            ENC_OUT=[128],         # 如需搜可加 [128, 256]
            ATT_DIM=[64, 128],
            GAMMA=[1.0, 2.0],      # FocalLoss γ
        )

    keys = list(space.keys())
    best = {"score": -1.0, "cfg": None, "metrics": None, "state": None}

    for values in product(*[space[k] for k in keys]):
        cfg = dict(zip(keys, values))
        score, metrics, state = train_one_config(
            IN_DIM, device, dl_train, dl_valid, label_index, keys_train,
            MAX_EPOCHS=MAX_EPOCHS, USE_AMP=USE_AMP, seed=seed, verbose=verbose, **cfg
        )
        if verbose:
            print(f"[GRID] cfg={cfg} >> PR={metrics.get('PR_AUC'):.4f} AUC={metrics.get('ROC_AUC'):.4f}")

        if score > best["score"]:
            best.update(score=score, cfg=cfg, metrics=metrics, state=state)

    # 保存最佳模型（可选）
    if save_best_to and best["state"] is not None:
        os.makedirs(os.path.dirname(save_best_to), exist_ok=True)
        torch.save({"model": best["state"], "in_dim": IN_DIM, "best_cfg": best["cfg"]}, save_best_to)
        if verbose:
            print(f">> saved best grid model to {save_best_to}")

    return best

# ====== 3) 调用示例 ======
# 用较小 epoch 快速比较；确定最优超参后，再用完整 EPOCHS 重训
space = dict(
    LR=[2e-4, 3e-4, 5e-4, 1e-3],
    WD=[0.0, 1e-5, 1e-4],
    DROPOUT=[0.0, 0.1, 0.2],
    ENC_HID=[64],
    ENC_OUT=[128],
    ATT_DIM=[64, 128],
    GAMMA=[1.0, 2.0],
)
best = grid_search_abmil(
    IN_DIM, device, dl_train, dl_valid, label_index, keys_train,
    space=space, MAX_EPOCHS=5, USE_AMP=True, seed=42,
    save_best_to=os.path.join(OUTDIR, "abmil_grid_best.pt"),
    verbose=True
)
print(">> Best config:", best["cfg"])
print(">> Best metrics:", best["metrics"])

# ====== 4) 用最佳配置重训（正式 EPOCHS） ======
cfg = best["cfg"]
model = ABMILModel(in_dim=IN_DIM, enc_hid=cfg["ENC_HID"], enc_out=cfg["ENC_OUT"],
                   att_dim=cfg["ATT_DIM"], dropout=cfg["DROPOUT"]).to(device)
y_train_arr = np.array([label_index[k] for k in keys_train], dtype=int)
pos_weight  = compute_pos_weight(y_train_arr).to(device)
criterion   = FocalLoss(gamma=cfg["GAMMA"], alpha=None, pos_weight=pos_weight, reduction="mean").to(device)
optimizer   = torch.optim.AdamW(model.parameters(), lr=cfg["LR"], weight_decay=cfg["WD"])
scaler      = torch.amp.GradScaler("cuda", enabled=True)

best_score = -1.0
best_path  = os.path.join(OUTDIR, "abmil_kmer_best.pt")
for epoch in range(1, EPOCHS+1):  # 这里用你的正式 EPOCHS
    model.train(); total_loss=0.0; n=0
    for keys, X, M, y in dl_train:
        X, M, y = X.to(device, non_blocking=True), M.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=True):
            logits, _ = model(X, M)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        scaler.step(optimizer); scaler.update()
        total_loss += loss.item() * X.size(0); n += X.size(0)

    metrics_v  = evaluate(model, dl_valid, device)
    valid_pr   = metrics_v["PR_AUC"]; valid_auc = metrics_v["ROC_AUC"]
    score = valid_pr if not math.isnan(valid_pr) else valid_auc
    print(f"[{epoch}/{EPOCHS}] loss={total_loss/max(1,n):.4f} | PR={valid_pr:.4f} AUC={valid_auc:.4f} ACC={metrics_v['ACC']:.4f}")

    if not math.isnan(score) and score > best_score:
        best_score = score
        torch.save({"model": model.state_dict(), "in_dim": IN_DIM}, best_path)
        print(f"  >> saved best to {best_path}")


[cfg LR=2.00e-04 WD=0.0e+00 DO=0.0 ATT=64 GAMMA=1.0] Epoch 1/5  loss=0.0683  PR=0.4423  AUC=0.8929  LR=2.00e-04
[cfg LR=2.00e-04 WD=0.0e+00 DO=0.0 ATT=64 GAMMA=1.0] Epoch 2/5  loss=0.0592  PR=0.4631  AUC=0.9000  LR=2.00e-04
[cfg LR=2.00e-04 WD=0.0e+00 DO=0.0 ATT=64 GAMMA=1.0] Epoch 3/5  loss=0.0572  PR=0.4823  AUC=0.9063  LR=2.00e-04
[cfg LR=2.00e-04 WD=0.0e+00 DO=0.0 ATT=64 GAMMA=1.0] Epoch 4/5  loss=0.0557  PR=0.4893  AUC=0.9077  LR=2.00e-04
[cfg LR=2.00e-04 WD=0.0e+00 DO=0.0 ATT=64 GAMMA=1.0] Epoch 5/5  loss=0.0549  PR=0.4921  AUC=0.9082  LR=2.00e-04
[GRID] cfg={'LR': 0.0002, 'WD': 0.0, 'DROPOUT': 0.0, 'ENC_HID': 64, 'ENC_OUT': 128, 'ATT_DIM': 64, 'GAMMA': 1.0} >> PR=0.4921 AUC=0.9082
[cfg LR=2.00e-04 WD=0.0e+00 DO=0.0 ATT=64 GAMMA=2.0] Epoch 1/5  loss=0.0363  PR=0.4492  AUC=0.8967  LR=2.00e-04
[cfg LR=2.00e-04 WD=0.0e+00 DO=0.0 ATT=64 GAMMA=2.0] Epoch 2/5  loss=0.0316  PR=0.4809  AUC=0.9040  LR=2.00e-04
[cfg LR=2.00e-04 WD=0.0e+00 DO=0.0 ATT=64 GAMMA=2.0] Epoch 3/5  loss=0.0306  PR

In [75]:
TEAM_NAME   = "nature"
OUTDIR      = "./submissions"
os.makedirs(OUTDIR, exist_ok=True)

JSON_PATHS = {
    "dataset0": "/kaggle/input/dataset-1/dataset0.json",
    "dataset1": "/kaggle/input/test-dataset/dataset1.json",
    "dataset2": "/kaggle/input/test-dataset/dataset2.json",
}


BEST_CKPT = "/kaggle/working/abmil_kmer_best.pt"
if not os.path.exists(BEST_CKPT):
    BEST_CKPT = os.path.join(OUTDIR, "abmil_grid_best.pt")
assert os.path.exists(BEST_CKPT), "未找到 best checkpoint，请先完成训练并保存。"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [78]:
def build_kmer_level_bags(json_path, label_keys_set=None):
    bags_kmer = {}
    with open(json_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            obj = json.loads(line)
            for tid, pos_dict in obj.items():
                for pos, kmer_dict in pos_dict.items():
                    key = (str(tid), str(pos))
                    # only filter if a set is provided
                    if (label_keys_set is not None) and (key not in label_keys_set):
                        continue

                    row_list = []
                    for kmer, reads in kmer_dict.items():
                        X = np.asarray(reads, dtype=float)
                        if X.size == 0 or X.ndim != 2 or X.shape[1] != 9:
                            continue
                        S = stats_for_matrix(X)                     # shape (9, 8)
                        S_flat = S.reshape(-1)                      # 72
                        n_reads = np.array([X.shape[0]], float)     # 1
                        kfeat = extract_kmer_context_features_one(kmer)  # 6
                        row = np.concatenate([S_flat, n_reads, kfeat], axis=0)  # 79
                        row_list.append(row)

                    if row_list:
                        bags_kmer[key] = np.vstack(row_list)        # (num_kmers, 79)
    return bags_kmer


In [79]:
print(">> Build bags for dataset0 (train)…")
bags0_raw = build_kmer_level_bags(JSON_PATHS["dataset0"], label_keys_set=None)  # read all keys

# align df_labels to the actually present keys
df_labels = df_labels[df_labels.apply(lambda r: (str(r.transcript_id), str(r.transcript_position)) in bags0_raw, axis=1)].reset_index(drop=True)
train_keys_0 = get_keys_from_df(df_labels)

F = bags0_raw[next(iter(bags0_raw))].shape[1]
log1p_cols = list(range(min(73, F))) if F < 80 else (list(range(72)) + [72, 76, 77, 78, 79])  # keep your rule

print(">> Fit scaler on dataset0 (train instances)…")
scaler_kmer = fit_scaler_on_train_kmer(train_keys_0, bags0_raw, log1p_indices=log1p_cols)

def build_and_transform(json_path):
    # NOTE: use label_keys_set (correct name), None => no filter
    raw = build_kmer_level_bags(json_path, label_keys_set=None)
    tr  = transform_all_bags_kmer(raw, scaler_kmer, log1p_indices=log1p_cols)
    return raw, tr


>> Build bags for dataset0 (train)…
>> Fit scaler on dataset0 (train instances)…


In [None]:
print(">> Build & transform for all datasets…")
bags_all, keys_all = {}, {}

_, bags0_tr = (bags0_raw, transform_all_bags_kmer(bags0_raw, scaler_kmer, log1p_indices=log1p_cols))
bags_all["dataset0"] = bags0_tr
keys_all["dataset0"] = train_keys_0

In [96]:
# ==== 放在创建 dl_test / 调 predict_scores_eval 之前 ====

# 1) 从 checkpoint 推断训练期的输入维度
def infer_dims_from_state_dict(sd):
    enc0 = sd["encoder.net.0.weight"]  # [ENC_HID, IN_DIM]
    ENC_HID = enc0.shape[0]
    IN_DIM_TRAIN = enc0.shape[1]
    Vw = sd["att.V.weight"]            # [ATT_DIM, ENC_OUT]
    ATT_DIM, ENC_OUT = Vw.shape[0], Vw.shape[1]
    return dict(ENC_HID=ENC_HID, IN_DIM=IN_DIM_TRAIN, ENC_OUT=ENC_OUT, ATT_DIM=ATT_DIM)

ckpt = torch.load(BEST_CKPT, map_location="cpu")
sd   = ckpt["model"]
dims = infer_dims_from_state_dict(sd)
EXPECTED_F = dims["IN_DIM"]  # ← 这里应是 40

# 2) 取 RFE 列（优先 ckpt 里，其次本地文件），并做基本一致性检查
def try_load_cols_keep(ckpt, cols_path_on_disk=None):
    for k in ("cols_keep", "rfe_cols", "COLS_KEEP", "RFE_COLS"):
        if k in ckpt and ckpt[k] is not None:
            arr = np.array(ckpt[k])
            if arr.ndim == 1:
                return arr.astype(int)
    if cols_path_on_disk and os.path.exists(cols_path_on_disk):
        return np.load(cols_path_on_disk).astype(int)
    return None

COLS_PATH = os.path.join(OUTDIR, "rfe_cols_keep.npy")
cols_keep = try_load_cols_keep(ckpt, cols_path_on_disk=COLS_PATH)

if cols_keep is not None:
    if len(cols_keep) != EXPECTED_F:
        raise RuntimeError(f"cols_keep length ({len(cols_keep)}) != ckpt expected IN_DIM ({EXPECTED_F}).")
else:
    print(f"[WARN] 未找到 RFE 列，退化为取前 {EXPECTED_F} 列。为保证精度，尽快恢复真实 cols_keep。")

# 3) 统一切特征维度到 EXPECTED_F（对 dataset0/1/2 全部执行）
def force_feature_width(bags, expected_f, cols_keep=None):
    fixed = {}
    for k, X in bags.items():
        if X.shape[1] == expected_f:
            fixed[k] = X
            continue
        if cols_keep is not None:
            if X.shape[1] < cols_keep.max() + 1:
                raise RuntimeError(f"Bag {k} has {X.shape[1]} cols but cols_keep max idx={cols_keep.max()}.")
            fixed[k] = X[:, cols_keep]
        else:
            if X.shape[1] < expected_f:
                raise RuntimeError(f"Bag {k} has only {X.shape[1]} cols, need {expected_f}.")
            fixed[k] = X[:, :expected_f]
    return fixed

for dname in list(bags_all.keys()):
    bags_all[dname] = force_feature_width(bags_all[dname], EXPECTED_F, cols_keep=cols_keep)

# 4) 最后一层保险：抽样检查每个数据集的一个 bag 维度
for dname in bags_all:
    any_key = next(iter(bags_all[dname]))
    F = bags_all[dname][any_key].shape[1]
    assert F == EXPECTED_F, f"{dname} still has F={F}, expected {EXPECTED_F}"

# 5) 用 ckpt 里的维度重建模型（确保和你之前一致）
model = ABMILModel(
    in_dim=EXPECTED_F,
    enc_hid=dims["ENC_HID"],
    enc_out=dims["ENC_OUT"],
    att_dim=dims["ATT_DIM"],
    dropout=ckpt.get("best_cfg", {}).get("DROPOUT", 0.1),
).to(device)
missing, unexpected = model.load_state_dict(sd, strict=False)
print("Loaded state_dict. Missing:", missing, "Unexpected:", unexpected)
model.eval()


[WARN] 未找到 RFE 列，退化为取前 40 列。为保证精度，尽快恢复真实 cols_keep。
Loaded state_dict. Missing: [] Unexpected: []


ABMILModel(
  (encoder): EncoderMLP(
    (net): Sequential(
      (0): Linear(in_features=40, out_features=64, bias=True)
      (1): ReLU(inplace=True)
      (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (3): Dropout(p=0.1, inplace=False)
      (4): Linear(in_features=64, out_features=128, bias=True)
      (5): ReLU(inplace=True)
      (6): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
  )
  (att): GatedAttentionMIL(
    (V): Linear(in_features=128, out_features=64, bias=True)
    (U): Linear(in_features=128, out_features=64, bias=True)
    (w): Linear(in_features=64, out_features=1, bias=False)
  )
  (cls): Linear(in_features=128, out_features=1, bias=True)
)

In [89]:
def write_submission(scores_dict, keys_order, out_csv):
    rows = [(tid, tpos, float(scores_dict.get((tid, tpos), 0.5))) for (tid, tpos) in keys_order]
    pd.DataFrame(rows, columns=["transcript_id", "transcript_position", "score"]).to_csv(out_csv, index=False)
    print("saved:", out_csv, " | N=", len(rows))



=== Inference on dataset0 ===
saved: ./submissions/nature_dataset0_1.csv  | N= 121838

=== Inference on dataset1 ===


KeyError: 'dataset1'

In [93]:
# 从 ckpt 读出训练时的输入维度
def infer_dims_from_state_dict(sd):
    enc0 = sd["encoder.net.0.weight"]  # [ENC_HID, IN_DIM]
    return enc0.shape[1]

ckpt = torch.load(BEST_CKPT, map_location="cpu")
sd   = ckpt["model"]
EXPECTED_F = infer_dims_from_state_dict(sd)  # 应该是 40

def try_load_cols_keep(ckpt, cols_path_on_disk=None):
    for k in ("cols_keep", "rfe_cols", "COLS_KEEP", "RFE_COLS"):
        if k in ckpt and ckpt[k] is not None:
            arr = np.array(ckpt[k])
            if arr.ndim == 1:
                return arr.astype(int)
    if cols_path_on_disk and os.path.exists(cols_path_on_disk):
        return np.load(cols_path_on_disk).astype(int)
    return None

COLS_PATH = os.path.join(OUTDIR, "rfe_cols_keep.npy")
COLS_KEEP_NP = try_load_cols_keep(ckpt, cols_path_on_disk=COLS_PATH)  # numpy 1D indices or None

if COLS_KEEP_NP is not None and len(COLS_KEEP_NP) != EXPECTED_F:
    raise RuntimeError(f"cols_keep len={len(COLS_KEEP_NP)} != expected {EXPECTED_F}")


In [94]:
def make_test_loader(bags_dict, keys_order, batch_size=8, num_workers=2):
    class BagDatasetTest(torch.utils.data.Dataset):
        def __init__(self, bags_dict, keys_order):
            self.bags, self.keys = bags_dict, keys_order
        def __len__(self): return len(self.keys)
        def __getitem__(self, idx):
            k = self.keys[idx]
            X = self.bags[k]  # (Ni, F) numpy
            M = np.ones((X.shape[0],), dtype=np.bool_)  # 如果你有真实 mask，这里替换
            return k, X, M, 0  # 先返回 numpy，collate 再转 torch

    def pad_collate_test(batch):
        keys, Xs, Ms, ys = zip(*batch)

        # —— 在 collate 阶段做“列选择”/“降维到 EXPECTED_F” —— #
        Xs_fixed = []
        for x in Xs:
            if COLS_KEEP_NP is not None:
                if x.shape[1] < COLS_KEEP_NP.max() + 1:
                    raise RuntimeError(f"One bag has F={x.shape[1]} < max(cols_keep)+1={COLS_KEEP_NP.max()+1}")
                x = x[:, COLS_KEEP_NP]
            else:
                if x.shape[1] < EXPECTED_F:
                    raise RuntimeError(f"One bag has only F={x.shape[1]}, need {EXPECTED_F}")
                x = x[:, :EXPECTED_F]
            Xs_fixed.append(x)

        # 现在每个 x 都是 (Ni, EXPECTED_F)
        lens = [x.shape[0] for x in Xs_fixed]
        B, Lmax, F = len(Xs_fixed), max(lens), Xs_fixed[0].shape[1]
        assert F == EXPECTED_F, f"Collate F={F} != EXPECTED_F={EXPECTED_F}"

        X = torch.zeros(B, Lmax, F, dtype=torch.float32)
        M = torch.zeros(B, Lmax, dtype=torch.bool)

        for i, (x_np, m_np) in enumerate(zip(Xs_fixed, Ms)):
            L = x_np.shape[0]
            X[i, :L] = torch.from_numpy(x_np).float()
            M[i, :L] = torch.from_numpy(m_np).bool()

        # 保证 keys 可哈希
        keys = [tuple(k) if not isinstance(k, tuple) else k for k in keys]
        y = torch.as_tensor(ys, dtype=torch.float32)
        return list(keys), X, M, y

    ds = BagDatasetTest(bags_dict, keys_order)
    return torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False,
                                       num_workers=num_workers, pin_memory=True,
                                       collate_fn=pad_collate_test)


In [95]:
@torch.no_grad()
def predict_scores_eval(model, dl, device):
    model.eval()
    out = {}
    use_amp = (device.type == "cuda")

    for keys, X, M, _ in dl:
        # 保险：运行时再断言一遍特征宽度
        assert X.shape[-1] == EXPECTED_F, f"Runtime F={X.shape[-1]} != {EXPECTED_F}"
        X = X.to(device, non_blocking=True)
        M = M.to(device, non_blocking=True)

        with torch.amp.autocast(device_type="cuda", enabled=use_amp):
            logits, *_ = model(X, M)

        if logits.ndim == 2 and logits.shape[1] == 1:
            bag_logits = logits.squeeze(1)
        else:
            bag_logits = logits.reshape(logits.shape[0], -1).mean(dim=1)

        probs = torch.sigmoid(bag_logits).detach().cpu().numpy()
        for k, p in zip(keys, probs):
            if not isinstance(k, tuple):
                try: k = tuple(k)
                except TypeError: k = (str(k),)
            out[k] = float(np.asarray(p).reshape(()))
    return out


In [97]:
def sortable_key(k):
    tid, tpos = k
    # 位置按整数排；若偶尔不是数字，就按字符串兜底
    try:
        return (str(tid), int(tpos))
    except:
        return (str(tid), str(tpos))

# —— dataset0 一样的流程（已完成）：用 df_labels 对齐 keys 并拟合 scaler_kmer ——

# —— dataset1/2：不需要标签，直接用出现过的全部 key ——
def build_and_transform(json_path):
    raw = build_kmer_level_bags(json_path, label_keys_set=None)  # 不过滤
    tr  = transform_all_bags_kmer(raw, scaler_kmer, log1p_indices=log1p_cols)
    return raw, tr

for dname in ("dataset1", "dataset2"):
    raw, tr = build_and_transform(JSON_PATHS[dname])
    # 如果你们提供了 sample_submission_{dname}.csv，可用它来定序：
    # sample_df = pd.read_csv(SAMPLE_PATHS[dname])  # 含三列
    # keys_all[dname] = [(str(r.transcript_id), str(r.transcript_position)) for _, r in sample_df.iterrows() if (str(r.transcript_id), str(r.transcript_position)) in tr]
    # 若没有 sample_submission，就用确定性排序写全部：
    keys_all[dname] = sorted(list(tr.keys()), key=sortable_key)
    bags_all[dname] = tr

# 预测 & 写提交（与你现有的一样）
for dname in ("dataset0", "dataset1", "dataset2"):
    print(f"\n=== Inference on {dname} ===")
    dl_test = make_test_loader(bags_all[dname], keys_all[dname], batch_size=8, num_workers=2)
    scores = predict_scores_eval(model, dl_test, device)
    out_csv = os.path.join(OUTDIR, f"{TEAM_NAME}_{dname}_1.csv")
    write_submission(scores, keys_all[dname], out_csv)



=== Inference on dataset0 ===
saved: ./submissions/nature_dataset0_1.csv  | N= 121838

=== Inference on dataset1 ===
saved: ./submissions/nature_dataset1_1.csv  | N= 90810

=== Inference on dataset2 ===
saved: ./submissions/nature_dataset2_1.csv  | N= 1323
