In [53]:
import os
import sys
import esm
import time
import torch
import random
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


from Bio import SeqIO
from torch import einsum
from pathlib import Path
from einops import rearrange
from torch.utils.data import DataLoader

In [54]:
class Args():
    def __init__(self, In=None, Out=None, weight=None,cutoff=None):
        self.In = In
        self.Out = Out
        self.weight = weight
        self.cutoff = cutoff

args = Args(In = 'Test_clef_rep.pkl',
            Out = 'Test_result.csv',
            weight = './pretrained_model/T6classifier-CLEF-DP+MSA+3Di+AT-0.7cutoff.pt'  
)

args

<__main__.Args at 0x1d828e20910>

In [55]:
class test_dnn(nn.Module):
    def __init__(self, num_embeds = 1280, finial_drop = 0.5, out_dim = 1):
        super().__init__()
        self.dnn = nn.Sequential(nn.Linear(num_embeds, 2 * num_embeds), nn.ReLU(),
                                 nn.Linear(2 * num_embeds, num_embeds))
        self.out = nn.Sequential(nn.Linear(num_embeds, 128), nn.ReLU(),
                                 nn.Linear(128, out_dim))
        self.binaryclass = True if out_dim == 1 else False
        self.Dropout = nn.Dropout(finial_drop)
        self.ln = nn.LayerNorm(num_embeds)
      
    def forward(self, x):
        x = x['feature']
        x = self.ln(self.dnn(x))
        x = self.out(self.Dropout(x))    
        if self.binaryclass:
            x = torch.sigmoid(x).squeeze(-1)  
        
            
        return x  

In [56]:
def check_hidden_layer_dimensions(data_dict):
    hidden_layer_size = None
    for key, value in data_dict.items():
        if not isinstance(value, np.ndarray):
            raise ValueError(f"Value for key '{key}' is not a numpy array.")

        current_size = value.shape[-1]
        if hidden_layer_size is None:
            hidden_layer_size = current_size
        elif hidden_layer_size != current_size:
            return None  

    return hidden_layer_size

In [57]:

def load_feature_from_local(feature_path, silence=False):
    '''
    load feature dict from local path (Using pickle.load() or torch.load())
    the dictionary is like:
        {
          Protein_ID : feature_array [a 1D numpy array]
        }
    '''
    # Try pickle.load() function 
    try:
        with open(feature_path, 'rb') as f:
            obj = pickle.load(f)
        if not silence:
            print("File is loaded using pickle.load()")
        return obj
    except (pickle.UnpicklingError, EOFError):
        pass

    # Try torch.load() function
    try:
        obj = torch.load(feature_path)
        if not silence:
            print("File is loaded using torch.load()")
        return obj
    except (torch.serialization.UnsupportedPackageTypeError, RuntimeError):
        pass

    print("Unable to load file.")
    return None

In [58]:
class Potein_rep_datasets:
  
  def __init__(self, input_path, train_range = None, test_range = None, label_tag = 'label'):
        '''
        [input_path] is a Path_dict containing feature ID and corresponding Local_path
        e.g {'feature':'./path/to/you/feature_file'}
        '''
        sequence_data = {}
        try:
            for key, value in input_path.items():
                if isinstance(value, str):
                    print(f"try to load feature from path:{value}")
                    tmp = load_feature_from_local(value)
                elif isinstance(value, np.ndarray):
                    print(f"try to load feature from numpy_array")
                    tmp = value
                else:
                    print(f"can not load feature {key}")
                    continue
                for ID, feat in tmp.items():
                    if ID not in sequence_data:
                        sequence_data[ID] = {key:feat}
                    else:
                        sequence_data[ID].update( {key:feat} )
            
            self.sequence_data = {}   
            for key, value in sequence_data.items():
                if len(value) < len(input_path):
                   print(f"imcomplete feature ID {key} removed")
                else:
                   self.sequence_data[key] = value
            
            if label_tag not in input_path:
                print(f"Add mock label [{label_tag}] of 0 for each sample")
                for key in self.sequence_data:
                    self.sequence_data[key][label_tag] = 0
        except:
            print(f"No valid [input_path] to load : {input_path}, return an empty dataset")
            self.sequence_data = {}
               
        self.data_indices = {i : ID for i, ID in enumerate(self.sequence_data)}
        
        self.label_tag = label_tag
        
        print(f"total {len(self.data_indices)} sample loaded")
        self.feature_list = list(input_path.keys())
        
        self.train_range = train_range
        self.test_range = test_range
        
        if not self.train_range:
            self.train_range = range(len(self.data_indices))
      
        if not self.test_range:
            self.test_range = range(len(self.data_indices))
            
        
        
        
  def Dataloader(self, batch_size, shuffle = True, 
                          test = False,
                          max_num_padding = None,
                          device = 'cpu'):

        sele_range = self.test_range if test else self.train_range
        Nsample=len(list(sele_range))
        indices=list(sele_range)
        if shuffle:
            random.shuffle(indices)
        datasets = []
        IDs = []
        n = 0
        for i in indices:
            n += 1
            
            ID = self.data_indices[i]
            IDs.append(ID)
            data = self.sequence_data[ID]
            datasets.append(data)
            
            if len(datasets) == batch_size or n == Nsample:
                try:
                    labels = torch.tensor([x[self.label_tag] for x in datasets]).to(torch.long)
                except:
                    print(f"feature <{self.label_tag}> is not a valid label value, using mock labels instead.")
                    labels =  torch.tensor([0 for x in datasets]).to(torch.long)
                batch = {
                  'labels':labels.to(device),
                  'ID':IDs
                }
                for key in self.feature_list:
                    if key != self.label_tag:
                      if max_num_padding:
                          padded_seq_input = [pad_to_max_length(x[key], max_num_padding)[0] for x in datasets]
                          valid_lens = torch.Tensor([pad_to_max_length(x[key], max_num_padding)[1] for x in datasets ]).to(torch.long)
                          seq_input = np.concatenate([np.expand_dims(x, axis = 0) for x in padded_seq_input])
                          seq_input = torch.from_numpy(seq_input).to(torch.float32)
                          batch.update({key:seq_input.to(device)})
                          if (valid_lens.max() > 0).item():
                              if 'valid_lens' in batch :
                                  try:
                                      assert (batch['valid_lens'] == valid_lens).sum().item() == batch_size 
                                  except:
                                      print(f"Warning: please make sure valid lens of 2D tensors is the same \n{batch['valid_lens']}\n{valid_lens}")
                            
                              batch.update({'valid_lens':valid_lens.to(device)})
                      else:
                          seq_input = np.vstack([x[key] for x in datasets])
                          seq_input = torch.from_numpy(seq_input).to(torch.float32)
                          batch.update({key:seq_input.to(device)})
                
                datasets = []
                IDs = []
                
                yield batch
    
  def __len__(self):
        return len(self.sequence_data)
      
  def split_test(self, test_size = 0.1):
    
        if len(self.sequence_data) > 1:
            train_indices, test_indices = train_test_split(range(len(self.sequence_data)), test_size=test_size, random_state=42)
            self.train_range = train_indices
            self.test_range = test_indices
        else:
            print(f"Number of {len(self)} data can not be splited.")
        
  

        


def pad_to_max_length(seq, max_length, len_dim = 0, feat_dim = 1):
    if len(seq.shape) < 2:
        return seq, 0
    seq_length = seq.shape[len_dim]
    if seq_length < max_length:
        padded_seq = np.zeros([max_length, seq.shape[feat_dim]])
        padded_seq[:seq_length] = seq
        return padded_seq, seq_length
    else:
        return seq[:max_length, :], max_length

In [59]:

def predict_from_1D_rep(input_file, initial_model, params_path,
                        output_file = None,cutoff = None,
                        Return = True):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    num_hidden = check_hidden_layer_dimensions(load_feature_from_local(input_file, silence=True))
    assert num_hidden, f"Dimension numbers of the last dimension is not same; {input_file}"

    model = initial_model(num_hidden).to(device)
    eff_type = os.path.split(params_path)[-1].lower().split('classifier')[0].split('-')[-1]
    eff_type = f'{eff_type.upper()}SE' if eff_type in ['t3', 't4', 't6'] else 'Effector'
    Dataset = Potein_rep_datasets({'feature':input_file})
    if not cutoff:
        try:
            cutoff = float(os.path.split(params_path)[-1].lower().split('cutoff')[0].split('-')[-1])
        except:
            cutoff = 0.5
    print(f'Binary cutoff of {cutoff} used.')
    output = {
        'ID':[],
        'pred':[],
        eff_type:[]
    }
    model.load_state_dict(torch.load(params_path, map_location=torch.device('cpu')))
    model.eval()
    Dataset.test_range = range(len(Dataset))
    for batch in Dataset.Dataloader(batch_size=32,shuffle=False,max_num_padding=None,test=True,device=device):
        with torch.no_grad():
            y_pred = model(batch)
        y_pred = y_pred.detach().to('cpu').numpy()
        output['ID'].extend(batch['ID'])
        output['pred'].append(y_pred)
    output['pred'] = np.concatenate(output['pred'], 0)  
    output['pred'] = list(output['pred'])  
    output[eff_type] = ['Yes' if x >= cutoff else 'No' for x in output['pred']]
    import pandas as pd
    output = pd.DataFrame(output)
    if output_file:
        try:
            output.to_csv(output_file)
            print(f'Predictions saved as {output_file}')
        except:
            print(f'Predictions failed to save as {output_file}')
            import uuid
            tmp_name = str(uuid.uuid4())+'_clef'
            tag = os.path.split(input_file)[-1]
            output_file = os.path.join(os.path.dirname(input_file), f"./{tag}_{eff_type}_prediction.csv") 
            output.to_csv(output_file)
            print(f'Predictions saved as {output_file}')    
    if Return:
        return output

In [60]:
input_file = args.In 
output_file = args.Out
cutoff = args.cutoff
classifier_path = args.weight   
classifer = test_dnn
config = {
  'input_file':input_file,
  'output_file':output_file,
  'initial_model':classifer,
  'params_path':classifier_path,
  'cutoff':cutoff,
  'Return':True
}        


In [61]:
output_dict = predict_from_1D_rep(**config)
print(output_dict.iloc[:10,:])
print(f"Predictions saved at {output_file}")

try to load feature from path:Test_clef_rep.pkl
File is loaded using pickle.load()
Add mock label [label] of 0 for each sample
total 10 sample loaded
Binary cutoff of 0.7 used.
Predictions saved as Test_result.csv
               ID      pred T6SE
0     NP_250554.1  0.993134  Yes
1     NP_249515.1  0.999430  Yes
2          4F0V_A  0.991892  Yes
3     YP_898952.1  0.998459  Yes
4      CBG37356.1  0.999968  Yes
5  WP_151253718.1  0.000017   No
6      ADZ63249.1  0.000073   No
7          P46922  0.000089   No
8          Q9TYU9  0.000034   No
9          Q92F67  0.001061   No
Predictions saved at Test_result.csv
