In [None]:
import pickle
from datasets import Dataset
import numpy as np
from sklearn.model_selection import KFold
from transformers import AutoTokenizer, EsmForTokenClassification, TrainingArguments
import torch
from sklearn.metrics import classification_report
from transformers import EsmForTokenClassification, TrainingArguments, Trainer
from transformers import pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
datas = pickle.load(open('WSAA_data_public.pkl', 'rb'))

In [None]:
# 假设datas是包含序列和标签的列表（格式见问题中的pickle数据）
sequences = [data['sequence'] for data in datas]
labels = [data['label'].squeeze().tolist() for data in datas]  # 确保标签是List[int]

# 转换为Hugging Face Dataset
dataset = Dataset.from_dict({
    "sequence": sequences,
    "labels": labels
})

In [None]:
# 加载快速分词器
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D", use_fast=True)

def tokenize_and_align_labels(examples):
    # 对序列进行分词（非快速分词器）
    tokenized_inputs = tokenizer(
        examples["sequence"],
        truncation=True,
        padding="max_length",
        max_length=2048,
        return_tensors="pt",
        add_special_tokens=True  # 包含[CLS]和[SEP]
    )
    
    labels = []
    for i, (sequence, label) in enumerate(zip(examples["sequence"], examples["labels"])):
        # 获取分词后的tokens（包括特殊token）
        tokens = tokenizer.convert_ids_to_tokens(tokenized_inputs["input_ids"][i])
        
        # 初始化对齐后的标签列表
        aligned_labels = []
        seq_pos = 0  # 原始序列中的位置
        
        for token in tokens:
            if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]:
                # 特殊token对应标签设为-100（被忽略）
                aligned_labels.append(-100)
            elif token.startswith("<") or token.endswith(">"):  # 其他特殊token
                aligned_labels.append(-100)
            else:
                # 确保token与原始序列中的氨基酸匹配
                if seq_pos < len(sequence) and token == sequence[seq_pos]:
                    aligned_labels.append(label[seq_pos])
                    seq_pos += 1
                else:
                    # 处理分词意外情况（如未知token）
                    aligned_labels.append(-100)
        
        labels.append(aligned_labels)
    
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # 移除忽略的标签（-100）
    true_labels = [[l for l in label if l != -100] for label in labels]
    true_predictions = [
        [p for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    # 展平列表
    true_labels_flat = np.concatenate(true_labels)
    true_predictions_flat = np.concatenate(true_predictions)

    # 计算分类报告
    report = classification_report(
        true_labels_flat,
        true_predictions_flat,
        target_names=["Not Binding Site", "Binding Site"],
        output_dict=True,
    )
    return {"precision": report["weighted avg"]["precision"], 
            "recall": report["weighted avg"]["recall"], 
            "f1": report["weighted avg"]["f1-score"]}

In [None]:
kf = KFold(n_splits=5, shuffle=True, random_state=114514)

for fold, (train_index, val_index) in enumerate(kf.split(dataset)):
    # 清除之前的显存
    torch.cuda.empty_cache()

    train_dataset = dataset.select(train_index)
    val_dataset = dataset.select(val_index)
    
    # 应用处理
    tokenized_train_dataset = train_dataset.map(
        tokenize_and_align_labels,
        batched=True,
        remove_columns=train_dataset.column_names,
        num_proc=4
    )

    tokenized_val_dataset = val_dataset.map(
        tokenize_and_align_labels,
        batched=True,
        remove_columns=val_dataset.column_names,
        num_proc=4
    )

    model = EsmForTokenClassification.from_pretrained(
        "facebook/esm2_t30_150M_UR50D",
        num_labels=2,  # 二分类
        id2label={0: 0, 1: 1},
        label2id={0: 0, 1: 1},
        ignore_mismatched_sizes=True,  # 忽略预训练头与当前头的尺寸不匹配
    )

    # model = get_peft_model(model, lora_config)
    # model.print_trainable_parameters()

    training_args = TrainingArguments(
        output_dir=f'finetuned_model/esm2-150M-L3000/cross_valid/logging/fold_{fold+1}',
        per_device_train_batch_size=4,
        gradient_accumulation_steps=3,
        # eval_accumulation_steps=3,
        per_device_eval_batch_size=8,
        # 手动设置评估和保存频率（替代 evaluation_strategy 和 save_strategy）
        eval_steps=100,  # 每100步评估一次
        save_steps=100,  # 每100步保存一次
        logging_dir="./logs",
        logging_steps=50,
        num_train_epochs=6,
        learning_rate=2e-5,
        weight_decay=0.01,
        fp16=True,
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        # save_strategy="steps",
        eval_strategy='steps',
        dataloader_num_workers=8,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_val_dataset,
        compute_metrics=compute_metrics,
    )

    trainer.train()

    # 保存模型
    model.save_pretrained(f'finetuned_model/esm2-150M-L3000/cross_valid/fold_{fold+1}', max_shard_size="196MB", safe_serialization=True)
    tokenizer.save_pretrained(f'finetuned_model/esm2-150M-L3000/cross_valid/fold_{fold+1}')


Map: 100%|██████████| 1743/1743 [00:08<00:00, 214.46 examples/s]
Map: 100%|██████████| 436/436 [00:02<00:00, 214.65 examples/s]
Some weights of EsmForTokenClassification were not initialized from the model checkpoint at raw_model\esm2-8M and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Precision,Recall,F1
100,0.4346,0.409609,0.795558,0.813218,0.76745
200,0.4054,0.394113,0.816039,0.832138,0.812988
300,0.3041,0.405673,0.817183,0.830242,0.800623
400,0.3038,0.411733,0.820405,0.827957,0.790986
500,0.4183,0.395063,0.819757,0.831558,0.801721
600,0.5019,0.384381,0.825806,0.839565,0.824484
700,0.3764,0.390696,0.821403,0.836112,0.817218
800,0.2888,0.402979,0.820523,0.833627,0.807501
900,0.3458,0.402835,0.822604,0.83583,0.812155
1000,0.3213,0.396853,0.820198,0.834623,0.812172


Map: 100%|██████████| 1743/1743 [00:08<00:00, 206.74 examples/s]
Map: 100%|██████████| 436/436 [00:02<00:00, 214.32 examples/s]
Some weights of EsmForTokenClassification were not initialized from the model checkpoint at raw_model\esm2-8M and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Precision,Recall,F1
100,0.4293,0.404795,0.781226,0.80158,0.72207
200,0.3386,0.391171,0.806455,0.826875,0.804727
300,0.3861,0.392269,0.811245,0.810677,0.810958
400,0.3238,0.395321,0.805392,0.823665,0.808828
500,0.324,0.396433,0.812825,0.83139,0.808068
600,0.3067,0.408987,0.80734,0.827547,0.801434
700,0.2897,0.392636,0.812232,0.83055,0.812474
800,0.3269,0.390841,0.811407,0.829912,0.811805
900,0.3783,0.389199,0.808771,0.826064,0.812105
1000,0.3344,0.400778,0.812444,0.831088,0.810572


Map: 100%|██████████| 1743/1743 [00:08<00:00, 205.94 examples/s]
Map: 100%|██████████| 436/436 [00:02<00:00, 212.32 examples/s]
Some weights of EsmForTokenClassification were not initialized from the model checkpoint at raw_model\esm2-8M and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Precision,Recall,F1
100,0.4279,0.38874,0.816434,0.830118,0.770028
200,0.3454,0.375255,0.831739,0.846875,0.813109
300,0.3766,0.371507,0.834234,0.85166,0.83157
400,0.2973,0.363591,0.838908,0.854189,0.830291
500,0.3181,0.368663,0.837292,0.853456,0.830743
600,0.324,0.361445,0.839793,0.853529,0.842788
700,0.2978,0.359627,0.83764,0.853898,0.837354
800,0.3479,0.361213,0.836761,0.84986,0.840444
900,0.277,0.365044,0.837036,0.853602,0.835632
1000,0.3928,0.365563,0.838466,0.852029,0.841779


Map: 100%|██████████| 1743/1743 [00:08<00:00, 203.36 examples/s]
Map: 100%|██████████| 436/436 [00:02<00:00, 211.83 examples/s]
Some weights of EsmForTokenClassification were not initialized from the model checkpoint at raw_model\esm2-8M and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Precision,Recall,F1
100,0.5482,0.394245,0.842517,0.840923,0.788476
200,0.3981,0.389629,0.832181,0.823005,0.827104
300,0.3327,0.357205,0.842468,0.856898,0.832465
400,0.4131,0.351372,0.839412,0.85555,0.839842
500,0.3337,0.349768,0.844518,0.858979,0.837404
600,0.3637,0.348985,0.84493,0.859533,0.839244
700,0.3068,0.349128,0.844988,0.859518,0.838965
800,0.2859,0.350467,0.840044,0.856181,0.83967
900,0.3449,0.350187,0.839353,0.855412,0.840098
1000,0.3771,0.355796,0.836257,0.847445,0.840207


Map: 100%|██████████| 1744/1744 [00:08<00:00, 204.67 examples/s]
Map: 100%|██████████| 435/435 [00:02<00:00, 210.05 examples/s]
Some weights of EsmForTokenClassification were not initialized from the model checkpoint at raw_model\esm2-8M and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Precision,Recall,F1
100,0.3851,0.407397,0.772371,0.808622,0.759065
200,0.3618,0.399124,0.805138,0.826971,0.801262
300,0.328,0.395043,0.816108,0.832895,0.802286
400,0.4172,0.386223,0.822529,0.838157,0.824176
500,0.3506,0.391993,0.818427,0.828385,0.822177
600,0.3085,0.391644,0.817845,0.835592,0.812372
700,0.3211,0.387164,0.825204,0.84077,0.819045
800,0.3104,0.395529,0.819487,0.83488,0.822364
900,0.3188,0.397021,0.821931,0.838517,0.816647
1000,0.3641,0.391047,0.822904,0.838786,0.823838


In [7]:
# classifier = pipeline(
#     "token-classification",
#     model='finetuned_model/esm2-8M/lora_finetune',
#     tokenizer='finetuned_model/esm2-8M/lora_finetune',
#     device="cuda" if torch.cuda.is_available() else "cpu",
# )

In [8]:
# total_gt, total_pred = [], []

# for idx in range(len(dataset['test'])):
#     seq = dataset['test'][idx]['sequence']
#     label = dataset['test'][idx]['labels']

#     with torch.no_grad():
#         outputs = classifier(seq)
    
#     seq_res = []
#     for out in outputs:
#         if out['entity'] == 'LABEL_0':
#             seq_res.append(0)
#         else:
#             seq_res.append(1)

#     total_gt.extend(label)
#     total_pred.extend(seq_res)

In [9]:
# print(classification_report(total_gt, total_pred))