# Load models

In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from datasets import load_dataset
from sklearn import metrics
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
class Network(nn.Module):
    def __init__(self, inp, hidden, output, device):
        super(Network,self).__init__()
        
        self.device = device
        
        self.linear1=nn.Linear(inp, hidden)
        self.linear2=nn.Linear(hidden, output)
        
        self.loss = nn.CrossEntropyLoss()
 
        
    def forward(self,x):
        x=self.linear1(x)
        x=self.linear2(x)
        return x
    
    def train_model(self, dataset, epochs):  
        self.train()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001) # TODO tune

        for epoch in range(epochs):
            with tqdm(dataset, unit="batch") as tepoch:
                for inputs, targets in tepoch:

                    inputs, targets = inputs.to(self.device), targets.to(self.device)
                    tepoch.set_description(f"Epoch {epoch + 1}")
                    
                    targets = targets[:, 0].long() # TODO what is that?

                    # clear the gradients
                    optimizer.zero_grad()
                    # compute the model output
                    yhat = self(inputs)
                    # calculate accuracy
                    correct = (yhat.argmax(1) == targets).type(torch.float).sum().item()
                    accuracy = correct / len(inputs)
                    # calculate loss
                    loss = self.loss(yhat, targets)
                    # credit assignment
                    loss.backward()
                    # update model weights
                    optimizer.step()

                    tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
        
    def test(self, dataloader):
        self.eval()
        pred_label, actuals = list(), list()
        
        with torch.no_grad():
            for inputs, targets in dataloader:
                targets = targets[:, 0].long()
                
                inputs = inputs.to(self.device)
                
                # evaluate the model on the test set
                yhat = self(inputs)
                yhat = yhat.cpu().detach().numpy()
                actual = targets.numpy()
                yhat = yhat.argmax(1)
                # reshape for stacking
                actual = actual.reshape((len(actual), 1))
                yhat = yhat.reshape((len(yhat), 1))
                # store
                pred_label.append(yhat)
                actuals.append(actual)
        pred_label, actuals = np.vstack(pred_label), np.vstack(actuals)
        print("Predictions: ", pred_label[:10])
        print("Real labels: ", actuals[:10])
        # calculate accuracy
        acc = metrics.accuracy_score(actuals, pred_label)
        f1 = metrics.f1_score(actuals, pred_label, average='micro', zero_division=0)
        print(f"Test metrics: \n Accuracy: {acc}, F1 score: {float(f1):>6f}\n")
        return acc, f1



In [3]:
model = torch.load("ProtBertBFD_embedding_CNN_family.pth")
model.eval()

Network(
  (linear1): Linear(in_features=1024, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=7, bias=True)
  (loss): CrossEntropyLoss()
)

In [4]:
from transformers import AutoTokenizer, AutoModel, pipeline
import re
import skimage.measure

2023-03-31 07:37:11.108339: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
tqdm.pandas()

In [6]:
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False )

In [7]:
transformer = AutoModel.from_pretrained("Rostlab/prot_bert_bfd")

Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
fe = pipeline('feature-extraction', model=transformer, tokenizer=tokenizer,device=0 ) # device=0 for GPU, -1 for CPU

In [9]:
def get_embedding(seq):
    seq  = " ".join(re.sub(r"[UZOB]", "X", seq))
    embedding = fe(seq)
    features =  np.array(embedding[0][1:len(seq)+1])
    features = skimage.measure.block_reduce(features, (1024, 1), np.average)
    return np.float32(np.array(features[0], dtype=float))

---------------------------

## Load dataset

In [10]:
df = pd.read_csv("SUA5_family.csv")
#df.rename(columns = {'id2':'ID'}, inplace = True)
df

Unnamed: 0,ID,Sequence,xref_interpro,protein_families
0,Q86U90,MSPARRCRGMRAAVAASVGLSEGPAGSRSGRLFRPPSPAPAAPGAR...,IPR017945;IPR006070;,SUA5 family
1,P32579,MYLGRHFLAMTSKALFDTKILKVNPLSIIFSPDAHIDGSLPTITDP...,IPR017945;IPR006070;IPR038385;IPR005145;IPR010...,SUA5 family
2,Q3U5F4,MSTARPCAGLRAAVAAGMGLSDGPASSGRGCRLLRPPEPAPALPGA...,IPR017945;IPR006070;,SUA5 family
3,P45748,MNNNLQRDAIAAAIDVLNEERVIAYPTEAVFGVGCDPDSETAVMRL...,IPR017945;IPR006070;IPR023535;,"SUA5 family, TsaC subfamily"
4,Q499R4,MSTARPCAGLRAAVAAGMGLSDGPAGSSRGCRLLRPPAPAPALPGA...,IPR017945;IPR006070;,SUA5 family
...,...,...,...,...
28205,A0A8J4YTV4,MTGEHVYRPDKAGLAAAVAAINAGQPIVIPTDTVYGLAVRAGDPAA...,IPR000793;IPR038376;IPR005294;IPR017945;IPR035...,ATPase A chain family; ATPase alpha/beta chain...
28206,A0A8U0Q8M2,MATLDGFPASFYGEPGVLGMLANGISAFLVLLQNFNTARTGVPAEG...,IPR017945;IPR000791;IPR006070;,SUA5 family; Acetate uptake transporter (AceTr...
28207,A0A8V0YD24,MVSTAGFAAVFYSEPAVLGLLANVISAFLLCLQNFATAHTGLKPQG...,IPR017945;IPR000791;IPR006070;,SUA5 family; Acetate uptake transporter (AceTr...
28208,A0A0A2VYR0,MTWDVLPQDFSIRHSGSGAMDTQIVPDAATCPECLRELNCPADRRY...,IPR004421;IPR017945;IPR041440;IPR004688;IPR011...,NiCoT transporter (TC 2.A.52) family; SUA5 fam...


### Preprocess

In [14]:
import spyprot



In [16]:
SIZE = 100
SEARCH_FIELDS = ['sequence', 'protein_families']

accessions = df['ID'].tolist()

final_res = {}

for i in tqdm(range(0, len(accessions) - SIZE, SIZE)):
    uni = spyprot.UniprotSearch(SEARCH_FIELDS, accessions=accessions[i:i + SIZE])
    res = uni.get()
    final_res.update(res)
    
uni = spyprot.UniprotSearch(SEARCH_FIELDS, accessions=accessions[len(accessions) - SIZE:])
res = uni.get()
final_res.update(res)

for i in range(len(SEARCH_FIELDS)):
    df[SEARCH_FIELDS[i]] = df['ID'].progress_apply(lambda x: final_res[x][i] if x in final_res.keys() else "")

100%|██████████| 2/2 [00:00<00:00,  3.88it/s]
100%|██████████| 268/268 [00:00<00:00, 222765.25it/s]
100%|██████████| 268/268 [00:00<00:00, 289336.80it/s]


## Run SPOUT classification

In [11]:
# PARAMS
SIGNAL_STRENGTH = 10
THRESHOLD = 0.8
OTHER_SIGNAL_MAX = 3 # originaly 1

In [14]:
def get_class(row):
    em = get_embedding(row['Sequence'])
    embedding = torch.from_numpy(em).to('cuda')
    prediction = model(embedding)
    prediction = prediction.cpu().detach().tolist()
    if (prediction[0] > SIGNAL_STRENGTH) & (max(prediction) == prediction[0]) & ((np.exp(prediction[0]) / np.sum(np.exp(prediction), axis=0)) > THRESHOLD) & (prediction[1] < OTHER_SIGNAL_MAX) & (prediction[2] < OTHER_SIGNAL_MAX) & (prediction[3] < OTHER_SIGNAL_MAX) & (prediction[4] < OTHER_SIGNAL_MAX) & (prediction[5] < OTHER_SIGNAL_MAX) & (prediction[6] < OTHER_SIGNAL_MAX):
        return prediction[0]
    return prediction

In [None]:
df['SPOUT_classification'] = df.progress_apply(get_class, axis=1)

 50%|████▉     | 13986/28210 [22:46<12:03, 19.65it/s]  

In [16]:
df

Unnamed: 0,ID,Sequence,xref_interpro,protein_families,SPOUT_classification
0,Q86U90,MSPARRCRGMRAAVAASVGLSEGPAGSRSGRLFRPPSPAPAAPGAR...,IPR017945;IPR006070;,SUA5 family,"[5.620612621307373, -5.080528736114502, -4.030..."
1,P32579,MYLGRHFLAMTSKALFDTKILKVNPLSIIFSPDAHIDGSLPTITDP...,IPR017945;IPR006070;IPR038385;IPR005145;IPR010...,SUA5 family,"[11.250757217407227, -4.741343975067139, -2.76..."
2,Q3U5F4,MSTARPCAGLRAAVAAGMGLSDGPASSGRGCRLLRPPEPAPALPGA...,IPR017945;IPR006070;,SUA5 family,"[5.418772220611572, -4.671252250671387, -3.143..."
3,P45748,MNNNLQRDAIAAAIDVLNEERVIAYPTEAVFGVGCDPDSETAVMRL...,IPR017945;IPR006070;IPR023535;,"SUA5 family, TsaC subfamily","[4.522557735443115, -3.3533084392547607, -3.33..."
4,Q499R4,MSTARPCAGLRAAVAAGMGLSDGPAGSSRGCRLLRPPAPAPALPGA...,IPR017945;IPR006070;,SUA5 family,"[5.367674350738525, -4.529707431793213, -3.948..."
...,...,...,...,...,...
28205,A0A8J4YTV4,MTGEHVYRPDKAGLAAAVAAINAGQPIVIPTDTVYGLAVRAGDPAA...,IPR000793;IPR038376;IPR005294;IPR017945;IPR035...,ATPase A chain family; ATPase alpha/beta chain...,"[-1.4523119926452637, -5.212150573730469, -2.9..."
28206,A0A8U0Q8M2,MATLDGFPASFYGEPGVLGMLANGISAFLVLLQNFNTARTGVPAEG...,IPR017945;IPR000791;IPR006070;,SUA5 family; Acetate uptake transporter (AceTr...,"[-1.874422311782837, -15.126782417297363, -6.8..."
28207,A0A8V0YD24,MVSTAGFAAVFYSEPAVLGLLANVISAFLLCLQNFATAHTGLKPQG...,IPR017945;IPR000791;IPR006070;,SUA5 family; Acetate uptake transporter (AceTr...,"[-1.7296158075332642, -12.199769020080566, -5...."
28208,A0A0A2VYR0,MTWDVLPQDFSIRHSGSGAMDTQIVPDAATCPECLRELNCPADRRY...,IPR004421;IPR017945;IPR041440;IPR004688;IPR011...,NiCoT transporter (TC 2.A.52) family; SUA5 fam...,11.500462


## See if knotted

In [18]:
from datasets import Dataset, load_dataset
import pandas as pd
from tqdm import tqdm
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollator, Trainer, TrainingArguments
from datasets import load_metric, Features, Value
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, recall_score
from math import exp

In [22]:
def tokenize_function(s):
  seq_split = " ".join(s['Sequence'])
  return tokenizerM1(seq_split)

In [23]:
tokenizerM1 = AutoTokenizer.from_pretrained("roa7n/knots_protbertBFD_alphafold")
modelM1 = AutoModelForSequenceClassification.from_pretrained("roa7n/knots_protbertBFD_alphafold")

In [25]:
dss = Dataset.from_pandas(df[['Sequence']])

tokenized_dataset = dss.map(tokenize_function, num_proc=4)
tokenized_dataset.set_format("pt")

     

#0:   0%|          | 0/7053 [00:00<?, ?ex/s]

 

#1:   0%|          | 0/7053 [00:00<?, ?ex/s]

 

#2:   0%|          | 0/7052 [00:00<?, ?ex/s]

 

#3:   0%|          | 0/7052 [00:00<?, ?ex/s]

In [26]:
training_args = TrainingArguments('outputs', fp16=True, per_device_eval_batch_size=1, report_to='none')  

trainer = Trainer(
    modelM1,
    training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset,
    tokenizer=tokenizerM1
)

predictions, _, _ = trainer.predict(tokenized_dataset)
predictions = [np.exp(p[1]) / np.sum(np.exp(p), axis=0) for p in predictions]
df['knotted'] = predictions

Using cuda_amp half precision backend
The following columns in the test set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: Sequence. If Sequence are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 28210
  Batch size = 1
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [27]:
df.to_csv('SUA5_SPOUT_evaluation.tsv', sep='\t', index=False)

In [31]:
df[~df['SPOUT_classification'].apply(lambda x: isinstance(x, list))]

Unnamed: 0,ID,Sequence,xref_interpro,protein_families,SPOUT_classification,knotted
5,P39153,MKTKRWFVDVTDELSTNDPQIAQAAALLRENEVVAFPTETVYGLGA...,IPR017945;IPR006070;IPR038385;IPR005145;IPR010...,SUA5 family,10.207406,0.996094
10,Q9UYB2,MTIIINVRERIEEWKIRIAAGFIREGKLVAFPTETVYGLGANALDE...,IPR017945;IPR006070;IPR038385;IPR005145;IPR010...,SUA5 family,11.630935,0.995117
187,A5HY38,MKTKVMRLDENNIDEHVISEAGDILRQGGLVVFPTETVYGLGANAL...,IPR017945;IPR006070;IPR038385;IPR005145;IPR010...,SUA5 family,10.680384,0.978516
194,Q5JIA4,MTIVINMRDGLDEKKIKVAARLILEGKLVAFPTETVYGLGADALNE...,IPR017945;IPR006070;IPR038385;IPR005145;IPR010...,SUA5 family,11.213667,0.995117
198,A0A075AQP2,MDTKILEFPPVKRKTGGKHSYLYEEIDFFENLNIKRAVKSLKEGSV...,IPR017945;IPR006070;IPR038385;IPR005145;IPR010...,SUA5 family,10.967866,0.970215
...,...,...,...,...,...,...
27807,A0A6A6D6I1,METRTLPVDASKLGKVIATPLPDDILDELDIELSLDSQDARHLKDA...,IPR017945;IPR006070;IPR038385;IPR005145;,SUA5 family,10.754261,0.997559
27971,A0A8H2ZZW0,MTVNHHPLRVLKCDPSSVNFDSPSPSTSRHVINDTETRTAIKAAAE...,IPR017945;IPR006070;IPR038385;IPR005145;,SUA5 family,10.371222,0.994141
28024,A7AVF9,MDVLSGIQRVTPVRVTLTDTSIETLDLLKSHLSIPGNLIALPTETV...,IPR017945;IPR006070;IPR038385;IPR005145;,SUA5 family,10.173036,0.991699
28046,K8YNJ8,MTVSTSASTSSSPPTIEAELTSDVDRAGDYLRAGGLVAFPTETVYG...,IPR017945;IPR006070;IPR038385;IPR005145;,SUA5 family,11.472482,0.994141


In [33]:
df[df['knotted'] > 0.9]

Unnamed: 0,ID,Sequence,xref_interpro,protein_families,SPOUT_classification,knotted
0,Q86U90,MSPARRCRGMRAAVAASVGLSEGPAGSRSGRLFRPPSPAPAAPGAR...,IPR017945;IPR006070;,SUA5 family,"[5.620612621307373, -5.080528736114502, -4.030...",0.990234
1,P32579,MYLGRHFLAMTSKALFDTKILKVNPLSIIFSPDAHIDGSLPTITDP...,IPR017945;IPR006070;IPR038385;IPR005145;IPR010...,SUA5 family,"[11.250757217407227, -4.741343975067139, -2.76...",0.996094
2,Q3U5F4,MSTARPCAGLRAAVAAGMGLSDGPASSGRGCRLLRPPEPAPALPGA...,IPR017945;IPR006070;,SUA5 family,"[5.418772220611572, -4.671252250671387, -3.143...",0.991211
4,Q499R4,MSTARPCAGLRAAVAAGMGLSDGPAGSSRGCRLLRPPAPAPALPGA...,IPR017945;IPR006070;,SUA5 family,"[5.367674350738525, -4.529707431793213, -3.948...",0.989746
5,P39153,MKTKRWFVDVTDELSTNDPQIAQAAALLRENEVVAFPTETVYGLGA...,IPR017945;IPR006070;IPR038385;IPR005145;IPR010...,SUA5 family,10.207406,0.996094
...,...,...,...,...,...,...
28128,A0A3P9BZC4,MATLEGFPASFYGEPGVMGMLANGISAFLVLLQNFNTAHTGKEAYG...,IPR017945;IPR000791;IPR006070;,SUA5 family; Acetate uptake transporter (AceTr...,"[0.04365567862987518, -11.512506484985352, -7....",0.928223
28159,A0A484CYB5,MATSDGFPASFYGEPGVLGMLSNGISAFLVLLQNFNTAHTGNAPVG...,IPR017945;IPR000791;IPR006070;,SUA5 family; Acetate uptake transporter (AceTr...,"[-1.5244265794754028, -12.74111270904541, -4.4...",0.931152
28182,A0A3P8NV22,MATLEGFPASFYGEPGVLGMLANGISAFLVLLQNFNTAHTGKEAYG...,IPR017945;IPR000791;IPR006070;,SUA5 family; Acetate uptake transporter (AceTr...,"[-0.1679820567369461, -11.23070240020752, -6.4...",0.922363
28208,A0A0A2VYR0,MTWDVLPQDFSIRHSGSGAMDTQIVPDAATCPECLRELNCPADRRY...,IPR004421;IPR017945;IPR041440;IPR004688;IPR011...,NiCoT transporter (TC 2.A.52) family; SUA5 fam...,11.500462,0.997559
