In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import root_mean_squared_error
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
import re
import joblib
from preprocess import pfeature_process
import csv
import torch
import yaml
import sys
sys.path.insert(1, '../')
from models.network import create_model

# Procesamiento de secuencias generadas para clasificarla como potencialmente 

In [2]:
model_path = '../models/dtr_model.pkl'
cd_hit_path = '../data/processed/generated_seqs_cd_hit.txt'


In [3]:
model = joblib.load(model_path)


In [5]:
test_seqs = pfeature_process(cd_hit_path)
test_seqs


Unnamed: 0,PC1,PC2,PC3,PC4,PC5,PC6,PC7,PC8,PC9,PC10,PC11,PC12,PC13,PC14,PC15,PC16,PC17,PC18
LDLDDWYTVDRDAMSM,1.837949,4.840107,-1.346732,-2.656623,-7.338777,-2.767022,-2.165855,1.325194,2.312694,-0.523458,3.584614,-0.935570,-3.048479,0.431662,1.877319,-2.623468,-0.521562,1.526863
KEAKEGATEWCPIVIN,-0.568820,-0.202947,-5.232013,3.124897,-1.335625,0.048295,2.267384,-0.566908,-1.088215,1.337356,-0.223225,0.119317,1.452453,1.708937,0.708616,1.848898,0.378951,0.462414
IYMYQNPQADYQKTVV,-3.017279,-0.494642,-0.121134,-1.773381,3.953475,-0.026841,2.132446,0.567820,1.115822,0.202210,1.469466,2.944362,2.261423,-0.125827,-0.614973,0.557258,-1.852524,-0.129903
YYIENVMHVAMPMYYK,-9.005871,-3.115693,6.102775,-2.664138,1.976713,-2.158311,-2.036856,-1.237692,0.749763,0.945722,0.729007,-0.265735,-0.654084,1.761730,-0.944222,-0.637844,1.203817,1.173627
DPAMEFDNAEIIDDDD,8.979129,3.274629,-8.581540,-4.372976,-6.572749,-2.914251,-2.991451,-1.163131,1.694540,0.896453,-1.788859,0.307389,-0.014008,0.050601,-0.399815,-0.678680,-2.300438,1.017219
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
YVYMMYYMYMVRMCHD,-12.379661,-5.053110,12.332297,-5.327318,-1.421743,-3.704994,-2.403009,-1.897884,2.784035,-2.361379,0.045044,1.520093,-0.672426,1.189801,-2.280145,1.661606,-0.133405,-0.691101
DNKHYYDYDTKFNYVV,0.790335,-0.722205,5.335459,-7.882551,-2.099492,1.918257,2.882968,1.351581,-0.544123,0.956583,-1.562104,1.896313,0.519547,-1.007728,-0.162434,-0.744662,-0.998843,-0.310421
WEHEQQHDNQDDGKDN,15.646699,-1.677398,-1.730841,-3.529699,0.123181,-0.802544,-1.731995,-0.947520,-2.755877,-2.185708,1.237319,0.234887,1.675939,-0.483076,0.675574,-0.101460,-1.916721,-0.279536
YYCIMNKMTDKHFFAA,-5.927433,-0.592338,7.104879,-0.959287,-3.463211,-0.251892,-0.099849,-1.084649,0.106960,-2.813652,0.429823,0.811960,1.133643,1.122421,0.237132,0.598897,-1.747828,-0.755030


In [6]:
predict_test = model.predict(test_seqs)
predict_test

array([1, 0, 0, ..., 0, 0, 0])

In [7]:
predict_test = pd.DataFrame(predict_test)

In [8]:
predict_test.index = test_seqs.index
predict_test.reset_index(inplace=True)
predict_test.shape

(1365, 2)

In [9]:
predict_test.rename(columns={'index':'Secuence', 0:'Label'}, inplace=True)



In [10]:
predict_test

Unnamed: 0,Secuence,Label
0,LDLDDWYTVDRDAMSM,1
1,KEAKEGATEWCPIVIN,0
2,IYMYQNPQADYQKTVV,0
3,YYIENVMHVAMPMYYK,1
4,DPAMEFDNAEIIDDDD,0
...,...,...
1360,YVYMMYYMYMVRMCHD,0
1361,DNKHYYDYDTKFNYVV,1
1362,WEHEQQHDNQDDGKDN,0
1363,YYCIMNKMTDKHFFAA,0


In [11]:
predict_positive = predict_test[predict_test['Label']==1]
predict_positive

Unnamed: 0,Secuence,Label
0,LDLDDWYTVDRDAMSM,1
3,YYIENVMHVAMPMYYK,1
6,YMMMDEMCQMYPTSQA,1
11,NRQRNSNGVAMSGTAT,1
16,EYDYIPPPHHHRNKNI,1
...,...,...
1349,NDRIIGNQVEIICVVC,1
1352,IHRFHYLPLPQKNSEE,1
1355,EIQSYMECIFPIQPVT,1
1357,EFEFEYFKKDYMYNRI,1


In [12]:
#Solo la columna de secuencias se almacenara como CSV ya que los pasos siguientes solo requieren esta columna
predict_positive['Secuence'].to_csv('../data/processed/predicted_positive.csv', index = False)

# peptideBERT


In [14]:
def load_bert_model(feature, device):
    config = yaml.load(open(f'../models/{feature}/config.yaml', 'r'), Loader=yaml.FullLoader)
    config['device'] = device
    model = create_model(config)
    model.load_state_dict(torch.load(f'../models/{feature}/model.pt',weights_only = False)['model_state_dict'], strict=False)
    return model
  


In [16]:
def predict_peptidebert(sequences):
    peptides =sequences.copy()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    MAX_LEN = max(map(len, sequences))
    # convert to tokens
    mapping = dict(zip(
        ['[PAD]','[UNK]','[CLS]','[SEP]','[MASK]','L',
        'A','G','V','E','S','I','K','R','D','T','P','N',
        'Q','F','Y','M','H','C','W'],
        range(30)
    ))

    
    for i in range(len(sequences)):
        sequences[i] = [mapping[c] for c in sequences[i]] 
        sequences[i].extend([0] * (MAX_LEN - len(sequences[i])))  # padding to max length
    
    results = pd.DataFrame({'Sequence':peptides})
    feats = ['hemo','sol','nf']
    with torch.inference_mode():
        for c in feats:
            model = load_bert_model(c,device)
            preds = []
            for i in range(len(sequences)):
                input_ids = torch.tensor([sequences[i]]).to(device)
                attention_mask = (input_ids != 0).float()
                output = float(model(input_ids, attention_mask)[0])
                #print(f'Secuencia {peptides[i]} {c]: {output}')
                preds.append(output)
                
            results = pd.concat([results,pd.DataFrame(preds, columns = [c]).astype(float)], axis=1)
    
    results.to_csv(f'../data/peptideBert_results.csv', index=False)
    return results
    

In [17]:
seqs = []
pos_seqs ='../data/processed/predicted_positive.csv' 
with open(pos_seqs) as fp:
    next(fp)
    f = csv.reader(fp, delimiter = ',', quotechar='"' )
    for line in f:
        seq = re.sub(r'[\[\'][\'\]]','',str(line)).strip()
        seqs.append(seq)
bert_results = predict_peptidebert(seqs)

In [18]:
bert_results

Unnamed: 0,Sequence,hemo,sol,nf
0,LDLDDWYTVDRDAMSM,0.058728,0.766240,0.154453
1,YYIENVMHVAMPMYYK,0.036115,0.240625,0.015137
2,YMMMDEMCQMYPTSQA,0.036783,0.835673,0.242257
3,NRQRNSNGVAMSGTAT,0.043448,0.920638,0.362148
4,EYDYIPPPHHHRNKNI,0.028100,0.893227,0.360850
...,...,...,...,...
424,NDRIIGNQVEIICVVC,0.234215,0.707088,0.006889
425,IHRFHYLPLPQKNSEE,0.097121,0.905724,0.189623
426,EIQSYMECIFPIQPVT,0.132906,0.856495,0.105274
427,EFEFEYFKKDYMYNRI,0.040359,0.901550,0.025191
