In [3]:
BENCHMARKS_DIR = '/home/nemophila/projects/protein_bert/anticrispr_benchmarks'

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import pandas as pd
from IPython.display import display

from tensorflow import keras

from sklearn.model_selection import train_test_split

from proteinbert import OutputType, OutputSpec, FinetuningModelGenerator, load_pretrained_model, finetune, evaluate_by_len
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs

# ===================== 1. 修改基准名称（对应你的数据集前缀） =====================
BENCHMARK_NAME = 'anticrispr_binary'  # 替换原signalP_binary为你的数据集前缀

# A local (non-global) binary output
OUTPUT_TYPE = OutputType(False, 'binary')
UNIQUE_LABELS = [0, 1]  # 你的数据集也是二分类（0/1），无需修改
OUTPUT_SPEC = OutputSpec(OUTPUT_TYPE, UNIQUE_LABELS)

# ===================== 2. 定义你的数据集根目录（核心修改） =====================
# 替换原BENCHMARKS_DIR，指向你的anticrispr_benchmarks文件夹绝对路径
BENCHMARKS_DIR = '/home/nemophila/projects/protein_bert/anticrispr_benchmarks'

# Loading the dataset
# ===================== 3. 加载你自己的训练/测试集（路径适配） =====================
# 加载训练集（你的anticrispr_binary.train.csv）
train_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.train.csv' % BENCHMARK_NAME)
train_set = pd.read_csv(train_set_file_path).dropna().drop_duplicates()
# 从训练集中拆分验证集（和原逻辑一致，按标签分层拆分）
train_set, valid_set = train_test_split(train_set, stratify = train_set['label'], test_size = 0.1, random_state = 0)

# 加载测试集（你的anticrispr_binary.test.csv）
test_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.test.csv' % BENCHMARK_NAME)
test_set = pd.read_csv(test_set_file_path).dropna().drop_duplicates()

# 打印数据集大小（验证是否加载成功）
print(f'{len(train_set)} training set records, {len(valid_set)} validation set records, {len(test_set)} test set records.')

# ===================== 以下部分无需修改（模型训练/评估逻辑通用） =====================
# Loading the pre-trained model and fine-tuning it on the loaded dataset
pretrained_model_generator, input_encoder = load_pretrained_model()

# get_model_with_hidden_layers_as_outputs gives the model output access to the hidden layers (on top of the output)
model_generator = FinetuningModelGenerator(pretrained_model_generator, OUTPUT_SPEC, pretraining_model_manipulation_function = \
        get_model_with_hidden_layers_as_outputs, dropout_rate = 0.5)

training_callbacks = [
    keras.callbacks.ReduceLROnPlateau(patience = 1, factor = 0.25, min_lr = 1e-05, verbose = 1),
    keras.callbacks.EarlyStopping(patience = 2, restore_best_weights = True),
]

finetune(model_generator, input_encoder, OUTPUT_SPEC, train_set['seq'], train_set['label'], valid_set['seq'], valid_set['label'], \
        seq_len = 512, batch_size = 32, max_epochs_per_stage = 40, lr = 1e-04, begin_with_frozen_pretrained_layers = True, \
        lr_with_frozen_pretrained_layers = 1e-02, n_final_epochs = 1, final_seq_len = 1024, final_lr = 1e-05, callbacks = training_callbacks)

# Evaluating the performance on the test-set
results, confusion_matrix = evaluate_by_len(model_generator, input_encoder, OUTPUT_SPEC, test_set['seq'], test_set['label'], \
        start_seq_len = 512, start_batch_size = 32)

print('Test-set performance:')
display(results)

print('Confusion matrix:')
display(confusion_matrix)

2026-02-05 12:28:56.960820: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2026-02-05 12:28:56.960842: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


996 training set records, 111 validation set records, 286 test set records.
[2026_02_05-12:28:58] Training set: Filtered out 0 of 996 (0.0%) records of lengths exceeding 510.
[2026_02_05-12:28:58] Validation set: Filtered out 0 of 111 (0.0%) records of lengths exceeding 510.
[2026_02_05-12:28:58] Training with frozen pretrained layers...


2026-02-05 12:28:58.242530: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2026-02-05 12:28:58.243363: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2026-02-05 12:28:58.249904: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:ab:00.0 name: NVIDIA L40S computeCapability: 8.9
coreClock: 2.52GHz coreCount: 142 deviceMemorySize: 44.53GiB deviceMemoryBandwidth: 804.75GiB/s
2026-02-05 12:28:58.249979: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2026-02-05 12:28:58.250020: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or

Epoch 1/40
Epoch 2/40
Epoch 3/40

Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.0024999999441206455.
Epoch 4/40

Epoch 00004: ReduceLROnPlateau reducing learning rate to 0.0006249999860301614.
[2026_02_05-12:29:56] Training the entire fine-tuned model...
[2026_02_05-12:30:03] Incompatible number of optimizer weights - will not initialize them.
Epoch 1/40
Epoch 2/40

Epoch 00002: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05.
Epoch 3/40
Epoch 4/40

Epoch 00004: ReduceLROnPlateau reducing learning rate to 1e-05.
Epoch 5/40
[2026_02_05-12:31:58] Training on final epochs of sequence length 1024...
[2026_02_05-12:31:58] Training set: Filtered out 0 of 996 (0.0%) records of lengths exceeding 1022.
[2026_02_05-12:31:58] Validation set: Filtered out 0 of 111 (0.0%) records of lengths exceeding 1022.
Test-set performance:


Unnamed: 0_level_0,# records,AUC
Model seq len,Unnamed: 1_level_1,Unnamed: 2_level_1
512,286,0.884024
All,286,0.884024


Confusion matrix:


Unnamed: 0,0,1
0,256,4
1,17,9


In [19]:
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, confusion_matrix, precision_score, recall_score
from proteinbert.finetuning import filter_dataset_by_len, split_dataset_by_len
from proteinbert.feature_extraction import extract_features
from proteinbert.cross_attention_head import build_cross_attention_model

def compute_sample_weights(labels):
    labels = np.array(labels, dtype = int)
    counts = np.bincount(labels, minlength = 2)
    total = counts.sum()
    weights = total / (2.0 * np.maximum(counts, 1))
    return np.array([weights[label] for label in labels], dtype = np.float32)

def focal_loss(alpha = 0.75, gamma = 2.0):
    def _loss(y_true, y_pred):
        y_true = keras.backend.cast(y_true, "float32")
        y_pred = keras.backend.clip(y_pred, keras.backend.epsilon(), 1.0 - keras.backend.epsilon())
        ce = -(y_true * keras.backend.log(y_pred) + (1.0 - y_true) * keras.backend.log(1.0 - y_pred))
        p_t = y_true * y_pred + (1.0 - y_true) * (1.0 - y_pred)
        alpha_t = y_true * alpha + (1.0 - y_true) * (1.0 - alpha)
        return keras.backend.mean(alpha_t * keras.backend.pow(1.0 - p_t, gamma) * ce)
    return _loss

def filter_with_features(df, features, seq_len, name):
    df = df.copy()
    df["__idx"] = np.arange(len(df))
    filtered = filter_dataset_by_len(df, seq_len = seq_len, dataset_name = name, verbose = True)
    feat = features[filtered["__idx"].values]
    filtered = filtered.drop(columns = ["__idx"]).reset_index(drop = True)
    return filtered, feat

def train_cross_attention(train_df, valid_df, train_features, valid_features, seq_len = 512, batch_size = 32, max_epochs = 40,
        begin_with_frozen = True, lr_frozen = 1e-02, lr = 1e-04, n_heads = 4, key_dim = 64,
        final_seq_len = 1024, n_final_epochs = 1, final_lr = 1e-05):
    feature_dim = train_features.shape[1]
    model, base_model = build_cross_attention_model(pretrained_model_generator, seq_len, feature_dim, n_heads = n_heads, key_dim = key_dim, dropout_rate = 0.5)
    X_train = input_encoder.encode_X(train_df["seq"], seq_len)
    X_valid = input_encoder.encode_X(valid_df["seq"], seq_len)
    y_train = train_df["label"].values.astype(float)
    y_valid = valid_df["label"].values.astype(float)
    sw_train = compute_sample_weights(y_train)
    sw_valid = compute_sample_weights(y_valid)
    loss_fn = focal_loss(alpha = 0.75, gamma = 2.0)
    if begin_with_frozen:
        for layer in base_model.layers:
            layer.trainable = False
        model.compile(optimizer = keras.optimizers.Adam(learning_rate = lr_frozen), loss = loss_fn)
        model.fit(X_train + [train_features], y_train, sample_weight = sw_train, validation_data = (X_valid + [valid_features], y_valid, sw_valid),
                batch_size = batch_size, epochs = max_epochs, callbacks = training_callbacks)
    for layer in base_model.layers:
        layer.trainable = True
    model.compile(optimizer = keras.optimizers.Adam(learning_rate = lr), loss = loss_fn)
    model.fit(X_train + [train_features], y_train, sample_weight = sw_train, validation_data = (X_valid + [valid_features], y_valid, sw_valid),
            batch_size = batch_size, epochs = max_epochs, callbacks = training_callbacks)
    if n_final_epochs > 0:
        final_batch_size = max(int(batch_size / (final_seq_len / seq_len)), 1)
        train_f, train_feat_f = filter_with_features(train_df, train_features, seq_len = final_seq_len, name = "Training set (final)")
        valid_f, valid_feat_f = filter_with_features(valid_df, valid_features, seq_len = final_seq_len, name = "Validation set (final)")
        X_train_f = input_encoder.encode_X(train_f["seq"], final_seq_len)
        X_valid_f = input_encoder.encode_X(valid_f["seq"], final_seq_len)
        y_train_f = train_f["label"].values.astype(float)
        y_valid_f = valid_f["label"].values.astype(float)
        sw_train_f = compute_sample_weights(y_train_f)
        sw_valid_f = compute_sample_weights(y_valid_f)
        final_model, final_base = build_cross_attention_model(pretrained_model_generator, final_seq_len, feature_dim, n_heads = n_heads, key_dim = key_dim, dropout_rate = 0.5)
        final_model.set_weights(model.get_weights())
        for layer in final_base.layers:
            layer.trainable = True
        final_model.compile(optimizer = keras.optimizers.Adam(learning_rate = final_lr), loss = loss_fn)
        final_model.fit(X_train_f + [train_feat_f], y_train_f, sample_weight = sw_train_f, validation_data = (X_valid_f + [valid_feat_f], y_valid_f, sw_valid_f),
                batch_size = final_batch_size, epochs = n_final_epochs, callbacks = training_callbacks)
        model = final_model
    return model

def _collect_metrics(y_true, y_pred, y_pred_class):
    if len(np.unique(y_true)) == 2:
        auc = roc_auc_score(y_true, y_pred)
        auprc = average_precision_score(y_true, y_pred)
    else:
        auc = np.nan
        auprc = np.nan
    f1 = f1_score(y_true, y_pred_class, zero_division = 0)
    return {"# records": len(y_true), "AUC": auc, "AUPRC": auprc, "F1": f1}

def _predict_by_len(model_weights, df, features, start_seq_len = 512, start_batch_size = 32, increase_factor = 2, n_heads = 6, key_dim = 64):
    dataset = df.copy()
    dataset["idx"] = np.arange(len(dataset))
    all_true = []
    all_pred = []
    for len_matching_dataset, seq_len, batch_size in split_dataset_by_len(dataset, start_seq_len = start_seq_len, start_batch_size = start_batch_size,
            increase_factor = increase_factor):
        if len(len_matching_dataset) == 0:
            continue
        idx = len_matching_dataset["idx"].values
        feats = features[idx]
        X = input_encoder.encode_X(len_matching_dataset["seq"], seq_len)
        model, _ = build_cross_attention_model(pretrained_model_generator, seq_len, feats.shape[1], n_heads = n_heads, key_dim = key_dim, dropout_rate = 0.5)
        model.set_weights(model_weights)
        y_true = len_matching_dataset["label"].values.astype(int)
        y_pred = model.predict(X + [feats], batch_size = batch_size).flatten()
        all_true.append(y_true)
        all_pred.append(y_pred)
    return np.concatenate(all_true, axis = 0), np.concatenate(all_pred, axis = 0)

def find_best_threshold(valid_df, valid_features, model_weights, start_seq_len = 512, start_batch_size = 32, min_recall = 0.6):
    """Precision优先策略：在满足min_recall约束的阈值中，选择Precision最高的"""
    y_true, y_pred = _predict_by_len(model_weights, valid_df, valid_features, start_seq_len = start_seq_len, start_batch_size = start_batch_size)
    best = {"thr": 0.5, "f1": -1.0, "precision": 0.0, "recall": 0.0, "auprc": average_precision_score(y_true, y_pred)}
    for thr in np.linspace(0.05, 0.95, 19):
        y_pred_class = (y_pred >= thr).astype(int)
        precision = precision_score(y_true, y_pred_class, zero_division = 0)
        recall = recall_score(y_true, y_pred_class, zero_division = 0)
        f1 = f1_score(y_true, y_pred_class, zero_division = 0)
        # 满足最小recall约束，优先最大化precision（减少FP）
        if recall >= min_recall and precision > best["precision"]:
            best = {"thr": thr, "f1": f1, "precision": precision, "recall": recall, "auprc": average_precision_score(y_true, y_pred)}
    return best

def evaluate_by_len_custom(model_weights, df, features, start_seq_len = 512, start_batch_size = 32, increase_factor = 2, threshold = 0.5, n_heads = 6, key_dim = 64):
    dataset = df.copy()
    dataset["idx"] = np.arange(len(dataset))
    results = []
    results_names = []
    all_true = []
    all_pred = []
    for len_matching_dataset, seq_len, batch_size in split_dataset_by_len(dataset, start_seq_len = start_seq_len, start_batch_size = start_batch_size,
            increase_factor = increase_factor):
        if len(len_matching_dataset) == 0:
            continue
        idx = len_matching_dataset["idx"].values
        feats = features[idx]
        X = input_encoder.encode_X(len_matching_dataset["seq"], seq_len)
        model, _ = build_cross_attention_model(pretrained_model_generator, seq_len, feats.shape[1], n_heads = n_heads, key_dim = key_dim, dropout_rate = 0.5)
        model.set_weights(model_weights)
        y_true = len_matching_dataset["label"].values.astype(int)
        y_pred = model.predict(X + [feats], batch_size = batch_size).flatten()
        y_pred_class = (y_pred >= threshold).astype(int)
        results.append(_collect_metrics(y_true, y_pred, y_pred_class))
        results_names.append(seq_len)
        all_true.append(y_true)
        all_pred.append(y_pred)
    y_true = np.concatenate(all_true, axis = 0)
    y_pred = np.concatenate(all_pred, axis = 0)
    y_pred_class = (y_pred >= threshold).astype(int)
    all_results = _collect_metrics(y_true, y_pred, y_pred_class)
    cm = confusion_matrix(y_true, y_pred_class, labels = [0, 1])
    results.append(all_results)
    results_names.append("All")
    results_df = pd.DataFrame(results, index = results_names)
    results_df.index.name = "Model seq len"
    cm_df = pd.DataFrame(cm, index = ["0", "1"], columns = ["0", "1"])
    return results_df, cm_df

train_features = extract_features(train_set["seq"], nc_len = 20, paac_lambda = 3)
valid_features = extract_features(valid_set["seq"], nc_len = 20, paac_lambda = 3)
test_features = extract_features(test_set["seq"], nc_len = 20, paac_lambda = 3)

train_filtered, train_features_f = filter_with_features(train_set, train_features, seq_len = 512, name = "Training set")
valid_filtered, valid_features_f = filter_with_features(valid_set, valid_features, seq_len = 512, name = "Validation set")
test_filtered, test_features_f = filter_with_features(test_set, test_features, seq_len = 512, name = "Test set")

print("Debug: manual feature dim:", train_features_f.shape[1])
_debug_model, _debug_base = build_cross_attention_model(pretrained_model_generator, 512, train_features_f.shape[1], dropout_rate = 0.5)
print("Debug: cross-attention layer:", any(layer.name == "cross-attention" for layer in _debug_model.layers))
print("Debug: seq kv dim:", _debug_base.output[0].shape)
print("Debug: global dim:", _debug_base.output[1].shape)

cross_attn_model = train_cross_attention(train_filtered, valid_filtered, train_features_f, valid_features_f, seq_len = 512, batch_size = 32, max_epochs = 40,
        begin_with_frozen = True, lr_frozen = 1e-02, lr = 1e-04, n_heads = 6, key_dim = 64, final_seq_len = 1024, n_final_epochs = 1, final_lr = 1e-05)

model_weights = cross_attn_model.get_weights()
best = find_best_threshold(valid_filtered, valid_features_f, model_weights, start_seq_len = 512, start_batch_size = 32, min_recall = 0.6)
print("Debug: best threshold:", best["thr"], "AUPRC(valid):", round(best["auprc"], 4),
      "P:", round(best["precision"], 4), "R:", round(best["recall"], 4), "F1:", round(best["f1"], 4))

results, confusion_matrix_df = evaluate_by_len_custom(model_weights, test_set[["seq", "label"]].copy(), test_features, start_seq_len = 512, start_batch_size = 32, threshold = best["thr"])

print("Test-set performance:")
display(results)

print("Confusion matrix:")
display(confusion_matrix_df)

[2026_02_05-14:13:09] Training set: Filtered out 0 of 996 (0.0%) records of lengths exceeding 510.
[2026_02_05-14:13:09] Validation set: Filtered out 0 of 111 (0.0%) records of lengths exceeding 510.
[2026_02_05-14:13:09] Test set: Filtered out 0 of 286 (0.0%) records of lengths exceeding 510.
Debug: manual feature dim: 230
Debug: cross-attention layer: True
Debug: seq kv dim: (None, 512, 1562)
Debug: global dim: (None, 15599)
Epoch 1/40
Epoch 2/40
Epoch 3/40

Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.0024999999441206455.
Epoch 4/40
Epoch 5/40

Epoch 00005: ReduceLROnPlateau reducing learning rate to 0.0006249999860301614.
Epoch 6/40
Epoch 7/40

Epoch 00007: ReduceLROnPlateau reducing learning rate to 0.00015624999650754035.
Epoch 8/40

Epoch 00008: ReduceLROnPlateau reducing learning rate to 3.9062499126885086e-05.
Epoch 1/40
Epoch 2/40

Epoch 00002: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05.
Epoch 3/40

Epoch 00003: ReduceLROnPlateau reducing 

Unnamed: 0_level_0,# records,AUC,AUPRC,F1
Model seq len,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
512,286,0.905473,0.619707,0.521739
All,286,0.905473,0.619707,0.521739


Confusion matrix:


Unnamed: 0,0,1
0,235,25
1,8,18
