In [1]:
%load_ext autoreload
%autoreload 2
%cd "/home/hew/python/contp"
%ls

/home/hew/python/contp
[0m[01;34mckpt[0m/  [01;34mdataset[0m/  README.md  [01;34mtemp[0m/           [01;34mutils[0m/
[01;34mdata[0m/  [01;34mmodel[0m/    [01;34mscript[0m/    Untitled.ipynb


In [2]:
import os
import torch
import numpy as np
from tqdm.notebook import tqdm
import pandas as pd
from model.ConTP_data_module import ConTPDataModule
from model.ConTP_module import ConTPModule
from utils.dataset import ProteinDataset
from utils.lightning import LitModelInference

root_path: /home/hew/python/contp


In [3]:
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report


def compute_label_metrics(preds, labels):
    # preds: [[3], [0]], labels: [[3, 1], [0]]
    epoch_metrics = {}

    mlb = MultiLabelBinarizer()
    mlb.fit(labels + preds)
    y_true = mlb.transform(labels)
    y_pred = mlb.transform(preds)

    # weighted
    weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='weighted', zero_division=0
    )

    # samples
    samples_precision, samples_recall, samples_f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='samples', zero_division=0
    )

    # micro P/R/F1
    micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='micro', zero_division=0
    )

    # sample-wise exact accuracy (subset accuracy)
    samples_acc = accuracy_score(y_true, y_pred)

    # micro accuracy  ← 新增（展平成二分类）
    micro_acc = accuracy_score(y_true.ravel(), y_pred.ravel())

    # save
    epoch_metrics['weighted_precision'] = weighted_precision
    epoch_metrics['weighted_recall'] = weighted_recall
    epoch_metrics['weighted_f1'] = weighted_f1

    epoch_metrics['samples_precision'] = samples_precision
    epoch_metrics['samples_recall'] = samples_recall
    epoch_metrics['samples_f1'] = samples_f1
    epoch_metrics['samples_acc'] = samples_acc

    epoch_metrics['micro_precision'] = micro_precision
    epoch_metrics['micro_recall'] = micro_recall
    epoch_metrics['micro_f1'] = micro_f1
    epoch_metrics['micro_acc'] = micro_acc
    return epoch_metrics

In [4]:
def multi_label_from_distance(dist, threshold=0.03):
    """
    根据距离矩阵 dist 计算多标签预测：
    1) 对 -dist 做 softmax 得到概率
    2) 对概率 > threshold 的类别作为预测标签
    3) 若某样本无任何标签，则选择 Top-1
    4) 返回每个样本的类别索引列表

    参数:
        dist: (N, C) torch.Tensor, 距离矩阵
        threshold: float, 多标签概率阈值

    返回:
        final_preds: list[list[int]]
    """

    # 1) 对距离做 softmax → 概率
    sorted_probs, sorted_indices = torch.sort(
        torch.softmax(-dist, dim=1), descending=True
    )
    probs = sorted_probs
    preds = sorted_indices

    # 2) 得到 mask 与索引
    mask = probs > threshold
    indices = mask.nonzero(as_tuple=False)  # (K, 2)

    N, C = probs.shape
    multi_label_pred = [[] for _ in range(N)]

    # 3) 阈值筛选
    for sample_id, class_id in indices.tolist():
        multi_label_pred[sample_id].append(class_id)

    # 4) 若为空 → 选 top1
    top1_ids = torch.argmax(probs, dim=1).tolist()
    for i in range(N):
        if len(multi_label_pred[i]) == 0:
            multi_label_pred[i].append(top1_ids[i])

    # 5) 根据 preds 中的真实 class_id 映射
    final_preds = []
    for i, idx_list in enumerate(multi_label_pred):
        cls_list = [preds[i, idx].item() for idx in idx_list]
        final_preds.append(cls_list)

    return final_preds

In [5]:
# task = 'substrate_classification'
task = 'tc_classification'
if task == 'substrate_classification':
    ckpt_path = '/home/hew/python/contp/ckpt/lightning_logs/substrate/checkpoints/last.ckpt'
    temp_dir = './temp/inference_substrate/'
    label_map = pd.read_csv('./data/substrate_mapping.csv')
    select_cluster = label_map['id'].tolist()
    label_key = 'substrate_ids'
elif task == 'tc_classification':
    ckpt_path = '/home/hew/python/contp/ckpt/lightning_logs/tc/checkpoints/last.ckpt'
    temp_dir = './temp/inference_tc/'
    label_map = pd.read_csv('./data/tc_mapping.csv')
    select_cluster = label_map['id'].tolist()
    label_key = 'label_id'
else:
    raise NotImplementedError

os.makedirs(temp_dir, exist_ok=True)
predictor = LitModelInference(ConTPModule, ConTPDataModule, ckpt_path)
datamodule = predictor.pl_data_module
# if task == 'substrate_classification':
#     datamodule.dataset = ProteinDataset(name='TCDB_substrate', path='/home/hew/python/TPNet/dataset/TCDB_substrate')
# elif task == 'tc_classification':
#     datamodule.dataset = ProteinDataset(name='TCDB_tc', path='/home/hew/python/TPNet/dataset/TCDB_tc')
# datamodule.dataframe = datamodule.dataset.metadata
datamodule.dataset

[loading checkpoint]: /home/hew/python/contp/ckpt/lightning_logs/tc/checkpoints/last.ckpt
Loading metadata from /home/hew/python/contp/dataset/TCDB_tc/metadata.csv


Seed set to 42


ProteinDataset[ TCDB_tc ], size: 99950, path: /home/hew/python/contp/dataset/TCDB_tc

In [None]:
datamodule.prepare_data()
datamodule.setup('fit')
datamodule.setup('test')

In [6]:
cache_file = f'{temp_dir}/class_embeddings.pth'
if os.path.exists(cache_file):
    # if False:
    class_embeddings = torch.load(cache_file)
    idx2sample = pd.read_csv(f'{temp_dir}/idx2sample.csv')
    idx2sample.set_index('sample_id', inplace=True)
    select_cluster = np.load(f'{temp_dir}/select_cluster.npy')

  class_embeddings = torch.load(cache_file)


In [None]:
test_dataset = datamodule.test_dataset
test_dataset_size = len(test_dataset)
test_dataset

In [None]:
from utils.wrapper.ESM import ESMWrapper

In [None]:
query_seqs = test_dataset.metadata['sequence'].tolist()[:100]
query_labels = test_dataset.metadata['substrate_ids'].map(eval).tolist()[:100]
len(query_seqs), len(query_labels)

In [None]:
from utils.file import write_fasta

write_fasta('./temp/example.fasta', query_seqs, [f'substrate_id: {y}' for y in query_labels])

In [None]:
device = 'cpu'
device = 'cuda:0'

In [None]:
esm = ESMWrapper('./temp/esm/', device=device)
esm.__init_submodule__()
esm

In [None]:
batch_size = 20
num_query = len(query_seqs)
num_batches = (num_query // batch_size) + (0 if num_query % batch_size == 0 else 1)

query_esm = []
for i in tqdm(range(num_batches), desc='Computing ESM Embeddings'):
    batch_seqs = query_seqs[i * batch_size: (i + 1) * batch_size]
    batch_embed = esm.forward(batch_seqs)['mean_representations']
    query_esm.append(batch_embed)

query_esm = torch.concat(query_esm, dim=0)
query_esm.shape

In [None]:
device = 'cuda:0'
query_X = predictor.ckpt_model.model.forward(query_esm.to(device))
query_X.shape

In [None]:
query_X = query_X
query_label = query_labels

train_C = class_embeddings
raw_pred, dist = predictor.ckpt_model.model.find_nearest_cluster(query_X, train_C, return_dist=True)
pred = np.array([select_cluster[i] for i in raw_pred])
pred = [[y] for y in pred]
compute_label_metrics(pred, query_label)

In [None]:
prob = torch.softmax(-dist, dim=1)
prob[0]

In [None]:
result = torch.sort(prob, descending=True)
probs = result[0]
preds = result[1]
probs[0], preds[0]

In [None]:
if task == 'substrate_classification':
    threshold = 0.034  # determined in the training set
    final_preds = multi_label_from_distance(dist, threshold=threshold)
    metrics = compute_label_metrics(final_preds, query_label)
else:
    final_preds = pred
    metrics = compute_label_metrics(pred, query_label)
metrics

In [None]:
final_preds

In [7]:
from utils.file import read_fasta
from utils.wrapper.ESM import ESMWrapper

query_seqs, query_labels = read_fasta('./temp/example.fasta')
query_labels = [eval(y.split('substrate_id: ')[-1]) for y in query_labels]
len(query_seqs), len(query_labels)

(100, 100)

In [8]:
# device = 'cpu'
device = 'cuda:0'

In [9]:
esm = ESMWrapper('./temp/esm/', device=device)
esm.__init_submodule__()
esm

[ESM] ESM model initializing...


ESMWrapper(path=/home/hew/python/contp/temp/esm/esm2_t33_650M_UR50D)

In [10]:
batch_size = 20
num_query = len(query_seqs)
num_batches = (num_query // batch_size) + (0 if num_query % batch_size == 0 else 1)

query_esm = []
for i in tqdm(range(num_batches), desc='Computing ESM Embeddings'):
    batch_seqs = query_seqs[i * batch_size: (i + 1) * batch_size]
    batch_embed = esm.forward(batch_seqs)['mean_representations']
    query_esm.append(batch_embed)

query_esm = torch.concat(query_esm, dim=0)
query_esm.shape

Computing ESM Embeddings:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([100, 1280])

In [11]:
device = 'cuda:0'
query_X = predictor.ckpt_model.model.forward(query_esm.to(device))
query_X.shape

torch.Size([100, 320])

In [12]:
query_X = query_X
query_label = query_labels

train_C = class_embeddings
raw_pred, dist = predictor.ckpt_model.model.find_nearest_cluster(query_X, train_C, return_dist=True)

In [18]:
if task == 'substrate_classification':
    threshold = 0.034  # determined in the training set
    preds = multi_label_from_distance(dist, threshold=threshold)

    final_pred = []
    for pred_list in preds:
        pred_labels = [label_map.loc[pred_y, 'substrate'] for pred_y in pred_list]
        final_pred.append(pred_labels)
else:
    preds = np.array([select_cluster[i] for i in raw_pred])
    final_pred = [label_map.loc[pred, 'tcid'] for pred in preds]

pred_df = pd.DataFrame(final_pred)
if task == 'substrate_classification':
    pred_df.columns = [f'substrate_top{col + 1}' for col in pred_df.columns.tolist()]
else:
    pred_df.columns = ['tcid']
pred_df.to_csv('./temp/prediction.csv', index=False)
pred_df

Unnamed: 0,tcid
0,1.A.1.11
1,1.A.1.11
2,1.A.1.13
3,1.A.1.13
4,1.A.1.13
...,...
95,1.A.26.2
96,1.A.30.1
97,1.A.30.1
98,1.A.30.1
