# test p2go partial component mfo best model

In [1]:
import pandas as pd
import os
import torch

from torch.utils.data import DataLoader
from deepfold.utils.make_graph import build_graph
from deepfold.utils.model import load_model_checkpoint

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn as nn
import math
from torch.nn import BCEWithLogitsLoss
from deepfold.models.esm_model import MLPLayer,MLPLayer3D
from deepfold.models.gnn_model import GCN



# attention  module
def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不相关的项 Defined in :numref:`sec_seq2seq_decoder`"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作."""
    # X:3D张量，valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换，从而其softmax输出为0
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

class DotProductAttention(nn.Module):
    """缩放点积注意力."""
    def __init__(self, dropout, lambd=1,**kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.lambd = lambd
        self.dropout = nn.Dropout(dropout)
    # queries的形状：(batch_size，查询的个数，d)
    # keys的形状：(batch_size，“键－值”对的个数，d)
    # values的形状：(batch_size，“键－值”对的个数，值的维度)
    # valid_lens的形状:(batch_size，)或者(batch_size，查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(self.lambd * d)
        # scores /= self.tao
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values),self.attention_weights

def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状."""
    # 输入X的形状:(batch_size，查询或者“键－值”对的个数，num_hiddens)
    # 输出X的形状:(batch_size，查询或者“键－值”对的个数，num_heads，
    # num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # 输出X的形状:(batch_size，num_heads，查询或者“键－值”对的个数,
    # num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    # 最终输出的形状:(batch_size*num_heads,查询或者“键－值”对的个数,
    # num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作."""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

class MultiHeadAttention(nn.Module):
    """多头注意力."""
    def __init__(self,
                 key_size,
                 query_size,
                 value_size,
                 num_hiddens,
                 num_heads,
                 dropout,
                 bias=False,
                 lambd=None,
                 **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout,lambd=lambd)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens=None, output_attentions=True):
        # queries，keys，values的形状:
        # (batch_size，查询或者“键－值”对的个数，num_hiddens)
        # valid_lens　的形状:
        # (batch_size，)或(batch_size，查询的个数)
        # 经过变换后，输出的queries，keys，values　的形状:
        # (batch_size*num_heads，查询或者“键－值”对的个数，
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0，将第一项（标量或者矢量）复制num_heads次，
            # 然后如此复制第二项，然后诸如此类。
            valid_lens = torch.repeat_interleave(valid_lens,
                                                 repeats=self.num_heads,
                                                 dim=0)

        # output的形状:(batch_size*num_heads，查询的个数，
        # num_hiddens/num_heads)
        output,weight = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size，查询的个数，num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        weight_concat = transpose_output(weight, self.num_heads)
        outputs = (output_concat, weight_concat) if output_attentions else (output_concat,)
        return outputs

class LabelBasedAttention(nn.Module):
    def __init__(self, label_dim, word_dim, latent_dim, nb_labels, nb_words,components_factor=2,temperature=2):
        super().__init__()
        self.label_dim = label_dim
        self.word_dim = word_dim
        self.latent_dim = latent_dim
        self.fc_label = MLPLayer(self.label_dim, self.latent_dim)
        self.fc_word = MLPLayer3D(self.word_dim, self.latent_dim)
        self.nb_labels = nb_labels
        self.nb_words = nb_words
        assert components_factor > 0
        self.nb_components = int(self.nb_labels/(components_factor*(1+int(self.nb_labels/self.nb_words))))
        self.component_embedding = nn.Parameter(torch.randn((self.nb_components,self.latent_dim)),requires_grad=True)
        self.temperature = temperature

    def forward(self, label_embedding, word_embedding, lengths):
        word_embedding = self.fc_word(word_embedding)
        label_embedding = self.fc_label(label_embedding)
        word_embedding = word_embedding.transpose(-1,-2)
        component_embedding = self.component_embedding.repeat((word_embedding.shape[0], 1, 1))
        c_w = torch.bmm(component_embedding,word_embedding)
        label_embedding = label_embedding.repeat((word_embedding.shape[0], 1, 1))
        l_c = torch.bmm(label_embedding,component_embedding.transpose(-1,-2))
        l_c = l_c.softmax(dim=-1)
        l_w = torch.bmm(l_c,c_w)
        l_w /= self.temperature
        l_w = masked_softmax(l_w, lengths+1)
        label_level_embedding = torch.bmm(l_w,word_embedding.transpose(-1,-2))
        
        return label_level_embedding, l_w

class P2GO(nn.Module):
    def __init__(self,
                 terms_embedding,adj,
                 model_dir: str = 'esm1b_t33_650M_UR50S',
                 aa_dim=1280,
                 latent_dim=256,
                 dropout_rate=0.1):
        super().__init__()
        self.terms_embedding = nn.Parameter(terms_embedding,requires_grad=True)
        self.terms_dim = terms_embedding.shape[1]
        self.latent_dim = latent_dim
        self.adj = adj
        
        # backbone
        backbone, _ = esm.pretrained.load_model_and_alphabet(
            model_dir)
        unfreeze_layers = None # [32] # total 0-32 layer
        self.backbone = self.unfreeze(backbone,unfreeze_layers)
        self.nb_classes = adj.shape[0]
        # label based attention
        self.label_attention = LabelBasedAttention(label_dim=self.terms_dim,word_dim=aa_dim,latent_dim=latent_dim,nb_labels=self.nb_classes,nb_words=1024,
                                                    components_factor=4,temperature=2)
        # go transform
        self.go_transfrom = MLPLayer(self.terms_dim,self.latent_dim)
        # gnn module
        self.gcn = GCN(latent_dim, latent_dim)
        # output layer
        self.go_transform_post = MLPLayer(int(2 * latent_dim), latent_dim)
        # post mlp
        self.post_mlp = MLPLayer(self.nb_classes, self.nb_classes)

    def unfreeze(self, backbone, unfreeze_layers:list):
        for name ,param in backbone.named_parameters():
            param.requires_grad = False
        if unfreeze_layers is not None:
            if 'lm_head' in name:
                param.requires_grad = True
            for idx in unfreeze_layers:
                for _, p in backbone.layers[idx].named_parameters():
                    p.requires_grad = True
        return backbone

    def forward(self, input_ids, lengths, labels, output_attention_weights=True):
        # backbone
        x = self.backbone(input_ids, repr_layers=[33])['representations'][33]
        # x = x[:, 1:]
        # x [B,L,C]
        label_level_embedding, weights = self.label_attention(self.terms_embedding, x, lengths)
        go_embedding = self.go_transfrom(self.terms_embedding)
        # go embedding
        go_out = self.gcn(go_embedding, self.adj)
        # output layer
        go_out = torch.cat((go_embedding, go_out), dim=1)
        go_out = self.go_transform_post(go_out)
        go_out = go_out.repeat((x.shape[0], 1, 1))
        logits = self.post_mlp(torch.sum(go_out * label_level_embedding, dim=-1))
        outputs = (logits,)
        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            labels = labels.float()
            loss = loss_fct(logits.view(-1, self.nb_classes),
                            labels.view(-1, self.nb_classes))
            outputs = (logits,loss)
        if output_attention_weights:
            outputs = (loss, logits, weights)

        return outputs

In [3]:
class Args:
    def __init__(self) -> None:
        self.data_path='esm'
        self.data_path='/share/home/niejianzheng/xbiome/datasets/protein/cafa3/'
        self.resume = '/home/wangbin/X-DeepGO/work_dir/p2go_partial_component_mfo__new_adj_top2down_t2f4/checkpoint_4.pth'
        self.batch_size = 4
        self.workers=1
        self.namespace = 'mfo'
        self.temperature=2,
        self.components_factor=4

# Dataset and DataLoader
args = Args()

In [4]:
test_df = pd.read_pickle('test_res.pkl')

test_df

Unnamed: 0,pdb_id,res,sequence
0,12as,"[46, 100, 116, 235, 248, 251]",MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQD...
1,13pk,"[39, 219, 376, 399]",EKKSINECDLKGKKVLIRVDFNVPVKNGKITNDYRIRSALPTLKKV...
2,1a05,"[140, 190, 222, 246, 250]",MKKIAIFAGDGIGPEIVAAARQVLDAVDQAAHLGLRCTEGLVGGAA...
3,1a0i,"[34, 238, 240]",VNIKTNPFKAVSFVESAIKKALDNAGYLIAEIKYDGVRGNICVDNT...
4,1a16,"[38, 243, 260, 271, 350, 354, 361, 383, 387, 4...",SEISRQEFQRRRQALVEQMQPGSAALIFAAPEVTRSADSEYPYRQN...
...,...,...,...
946,7enl,"[39, 159, 168, 211, 246, 295, 320, 345, 373, 396]",AVSKVYARSVYDSRGNPTVEVELTTEKGVFRSIVPSGASTGVHEAL...
947,7nn9,"[151, 220, 277, 371, 412]",RDFNNLTKGLCTINSWHIYGKDNAVRIGEDSDVLVTREPYVSCDPD...
948,7odc,"[69, 197, 274]",MSSFTKDEFDCHILDEGFTAKDILDQKINEVSSSDDKDAFYVADLG...
949,8pch,"[19, 25, 159]",YPPSMDWRKKGNFVSPVKNQGSCGSCWTFSTTGALESAVAIATGKM...


In [5]:
from torch.utils.data.dataset import Dataset
from typing import Dict
import esm
import gc
import random

class TestAttentionDataset(Dataset):
    def __init__(self,
                 test_file_name,
                 label_map,
                 model_dir: str = 'esm1b_t33_650M_UR50S',
                 max_length: int = 1024,
                 truncate: bool = True,
                 random_crop: bool = False):
        super().__init__()

        self.seqs, self.res = self.load_dataset(test_file_name)
        self.terms_dict = label_map
        self.num_classes = len(self.terms_dict)
        self.max_length = max_length
        self.truncate = truncate
        self.random_crop = random_crop

        esm_model, self.alphabet = esm.pretrained.load_model_and_alphabet(
            model_dir)
        self.batch_converter = self.alphabet.get_batch_converter()
        self.free_memory(esm_model)
    
    def free_memory(self, esm_model):
        del esm_model
        gc.collect()
        print('Delete the esm model, free memory!')

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        sequence = self.seqs[idx]
        if self.truncate:
            sequence = sequence[:self.max_length - 2]
        length = len(sequence)
        multilabel = [0] * self.num_classes
        return sequence, length, multilabel,self.res[idx]

    def load_dataset(self, test_df_file):
        df = pd.read_pickle(test_df_file)
        seq = list(df['sequence'])
        res = list(df['res'])
        return seq,res

    def collate_fn(self, examples) -> Dict[str, torch.Tensor]:
        """Function to transform tokens string to IDs; it depends on the model
        used."""
        sequences_list = [ex[0] for ex in examples]
        lengths = [ex[1] for ex in examples]
        multilabel_list = [ex[2] for ex in examples]
        res_list = [ex[3] for ex in examples]

        labels, strs, all_tokens = self.batch_converter([
            ('', sequence) for sequence in sequences_list
        ])

        # The model is trained on truncated sequences and passing longer ones in at
        # infernce will cause an error. See https://github.com/facebookresearch/esm/issues/21
        if self.truncate:
            all_tokens = all_tokens[:, :self.max_length]

        if all_tokens.shape[1] < 1024:
            tmp = torch.ones((all_tokens.shape[0], 1024 - all_tokens.shape[1]))
            all_tokens = torch.cat([all_tokens, tmp], dim=1)
        all_tokens = all_tokens.int()
        all_tokens = all_tokens.to('cpu')
        encoded_inputs = {
            'input_ids': all_tokens,
        }
        encoded_inputs['lengths'] = torch.tensor(lengths, dtype=torch.int)
        encoded_inputs['labels'] = torch.tensor(multilabel_list,
                                                dtype=torch.int)
        encoded_inputs['res'] = res_list
        return encoded_inputs


In [6]:


args.gpu = 0
# Dataset and DataLoader
adj, multi_hot_vector, label_map, label_map_ivs,_ = build_graph(
    data_path=args.data_path, namespace=args.namespace)
test_dataset = TestAttentionDataset(test_file_name='test_res.pkl',label_map=label_map)
# dataloders
test_loader = DataLoader(test_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True,
                            collate_fn=test_dataset.collate_fn)

Number of annotated terms: 6367
number of edges:7965




Delete the esm model, free memory!


In [7]:
batch = next(iter(test_loader))

In [8]:
batch

{'input_ids': tensor([[ 0, 20, 15,  ...,  1,  1,  1],
         [ 0,  9, 15,  ...,  1,  1,  1],
         [ 0, 20, 15,  ...,  1,  1,  1],
         [ 0,  7, 17,  ...,  1,  1,  1]], dtype=torch.int32),
 'lengths': tensor([330, 415, 358, 348], dtype=torch.int32),
 'labels': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.int32),
 'res': [[46, 100, 116, 235, 248, 251],
  [39, 219, 376, 399],
  [140, 190, 222, 246, 250],
  [34, 238, 240]]}

In [9]:
type(batch['input_ids'])

torch.Tensor

In [10]:
input_ids = batch['input_ids']

In [11]:
input_ids[0][46].item()

13

In [12]:

test_dataset.alphabet.get_tok(input_ids[0][100].item()),test_df.sequence[0][99]

('R', 'R')

In [13]:
import numpy as np

# model
terms_all = pd.read_pickle(os.path.join(args.data_path,'all_terms_partial_order_embeddings.pkl'))
terms = pd.read_pickle(os.path.join(args.data_path,args.namespace,args.namespace + '_terms.pkl'))
terms_embedding = terms.merge(terms_all)
embeddings = np.concatenate([np.array(embedding,ndmin=2) for embedding in terms_embedding.embeddings.values])
terms_embedding = torch.Tensor(embeddings)
terms_embedding = terms_embedding.cuda()
adj = adj.cuda()
model = P2GO(terms_embedding, adj, model_dir= 'esm1b_t33_650M_UR50S', aa_dim= 1280, latent_dim = 256, dropout_rate=0.1)
if args.resume is not None:
    model_state, optimizer_state = load_model_checkpoint(args.resume)
    model.load_state_dict(model_state)

model = model.cuda()



=> loading checkpoint '/home/wangbin/X-DeepGO/work_dir/p2go_partial_component_mfo__new_adj_top2down_t2f4/checkpoint_4.pth'


In [14]:
model.eval()
batch = {key: val.cuda() for key, val in batch.items() if type(val) is torch.Tensor}
(loss, logits, weights) = model(**batch)

In [15]:
loss

tensor(0.0043, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [16]:
logits = logits.sigmoid().detach().cpu()

In [17]:
logits

tensor([[2.0182e-05, 1.0657e-05, 2.9552e-05,  ..., 8.0472e-06, 3.0107e-05,
         3.5217e-05],
        [1.6076e-05, 8.4371e-06, 2.8903e-05,  ..., 1.4782e-05, 8.8768e-05,
         4.2436e-05],
        [1.4485e-05, 9.1310e-06, 3.5951e-05,  ..., 9.7281e-06, 1.1790e-04,
         8.1055e-05],
        [1.6528e-05, 2.4249e-05, 9.5311e-06,  ..., 1.0609e-05, 3.5822e-05,
         1.9297e-05]])

In [18]:
torch.sum(logits > 0.4,dim=1)

tensor([ 4, 21,  6,  8])

In [19]:
idxs = logits[0].sort(descending=True).indices[logits[0].sort(descending=True).values>0.40]

In [20]:
for idx in idxs:
    print(label_map_ivs[idx.item()])

GO:0003674
GO:0003824
GO:0016874
GO:0005488


In [21]:
weights = weights.detach().cpu()

In [22]:
weights[0][idxs[3]]

tensor([9.2074e-05, 7.1975e-05, 1.4078e-03,  ..., 0.0000e+00, 0.0000e+00,
        0.0000e+00])

In [23]:
weights[0][idxs[3]].sort(descending=True)

torch.return_types.sort(
values=tensor([0.1193, 0.0694, 0.0579,  ..., 0.0000, 0.0000, 0.0000]),
indices=tensor([181, 183, 220,  ..., 565, 566, 567]))

In [24]:
length = batch['lengths']

In [25]:
length[0]

tensor(330, device='cuda:0', dtype=torch.int32)

In [26]:
1/330

0.0030303030303030303

In [28]:
sorted_weights = weights[0][idxs[0]].sort(descending=True)
sorted_weights

torch.return_types.sort(
values=tensor([0.5551, 0.2011, 0.1488,  ..., 0.0000, 0.0000, 0.0000]),
indices=tensor([116, 110, 102,  ..., 565, 566, 567]))

In [29]:
sorted_weights.indices[:int(0.1*(length[0])+1)]

tensor([116, 110, 102, 213, 252, 214, 108, 215, 105, 220, 212, 265, 299,  74,
         46, 264, 292,  71, 211, 255, 181, 206, 210, 216, 247, 217, 207, 104,
        251,  78,   9, 106, 103, 112])

In [27]:
# [46, 100, 116, 235, 248, 251]	
print(label_map_ivs[idxs[0].item()])
weights[0][idxs[0]].sort(descending=True)[].indices[weights[0][idxs[0]].sort(descending=True).values > 0.012]

# [115, 109, 101, 212, 251]

GO:0003674


tensor([116, 110, 102, 213, 252])

In [28]:
# [46, 100, 116, 235, 248, 251]	
print(label_map_ivs[idxs[1].item()])
weights[0][idxs[1]].sort(descending=True).indices[weights[0][idxs[1]].sort(descending=True).values > 0.012]

# [298,  52,  49, 213, 263, 292,  48, 268]

GO:0003824


tensor([299,  53,  50, 214, 264, 293,  49, 269])

In [84]:
# [46, 100, 116, 235, 248, 251]	
set(list(weights[0][idxs[3]].sort(descending=True).indices[weights[0][idxs[3]].sort(descending=True).values > 0.012].sort().values-1)).intersection(set(test_df.res[0]))

set()

In [29]:
# [46, 100, 116, 235, 248, 251]	
print(label_map_ivs[idxs[2].item()])
weights[0][idxs[2]].sort(descending=True).indices[weights[0][idxs[2]].sort(descending=True).values > 0.012]

# [184,  52, 298, 283,  38, 297,  51]

GO:0016874


tensor([185,  53, 299, 284,  39, 298,  52])

In [30]:
# [46, 100, 116, 235, 248, 251]
print(label_map_ivs[idxs[3].item()])	
weights[0][idxs[3]].sort(descending=True).indices[weights[0][idxs[3]].sort(descending=True).values > 0.012]

# [180, 182, 219, 194, 179, 199, 110, 208, 101, 202,  38, 169, 151, 250,
#         106, 207, 185, 177, 201, 193, 186, 147, 181]

GO:0005488


tensor([181, 183, 220, 195, 180, 200, 111, 209, 102, 203,  39, 170, 152, 251,
        107, 208, 186, 178, 202, 194, 187, 148, 182])

In [30]:
set(list(weights[0][idxs[1]].sort(descending=True).indices[weights[0][idxs[3]].sort(descending=True).values > 0.012].sort().values-1)).intersection(set(test_df.res[0]))

set()

In [33]:
model.eval()
score_total = []
for i,batch in enumerate(test_loader):
    lengths = batch['lengths']
    ress = batch['res']
    batch = {key: val.cuda() for key, val in batch.items() if type(val) is torch.Tensor}
    (_, logits, weights) = model(**batch)
    preds = logits.sigmoid().detach().cpu()
    weights = weights.detach().cpu()
    for pred,weight,length,res in zip(preds,weights,lengths,ress):
        idxs = pred.sort(descending=True).indices[pred.sort(descending=True).values>0.40]
        
        pred_sites =[]
        for idx in idxs:
            sorted_weight = weight[idx].sort(descending=True)
            candidate = sorted_weight.indices[:int(0.1*length+1)]-1
            candidate = list(candidate.detach().cpu().numpy())
            pred_sites.extend(candidate)
        score = len(set(list(pred_sites)).intersection(set(res)))/len(set(list(pred_sites)).union(set(res)))
        score_total.append(score)
print(sum(score_total)/len(score_total))


0.018049491175613497


# test p2go partial component postmlp fv bpo t4f8

In [1]:
import torch.nn as nn
import math
from torch.nn import BCEWithLogitsLoss
from deepfold.models.esm_model import MLPLayer,MLPLayer3D
from deepfold.models.gnn_model import GCN

# attention  module
def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不相关的项 Defined in :numref:`sec_seq2seq_decoder`"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作."""
    # X:3D张量，valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换，从而其softmax输出为0
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)



class LabelBasedAttention(nn.Module):
    def __init__(self, label_dim, word_dim, latent_dim, nb_labels, nb_words,components_factor=2,temperature=2):
        super().__init__()
        self.label_dim = label_dim
        self.word_dim = word_dim
        self.latent_dim = latent_dim
        self.fc_q = MLPLayer(self.label_dim, self.latent_dim)
        self.fc_k = MLPLayer3D(self.word_dim, self.latent_dim)
        self.nb_labels = nb_labels
        self.nb_words = nb_words
        assert components_factor > 0
        self.nb_components = int(self.nb_labels/(components_factor*(1+int(self.nb_labels/self.nb_words))))
        self.component_embedding = nn.Parameter(torch.randn((self.nb_components,self.latent_dim)),requires_grad=True)
        self.temperature = temperature
        # fc_v
        self.fc_v = MLPLayer3D(self.word_dim, self.latent_dim)

    def forward(self, label_embedding, word_embedding, lengths):
        k_embedding = self.fc_k(word_embedding)
        q_embedding = self.fc_q(label_embedding)
        k_embedding = k_embedding.transpose(-1,-2)
        component_embedding = self.component_embedding.repeat((word_embedding.shape[0], 1, 1))
        c_w = torch.bmm(component_embedding,k_embedding)
        q_embedding = q_embedding.repeat((word_embedding.shape[0], 1, 1))
        l_c = torch.bmm(q_embedding,component_embedding.transpose(-1,-2))
        l_c = l_c.softmax(dim=-1)
        l_w = torch.bmm(l_c,c_w)
        l_w /= self.temperature
        l_w = masked_softmax(l_w, lengths+1)
        label_level_embedding = torch.bmm(l_w,self.fc_v(word_embedding))
        
        return label_level_embedding, l_w

class P2GO(nn.Module):
    def __init__(self,
                 terms_embedding,adj,
                 model_dir: str = 'esm1b_t33_650M_UR50S',
                 aa_dim=1280,
                 latent_dim=256,
                 dropout_rate=0.1,temperature=2,components_factor=2):
        super().__init__()
        self.terms_embedding = nn.Parameter(terms_embedding,requires_grad=True)
        self.terms_dim = terms_embedding.shape[1]
        self.latent_dim = latent_dim
        self.adj = adj
        
        # backbone
        backbone, _ = esm.pretrained.load_model_and_alphabet(
            model_dir)
        unfreeze_layers = None # [32] # total 0-32 layer
        self.backbone = self.unfreeze(backbone,unfreeze_layers)
        self.nb_classes = adj.shape[0]
        # label based attention
        self.label_attention = LabelBasedAttention(label_dim=self.terms_dim,
                                                word_dim=aa_dim,latent_dim=latent_dim,
                                                nb_labels=self.nb_classes,nb_words=1024,
                                                components_factor=components_factor,
                                                temperature=temperature)
        # go transform
        self.go_transfrom = MLPLayer(self.terms_dim,self.latent_dim,dropout_rate)
        # gnn module
        self.gcn = GCN(latent_dim, latent_dim)
        # output layer
        self.go_transform_post = MLPLayer(int(2 * latent_dim), latent_dim, dropout_rate)
        # post mlp
        self.post_mlp = MLPLayer(self.nb_classes, self.nb_classes, dropout_rate)

    def unfreeze(self, backbone, unfreeze_layers:list):
        for name ,param in backbone.named_parameters():
            param.requires_grad = False
        if unfreeze_layers is not None:
            if 'lm_head' in name:
                param.requires_grad = True
            for idx in unfreeze_layers:
                for _, p in backbone.layers[idx].named_parameters():
                    p.requires_grad = True
        return backbone

    def forward(self, input_ids, lengths, labels, output_attention_weights=True):
        # backbone
        x = self.backbone(input_ids, repr_layers=[33])['representations'][33]
        # x = x[:, 1:]
        # x [B,L,C]
        label_level_embedding, weights = self.label_attention(self.terms_embedding, x, lengths)
        go_embedding = self.go_transfrom(self.terms_embedding)
        # go embedding
        go_out = self.gcn(go_embedding, self.adj)
        # output layer
        go_out = torch.cat((go_embedding, go_out), dim=1)
        go_out = self.go_transform_post(go_out)
        go_out = go_out.repeat((x.shape[0], 1, 1))
        logits = self.post_mlp(torch.sum(go_out * label_level_embedding, dim=-1))
        outputs = (logits,)
        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            labels = labels.float()
            loss = loss_fct(logits.view(-1, self.nb_classes),
                            labels.view(-1, self.nb_classes))
            outputs = (logits,loss)
        if output_attention_weights:
            if labels is not None:
                outputs = (loss, logits, weights)
            else:
                outputs = (logits, weights)

        return outputs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Args:
    def __init__(self) -> None:
        self.data_path='esm'
        self.data_path='/share/home/niejianzheng/xbiome/datasets/protein/cafa3/'
        self.resume = '/home/wangbin/X-DeepGO/work_dir/p2go_partial_component_postmlp_fv_bpo__new_adj_top2down_t4f8/checkpoint_10.pth'
        self.batch_size = 4
        self.workers=1
        self.namespace = 'bpo'
        self.temperature=4
        self.components_factor=8

# Dataset and DataLoader
args = Args()

In [3]:
import pandas as pd
import os
import torch

from torch.utils.data import DataLoader
from deepfold.utils.make_graph import build_graph
from deepfold.utils.model import load_model_checkpoint
from torch.utils.data.dataset import Dataset
from typing import Dict
import esm
import gc
import random

class TestAttentionDataset(Dataset):
    def __init__(self,
                 test_file_name,
                 label_map,
                 model_dir: str = 'esm1b_t33_650M_UR50S',
                 max_length: int = 1024,
                 truncate: bool = True,
                 random_crop: bool = False):
        super().__init__()

        self.seqs, self.res = self.load_dataset(test_file_name)
        self.terms_dict = label_map
        self.num_classes = len(self.terms_dict)
        self.max_length = max_length
        self.truncate = truncate
        self.random_crop = random_crop

        esm_model, self.alphabet = esm.pretrained.load_model_and_alphabet(
            model_dir)
        self.batch_converter = self.alphabet.get_batch_converter()
        self.free_memory(esm_model)
    
    def free_memory(self, esm_model):
        del esm_model
        gc.collect()
        print('Delete the esm model, free memory!')

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        sequence = self.seqs[idx]
        if self.truncate:
            sequence = sequence[:self.max_length - 2]
        length = len(sequence)
        multilabel = [0] * self.num_classes
        return sequence, length, multilabel,self.res[idx]

    def load_dataset(self, test_df_file):
        df = pd.read_pickle(test_df_file)
        seq = list(df['sequence'])
        res = list(df['res'])
        return seq,res

    def collate_fn(self, examples) -> Dict[str, torch.Tensor]:
        """Function to transform tokens string to IDs; it depends on the model
        used."""
        sequences_list = [ex[0] for ex in examples]
        lengths = [ex[1] for ex in examples]
        multilabel_list = [ex[2] for ex in examples]
        res_list = [ex[3] for ex in examples]

        labels, strs, all_tokens = self.batch_converter([
            ('', sequence) for sequence in sequences_list
        ])

        # The model is trained on truncated sequences and passing longer ones in at
        # infernce will cause an error. See https://github.com/facebookresearch/esm/issues/21
        if self.truncate:
            all_tokens = all_tokens[:, :self.max_length]

        if all_tokens.shape[1] < 1024:
            tmp = torch.ones((all_tokens.shape[0], 1024 - all_tokens.shape[1]))
            all_tokens = torch.cat([all_tokens, tmp], dim=1)
        all_tokens = all_tokens.int()
        all_tokens = all_tokens.to('cpu')
        encoded_inputs = {
            'input_ids': all_tokens,
        }
        encoded_inputs['lengths'] = torch.tensor(lengths, dtype=torch.int)
        encoded_inputs['labels'] = torch.tensor(multilabel_list,
                                                dtype=torch.int)
        encoded_inputs['res'] = res_list
        return encoded_inputs


In [4]:
args.gpu = 0
# Dataset and DataLoader
adj, multi_hot_vector, label_map, label_map_ivs,_ = build_graph(
    data_path=args.data_path, namespace=args.namespace)
test_dataset = TestAttentionDataset(test_file_name='test_res.pkl',label_map=label_map)
# dataloders
test_loader = DataLoader(test_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True,
                            collate_fn=test_dataset.collate_fn)

Number of annotated terms: 19901
number of edges:38021




Delete the esm model, free memory!


In [5]:
import numpy as np

# model
terms_all = pd.read_pickle(os.path.join(args.data_path,'all_terms_partial_order_embeddings.pkl'))
terms = pd.read_pickle(os.path.join(args.data_path,args.namespace,args.namespace + '_terms.pkl'))
terms_embedding = terms.merge(terms_all)
embeddings = np.concatenate([np.array(embedding,ndmin=2) for embedding in terms_embedding.embeddings.values])
terms_embedding = torch.Tensor(embeddings)
terms_embedding = terms_embedding.cuda()
adj = adj.cuda()
model = P2GO(terms_embedding, adj, model_dir= 'esm1b_t33_650M_UR50S', aa_dim= 1280, latent_dim = 256, dropout_rate=0.1,temperature=args.temperature,components_factor=args.components_factor)

if args.resume is not None:
    model_state, optimizer_state = load_model_checkpoint(args.resume)
    model.load_state_dict(model_state)

model = model.cuda()



=> loading checkpoint '/home/wangbin/X-DeepGO/work_dir/p2go_partial_component_postmlp_fv_bpo__new_adj_top2down_t4f8/checkpoint_10.pth'


In [7]:
model.eval()
score_total = []
for i,batch in enumerate(test_loader):
    lengths = batch['lengths']
    ress = batch['res']
    batch = {key: val.cuda() for key, val in batch.items() if type(val) is torch.Tensor}
    (_, logits, weights) = model(**batch)
    preds = logits.sigmoid().detach().cpu()
    weights = weights.detach().cpu()
    for pred,weight,length,res in zip(preds,weights,lengths,ress):
        idxs = pred.sort(descending=True).indices[pred.sort(descending=True).values>0.40]
        score_sample =[]
        for idx in idxs:
            sorted_weight = weight[idx].sort(descending=True)
            candidate = sorted_weight.indices[sorted_weight.values > 4/length]-1
            candidate = candidate.detach().cpu().numpy()
            score = len(set(list(candidate)).intersection(set(res)))/len(set(list(candidate)).union(set(res)))
            score_sample.append(score)
        score_total.append(sum(score_sample)/len(score_sample))
print(sum(score_total)/len(score_total))

0.005874340211484921


# test p2go partial component postmlp t2f4 cco

In [4]:
import pandas as pd
import os
import torch

from torch.utils.data import DataLoader
from deepfold.utils.make_graph import build_graph
from deepfold.utils.model import load_model_checkpoint
import torch.nn as nn
import math
from torch.nn import BCEWithLogitsLoss
from deepfold.models.esm_model import MLPLayer,MLPLayer3D
from deepfold.models.gnn_model import GCN

# attention  module
def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不相关的项 Defined in :numref:`sec_seq2seq_decoder`"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作."""
    # X:3D张量，valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换，从而其softmax输出为0
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

class DotProductAttention(nn.Module):
    """缩放点积注意力."""
    def __init__(self, dropout, lambd=1,**kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.lambd = lambd
        self.dropout = nn.Dropout(dropout)
    # queries的形状：(batch_size，查询的个数，d)
    # keys的形状：(batch_size，“键－值”对的个数，d)
    # values的形状：(batch_size，“键－值”对的个数，值的维度)
    # valid_lens的形状:(batch_size，)或者(batch_size，查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(self.lambd * d)
        # scores /= self.tao
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values),self.attention_weights

def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状."""
    # 输入X的形状:(batch_size，查询或者“键－值”对的个数，num_hiddens)
    # 输出X的形状:(batch_size，查询或者“键－值”对的个数，num_heads，
    # num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # 输出X的形状:(batch_size，num_heads，查询或者“键－值”对的个数,
    # num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    # 最终输出的形状:(batch_size*num_heads,查询或者“键－值”对的个数,
    # num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作."""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

class MultiHeadAttention(nn.Module):
    """多头注意力."""
    def __init__(self,
                 key_size,
                 query_size,
                 value_size,
                 num_hiddens,
                 num_heads,
                 dropout,
                 bias=False,
                 lambd=None,
                 **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout,lambd=lambd)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens=None, output_attentions=True):
        # queries，keys，values的形状:
        # (batch_size，查询或者“键－值”对的个数，num_hiddens)
        # valid_lens　的形状:
        # (batch_size，)或(batch_size，查询的个数)
        # 经过变换后，输出的queries，keys，values　的形状:
        # (batch_size*num_heads，查询或者“键－值”对的个数，
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0，将第一项（标量或者矢量）复制num_heads次，
            # 然后如此复制第二项，然后诸如此类。
            valid_lens = torch.repeat_interleave(valid_lens,
                                                 repeats=self.num_heads,
                                                 dim=0)

        # output的形状:(batch_size*num_heads，查询的个数，
        # num_hiddens/num_heads)
        output,weight = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size，查询的个数，num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        weight_concat = transpose_output(weight, self.num_heads)
        outputs = (output_concat, weight_concat) if output_attentions else (output_concat,)
        return outputs

class LabelBasedAttention(nn.Module):
    def __init__(self, label_dim, word_dim, latent_dim, nb_labels, nb_words,components_factor=2,temperature=2):
        super().__init__()
        self.label_dim = label_dim
        self.word_dim = word_dim
        self.latent_dim = latent_dim
        self.fc_label = MLPLayer(self.label_dim, self.latent_dim)
        self.fc_word = MLPLayer3D(self.word_dim, self.latent_dim)
        self.nb_labels = nb_labels
        self.nb_words = nb_words
        assert components_factor > 0
        self.nb_components = int(self.nb_labels/(components_factor*(1+int(self.nb_labels/self.nb_words))))
        self.component_embedding = nn.Parameter(torch.randn((self.nb_components,self.latent_dim)),requires_grad=True)
        self.temperature = temperature

    def forward(self, label_embedding, word_embedding, lengths):
        word_embedding = self.fc_word(word_embedding)
        label_embedding = self.fc_label(label_embedding)
        word_embedding = word_embedding.transpose(-1,-2)
        component_embedding = self.component_embedding.repeat((word_embedding.shape[0], 1, 1))
        c_w = torch.bmm(component_embedding,word_embedding)
        label_embedding = label_embedding.repeat((word_embedding.shape[0], 1, 1))
        l_c = torch.bmm(label_embedding,component_embedding.transpose(-1,-2))
        l_c = l_c.softmax(dim=-1)
        l_w = torch.bmm(l_c,c_w)
        l_w /= self.temperature
        l_w = masked_softmax(l_w, lengths+1)
        label_level_embedding = torch.bmm(l_w,word_embedding.transpose(-1,-2))
        
        return label_level_embedding, l_w

class P2GO(nn.Module):
    def __init__(self,
                 terms_embedding,adj,
                 model_dir: str = 'esm1b_t33_650M_UR50S',
                 aa_dim=1280,
                 latent_dim=256,
                 dropout_rate=0.1):
        super().__init__()
        self.terms_embedding = nn.Parameter(terms_embedding,requires_grad=True)
        self.terms_dim = terms_embedding.shape[1]
        self.latent_dim = latent_dim
        self.adj = adj
        
        # backbone
        backbone, _ = esm.pretrained.load_model_and_alphabet(
            model_dir)
        unfreeze_layers = None # [32] # total 0-32 layer
        self.backbone = self.unfreeze(backbone,unfreeze_layers)
        self.nb_classes = adj.shape[0]
        # label based attention
        self.label_attention = LabelBasedAttention(label_dim=self.terms_dim,word_dim=aa_dim,latent_dim=latent_dim,nb_labels=self.nb_classes,nb_words=1024,
                                                    components_factor=4,temperature=2)
        # go transform
        self.go_transfrom = MLPLayer(self.terms_dim,self.latent_dim)
        # gnn module
        self.gcn = GCN(latent_dim, latent_dim)
        # output layer
        self.go_transform_post = MLPLayer(int(2 * latent_dim), latent_dim)
        # post mlp
        self.post_mlp = MLPLayer(self.nb_classes, self.nb_classes)

    def unfreeze(self, backbone, unfreeze_layers:list):
        for name ,param in backbone.named_parameters():
            param.requires_grad = False
        if unfreeze_layers is not None:
            if 'lm_head' in name:
                param.requires_grad = True
            for idx in unfreeze_layers:
                for _, p in backbone.layers[idx].named_parameters():
                    p.requires_grad = True
        return backbone

    def forward(self, input_ids, lengths, labels, output_attention_weights=True):
        # backbone
        x = self.backbone(input_ids, repr_layers=[33])['representations'][33]
        # x = x[:, 1:]
        # x [B,L,C]
        label_level_embedding, weights = self.label_attention(self.terms_embedding, x, lengths)
        go_embedding = self.go_transfrom(self.terms_embedding)
        # go embedding
        go_out = self.gcn(go_embedding, self.adj)
        # output layer
        go_out = torch.cat((go_embedding, go_out), dim=1)
        go_out = self.go_transform_post(go_out)
        go_out = go_out.repeat((x.shape[0], 1, 1))
        logits = self.post_mlp(torch.sum(go_out * label_level_embedding, dim=-1))
        outputs = (logits,)
        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            labels = labels.float()
            loss = loss_fct(logits.view(-1, self.nb_classes),
                            labels.view(-1, self.nb_classes))
            outputs = (logits,loss)
        if output_attention_weights:
            outputs = (loss, logits, weights)

        return outputs

class Args:
    def __init__(self) -> None:
        self.data_path='esm'
        self.data_path='/share/home/niejianzheng/xbiome/datasets/protein/cafa3/'
        self.resume = '/home/wangbin/X-DeepGO/work_dir/p2go_partial_component_cco__new_adj_top2down_t2f4/checkpoint_8.pth'
        self.batch_size = 4
        self.workers=1
        self.namespace = 'cco'

# Dataset and DataLoader
args = Args()

from torch.utils.data.dataset import Dataset
from typing import Dict
import esm
import gc
import random

class TestAttentionDataset(Dataset):
    def __init__(self,
                 test_file_name,
                 label_map,
                 model_dir: str = 'esm1b_t33_650M_UR50S',
                 max_length: int = 1024,
                 truncate: bool = True,
                 random_crop: bool = False):
        super().__init__()

        self.seqs, self.res = self.load_dataset(test_file_name)
        self.terms_dict = label_map
        self.num_classes = len(self.terms_dict)
        self.max_length = max_length
        self.truncate = truncate
        self.random_crop = random_crop

        esm_model, self.alphabet = esm.pretrained.load_model_and_alphabet(
            model_dir)
        self.batch_converter = self.alphabet.get_batch_converter()
        self.free_memory(esm_model)
    
    def free_memory(self, esm_model):
        del esm_model
        gc.collect()
        print('Delete the esm model, free memory!')

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        sequence = self.seqs[idx]
        if self.truncate:
            sequence = sequence[:self.max_length - 2]
        length = len(sequence)
        multilabel = [0] * self.num_classes
        return sequence, length, multilabel,self.res[idx]

    def load_dataset(self, test_df_file):
        df = pd.read_pickle(test_df_file)
        seq = list(df['sequence'])
        res = list(df['res'])
        return seq,res

    def collate_fn(self, examples) -> Dict[str, torch.Tensor]:
        """Function to transform tokens string to IDs; it depends on the model
        used."""
        sequences_list = [ex[0] for ex in examples]
        lengths = [ex[1] for ex in examples]
        multilabel_list = [ex[2] for ex in examples]
        res_list = [ex[3] for ex in examples]

        labels, strs, all_tokens = self.batch_converter([
            ('', sequence) for sequence in sequences_list
        ])

        # The model is trained on truncated sequences and passing longer ones in at
        # infernce will cause an error. See https://github.com/facebookresearch/esm/issues/21
        if self.truncate:
            all_tokens = all_tokens[:, :self.max_length]

        if all_tokens.shape[1] < 1024:
            tmp = torch.ones((all_tokens.shape[0], 1024 - all_tokens.shape[1]))
            all_tokens = torch.cat([all_tokens, tmp], dim=1)
        all_tokens = all_tokens.int()
        all_tokens = all_tokens.to('cpu')
        encoded_inputs = {
            'input_ids': all_tokens,
        }
        encoded_inputs['lengths'] = torch.tensor(lengths, dtype=torch.int)
        encoded_inputs['labels'] = torch.tensor(multilabel_list,
                                                dtype=torch.int)
        encoded_inputs['res'] = res_list
        return encoded_inputs

args.gpu = 0
# Dataset and DataLoader
adj, multi_hot_vector, label_map, label_map_ivs,_ = build_graph(
    data_path=args.data_path, namespace=args.namespace)
test_dataset = TestAttentionDataset(test_file_name='test_res.pkl',label_map=label_map)
# dataloders
test_loader = DataLoader(test_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True,
                            collate_fn=test_dataset.collate_fn)
import numpy as np
# model
terms_all = pd.read_pickle(os.path.join(args.data_path,'all_terms_partial_order_embeddings.pkl'))
terms = pd.read_pickle(os.path.join(args.data_path,args.namespace,args.namespace + '_terms.pkl'))
terms_embedding = terms.merge(terms_all)
embeddings = np.concatenate([np.array(embedding,ndmin=2) for embedding in terms_embedding.embeddings.values])
terms_embedding = torch.Tensor(embeddings)
terms_embedding = terms_embedding.cuda()
adj = adj.cuda()
model = P2GO(terms_embedding, adj, model_dir= 'esm1b_t33_650M_UR50S', aa_dim= 1280, latent_dim = 256, dropout_rate=0.1)
if args.resume is not None:
    model_state, optimizer_state = load_model_checkpoint(args.resume)
    model.load_state_dict(model_state)

model = model.cuda()

model.eval()
score_total = []
for i,batch in enumerate(test_loader):
    lengths = batch['lengths']
    ress = batch['res']
    batch = {key: val.cuda() for key, val in batch.items() if type(val) is torch.Tensor}
    (_, logits, weights) = model(**batch)
    preds = logits.sigmoid().detach().cpu()
    weights = weights.detach().cpu()
    for pred,weight,length,res in zip(preds,weights,lengths,ress):
        idxs = pred.sort(descending=True).indices[pred.sort(descending=True).values>0.5]
        
        score_sample =[]
        for idx in idxs:
            sorted_weight = weight[idx].sort(descending=True)
            candidate = sorted_weight.indices[sorted_weight.values > 4/length]-1
            candidate = candidate.detach().cpu().numpy()
            score = len(set(list(candidate)).intersection(set(res)))/len(set(list(candidate)).union(set(res)))
            score_sample.append(score)
        score_total.append(sum(score_sample)/len(score_sample))
print(sum(score_total)/len(score_total))


Number of annotated terms: 2470
number of edges:3676




Delete the esm model, free memory!




=> loading checkpoint '/home/wangbin/X-DeepGO/work_dir/p2go_partial_component_cco__new_adj_top2down_t2f4/checkpoint_8.pth'
0.005390249583437039
