In [2]:
import numpy as np
import pandas as pd
from itertools import product
from sklearn.feature_extraction.text import CountVectorizer
import random
import torch
from transformers import AutoTokenizer, AutoModel
from rdkit import Chem, DataStructs
from rdkit.Chem import rdFingerprintGenerator as rfg
from rdkit.Chem import MACCSkeys, Descriptors
from rdkit.ML.Descriptors.MoleculeDescriptors import MolecularDescriptorCalculator

In [3]:
df = pd.read_csv('eremeeva_aptamers_dataset.csv', index_col=0)


## Augmentation

In [4]:
def augment_negatives_per_sequence(df: pd.DataFrame, n_neg_per_seq: int = 1, seed: int = 42):
    """
    Для каждого уникального sequence добавляет до n_neg_per_seq новых негативных пар:
      (sequence, canonical_smiles) с label=0, которых не было в исходном df.
    Правила:
      - пары не пересекаются с существующими (любой метки);
      - pKd_value = NaN, origin='augmented_neg', source='augmentation';
      - type берётся из исходных строк для этого sequence (если есть), иначе инференс по U/T;
      - molecular_weight подтягивается по canonical_smiles, если известен.
    
    Возврат:
      df_aug (исходные + новые негативы), df_neg (только сгенерированные).
    """
    rng = random.Random(seed)

    base = df.dropna(subset=["sequence", "canonical_smiles"]).copy()

    forbidden = set(map(tuple, base[["sequence", "canonical_smiles"]].astype(str).values))

    unique_sequences = base["sequence"].astype(str).unique().tolist()
    all_smiles = base["canonical_smiles"].astype(str).unique().tolist()
    mw_map = base.dropna(subset=["canonical_smiles"]).groupby("canonical_smiles")["molecular_weight"].first().to_dict()

    neg_rows = []

    for s in unique_sequences:
        type_series = base.loc[base["sequence"].astype(str) == s, "type"].dropna()
        if len(type_series) > 0:
            type_val = str(type_series.iloc[0])
        else:
            type_val = "RNA" if "U" in s.upper() else "DNA"

        used_smiles = set(base.loc[base["sequence"].astype(str) == s, "canonical_smiles"].astype(str).tolist())

        pool = [m for m in all_smiles if m not in used_smiles]

        if not pool:
            continue  

        k = min(n_neg_per_seq, len(pool))
        candidates = rng.sample(pool, k=k)

        for m in candidates:
            pair = (s, m)
            if pair in forbidden:
                continue  

            neg_rows.append({
                "type": type_val,
                "sequence": s,
                "canonical_smiles": m,
                "pKd_value": np.nan,
                "label": 0,
                "buffer": np.nan,
                "origin": "augmented_neg",
                "source": "augmentation",
                "molecular_weight": mw_map.get(m, np.nan),
            })
            forbidden.add(pair)

    df_neg = pd.DataFrame(neg_rows)

    for col in df.columns:
        if col not in df_neg.columns:
            df_neg[col] = np.nan
    df_neg = df_neg[df.columns.tolist()]

    df_aug = pd.concat([df, df_neg], ignore_index=True)

    return df_aug

In [5]:
df_aug = augment_negatives_per_sequence(df, n_neg_per_seq=3, seed=42)
#df_aug.to_csv('eremeeva_aptamers_dataset_with_negatives.csv')


In [6]:
df_aug['canonical_smiles'].unique().tolist()

['Nc1c(S(=O)(=O)O)cc(Nc2ccc(Nc3nc(Cl)nc(Nc4ccccc4S(=O)(=O)O)n3)c(S(=O)(=O)O)c2)c2c1C(=O)c1ccccc1C2=O',
 'Nc1c(S(=O)(=O)O)cc(Nc2ccc(S(=O)(=O)O)c(Nc3nc(Cl)nc(Cl)n3)c2)c2c1C(=O)c1ccccc1C2=O',
 'Nc1ncnc2c1ncn2C1OC(COP(=O)(O)OP(=O)(O)OP(=O)(O)O)C(O)C1O',
 'Nc1ncnc2c1ncn2C1OC(COP(=O)(O)O)C(O)C1O',
 'Cc1cc2nc3c(=O)[nH]c(=O)nc-3n(CC(O)C(O)C(O)COP(=O)(O)OP(=O)(O)OCC3OC(n4cnc5c(N)ncnc54)C(O)C3O)c2cc1C',
 'Cc1cc2nc3c(=O)[nH]c(=O)nc-3n(CC(O)C(O)C(O)COP(=O)(O)O)c2cc1C',
 'Nc1ncnc2c1ncn2C1OC(CO)C(O)C1O',
 'Cc1cc2nc3c(=O)[nH]c(=O)nc-3n(CC(O)C(O)C(O)CO)c2cc1C',
 'NC(=O)c1ccc[n+](C2OC(COP(=O)([O-])O)C(O)C2O)c1',
 'CC1=C2N=C(C=C3N=C(C(C)=C4[N-]C(C(CC(N)=O)C4(C)CCC(=O)NCC(C)OP(=O)([O-])OC4C(CO)OC(n5cnc6cc(C)c(C)cc65)C4O)C4(C)N=C1C(CCC(N)=O)C4(C)CC(N)=O)C(CCC(N)=O)C3(C)C)C(CCC(N)=O)C2(C)CC(N)=O.[C-]#N.[Co+3]',
 'C/C1=C2/[N-]C([C@H](CC(N)=O)[C@@]2(C)CCC(=O)NCC(C)O)[C@]2(C)N=C(/C(C)=C3\\N=C(/C=C4\\N=C1[C@@H](CCC(N)=O)C4(C)C)[C@@H](CCC(N)=O)[C@]3(C)CC(N)=O)[C@@H](CCC(N)=O)[C@]2(C)CC(N)=O.[C-]#N.[C-]#N.[Co]',

In [7]:
df_aug['label'].value_counts()

label
0    4578
1    1922
Name: count, dtype: int64

In [8]:
max(df_aug['sequence'].apply(len))

216

## Sequence encoding

### One-hot

In [9]:
import numpy as np

def onehot_with_type_bit(seqs, types, max_len=216):
    if len(seqs) != len(types):
        raise ValueError("seqs и types must be the same length")

    alphabet = {'A':0, 'C':1, 'G':2, 'T':3}
    N = len(seqs)
    out = np.zeros((N, max_len*4 + 1), dtype=np.float64)

    for i, (s, t) in enumerate(zip(seqs, types)):
        d = 1.0 if str(t).strip().upper() == "RNA" else 0.0
        raw = (s or "").upper().replace("U", "T")
        raw = "".join(ch for ch in raw if ch in alphabet)
        raw = raw[:max_len]
        for j, ch in enumerate(raw):
            out[i, j*4 + alphabet[ch]] = 1.0
        out[i, -1] = d

    return out


### Kmer

In [10]:

def kmer_freq_with_type_bit(seqs, types, k=6):
    if len(seqs) != len(types):
        raise ValueError("seqs и types must be the same length")

    vocab = [''.join(p) for p in product('ACGT', repeat=k)]
    vec = CountVectorizer(analyzer='char', ngram_range=(k, k), lowercase=False, vocabulary=vocab)

    seqs_norm = [(s or "").upper().replace("U", "T") for s in seqs]
    X = vec.fit_transform(seqs_norm).astype(np.float64).toarray()

    row_sum = X.sum(axis=1, keepdims=True)
    row_sum[row_sum == 0] = 1.0
    X = X / row_sum

    d = (np.array([1.0 if str(t).upper() == "RNA" else 0.0 for t in types], dtype=np.float64)).reshape(-1, 1)
    return np.concatenate([X, d], axis=1)


### Pretrained embeddings

In [11]:
@torch.no_grad()
def gena_embed(seqs, types,
               model_name='AIRI-Institute/gena-lm-bert-base-t2t-multi', max_len=216, batch_size=64):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name, output_hidden_states=True, trust_remote_code=True).to('cpu')
    device = next(model.parameters()).device
    model.eval()
    out = []
    for i in range(0, len(seqs), batch_size):
        batch = [(s or "").upper().replace("U","T") for s in seqs[i:i+batch_size]]
        enc = tokenizer(batch, padding=True, truncation=True, max_length=max_len, return_tensors="pt").to(device)
        h = model(**enc).hidden_states[-1]                 
        m = enc["attention_mask"].unsqueeze(-1)           
        out.append(((h * m).sum(1) / m.sum(1).clamp(min=1)).cpu().numpy())
    E = np.vstack(out)
    d = (np.array([1.0 if str(t).upper() == "RNA" else 0.0 for t in types], dtype=np.float64)).reshape(-1, 1)
    return np.concatenate([E, d], axis=1)                              


## Molecule encoding

### Morgan FP

In [12]:
def morgan_fp(smiles_list, n_bits=2048, radius=2, counts=False):
    X = np.zeros((len(smiles_list), n_bits), dtype=np.int32 if counts else np.uint8)
    gen = rfg.GetMorganGenerator(radius=radius, fpSize=n_bits)

    for i, smi in enumerate(smiles_list):
        mol = Chem.MolFromSmiles(str(smi) if smi is not None else "")
        if mol is None:
            continue
        if counts:
            fp = gen.GetCountFingerprint(mol)  
            for idx, val in fp.GetNonzeroElements().items():
                if idx < n_bits:
                    X[i, idx] = val
        else:
            fp = gen.GetFingerprint(mol)      
            DataStructs.ConvertToNumpyArray(fp, X[i])

    return X


In [13]:
smileses = ['Nc1c(S(=O)(=O)O)cc(Nc2ccc(Nc3nc(Cl)nc(Nc4ccccc4S(=O)(=O)O)n3)c(S(=O)(=O)O)c2)c2c1C(=O)c1ccccc1C2=O',
 'Nc1c(S(=O)(=O)O)cc(Nc2ccc(S(=O)(=O)O)c(Nc3nc(Cl)nc(Cl)n3)c2)c2c1C(=O)c1ccccc1C2=O',
 'Nc1ncnc2c1ncn2C1OC(COP(=O)(O)OP(=O)(O)OP(=O)(O)O)C(O)C1O',
 'Nc1ncnc2c1ncn2C1OC(COP(=O)(O)O)C(O)C1O',
 'Cc1cc2nc3c(=O)[nH]c(=O)nc-3n(CC(O)C(O)C(O)COP(=O)(O)OP(=O)(O)OCC3OC(n4cnc5c(N)ncnc54)C(O)C3O)c2cc1C',
 'Cc1cc2nc3c(=O)[nH]c(=O)nc-3n(CC(O)C(O)C(O)COP(=O)(O)O)c2cc1C',
 'Nc1ncnc2c1ncn2C1OC(CO)C(O)C1O']

In [14]:
morgan_fp(smileses, counts=True)

array([[0, 2, 0, ..., 0, 0, 0],
       [0, 2, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 3, 0, ..., 0, 0, 0],
       [0, 3, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], shape=(7, 2048), dtype=int32)

### MACCS FP

In [15]:
def maccs_fp(smiles_list):
    X = np.zeros((len(smiles_list), 167), dtype=np.uint8)
    for i, smi in enumerate(smiles_list):
        mol = Chem.MolFromSmiles(str(smi) if smi is not None else "")
        if mol is None:
            continue
        bv = MACCSkeys.GenMACCSKeys(mol)
        arr = np.zeros((bv.GetNumBits(),), dtype=np.int8)
        DataStructs.ConvertToNumpyArray(bv, arr)
        X[i] = arr[:]  # берём 166 бит (1..166)
    return X


In [16]:
maccs_fp(smileses)

array([[0, 0, 0, ..., 1, 1, 0],
       [0, 0, 0, ..., 1, 1, 0],
       [0, 0, 0, ..., 1, 1, 0],
       ...,
       [0, 0, 0, ..., 1, 1, 0],
       [0, 0, 0, ..., 1, 1, 0],
       [0, 0, 0, ..., 1, 1, 0]], shape=(7, 167), dtype=uint8)

### Descriptors

In [17]:
def physchem_descriptors(smiles_list, return_names=False):
    names = [n for n, _ in Descriptors._descList]
    calc = MolecularDescriptorCalculator(names)
    N, D = len(smiles_list), len(names)

    # считаем в float64, чтобы не ловить overflow в цикле
    X = np.full((N, D), np.nan, dtype=np.float64)

    for i, smi in enumerate(smiles_list):
        mol = Chem.MolFromSmiles(str(smi) if smi is not None else "")
        if mol is None:
            continue
        try:
            vals = np.asarray(calc.CalcDescriptors(mol), dtype=np.float64)
            # нечисловые значения -> NaN (чтобы потом колонка отфильтровалась)
            vals[~np.isfinite(vals)] = np.nan
            X[i] = vals
        except Exception:
            pass

    # маска колонок без NaN
    keep = ~np.isnan(X).any(axis=0)
    X = X[:, keep]
    X = np.clip(X, -1e9, 1e9).astype(np.float32, copy=False)
    kept_names = [n for n, k in zip(names, keep) if k]

    return (X, kept_names) if return_names else X


### Pretrained embeddings

In [18]:
tok = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
mdl = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
@torch.no_grad()
def chemberta_embed(smiles_list, tok, mdl,
                    batch_size=64, max_len=128, pooling="mean", device=None):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    mdl.to(device).eval()
    out = []
    for i in range(0, len(smiles_list), batch_size):
        batch = [str(s) if s is not None else "" for s in smiles_list[i:i+batch_size]]
        enc = tok(batch, padding=True, truncation=True, max_length=max_len, return_tensors="pt").to(device)
        h = mdl(**enc).last_hidden_state 
        if pooling == "cls":
            pooled = h[:, 0]
        else:
            m = enc["attention_mask"].unsqueeze(-1)
            pooled = (h * m).sum(1) / m.sum(1).clamp(min=1)
        out.append(pooled.detach().cpu().numpy())
    return np.vstack(out)


In [19]:
chemberta_embed(smileses, tok, mdl)

array([[-0.23657979, -0.39311787, -0.49733528, ...,  0.02844087,
        -0.16161488,  0.06855042],
       [ 0.01253446, -0.24209289, -0.20824078, ..., -0.03621006,
        -0.20099978,  0.1657696 ],
       [ 0.38088298, -0.3987004 , -0.76109195, ...,  0.18910536,
         0.00204533,  0.24978273],
       ...,
       [ 0.24788767, -0.43359447, -0.63609624, ...,  0.07548268,
         0.20819755,  0.34934604],
       [ 0.2242443 , -0.4126285 , -0.27724662, ...,  0.03192327,
         0.17654993,  0.2161503 ],
       [ 0.7941891 , -0.25084215,  0.42892194, ..., -0.07675749,
         0.32636815,  0.05392162]], shape=(7, 768), dtype=float32)

## Data Splitting

In [20]:
from sklearn.model_selection import StratifiedGroupKFold, GroupKFold

def stratified_group_splits(df, label_col="label", group_cols=("sequence","canonical_smiles"),
                            n_splits=5, random_state=42):
    groups = df[list(group_cols)].astype(str).agg("||".join, axis=1)
    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    splits = []
    for tr, va in sgkf.split(X=np.zeros(len(df)), y=df[label_col].values, groups=groups.values):
        splits.append((tr, va))
    return splits

def cold_aptamer_splits(df, n_splits=5, group_col="sequence"):
    gkf = GroupKFold(n_splits=n_splits)
    groups = df[group_col].astype(str).values
    y = df["label"].values if "label" in df.columns else np.zeros(len(df))
    return [(tr, va) for tr, va in gkf.split(X=np.zeros(len(df)), y=y, groups=groups)]

def cold_molecule_splits(df, n_splits=5, group_col="canonical_smiles"):
    gkf = GroupKFold(n_splits=n_splits)
    groups = df[group_col].astype(str).values
    y = df["label"].values if "label" in df.columns else np.zeros(len(df))
    return [(tr, va) for tr, va in gkf.split(X=np.zeros(len(df)), y=y, groups=groups)]


def cold_both_splits(df, n_splits=5, seq_col="sequence", mol_col="canonical_smiles", random_state=42):
    rng = np.random.RandomState(random_state)
    uniq_seqs = df[seq_col].astype(str).unique()
    uniq_mols = df[mol_col].astype(str).unique()
    rng.shuffle(uniq_seqs)
    rng.shuffle(uniq_mols)

    seq_fold = {s: i % n_splits for i, s in enumerate(uniq_seqs)}
    mol_fold = {m: i % n_splits for i, m in enumerate(uniq_mols)}

    s_f = df[seq_col].astype(str).map(seq_fold).to_numpy()
    m_f = df[mol_col].astype(str).map(mol_fold).to_numpy()

    splits = []
    for f in range(n_splits):
        val_mask = (s_f == f) | (m_f == f)
        train_mask = (s_f != f) & (m_f != f)
        val_idx = np.where(val_mask)[0]
        train_idx = np.where(train_mask)[0]
        splits.append((train_idx, val_idx))
    return splits



## Compute combined embeddings

In [21]:
import numpy as np
import pandas as pd
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier

# ========= wrappers, using YOUR feature functions =========

def make_apt_features(df, cfg):
    """
    cfg:
      {'name':'kmer',   'k':6}
      {'name':'onehot', 'max_len':216}
      {'name':'gena',   'model_name':'AIRI-Institute/gena-lm-bert-base-t2t-multi', 'max_len':216, 'batch_size':64}
    Возвращает np.ndarray (N, D)
    """
    name = cfg.get('name')
    seqs  = df['sequence'].tolist()
    types = df['type'].tolist()

    if name == 'kmer':
        return kmer_freq_with_type_bit(seqs, types, k=cfg.get('k', 6)).astype(np.float64)
    elif name == 'onehot':
        return onehot_with_type_bit(seqs, types, max_len=cfg.get('max_len', 216)).astype(np.float64)
    elif name == 'gena':
        return gena_embed(
            seqs, types,
            model_name=cfg.get('model_name', 'AIRI-Institute/gena-lm-bert-base-t2t-multi'),
            max_len=cfg.get('max_len', 216),
            batch_size=cfg.get('batch_size', 64)
        ).astype(np.float64)
    else:
        raise ValueError(f"Unknown apt feature: {name}")

def make_mol_features(df, cfg):
    """
    cfg:
      {'name':'morgan','n_bits':2048,'radius':2,'counts':False}
      {'name':'maccs'}
      {'name':'physchem'}
      {'name':'chemberta','tok':tok,'mdl':mdl,'batch_size':64,'max_len':128,'pooling':'mean'}
      {'name':'concat','parts':[ ...под-конфиги как выше... ]}
    Возвращает np.ndarray (N, D)
    """
    name = cfg.get('name')
    smiles = df['canonical_smiles'].tolist()

    if name == 'morgan':
        return morgan_fp(
            smiles,
            n_bits=cfg.get('n_bits', 2048),
            radius=cfg.get('radius', 2),
            counts=cfg.get('counts', False)
        ).astype(np.float64)

    elif name == 'maccs':
        # твоя maccs_fp возвращает 167 бит (включая бит 0 RDKit) — используем как есть
        return maccs_fp(smiles).astype(np.float64)

    elif name == 'physchem':
        return physchem_descriptors(smiles).astype(np.float64)

    elif name == 'chemberta':
        tok = cfg['tok']; mdl = cfg['mdl']
        return chemberta_embed(
            smiles, tok, mdl,
            batch_size=cfg.get('batch_size', 64),
            max_len=cfg.get('max_len', 128),
            pooling=cfg.get('pooling', 'mean'),
            device=cfg.get('device', None)
        ).astype(np.float64)

    elif name == 'concat':
        mats = [make_mol_features(df, c) for c in cfg.get('parts', [])]
        return np.concatenate(mats, axis=1).astype(np.float64) if mats else np.zeros((len(df), 0), dtype=np.float64)

    else:
        raise ValueError(f"Unknown mol feature: {name}")


def _get_splits(df, split_mode="group", n_splits=5, random_state=42):
    if split_mode == "group":
        return stratified_group_splits(df, n_splits=n_splits, random_state=random_state)
    elif split_mode == "cold_aptamer":
        return cold_aptamer_splits(df, n_splits=n_splits)
    elif split_mode == "cold_molecule":
        return cold_molecule_splits(df, n_splits=n_splits)
    elif split_mode == "cold_both":
        return cold_both_splits(df, n_splits=n_splits, random_state=random_state)
    else:
        raise ValueError(f"Unknown split_mode: {split_mode}")

# ========= models & utils =========

def _model_factory(name, random_state=42):
    name = name.lower()
    if name in ("logreg","logistic","logistic_regression"):
        return LogisticRegression(max_iter=2000, class_weight="balanced", n_jobs=-1, solver="lbfgs")
    if name in ("rf","randomforest","random_forest"):
        return RandomForestClassifier(
            n_estimators=500, max_depth=None, n_jobs=-1, class_weight="balanced_subsample", random_state=random_state
        )
    if name in ("mlp",):
        return MLPClassifier(hidden_layer_sizes=(256,128), activation="relu",
                             learning_rate_init=1e-3, alpha=1e-4,
                             max_iter=120, early_stopping=True, random_state=random_state)
    if name in ("lgbm","lightgbm"):
        try:
            from lightgbm import LGBMClassifier
        except Exception:
            return None
        return LGBMClassifier(
            n_estimators=600, learning_rate=0.05, max_depth=-1,
            subsample=0.8, colsample_bytree=0.8, reg_lambda=1.0,
            random_state=random_state, n_jobs=-1, verbosity=-1
        )
    raise ValueError(f"Unknown model: {name}")

def _cfg_name(cfg, side):
    n = cfg.get("name","?")
    if side=="apt":
        if n=="kmer":   return f"kmer(k={cfg.get('k',6)})"
        if n=="onehot": return f"onehot(L={cfg.get('max_len',216)})"
        if n=="gena":   return f"gena(mean,last)"
    else:
        if n=="morgan":    return f"morgan({cfg.get('n_bits',2048)},{cfg.get('radius',2)}{'c' if cfg.get('counts',False) else 'b'})"
        if n=="maccs":     return "maccs(167)"
        if n=="physchem":  return "physchem(full)"
        if n=="chemberta": return "chemberta(mean)"
        if n=="concat":    return "concat(" + "+".join(_cfg_name(c,'mol') for c in cfg.get('parts',[])) + ")"
    return n

def _best_f1(y_true, y_score):
    if len(np.unique(y_true)) < 2:
        return np.nan
    thr = np.linspace(0,1,201)
    best = 0.0
    for t in thr:
        p = (y_score >= t).astype(int)
        best = max(best, f1_score(y_true, p, zero_division=0))
    return best

# ========= main: screening =========

from sklearn.metrics import average_precision_score, roc_auc_score, f1_score, matthews_corrcoef
from sklearn.preprocessing import StandardScaler
import numpy as np
import pandas as pd
from tqdm import tqdm
import warnings

def screen_models(
    df,
    apt_cfgs,
    mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="group",
    n_splits=5,
    random_state=42,
    scale=True,
):
    """
    Возвращает dict {model_name: DataFrame}, где строки — apt-конфиги, колонки — mol-конфиги,
    значение — строка "PR mean±std | ROC mean±std | F1 mean±std | MCC mean±std".

    ВАЖНО: F1 и MCC считаются по бинарным предсказаниям clf.predict(X) (дефолтный порог модели).
    """
    df = df.reset_index(drop=True)
    y = df["label"].to_numpy().astype(int)
    splits = _get_splits(df, split_mode=split_mode, n_splits=n_splits, random_state=random_state)

    # предрасчёт фич
    apt_map = {}
    for a in apt_cfgs:
        an = _cfg_name(a, 'apt')
        apt_map[an] = make_apt_features(df, a)

    mol_map = {}
    for m in mol_cfgs:
        mn = _cfg_name(m, 'mol')
        mol_map[mn] = make_mol_features(df, m)

    # таблицы
    rows = [ _cfg_name(a,'apt') for a in apt_cfgs ]
    cols = [ _cfg_name(m,'mol') for m in mol_cfgs ]
    results = {}

    for model_name in tqdm(model_names):
        clf_proto = _model_factory(model_name, random_state=random_state)
        if clf_proto is None:
            results[model_name] = pd.DataFrame("N/A", index=rows, columns=cols)
            continue

        table = pd.DataFrame(index=rows, columns=cols, dtype=object)

        for ar in tqdm(rows):
            Xa = apt_map[ar].astype(np.float64)
            for mc in tqdm(cols):
                Xm = mol_map[mc].astype(np.float64)
                X = np.concatenate([Xa, Xm], axis=1)

                pr_list, roc_list, f1_list, mcc_list, npr_list = [], [], [], [], []
                for tr, va in splits:
                    Xtr, Xva = X[tr], X[va]
                    ytr, yva = y[tr], y[va]

                    if scale:
                        scaler = StandardScaler()
                        Xtr = scaler.fit_transform(Xtr)
                        Xva = scaler.transform(Xva)

                    clf = _model_factory(model_name, random_state=random_state)
                    with warnings.catch_warnings():
                        warnings.filterwarnings("ignore", message="X does not have valid feature names.*")
                        clf.fit(Xtr, ytr)

                    # непрерывный счёт для PR/ROC
                    with warnings.catch_warnings():
                        warnings.filterwarnings("ignore", message="X does not have valid feature names.*")
                        if hasattr(clf, "predict_proba"):
                            s = clf.predict_proba(Xva)[:, 1]
                        elif hasattr(clf, "decision_function"):
                            d = clf.decision_function(Xva)
                            s = (d - d.min()) / (d.max() - d.min() + 1e-8)
                        else:
                            s = clf.predict(Xva).astype(float)  # деградация, если нет score

                    # бинарные предсказания по дефолтному порогу модели
                    with warnings.catch_warnings():
                        warnings.filterwarnings("ignore", message="X does not have valid feature names.*")
                        yhat = clf.predict(Xva)

                    pr  = average_precision_score(yva, s) if len(np.unique(yva))>1 else np.nan
                    p = yva.mean()
                    npr = (pr - p) / (1 - p + 1e-12)
                    
                    roc = roc_auc_score(yva, s) if len(np.unique(yva))>1 else np.nan
                    f1  = f1_score(yva, yhat, zero_division=0)
                    mcc = matthews_corrcoef(yva, yhat)

                    pr_list.append(pr); roc_list.append(roc); f1_list.append(f1); mcc_list.append(mcc), npr_list.append(npr)

                def fmt(arr):
                    arr = np.array(arr, dtype=float)
                    return f"{np.nanmean(arr):.3f}±{np.nanstd(arr):.3f}"

                table.loc[ar, mc] = (
                    f"PR {fmt(pr_list)} | ROC {fmt(roc_list)} | F1 {fmt(f1_list)} | MCC {fmt(mcc_list)}| nPR {fmt(npr_list)}"
                )

        results[model_name] = table

    return results



## Configs

In [28]:
apt_cfgs = []

# k-mer
for k in [3, 4, 5]:
    apt_cfgs.append({'name': 'kmer', 'k': k})

# one-hot
for L in [216]:
    apt_cfgs.append({'name': 'onehot', 'max_len': L})

# GENA-LM (подставляет модель прямо в функции по имени)
gena_models = [
    'AIRI-Institute/gena-lm-bert-base-t2t-multi']
for model_name, max_len, batch_size in product(gena_models, [216], [64]):
    apt_cfgs.append({'name': 'gena', 'model_name': model_name, 'max_len': max_len, 'batch_size': batch_size})

print(f"apt_cfgs: {len(apt_cfgs)} вариантов")

apt_cfgs: 5 вариантов


In [29]:
mol_cfgs = []

# Morgan (все комбинации сетки)
for n_bits, radius, counts in product([1024], [2], [True]):
    mol_cfgs.append({'name': 'morgan', 'n_bits': n_bits, 'radius': radius, 'counts': counts})

# MACCS
mol_cfgs.append({'name': 'maccs'})

# Полные physchem дескрипторы (с последующей фильтрацией NaN-колонок в твоей функции)
#mol_cfgs.append({'name': 'physchem'})

# ChemBERTa (использует уже созданные тобой tok/mdl)
mol_cfgs.append({'name': 'chemberta', 'tok': tok, 'mdl': mdl,
                'batch_size': 64, 'max_len': max_len, 'pooling': 'mean'})



print(f"mol_cfgs: {len(mol_cfgs)} вариантов")

mol_cfgs: 3 вариантов


## No augmentation

In [24]:
res_1_0_group = screen_models(
    df,
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="group",
    n_splits=5,
    scale=True,   # включи True, если в признаках много непрерывных чисел (emb/physchem)
)

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 0/4 [00:00<?, ?it/s]
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 7/7 [01:13<00:00, 10.51s/it]

[A
[A
[A
[A
[A
[A
[A
100%|██████████| 7/7 [00:49<00:00,  7.05s/it]

[A
[A
[A
[A
[A
[A
[A
100%|██████████| 7/7 [00:57<00:00,  8.24s/it]

[A
[A
[A
[A
[A
[A
[A
100%|██████████| 7/7 [00:55<00:00,  7.92s/it]

[A
[A
[A
[A
[A
[A
[A
100%|██████████| 7/7 [00:47<00:00,  6.81s/it]
100%|██████████| 5/5 [04:43<00:00, 56.75s/it]
 25%|██▌       | 1/4 [04:43<14:11, 283.77s/it]
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 7/7 [00:39<00:00,  5.67s/it]

[A
[A
[A
[A
[A
[A
[A
100%|██████████| 7/7 [00:40<00:00,  5.77s/it]

[A
[A
[A
[A
[A
[A
[A
100%|██████████| 7/7 [00:44<00:

In [39]:
res_1_0_group['mlp'].to_excel('res_no_aug_mlp.xlsx')

In [40]:

res_1_0_cold_aptamer = screen_models(
    df,
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="cold_aptamer",
    n_splits=5,
    scale=True,   # включи True, если в признаках много непрерывных чисел (emb/physchem)
)


The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 0/4 [00:00<?, ?it/s]
[A
[A
[A
[A
100%|██████████| 4/4 [00:55<00:00, 13.99s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:25<00:00,  6.27s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:27<00:00,  6.96s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:28<00:00,  7.20s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:27<00:00,  6.83s/it]
100%|██████████| 5/5 [02:45<00:00, 33.02s/it]
 25%|██▌       | 1/4 [02:45<08:15, 165.09s/it]
[A
[A
[A
[A
100%|██████████| 4/4 [00:23<00:00,  5.79s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:23<00:00,  5.96s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:25<00:00,  6.27s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:23<00:00,  5.90s/it]

[A
[A
[A
[A
1

In [48]:

res_1_0_cold_molecule = screen_models(
    df,
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="cold_molecule",
    n_splits=5,
    scale=True,   # включи True, если в признаках много непрерывных чисел (emb/physchem)
)



The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 0/4 [00:00<?, ?it/s]
[A
[A
[A
[A
100%|██████████| 4/4 [00:52<00:00, 13.20s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:24<00:00,  6.01s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:27<00:00,  6.80s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:27<00:00,  6.88s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:27<00:00,  6.80s/it]
100%|██████████| 5/5 [02:38<00:00, 31.76s/it]
 25%|██▌       | 1/4 [02:38<07:56, 158.78s/it]
[A
[A
[A
[A
100%|██████████| 4/4 [00:22<00:00,  5.52s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:23<00:00,  5.80s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:25<00:00,  6.28s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:23<00:00,  5.83s/it]

[A
[A
[A
[A
1

In [49]:
res_1_0_cold_aptamer.keys()

dict_keys(['logreg', 'rf', 'lgbm', 'mlp'])

In [52]:
res_1_0_cold_both = screen_models(
    df,
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="cold_both",
    n_splits=5,
    scale=True)  

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 0/4 [00:00<?, ?it/s]
[A
[A
[A
[A
100%|██████████| 4/4 [00:50<00:00, 12.69s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:23<00:00,  5.77s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:27<00:00,  6.85s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:25<00:00,  6.45s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:24<00:00,  6.19s/it]
100%|██████████| 5/5 [02:31<00:00, 30.38s/it]
 25%|██▌       | 1/4 [02:31<07:35, 151.90s/it]
[A
[A
[A
[A
100%|██████████| 4/4 [00:19<00:00,  4.89s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:20<00:00,  5.06s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:21<00:00,  5.46s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:20<00:00,  5.06s/it]

[A
[A
[A
[A
1

## Augmentation

### 1 to 1

In [50]:
res_1_3_group = screen_models(
    augment_negatives_per_sequence(df, n_neg_per_seq=4, seed=42),
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","lgbm","mlp"),
    split_mode="group",
    n_splits=5,
    scale=True)  

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 0/3 [00:00<?, ?it/s]
[A
[A
[A
100%|██████████| 3/3 [01:05<00:00, 21.80s/it]

[A
[A
[A
100%|██████████| 3/3 [00:40<00:00, 13.38s/it]

[A
[A
[A
100%|██████████| 3/3 [01:03<00:00, 21.05s/it]

[A
[A
[A
100%|██████████| 3/3 [01:10<00:00, 23.49s/it]

[A
[A
[A
100%|██████████| 3/3 [00:53<00:00, 17.69s/it]
100%|██████████| 5/5 [04:52<00:00, 58.47s/it]
 33%|███▎      | 1/3 [04:52<09:44, 292.33s/it]
[A
[A
[A
100%|██████████| 3/3 [00:47<00:00, 15.86s/it]

[A
[A
[A
100%|██████████| 3/3 [00:55<00:00, 18.36s/it]

[A
[A
[A
100%|██████████| 3/3 [01:18<00:00, 26.20s/it]

[A
[A
[A
100%|██████████| 3/3 [00:47<00:00, 15.80s/it]

[A
[A
[A
100%|██████████| 3/3 [02:30<00:00, 50.25s

In [76]:
res_1_3_group

{'logreg':                                                    morgan(1024,2c)  \
 kmer(k=3)        PR 0.504±0.019 | ROC 0.736±0.006 | F1 0.531±0....   
 kmer(k=4)        PR 0.494±0.018 | ROC 0.743±0.010 | F1 0.538±0....   
 kmer(k=5)        PR 0.393±0.023 | ROC 0.680±0.016 | F1 0.471±0....   
 onehot(L=216)    PR 0.446±0.015 | ROC 0.703±0.012 | F1 0.510±0....   
 gena(mean,last)  PR 0.428±0.014 | ROC 0.710±0.011 | F1 0.508±0....   
 
                                                         maccs(167)  \
 kmer(k=3)        PR 0.409±0.016 | ROC 0.701±0.017 | F1 0.474±0....   
 kmer(k=4)        PR 0.399±0.026 | ROC 0.699±0.017 | F1 0.477±0....   
 kmer(k=5)        PR 0.303±0.020 | ROC 0.572±0.019 | F1 0.364±0....   
 onehot(L=216)    PR 0.366±0.015 | ROC 0.655±0.015 | F1 0.451±0....   
 gena(mean,last)  PR 0.335±0.012 | ROC 0.627±0.016 | F1 0.412±0....   
 
                                                    chemberta(mean)  
 kmer(k=3)        PR 0.514±0.017 | ROC 0.749±0.005 | F1 0.544±0.

In [51]:
for model in res_1_3_group.keys():
    res_1_3_group[model].to_excel(f'res_1_4_aug_{model}_group.xlsx')

### 1 to 2

In [63]:
res_1_2_group = screen_models(
    df=augment_negatives_per_sequence(df, n_neg_per_seq=2, seed=42),
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="group",
    n_splits=5,
    scale=True,   # включи True, если в признаках много непрерывных чисел (emb/physchem)
)

res_1_2_cold_aptamer = screen_models(
    df=augment_negatives_per_sequence(df, n_neg_per_seq=2, seed=42),
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="cold_aptamer",
    n_splits=5,
    scale=True,   # включи True, если в признаках много непрерывных чисел (emb/physchem)
)

res_1_2_cold_molecule = screen_models(
    df=augment_negatives_per_sequence(df, n_neg_per_seq=2, seed=42),
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="cold_molecule",
    n_splits=5,
    scale=True,   # включи True, если в признаках много непрерывных чисел (emb/physchem)
)


The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 0/4 [00:00<?, ?it/s]
[A
[A
[A
[A
100%|██████████| 4/4 [01:06<00:00, 16.56s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:36<00:00,  9.21s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:55<00:00, 13.96s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [01:04<00:00, 16.14s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:49<00:00, 12.33s/it]
100%|██████████| 5/5 [04:33<00:00, 54.60s/it]
 25%|██▌       | 1/4 [04:33<13:39, 273.03s/it]
[A
[A
[A
[A
100%|██████████| 4/4 [00:39<00:00,  9.97s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:41<00:00, 10.49s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:49<00:00, 12.48s/it]

[A
[A
[A
[A
100%|██████████| 4/4 [00:41<00:00, 10.41s/it]

[A
[A
[A
[A
1

: 

: 

### 1 to 3

In [None]:
res_1_3_group = screen_models(
    df=augment_negatives_per_sequence(df, n_neg_per_seq=3, seed=42),
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="group",
    n_splits=5,
    scale=False,   # включи True, если в признаках много непрерывных чисел (emb/physchem)
)

res_1_3_cold_aptamer = screen_models(
    df=augment_negatives_per_sequence(df, n_neg_per_seq=3, seed=42),
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="cold_aptamer",
    n_splits=5,
    scale=False,   # включи True, если в признаках много непрерывных чисел (emb/physchem)
)

res_1_3_cold_molecule = screen_models(
    df=augment_negatives_per_sequence(df, n_neg_per_seq=3, seed=42),
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="cold_molecule",
    n_splits=5,
    scale=False,   # включи True, если в признаках много непрерывных чисел (emb/physchem)
)

res_1_3_cold_both = screen_models(
    df=augment_negatives_per_sequence(df, n_neg_per_seq=3, seed=42),
    apt_cfgs=apt_cfgs,
    mol_cfgs=mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="cold_both",
    n_splits=5,
    scale=False,   # включи True, если в признаках много непрерывных чисел (emb/physchem)
)

### 1 to 4

### Ablation study

In [37]:
import numpy as np
from contextlib import contextmanager

# --- 1) генерация случайных фич той же размерности ---
def randomize_like(X, how="permute", seed=42):
    """
    X: np.ndarray (N, D)
    how="permute"  — перетасовать значения в каждом столбце (сохр. маргинали и масштаб).
    how="gaussian" — сэмплировать N(mean_j, std_j) для каждого столбца.
    """
    rng = np.random.RandomState(seed)
    X = np.asarray(X)
    N, D = X.shape
    R = np.empty_like(X, dtype=float)

    if how == "permute":
        R[:] = X
        for j in range(D):
            rng.shuffle(R[:, j])       # тасуем значения этого признака по объектам
        return R

    elif how == "gaussian":
        mu = np.nanmean(X, axis=0)
        sd = np.nanstd(X, axis=0)
        sd = np.where(sd == 0, 1.0, sd)
        return rng.normal(loc=mu, scale=sd, size=(N, D))

    else:
        raise ValueError("how must be 'permute' or 'gaussian'")

# --- 2) контекстная подмена фабрик фич ---
@contextmanager
def random_feature_ablation(side="both", how="permute", seed=42):
    """
    side: 'apt' | 'mol' | 'both'
    Подменяет make_apt_features/make_mol_features так, чтобы они
    возвращали случайные фичи той же размерности.
    """
    global make_apt_features, make_mol_features
    _orig_apt = make_apt_features
    _orig_mol = make_mol_features

    def _wrap(fn):
        def _inner(df, cfg, _fn=fn):
            X = _fn(df, cfg)
            return randomize_like(X, how=how, seed=seed)
        return _inner

    try:
        if side in ("apt", "both"):
            make_apt_features = _wrap(_orig_apt)
        if side in ("mol", "both"):
            make_mol_features = _wrap(_orig_mol)
        yield
    finally:
        make_apt_features = _orig_apt
        make_mol_features = _orig_mol

# --- 3) сахар: вызвать твою screen_models с абляцией фич ---
def screen_models_ablate_features(
    df,
    apt_cfgs,
    mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="group",
    n_splits=5,
    random_state=42,
    scale=True,
    side="both",
    how="permute",
    seed=42,
):
    with random_feature_ablation(side=side, how=how, seed=seed):
        return screen_models(
            df=df,
            apt_cfgs=apt_cfgs,
            mol_cfgs=mol_cfgs,
            model_names=model_names,
            split_mode=split_mode,
            n_splits=n_splits,
            random_state=random_state,
            scale=scale,
        )


In [38]:
res_rand = screen_models_ablate_features(df, apt_cfgs, mol_cfgs,
                                         model_names=("lgbm","logreg", "mlp"),
                                         split_mode="group",
                                         side="both", how="permute", seed=42)


The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 0/3 [00:00<?, ?it/s]
[A
[A
[A
100%|██████████| 3/3 [00:26<00:00,  8.73s/it]

[A
[A
[A
100%|██████████| 3/3 [00:31<00:00, 10.54s/it]

[A
[A
[A
100%|██████████| 3/3 [00:41<00:00, 13.67s/it]

[A
[A
[A
100%|██████████| 3/3 [00:27<00:00,  9.10s/it]

[A
[A
[A
100%|██████████| 3/3 [01:21<00:00, 27.09s/it]
100%|██████████| 5/5 [03:27<00:00, 41.49s/it]
 33%|███▎      | 1/3 [03:27<06:54, 207.45s/it]
[A
[A
[A
100%|██████████| 3/3 [00:43<00:00, 14.45s/it]

[A
[A
[A
100%|██████████| 3/3 [00:16<00:00,  5.43s/it]

[A
[A
[A
100%|██████████| 3/3 [00:17<00:00,  5.69s/it]

[A
[A
[A
100%|██████████| 3/3 [00:17<00:00,  5.70s/it]

[A
[A
[A
100%|██████████| 3/3 [00:16<00:00,  5.62s

In [39]:
for model in res_rand.keys():
    res_rand[model].to_excel(f'res_{model}_rand_ablation_both.xlsx')

In [40]:
res_mol_rand = screen_models_ablate_features(df, apt_cfgs, mol_cfgs,
                                             model_names=("lgbm","logreg", "mlp"),
                                             split_mode="group",
                                             side="mol", how="gaussian", seed=42)


The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 0/3 [00:00<?, ?it/s]
[A
[A
[A
100%|██████████| 3/3 [00:48<00:00, 16.28s/it]

[A
[A
[A
100%|██████████| 3/3 [01:01<00:00, 20.44s/it]

[A
[A
[A
100%|██████████| 3/3 [01:19<00:00, 26.42s/it]

[A
[A
[A
100%|██████████| 3/3 [00:57<00:00, 19.25s/it]

[A
[A
[A
100%|██████████| 3/3 [01:42<00:00, 34.30s/it]
100%|██████████| 5/5 [05:50<00:00, 70.02s/it]
 33%|███▎      | 1/3 [05:50<11:40, 350.10s/it]
[A
[A
[A
100%|██████████| 3/3 [00:42<00:00, 14.20s/it]

[A
[A
[A
100%|██████████| 3/3 [00:16<00:00,  5.47s/it]

[A
[A
[A
100%|██████████| 3/3 [00:17<00:00,  5.88s/it]

[A
[A
[A
100%|██████████| 3/3 [00:17<00:00,  5.96s/it]

[A
[A
[A
100%|██████████| 3/3 [00:17<00:00,  5.91s

In [41]:
for model in res_mol_rand.keys():
    res_mol_rand[model].to_excel(f'res_mol_{model}_rand_ablation_both.xlsx')

In [42]:
res_apt_rand = screen_models_ablate_features(df, apt_cfgs, mol_cfgs,
                                             model_names=("lgbm","logreg", "mlp"),
                                             split_mode="group",
                                             side="apt", how="gaussian", seed=42)

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 0/3 [00:00<?, ?it/s]
[A
[A
[A
100%|██████████| 3/3 [00:31<00:00, 10.60s/it]

[A
[A
[A
100%|██████████| 3/3 [00:47<00:00, 15.93s/it]

[A
[A
[A
100%|██████████| 3/3 [01:38<00:00, 32.74s/it]

[A
[A
[A
100%|██████████| 3/3 [01:25<00:00, 28.56s/it]

[A
[A
[A
100%|██████████| 3/3 [01:18<00:00, 26.30s/it]
100%|██████████| 5/5 [05:42<00:00, 68.49s/it]
 33%|███▎      | 1/3 [05:42<11:24, 342.44s/it]
[A
[A
[A
100%|██████████| 3/3 [00:44<00:00, 14.69s/it]

[A
[A
[A
100%|██████████| 3/3 [00:18<00:00,  6.15s/it]

[A
[A
[A
100%|██████████| 3/3 [00:17<00:00,  5.93s/it]

[A
[A
[A
100%|██████████| 3/3 [00:17<00:00,  5.88s/it]

[A
[A
[A
100%|██████████| 3/3 [00:17<00:00,  5.91s

In [43]:
for model in res_apt_rand.keys():
    res_apt_rand[model].to_excel(f'res_apt_{model}_rand_ablation_both.xlsx')

### concats

In [52]:
import numpy as np
import warnings
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score, matthews_corrcoef

# ---- НОВОЕ: компактный хелпер для смешения двух векторов признаков ----
def fuse_features(Xa, Xm, fusion="concat", proj_dim=256, random_state=42):
    """
    Xa: (N, Da), Xm: (N, Dm)
    fusion ∈ {
        "concat",                # [Xa, Xm]
        "sum",                   # P(Xa) + P(Xm)
        "hadamard",              # P(Xa) ⊙ P(Xm)
        "absdiff",               # |P(Xa) - P(Xm)|
        "biaffine",              # (Xa@U) ⊙ (Xm@V)  -> k dims
        "poly2"                  # [P(Xa), P(Xm), P(Xa)⊙P(Xm), |P(Xa)-P(Xm)|, P(Xa)+P(Xm)]
    }
    P(·) — общая линейная проекция до размера k=proj_dim, если Da!=Dm (или просто для стабилизации).
    """
    Xa = np.asarray(Xa, dtype=np.float64)
    Xm = np.asarray(Xm, dtype=np.float64)
    Na, Da = Xa.shape
    Nm, Dm = Xm.shape
    assert Na == Nm, "Xa и Xm должны иметь одинаковое число строк (N)."

    if fusion == "concat":
        return np.concatenate([Xa, Xm], axis=1)

    # подготовим проекции (детерминированно от размеров и random_state)
    def proj(X, D_in, tag):
        k = proj_dim if proj_dim is not None else min(Da, Dm)
        # детерминированное семя под конкретную проекцию
        seed = (hash((tag, D_in, proj_dim, random_state)) & 0x7fffffff)
        rng = np.random.RandomState(seed)
        W = rng.normal(0.0, 1.0 / np.sqrt(D_in), size=(D_in, k))
        return X @ W  # (N, k)

    # если нужно выровнять размерности — проецируем
    if Da != Dm:
        A = proj(Xa, Da, "A")
        M = proj(Xm, Dm, "M")
    else:
        # даже при равных можно оставить как есть, либо тоже проецировать;
        # оставим как есть для "sum/absdiff/hadamard/poly2", для "biaffine" всё равно спроецируем
        A, M = Xa, Xm

    if fusion == "sum":
        if A.shape[1] != M.shape[1]:
            # на всякий случай приведём к общему k
            A = proj(Xa, Da, "A_sum")
            M = proj(Xm, Dm, "M_sum")
        return A + M

    if fusion == "hadamard":
        if A.shape[1] != M.shape[1]:
            A = proj(Xa, Da, "A_h")
            M = proj(Xm, Dm, "M_h")
        return A * M

    if fusion == "absdiff":
        if A.shape[1] != M.shape[1]:
            A = proj(Xa, Da, "A_d")
            M = proj(Xm, Dm, "M_d")
        return np.abs(A - M)

    if fusion == "biaffine":
        # низкоранговый билинейный: (Xa@U) ⊙ (Xm@V)
        k = proj_dim if proj_dim is not None else min(Da, Dm, 256)
        seedU = (hash(("U", Da, k, random_state)) & 0x7fffffff)
        seedV = (hash(("V", Dm, k, random_state)) & 0x7fffffff)
        rngU = np.random.RandomState(seedU)
        rngV = np.random.RandomState(seedV)
        U = rngU.normal(0.0, 1.0 / np.sqrt(Da), size=(Da, k))
        V = rngV.normal(0.0, 1.0 / np.sqrt(Dm), size=(Dm, k))
        Au = Xa @ U
        Mv = Xm @ V
        return Au * Mv  # (N, k)

    if fusion == "poly2":
        # компактные квадратичные взаимодействия без взрыва размерности
        if A.shape[1] != M.shape[1]:
            A = proj(Xa, Da, "A_p2")
            M = proj(Xm, Dm, "M_p2")
        inter = A * M
        diff  = np.abs(A - M)
        summ  = A + M
        return np.concatenate([A, M, inter, diff, summ], axis=1)

    raise ValueError(f"Unknown fusion mode: {fusion}")

# ---- МАЛЕНЬКАЯ правка твоей функции: одна строка заменена на fuse_features(...) ----
def screen_models(
    df,
    apt_cfgs,
    mol_cfgs,
    model_names=("logreg","rf","lgbm","mlp"),
    split_mode="group",
    n_splits=5,
    random_state=42,
    scale=True,
    fusion="concat",        # <--- НОВОЕ: способ смешения
    proj_dim=256,           # <--- НОВОЕ: размерность проекции для fusion-режимов
):
    """
    Возвращает dict {model_name: DataFrame}, где строки — apt-конфиги, колонки — mol-конфиги.
    Значение — строка "PR mean±std | ROC mean±std | F1 mean±std | MCC mean±std | nPR mean±std".
    F1 и MCC — по clf.predict(X). Под капотом фичи смешиваются через `fusion`.
    """
    df = df.reset_index(drop=True)
    y = df["label"].to_numpy().astype(int)
    splits = _get_splits(df, split_mode=split_mode, n_splits=n_splits, random_state=random_state)

    # предрасчёт фич на всём df (как и раньше, чтобы размерности совпадали во всех фолдах)
    apt_map = {}
    for a in apt_cfgs:
        an = _cfg_name(a, 'apt')
        apt_map[an] = make_apt_features(df, a)

    mol_map = {}
    for m in mol_cfgs:
        mn = _cfg_name(m, 'mol')
        mol_map[mn] = make_mol_features(df, m)

    rows = [ _cfg_name(a,'apt') for a in apt_cfgs ]
    cols = [ _cfg_name(m,'mol') for m in mol_cfgs ]
    results = {}

    for model_name in model_names:
        clf_proto = _model_factory(model_name, random_state=random_state)
        if clf_proto is None:
            results[model_name] = pd.DataFrame("N/A", index=rows, columns=cols)
            continue

        table = pd.DataFrame(index=rows, columns=cols, dtype=object)

        for ar in rows:
            Xa_full = apt_map[ar].astype(np.float64)
            for mc in cols:
                Xm_full = mol_map[mc].astype(np.float64)

                # НОВОЕ: строим СМЕШАННЫЕ признаки один раз на весь df
                X_full = fuse_features(Xa_full, Xm_full, fusion=fusion, proj_dim=proj_dim, random_state=random_state)

                pr_list, roc_list, f1_list, mcc_list, npr_list = [], [], [], [], []
                for tr, va in splits:
                    Xtr, Xva = X_full[tr], X_full[va]
                    ytr, yva = y[tr], y[va]

                    if scale:
                        scaler = StandardScaler()
                        Xtr = scaler.fit_transform(Xtr)
                        Xva = scaler.transform(Xva)

                    clf = _model_factory(model_name, random_state=random_state)
                    with warnings.catch_warnings():
                        warnings.filterwarnings("ignore", message="X does not have valid feature names.*")
                        clf.fit(Xtr, ytr)

                    # непрерывный счёт для PR/ROC
                    with warnings.catch_warnings():
                        warnings.filterwarnings("ignore", message="X does not have valid feature names.*")
                        if hasattr(clf, "predict_proba"):
                            s = clf.predict_proba(Xva)[:, 1]
                        elif hasattr(clf, "decision_function"):
                            d = clf.decision_function(Xva)
                            s = (d - d.min()) / (d.max() - d.min() + 1e-8)
                        else:
                            s = clf.predict(Xva).astype(float)

                    # бинарные предсказания по дефолтному порогу
                    with warnings.catch_warnings():
                        warnings.filterwarnings("ignore", message="X does not have valid feature names.*")
                        yhat = clf.predict(Xva)

                    pr  = average_precision_score(yva, s) if len(np.unique(yva))>1 else np.nan
                    p   = yva.mean()
                    npr = (pr - p) / (1 - p + 1e-12)
                    roc = roc_auc_score(yva, s) if len(np.unique(yva))>1 else np.nan
                    f1  = f1_score(yva, yhat, zero_division=0)
                    mcc = matthews_corrcoef(yva, yhat)

                    pr_list.append(pr); roc_list.append(roc); f1_list.append(f1); mcc_list.append(mcc); npr_list.append(npr)

                def fmt(arr):
                    arr = np.array(arr, dtype=float)
                    return f"{np.nanmean(arr):.3f}±{np.nanstd(arr):.3f}"

                table.loc[ar, mc] = (
                    f"PR {fmt(pr_list)} | ROC {fmt(roc_list)} | F1 {fmt(f1_list)} | MCC {fmt(mcc_list)} | nPR {fmt(npr_list)}"
                )

        results[model_name] = table

    return results


In [77]:
df_aug = augment_negatives_per_sequence(df, n_neg_per_seq=3, seed=42)

In [78]:
# аптамер: только kmer(k=3)
apt_cfgs = [
    {"name": "kmer", "k": 4},
]

# молекула: только ChemBERTa (использует уже загруженные tok, mdl)
mol_cfgs = [
    {"name": "chemberta", "tok": tok, "mdl": mdl, "batch_size": 64, "max_len": 128, "pooling": "mean"},
]

In [89]:
 # примеры:
res_concat   = screen_models(df_aug, apt_cfgs, mol_cfgs, model_names=("lgbm","mlp"),
                             split_mode="cold_molecule", fusion="concat")

res_hadamard = screen_models(df_aug, apt_cfgs, mol_cfgs, model_names=("lgbm","mlp"),
                             split_mode="cold_molecule", fusion="hadamard", proj_dim=256)

res_poly2    = screen_models(df_aug, apt_cfgs, mol_cfgs, model_names=("lgbm","mlp"),
                             split_mode="cold_molecule", fusion="poly2", proj_dim=256)

res_biaffine = screen_models(df_aug, apt_cfgs, mol_cfgs, model_names=("lgbm","mlp"),
                             split_mode="cold_molecule", fusion="biaffine", proj_dim=256)


KeyboardInterrupt: 

In [80]:
res_concat

{'lgbm':                                              chemberta(mean)
 kmer(k=4)  PR 0.853±0.015 | ROC 0.897±0.010 | F1 0.768±0....,
 'mlp':                                              chemberta(mean)
 kmer(k=4)  PR 0.796±0.027 | ROC 0.874±0.015 | F1 0.715±0....}

In [None]:
for model in res_biaffine.keys():
    res_biaffine[model].to_excel(f'res_{model}_biaffine_cold_molecule.xlsx')
    res_poly2[model].to_excel(f'res_{model}_poly2_cold_molecule.xlsx')
    res_concat[model].to_excel(f'res_{model}_concat_cold_molecule.xlsx')
    res_hadamard[model].to_excel(f'res_{model}_hadamard_cold_molecule.xlsx')

## Final ML classifier

In [22]:
# ============================================
#  K-mer(4) + ChemBERTa(mean): LGBM/MLP + Optuna
#  feature selection, HPO, importance, fold-wise test stats
# ============================================

import warnings
import numpy as np
import pandas as pd
from sklearn.metrics import (
    average_precision_score, roc_auc_score, f1_score, matthews_corrcoef, confusion_matrix
)
from sklearn.preprocessing import StandardScaler

# ---- ваши функции: используем как есть ----
# make_apt_features(df, cfg)
# make_mol_features(df, cfg)
# _get_splits(df, split_mode="group", n_splits=5, random_state=42)

# ---- Optuna ----
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner


# --------- утилиты метрик/агрегации ----------
def _compute_fold_metrics(y_true, scores, y_pred):
    pr = average_precision_score(y_true, scores) if len(np.unique(y_true)) > 1 else np.nan
    p = float(np.mean(y_true))
    npr = (pr - p) / (1 - p + 1e-12) if not np.isnan(pr) else np.nan
    roc = roc_auc_score(y_true, scores) if len(np.unique(y_true)) > 1 else np.nan
    f1 = f1_score(y_true, y_pred, zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    return dict(PR=pr, ROC=roc, F1=f1, MCC=mcc, nPR=npr, TN=tn, FP=fp, FN=fn, TP=tp)

def _agg_mean_std(arr):
    arr = np.asarray(arr, dtype=float)
    return f"{np.nanmean(arr):.3f}±{np.nanstd(arr):.3f}"

def _print_fold_val_stats(splits, y):
    print("\n[Validation (\"test\") per fold: size and class balance]")
    tot_N = tot_pos = tot_neg = 0
    for i, (_, va) in enumerate(splits, 1):
        yv = y[va]
        pos = int(np.sum(yv == 1)); neg = int(np.sum(yv == 0)); N = len(yv)
        tot_N += N; tot_pos += pos; tot_neg += neg
        print(f"  Fold {i}: N={N}, pos={pos}, neg={neg}, pos%={100*pos/max(N,1):.1f}")
    print(f"  Total (sum over folds): N={tot_N}, pos={tot_pos}, neg={tot_neg}, pos%={100*tot_pos/max(tot_N,1):.1f}")


# --------- построение фич и имён ----------
def _build_features_and_names(df, tok, mdl, chem_max_len=128, chem_batch=64):
    """KMER k=4 + ChemBERTa(mean). Возвращает X, y, names, groups."""
    apt_cfg = {"name": "kmer", "k": 4}  # фиксировано по ТЗ
    mol_cfg = {"name": "chemberta", "tok": tok, "mdl": mdl,
               "batch_size": chem_batch, "max_len": chem_max_len, "pooling": "mean"}

    Xa = make_apt_features(df, apt_cfg).astype(np.float32)
    Xm = make_mol_features(df, mol_cfg).astype(np.float32)
    X = np.concatenate([Xa, Xm], axis=1)
    y = df["label"].to_numpy().astype(int)

    # имена/группы фич (если нужных генераторов имён нет — синтетически)
    na = Xa.shape[1]; nm = Xm.shape[1]
    names = [f"kmer4_f{i}" for i in range(na)] + [f"chemberta_f{j}" for j in range(nm)]
    groups = (["apt:kmer4"] * na) + (["mol:chemberta"] * nm)
    return X, y, np.array(names, dtype=object), np.array(groups, dtype=object)


# --------- быстрый отбор фич через LGBM gain ----------
def lgbm_select_topk(X, y, splits, topk_list=(256, 512, 1024), random_state=42):
    """Черновой LGBM -> gain-важности -> перебор top-k -> лучший по CV MCC."""
    from lightgbm import LGBMClassifier
    # 1) единый LGBM по всему X для грубой важности
    base = LGBMClassifier(
        n_estimators=600, learning_rate=0.05, num_leaves=64,
        subsample=0.8, colsample_bytree=0.8, reg_lambda=1.0,
        random_state=random_state, n_jobs=-1, verbosity=-1
    )
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        base.fit(X, y)

    # gain importance (fallback на split при ошибке)
    try:
        gain = base.booster_.feature_importance(importance_type="gain").astype(float)
    except Exception:
        gain = base.booster_.feature_importance(importance_type="split").astype(float)

    order = np.argsort(gain)[::-1]  # по убыванию
    best_k = None; best_mcc = -1e9; best_idx = None

    for k in topk_list:
        sel = order[:min(k, X.shape[1])]
        mcc_scores = []
        for tr, va in splits:
            Xtr, ytr = X[tr][:, sel], y[tr]
            Xva, yva = X[va][:, sel], y[va]
            clf = LGBMClassifier(
                n_estimators=400, learning_rate=0.05, num_leaves=48,
                subsample=0.9, colsample_bytree=0.9, reg_lambda=1.0,
                random_state=random_state, n_jobs=-1, verbosity=-1
            )
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                clf.fit(Xtr, ytr)

            # оценки + предсказания
            if hasattr(clf, "predict_proba"):
                s = clf.predict_proba(Xva)[:, 1]
            elif hasattr(clf, "decision_function"):
                d = clf.decision_function(Xva); s = (d - d.min())/(d.max()-d.min()+1e-8)
            else:
                s = clf.predict(Xva).astype(float)
            yhat = clf.predict(Xva)
            metr = _compute_fold_metrics(yva, s, yhat)
            mcc_scores.append(metr["MCC"])
        mcc_mean = float(np.nanmean(mcc_scores))
        if mcc_mean > best_mcc:
            best_mcc, best_k, best_idx = mcc_mean, k, sel

    return best_idx, gain, best_k, best_mcc


# --------- LGBM + Optuna ----------
def tune_lgbm_optuna(X, y, sel_idx, splits, random_state=42, n_trials=60, timeout=None):
    from lightgbm import LGBMClassifier

    Xs = X[:, sel_idx].astype(np.float32)

    def objective(trial):
        params = {
            "n_estimators": trial.suggest_int("n_estimators", 300, 1200),
            "learning_rate": trial.suggest_float("learning_rate", 1e-3, 2e-1, log=True),
            "num_leaves": trial.suggest_int("num_leaves", 16, 256, step=8),
            "max_depth": trial.suggest_int("max_depth", -1, 16),
            "min_child_samples": trial.suggest_int("min_child_samples", 5, 120),
            "subsample": trial.suggest_float("subsample", 0.5, 1.0),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),
            "reg_lambda": trial.suggest_float("reg_lambda", 1e-3, 100.0, log=True),
            "reg_alpha": trial.suggest_float("reg_alpha", 0.0, 10.0),
            "random_state": random_state,
            "n_jobs": -1,
            "verbosity": -1
        }
        mcc_scores = []
        for fold_id, (tr, va) in enumerate(splits, 1):
            Xtr, ytr = Xs[tr], y[tr]
            Xva, yva = Xs[va], y[va]
            clf = LGBMClassifier(**params)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                clf.fit(Xtr, ytr)

            # continuous & hard predictions
            if hasattr(clf, "predict_proba"):
                s = clf.predict_proba(Xva)[:, 1]
            elif hasattr(clf, "decision_function"):
                d = clf.decision_function(Xva); s = (d - d.min())/(d.max()-d.min()+1e-8)
            else:
                s = clf.predict(Xva).astype(float)
            yhat = clf.predict(Xva)

            metr = _compute_fold_metrics(yva, s, yhat)
            mcc_scores.append(metr["MCC"])

            trial.report(np.nanmean(mcc_scores), step=fold_id)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()

        return float(np.nanmean(mcc_scores))

    study = optuna.create_study(direction="maximize",
                                sampler=TPESampler(seed=random_state),
                                pruner=MedianPruner())
    study.optimize(objective, n_trials=n_trials, timeout=timeout, show_progress_bar=False)

    best_params = study.best_trial.params.copy()

    # финальная CV для сводки метрик и важностей
    folds_metrics = []
    imp_gain = np.zeros(Xs.shape[1], dtype=float)
    for (tr, va) in splits:
        Xtr, ytr = Xs[tr], y[tr]
        Xva, yva = Xs[va], y[va]
        clf = LGBMClassifier(**best_params, random_state=random_state, n_jobs=-1, verbosity=-1)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            clf.fit(Xtr, ytr)

        if hasattr(clf, "predict_proba"):
            s = clf.predict_proba(Xva)[:, 1]
        elif hasattr(clf, "decision_function"):
            d = clf.decision_function(Xva); s = (d - d.min())/(d.max()-d.min()+1e-8)
        else:
            s = clf.predict(Xva).astype(float)
        yhat = clf.predict(Xva)

        folds_metrics.append(_compute_fold_metrics(yva, s, yhat))

        try:
            imp_gain += clf.booster_.feature_importance(importance_type="gain")
        except Exception:
            imp_gain += clf.booster_.feature_importance(importance_type="split")

    return best_params, folds_metrics, imp_gain


# --------- MLP + Optuna ----------
def tune_mlp_optuna(X, y, sel_idx, splits, random_state=42, n_trials=50, timeout=None, max_iter=250):
    from sklearn.neural_network import MLPClassifier

    Xs = X[:, sel_idx].astype(np.float32)

    def objective(trial):
        n_layers = trial.suggest_int("n_layers", 1, 3)
        width = trial.suggest_categorical("width", [64, 128, 256, 512])
        hidden = tuple([width] * n_layers)

        params = {
            "hidden_layer_sizes": hidden,
            "activation": trial.suggest_categorical("activation", ["relu", "tanh"]),
            "alpha": trial.suggest_float("alpha", 1e-6, 1e-2, log=True),
            "learning_rate_init": trial.suggest_float("lr", 1e-4, 5e-2, log=True),
            "batch_size": trial.suggest_categorical("batch_size", [64, 128, 256]),
            "solver": "adam",
            "max_iter": max_iter,
            "early_stopping": True,
            "n_iter_no_change": 10,
            "random_state": random_state
        }

        mcc_scores = []
        for fold_id, (tr, va) in enumerate(splits, 1):
            Xtr, ytr = Xs[tr], y[tr]
            Xva, yva = Xs[va], y[va]

            scaler = StandardScaler()
            Xtr = scaler.fit_transform(Xtr)
            Xva = scaler.transform(Xva)

            clf = MLPClassifier(**params)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                clf.fit(Xtr, ytr)

            if hasattr(clf, "predict_proba"):
                s = clf.predict_proba(Xva)[:, 1]
            elif hasattr(clf, "decision_function"):
                d = clf.decision_function(Xva); s = (d - d.min())/(d.max()-d.min()+1e-8)
            else:
                s = clf.predict(Xva).astype(float)
            yhat = clf.predict(Xva)

            metr = _compute_fold_metrics(yva, s, yhat)
            mcc_scores.append(metr["MCC"])

            trial.report(np.nanmean(mcc_scores), step=fold_id)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()

        return float(np.nanmean(mcc_scores))

    study = optuna.create_study(direction="maximize",
                                sampler=TPESampler(seed=random_state),
                                pruner=MedianPruner())
    study.optimize(objective, n_trials=n_trials, timeout=timeout, show_progress_bar=False)

    best_params = study.best_trial.params.copy()

    folds_metrics = []
    for (tr, va) in splits:
        Xtr, ytr = Xs[tr], y[tr]
        Xva, yva = Xs[va], y[va]

        scaler = StandardScaler()
        Xtr = scaler.fit_transform(Xtr)
        Xva = scaler.transform(Xva)

        hidden = tuple([best_params["width"]] * best_params["n_layers"])
        clf = MLPClassifier(
            hidden_layer_sizes=hidden,
            activation=best_params["activation"],
            alpha=best_params["alpha"],
            learning_rate_init=best_params["lr"],
            batch_size=best_params["batch_size"],
            solver="adam", max_iter=max_iter, early_stopping=True,
            n_iter_no_change=10, random_state=random_state
        )
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            clf.fit(Xtr, ytr)

        if hasattr(clf, "predict_proba"):
            s = clf.predict_proba(Xva)[:, 1]
        elif hasattr(clf, "decision_function"):
            d = clf.decision_function(Xva); s = (d - d.min())/(d.max()-d.min()+1e-8)
        else:
            s = clf.predict(Xva).astype(float)
        yhat = clf.predict(Xva)

        folds_metrics.append(_compute_fold_metrics(yva, s, yhat))

    return best_params, folds_metrics


# --------- Главная обёртка: полный эксперимент ----------
def run_kmer4_chemberta_optuna(
    df,
    tok, mdl,                 # объекты ChemBERTa (как вы их уже используете)
    split_mode="group",
    n_splits=5,
    random_state=42,
    # отбор фич
    topk_list=(256, 512, 1024, 1536),
    # HPO
    n_trials_lgbm=100,
    n_trials_mlp=100,
    timeout_lgbm=None,
    timeout_mlp=None,
):
    df = df.reset_index(drop=True)
    # сплиты
    splits = _get_splits(df, split_mode=split_mode, n_splits=n_splits, random_state=random_state)

    # фичи
    X, y, feat_names, groups = _build_features_and_names(df, tok, mdl)
    _print_fold_val_stats(splits, y)

    # отбор top-k
    print("\n[Feature selection] LGBM gain -> top-k sweep")
    sel_idx, gain, best_k, est_mcc = lgbm_select_topk(X, y, splits, topk_list=topk_list, random_state=random_state)
    print(f"  Selected top-k = {len(sel_idx)} (requested best={best_k}), est. CV-MCC={est_mcc:.3f}")

    names_sel  = feat_names[sel_idx]
    groups_sel = groups[sel_idx]

    # ---- LGBM (Optuna) ----
    print("\n[Tuning] LightGBM (Optuna)")
    lgbm_best_params, lgbm_folds, lgbm_imp_gain_sel = tune_lgbm_optuna(
        X, y, sel_idx, splits, random_state=random_state,
        n_trials=n_trials_lgbm, timeout=timeout_lgbm
    )

    # групповые важности LGBM
    df_imp = pd.DataFrame({
        "feature": names_sel,
        "group": groups_sel,
        "gain": lgbm_imp_gain_sel
    })
    group_share = (df_imp.groupby("group")["gain"].sum() /
                   max(df_imp["gain"].sum(), 1e-12)).sort_values(ascending=False)

    top20 = df_imp.sort_values("gain", ascending=False).head(20).reset_index(drop=True)

    # метрики LGBM
    lgbm_summary = {
        "Model": "LGBM",
        "PR":  _agg_mean_std([m["PR"] for m in lgbm_folds]),
        "ROC": _agg_mean_std([m["ROC"] for m in lgbm_folds]),
        "F1":  _agg_mean_std([m["F1"] for m in lgbm_folds]),
        "MCC": _agg_mean_std([m["MCC"] for m in lgbm_folds]),
        "nPR": _agg_mean_std([m["nPR"] for m in lgbm_folds]),
    }
    TN = int(sum(m["TN"] for m in lgbm_folds)); FP = int(sum(m["FP"] for m in lgbm_folds))
    FN = int(sum(m["FN"] for m in lgbm_folds)); TP = int(sum(m["TP"] for m in lgbm_folds))
    lgbm_cm = dict(TN=TN, FP=FP, FN=FN, TP=TP)

    # ---- MLP (Optuna) ----
    print("\n[Tuning] MLP (Optuna)")
    mlp_best_params, mlp_folds = tune_mlp_optuna(
        X, y, sel_idx, splits, random_state=random_state,
        n_trials=n_trials_mlp, timeout=timeout_mlp
    )
    mlp_summary = {
        "Model": "MLP",
        "PR":  _agg_mean_std([m["PR"] for m in mlp_folds]),
        "ROC": _agg_mean_std([m["ROC"] for m in mlp_folds]),
        "F1":  _agg_mean_std([m["F1"] for m in mlp_folds]),
        "MCC": _agg_mean_std([m["MCC"] for m in mlp_folds]),
        "nPR": _agg_mean_std([m["nPR"] for m in mlp_folds]),
    }
    TN = int(sum(m["TN"] for m in mlp_folds)); FP = int(sum(m["FP"] for m in mlp_folds))
    FN = int(sum(m["FN"] for m in mlp_folds)); TP = int(sum(m["TP"] for m in mlp_folds))
    mlp_cm = dict(TN=TN, FP=FP, FN=FN, TP=TP)

    # итоговая таблица
    metrics_table = pd.DataFrame([lgbm_summary, mlp_summary]).set_index("Model")
    print("\n=== CV metrics (kmer4 + ChemBERTa mean) ===")
    print(metrics_table)

    print("\n=== LGBM best params (Optuna) ===")
    print(lgbm_best_params)
    print("\n=== MLP best params (Optuna) ===")
    print(mlp_best_params)

    print("\n=== LGBM group importance (share over selected) ===")
    print(group_share.to_frame(name="share"))
    print("\n=== Top-20 LGBM features (gain) ===")
    print(top20)

    print("\n=== Confusion matrices (sum over folds) ===")
    print("LGBM:", lgbm_cm)
    print("MLP :", mlp_cm)

    artifacts = {
        "selected_idx": sel_idx,
        "selected_feature_names": names_sel.tolist(),
        "selected_feature_groups": groups_sel.tolist(),
        "gain_all": gain,  # сырая важность до отбора
        "metrics_table": metrics_table,
        "lgbm_best_params": lgbm_best_params,
        "mlp_best_params": mlp_best_params,
        "lgbm_group_importance_share": group_share.to_dict(),
        "lgbm_top20_features": top20,
        "lgbm_confusion": lgbm_cm,
        "mlp_confusion": mlp_cm,
        "splits": splits,  # на случай дальнейшего анализа
    }
    return artifacts


In [None]:
# допустим, у вас уже есть df (в т.ч. df_aug), а также объекты токенайзера/модели ChemBERTa: tok, mdl

art = run_kmer4_chemberta_optuna(
    df_aug, tok, mdl,
    split_mode="group",      # или "cold_aptamer" / "cold_molecule" / "cold_both"
    n_splits=5,
    random_state=42,
    topk_list=(256, 512, 1024, 1536),
    n_trials_lgbm=100,
    n_trials_mlp=100,
    timeout_lgbm=None,
    timeout_mlp=None,
)


In [None]:
# =========================================================
#  Evaluate best LGBM (from artifacts) on cold_* splits
#  (kmer4 + ChemBERTa mean, same selected_idx & best params)
# =========================================================
import numpy as np
import pandas as pd
import warnings
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score, matthews_corrcoef, confusion_matrix
from sklearn.preprocessing import StandardScaler  # не обяз. для LGBM, но не мешает
from lightgbm import LGBMClassifier

# ---- вспомогательные функции (как в прошлой ячейке) ----
def _compute_fold_metrics(y_true, scores, y_pred):
    pr = average_precision_score(y_true, scores) if len(np.unique(y_true)) > 1 else np.nan
    p = float(np.mean(y_true))
    npr = (pr - p) / (1 - p + 1e-12) if not np.isnan(pr) else np.nan
    roc = roc_auc_score(y_true, scores) if len(np.unique(y_true)) > 1 else np.nan
    f1 = f1_score(y_true, y_pred, zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    return dict(PR=pr, ROC=roc, F1=f1, MCC=mcc, nPR=npr, TN=tn, FP=fp, FN=fn, TP=tp)

def _agg_mean_std(arr):
    arr = np.asarray(arr, dtype=float)
    return f"{np.nanmean(arr):.3f}±{np.nanstd(arr):.3f}"

def _print_fold_val_stats(splits, y, title="Validation per fold"):
    print(f"\n[{title}: size and class balance]")
    tot_N = tot_pos = tot_neg = 0
    for i, (_, va) in enumerate(splits, 1):
        yv = y[va]
        pos = int(np.sum(yv == 1)); neg = int(np.sum(yv == 0)); N = len(yv)
        tot_N += N; tot_pos += pos; tot_neg += neg
        print(f"  Fold {i}: N={N}, pos={pos}, neg={neg}, pos%={100*pos/max(N,1):.1f}")
    print(f"  Total: N={tot_N}, pos={tot_pos}, neg={tot_neg}, pos%={100*tot_pos/max(tot_N,1):.1f}")

# ---- фичи kmer(4)+ChemBERTa(mean) (как раньше) ----
def _build_features_kmer4_chemberta(df, tok, mdl, chem_max_len=128, chem_batch=64):
    apt_cfg = {"name": "kmer", "k": 4}
    mol_cfg = {"name": "chemberta", "tok": tok, "mdl": mdl,
               "batch_size": chem_batch, "max_len": chem_max_len, "pooling": "mean"}
    Xa = make_apt_features(df, apt_cfg).astype(np.float32)
    Xm = make_mol_features(df, mol_cfg).astype(np.float32)
    X = np.concatenate([Xa, Xm], axis=1)
    y = df["label"].to_numpy().astype(int)
    return X, y

def _parse_mean_from_cell(cell):
    """Парсим '0.706±0.015' -> 0.706 (float)."""
    try:
        return float(str(cell).split("±")[0].strip())
    except Exception:
        return np.nan

def evaluate_lgbm_on_split(df, tok, mdl, artifacts, split_mode, n_splits=5, random_state=42):
    """
    Оцениваем лучшую LGBM (из artifacts) на заданном split_mode ('cold_aptamer'/'cold_molecule').
    Используем те же selected_idx и best_params, что и для group-бейзлайна.
    """
    # фичи
    X, y = _build_features_kmer4_chemberta(df, tok, mdl)
    # сплиты
    splits = _get_splits(df, split_mode=split_mode, n_splits=n_splits, random_state=random_state)
    _print_fold_val_stats(splits, y, title=f'{split_mode} validation')

    # выбранные фичи и лучшие параметры
    sel_idx = artifacts.get("selected_idx", None)
    if sel_idx is not None:
        Xs = X[:, sel_idx].astype(np.float32)
    else:
        Xs = X.astype(np.float32)  # fallback, если по какой-то причине нет отбора
    best_params = dict(artifacts["lgbm_best_params"])  # копия
    best_params.update(dict(random_state=random_state, n_jobs=-1, verbosity=-1))

    folds_metrics = []
    for tr, va in splits:
        Xtr, ytr = Xs[tr], y[tr]
        Xva, yva = Xs[va], y[va]

        clf = LGBMClassifier(**best_params)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            clf.fit(Xtr, ytr)

        # continuous & hard predictions
        if hasattr(clf, "predict_proba"):
            s = clf.predict_proba(Xva)[:, 1]
        elif hasattr(clf, "decision_function"):
            d = clf.decision_function(Xva); s = (d - d.min())/(d.max()-d.min()+1e-8)
        else:
            s = clf.predict(Xva).astype(float)
        yhat = clf.predict(Xva)

        folds_metrics.append(_compute_fold_metrics(yva, s, yhat))

    # агрегация
    summary = {
        "PR":  _agg_mean_std([m["PR"] for m in folds_metrics]),
        "ROC": _agg_mean_std([m["ROC"] for m in folds_metrics]),
        "F1":  _agg_mean_std([m["F1"] for m in folds_metrics]),
        "MCC": _agg_mean_std([m["MCC"] for m in folds_metrics]),
        "nPR": _agg_mean_std([m["nPR"] for m in folds_metrics]),
    }
    cm = dict(
        TN=int(sum(m["TN"] for m in folds_metrics)),
        FP=int(sum(m["FP"] for m in folds_metrics)),
        FN=int(sum(m["FN"] for m in folds_metrics)),
        TP=int(sum(m["TP"] for m in folds_metrics)),
    )
    return summary, cm, folds_metrics

def evaluate_best_lgbm_on_cold(df, tok, mdl, artifacts, n_splits=5, random_state=42):
    """Запуск на cold_aptamer и cold_molecule + сравнительная сводка с baseline (group)."""
    # baseline из artifacts (group)
    base_row = artifacts["metrics_table"].loc["LGBM"]
    base = {k: _parse_mean_from_cell(base_row[k]) for k in ["PR","ROC","F1","MCC","nPR"]}

    out = {}
    for split_mode in ["cold_aptamer", "cold_molecule"]:
        print(f"\n===== LGBM on {split_mode} =====")
        summary, cm, _ = evaluate_lgbm_on_split(
            df, tok, mdl, artifacts,
            split_mode=split_mode, n_splits=n_splits, random_state=random_state
        )
        print("\nSummary metrics:", summary)
        print("Confusion matrix (sum over folds):", cm)
        # сравнение с baseline (только mean-часть метрик)
        comp = {k: float(summary[k].split("±")[0]) - base[k] for k in base.keys()}
        print("Δ vs. group (mean):", comp)
        out[split_mode] = dict(summary=summary, cm=cm, delta_vs_group=comp)

    # компактная табличка сравнения
    rows = []
    for sm in ["group", "cold_aptamer", "cold_molecule"]:
        if sm == "group":
            rows.append(["group"] + [base[k] for k in ["PR","ROC","F1","MCC","nPR"]])
        else:
            s = out[sm]["summary"]
            rows.append([sm] + [float(s[k].split("±")[0]) for k in ["PR","ROC","F1","MCC","nPR"]])
    comp_df = pd.DataFrame(rows, columns=["split","PR","ROC","F1","MCC","nPR"]).set_index("split")
    print("\n=== Comparison (mean only) ===")
    print(comp_df)

    return out, comp_df



In [None]:

out, comp_df = evaluate_best_lgbm_on_cold(
    df_aug, tok, mdl, artifacts=art,
    n_splits=5, random_state=42
)

NameError: name 'df_aug' is not defined