In [1]:
%load_ext autoreload
%autoreload 2
%cd "/home/hew/python/contp"
%ls

/home/hew/python/contp
[0m[01;34mckpt[0m/     [01;34mdataset_bak[0m/         [01;34mmodel[0m/     [01;34mtemp[0m/           wget-log
[01;34mdata[0m/     [01;31mdataset.zip[0m          README.md  Untitled.ipynb
[01;34mdataset[0m/  download_dataset.sh  [01;34mscript[0m/    [01;34mutils[0m/


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import os
import torch
import numpy as np
from tqdm.notebook import tqdm
import pandas as pd
from model.ConTP_data_module import ConTPDataModule
from model.ConTP_module import ConTPModule
from utils.dataset import ProteinDataset
from utils.lightning import LitModelInference

root_path: /home/hew/python/contp


In [3]:
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report


def compute_label_metrics(preds, labels):
    # preds: [[3], [0]], labels: [[3, 1], [0]]
    epoch_metrics = {}

    mlb = MultiLabelBinarizer()
    mlb.fit(labels + preds)
    y_true = mlb.transform(labels)
    y_pred = mlb.transform(preds)

    # weighted
    weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='weighted', zero_division=0
    )

    # samples
    samples_precision, samples_recall, samples_f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='samples', zero_division=0
    )

    # micro P/R/F1
    micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='micro', zero_division=0
    )

    # sample-wise exact accuracy (subset accuracy)
    samples_acc = accuracy_score(y_true, y_pred)

    # micro accuracy  ← 新增（展平成二分类）
    micro_acc = accuracy_score(y_true.ravel(), y_pred.ravel())

    # save
    epoch_metrics['weighted_precision'] = weighted_precision
    epoch_metrics['weighted_recall'] = weighted_recall
    epoch_metrics['weighted_f1'] = weighted_f1

    epoch_metrics['samples_precision'] = samples_precision
    epoch_metrics['samples_recall'] = samples_recall
    epoch_metrics['samples_f1'] = samples_f1
    epoch_metrics['samples_acc'] = samples_acc

    epoch_metrics['micro_precision'] = micro_precision
    epoch_metrics['micro_recall'] = micro_recall
    epoch_metrics['micro_f1'] = micro_f1
    epoch_metrics['micro_acc'] = micro_acc
    return epoch_metrics

In [4]:
def multi_label_from_distance(dist, threshold=0.035):
    """
    根据距离矩阵 dist 计算多标签预测：
    1) 对 -dist 做 softmax 得到概率
    2) 对概率 > threshold 的类别作为预测标签
    3) 若某样本无任何标签，则选择 Top-1
    4) 返回每个样本的类别索引列表

    参数:
        dist: (N, C) torch.Tensor, 距离矩阵
        threshold: float, 多标签概率阈值

    返回:
        final_preds: list[list[int]]
    """

    # 1) 对距离做 softmax → 概率
    sorted_probs, sorted_indices = torch.sort(
        torch.softmax(-dist, dim=1), descending=True
    )
    probs = sorted_probs
    preds = sorted_indices

    # 2) 得到 mask 与索引
    mask = probs > threshold
    indices = mask.nonzero(as_tuple=False)  # (K, 2)

    N, C = probs.shape
    multi_label_pred = [[] for _ in range(N)]

    # 3) 阈值筛选
    for sample_id, class_id in indices.tolist():
        multi_label_pred[sample_id].append(class_id)

    # 4) 若为空 → 选 top1
    top1_ids = torch.argmax(probs, dim=1).tolist()
    for i in range(N):
        if len(multi_label_pred[i]) == 0:
            multi_label_pred[i].append(top1_ids[i])

    # 5) 根据 preds 中的真实 class_id 映射
    final_preds = []
    for i, idx_list in enumerate(multi_label_pred):
        cls_list = [preds[i, idx].item() for idx in idx_list]
        final_preds.append(cls_list)

    return final_preds

In [5]:
task = 'substrate_classification'
# task = 'tc_classification'
if task == 'substrate_classification':
    ckpt_path = '/home/hew/python/contp/ckpt/lightning_logs/substrate/checkpoints/last.ckpt'
    temp_dir = './temp/inference_substrate/'
    label_map = pd.read_csv('./data/substrate_mapping.csv')
    select_cluster = label_map['id'].tolist()
    label_key = 'substrate_ids'
    dataset_name = 'TCDB_substrate'
    dataset_path = '/home/hew/python/contp/dataset/TCDB_substrate'
elif task == 'tc_classification':
    ckpt_path = '/home/hew/python/contp/ckpt/lightning_logs/tc/checkpoints/last.ckpt'
    temp_dir = './temp/inference_tc/'
    label_map = pd.read_csv('./data/tc_mapping.csv')
    select_cluster = label_map['id'].tolist()
    label_key = 'label_id'
    dataset_name = 'TCDB_tc'
    dataset_path = '/home/hew/python/contp/dataset/TCDB_tc'
else:
    raise NotImplementedError

if not os.path.exists(dataset_path):
    raise FileNotFoundError('Please download the dataset first!')

os.makedirs(temp_dir, exist_ok=True)
predictor = LitModelInference(ConTPModule, ConTPDataModule, ckpt_path)
datamodule = predictor.pl_data_module
datamodule.dataset = ProteinDataset(dataset_name, dataset_path)
datamodule.dataframe = datamodule.dataset.metadata
datamodule.dataset

[loading checkpoint]: /home/hew/python/contp/ckpt/lightning_logs/substrate/checkpoints/last.ckpt


Seed set to 42


ProteinDataset[ TCDB_substrate ], size: 47420, path: /home/hew/python/contp/dataset/TCDB_substrate

In [6]:
datamodule.prepare_data()
datamodule.setup('fit')

use the original split of the dataset
[prepare_data] max_len: 2000, subset_ratio: 1, number: 41475
[self.train_dataset] 28994
[self.val_dataset] 12481


In [7]:
cache_file = f'{temp_dir}/class_embeddings.pth'
if os.path.exists(cache_file):
    # if False:
    class_embeddings = torch.load(cache_file)
    idx2sample = pd.read_csv(f'{temp_dir}/idx2sample.csv')
    idx2sample.set_index('sample_id', inplace=True)
    select_cluster = datamodule.contrastive_dataset.unique_labels
else:
    # extract esm embeddings for all samples and group by class
    unique_labels = datamodule.contrastive_dataset.unique_labels
    class_embeddings_dict = {}
    class_sequences_dict = {}
    all_samples = []
    for label in tqdm(unique_labels):
        class_samples = datamodule.contrastive_dataset.label2idx[label]
        class_embeddings = np.array(
            [datamodule.contrastive_dataset.dataset[idx]['esm_embedding'] for idx in class_samples])
        class_sequences = np.array([datamodule.contrastive_dataset.dataset[idx]['sequence'] for idx in class_samples])
        class_embeddings_dict[label] = class_embeddings
        class_sequences_dict[label] = class_sequences
        print(label, len(class_samples), class_embeddings.shape)
        all_samples.extend(class_samples)

    temp_file = f'{temp_dir}/train_esm_embeddings.npy'
    np.save(temp_file, class_embeddings_dict)

    # record the sample mapping
    idx2sample = {i: x for i, x in enumerate(all_samples)}
    idx2sample = pd.DataFrame(idx2sample, index=['sample_id']).T
    idx2sample['idx'] = idx2sample.index
    idx2sample = idx2sample.set_index('sample_id')
    idx2sample.to_csv(f'{temp_dir}/idx2sample.csv')

    # compute the class embedding
    concat_embed, cluster_labels = predictor.ckpt_model.model.compute_cluster_center(temp_file,
                                                                                     return_sample_embed=True)

    # compute the latent embedding for each class
    class_embeddings = []
    select_cluster = unique_labels
    for i in select_cluster:
        indices = np.where(cluster_labels == i)[0]
        class_embeddings.append(concat_embed[indices].mean(0))

    class_embeddings = torch.stack(class_embeddings, dim=0)
    torch.save(class_embeddings, cache_file)

class_embeddings.shape, class_embeddings.device

  class_embeddings = torch.load(cache_file)


(torch.Size([70, 320]), device(type='cuda', index=0))

In [8]:
datamodule.setup('test')
test_dataset = datamodule.test_dataset
test_dataset_size = len(test_dataset)
test_dataset

[self.test_dataset] 12481


ProteinDataset[ test_dataset ], size: 12481, path: /home/hew/python/contp/dataset/TCDB_substrate

In [9]:
query_esm_embedding = []
query_label = []
for i in tqdm(range(test_dataset_size)):
    esm_embedding = test_dataset[i]['esm_embedding']
    query_esm_embedding.append(torch.from_numpy(esm_embedding).float())

test_X = torch.stack(query_esm_embedding, dim=0)
if task == 'substrate_classification':
    test_y = [eval(test_dataset[s][label_key]) for s in range(test_dataset_size)]
elif task == 'tc_classification':
    test_y = [test_dataset[s][label_key] for s in range(test_dataset_size)]
    test_y = [[y] for y in test_y]
else:
    raise NotImplementedError
test_X.shape, len(test_y), test_y[:5]

  0%|          | 0/12481 [00:00<?, ?it/s]

(torch.Size([12481, 1280]), 12481, [[26], [24, 31, 25, 26], [25], [25], [25]])

In [10]:
device = 'cuda:0'
test_X = predictor.ckpt_model.model.forward(test_X.to(device))
test_X.shape

torch.Size([12481, 320])

In [11]:
query_X = test_X
query_label = test_y

train_C = class_embeddings
raw_pred, dist = predictor.ckpt_model.model.find_nearest_cluster(query_X, train_C, return_dist=True)
pred = np.array([select_cluster[i] for i in raw_pred])
pred = [[y] for y in pred]
compute_label_metrics(pred, query_label)

{'weighted_precision': 0.9459579987307455,
 'weighted_recall': 0.6680472029955747,
 'weighted_f1': 0.7585857244117813,
 'samples_precision': 0.9434340197099591,
 'samples_recall': 0.8077317089074961,
 'samples_f1': 0.8451752059956531,
 'samples_acc': 0.7134844964345806,
 'micro_precision': 0.9434340197099591,
 'micro_recall': 0.6680472029955747,
 'micro_f1': 0.7822101172484804,
 'micro_acc': 0.9924948779287374}

In [12]:
if task == 'substrate_classification':
    threshold = 0.034  # determined in the training set
    final_preds = multi_label_from_distance(dist, threshold=threshold)
    metrics = compute_label_metrics(final_preds, query_label)
else:
    final_preds = pred
    metrics = compute_label_metrics(pred, query_label)
metrics

{'weighted_precision': 0.9449007805223315,
 'weighted_recall': 0.6948258254850789,
 'weighted_f1': 0.7773231873093615,
 'samples_precision': 0.9416045722832039,
 'samples_recall': 0.8211788151756905,
 'samples_f1': 0.8538390416925791,
 'samples_acc': 0.7315118980850893,
 'micro_precision': 0.9419320104599292,
 'micro_recall': 0.6948258254850789,
 'micro_f1': 0.7997257411518872,
 'micro_acc': 0.9929790424302082}

In [13]:
if task == 'substrate_classification':
    key = 'substrate'
elif task == 'tc_classification':
    key = 'tcid'
else:
    raise NotImplementedError

final_pred = []
for pred_list in final_preds:
    pred_labels = [label_map.loc[pred_y, key] for pred_y in pred_list]
    final_pred.append(pred_labels)
pred_df = pd.DataFrame(final_pred)
pred_df

Unnamed: 0,0,1,2
0,cation:calcium,,
1,cation:lithium,,
2,cation:potassium,,
3,cation:potassium,,
4,cation:potassium,,
...,...,...,...
12476,cation:proton,,
12477,nucleic acid:nucleotide,,
12478,cation:iron,,
12479,anion:aspartate,,
