In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
from pathlib import Path
import json
import pickle
import py3Dmol
import subprocess
import time
import more_itertools as mit
import itertools
import tempfile
import torch
import biotite.structure as struc
import biotite.structure.io as strucio
import biotite.application.dssp as dssp
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained
from sklearn.model_selection import cross_val_score
from imblearn.under_sampling import NearMiss, CondensedNearestNeighbour, RandomUnderSampler, InstanceHardnessThreshold, AllKNN
from scipy.stats import kurtosis
from scipy.stats import skew
from scipy import stats
import math
import matplotlib.pyplot as plt
from Bio import SeqIO
from scipy.stats import gmean
from statsmodels.tsa.stattools import breakvar_heteroskedasticity_test
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score, precision_score, auc, make_scorer, recall_score, matthews_corrcoef, f1_score
import optuna
import catboost
from optuna.samplers import TPESampler

In [2]:
def encode_data(data, input_size):
    """
    Takes in fasta sequence and returns encoded/padded data
    """
    residue_dictionary = {"A": 1, "E": 2, "L": 3, "M": 4, "C": 5, "D": 6, "F": 7, "G": 8,
                          "H": 9, "K":10, "N": 11, "P": 12, "Q": 13, "R": 14, "S": 15,
                          "W": 16, "Y": 17, "T": 18, "V": 19, "I": 20}
    
    fasta = list(str(data))
    # Encode data
    for index, value in enumerate(fasta):
        fasta[index] = residue_dictionary[value]
    # Pad data

    # Invert FASTA and make list 200 times the length to avoid edge cases where FASTA is small
    padding = fasta[::-1]*2000
    
    split = int((input_size-len(fasta))/2)
    last_padding_len = input_size - len(fasta) - split

    stop_pos = int(split+len(fasta))
    padding_1 = padding[-split:]
    padding_2 = padding[:last_padding_len]
    fasta = padding_1 + fasta + padding_2
    
    # Reshape data for input
    fasta = np.array(fasta).reshape(-1, input_size, 1)
    # Normalize data by subtracting training mean and dividing by training std. deviation
    fasta = (fasta - 10.108613363425793)/6.034641898334733
    return fasta, split, stop_pos

def encode_presort_data(data, input_size):
    """
    Takes in fasta sequence and returns encoded/padded data
    """
    residue_dictionary = {"A": 1, "E": 2, "L": 3, "M": 4, "C": 5, "D": 6, "F": 7, "G": 8,
                          "H": 9, "K":10, "N": 11, "P": 12, "Q": 13, "R": 14, "S": 15,
                          "W": 16, "Y": 17, "T": 18, "V": 19, "I": 20}
    
    fasta = list(str(data))
    # Encode data
    for index, value in enumerate(fasta):
        fasta[index] = residue_dictionary[value]
    # Pad data

    # Invert FASTA and make list 200 times the length to avoid edge cases where FASTA is small
    padding = fasta[::-1]*2000
    
    split = int((input_size-len(fasta))/2)
    last_padding_len = input_size - len(fasta) - split

    stop_pos = int(split+len(fasta))
    padding_1 = padding[-split:]
    padding_2 = padding[:last_padding_len]
    fasta = padding_1 + fasta + padding_2
    
    # Reshape data for input
    fasta = np.array(fasta).reshape(-1, input_size, 1)
    # Normalize data by subtracting training mean and dividing by training std. deviation
    fasta = (fasta - 10.15)/5.98
    return fasta, split, stop_pos


def predict_data(fasta, model, input_size):
    """
    Generate prediction for data point. Will return either probability of 
    crystallization or a classification.
    """

    data, start_pos, stop_pos = encode_data(fasta, input_size)
    prediction = model.predict(data).reshape(input_size, 1)
    prediction = prediction[start_pos:stop_pos]
    prediction = [float(i) for i in prediction]

    return prediction


def presort_data(fasta, model, input_size):
    """
    Generate prediction for data point. Will return either probability of 
    crystallization or a classification.
    """

    data, start_pos, stop_pos = encode_presort_data(fasta, input_size)
    prediction = model.predict(data)[0]
    prediction = list(prediction[:,1])
    prediction = prediction[start_pos:stop_pos]
    prediction = [float(i) for i in prediction]

    return prediction

def encode_sequence(fasta):
    
    residue_dictionary = {"A": 1, "E": 2, "L": 3, "M": 4, "C": 5, "D": 6, "F": 7, "G": 8,
                          "H": 9, "K":10, "N": 11, "P": 12, "Q": 13, "R": 14, "S": 15,
                          "W": 16, "Y": 17, "T": 18, "V": 19, "I": 20}
    
    fasta = list(str(fasta))
    # Encode data
    for index, value in enumerate(fasta):
        fasta[index] = int(residue_dictionary[value])
        
    return fasta

def process_protein(sequence, mae_pred, plddt_pred, presort_pred, ordinal_list, model):
    
    predictions = []
    
    win_size = 11
    
    start, label, stop = 0, int(win_size), int((win_size * 2) + 1)
    

    while stop < len(sequence)+1:
        
        prediction = model.predict(mae_pred[start:stop] + plddt_pred[start:stop] + presort_pred[start:stop] + ordinal_list[start:stop])
        predictions.append(prediction)
        
        start += 1
        label += 1
        stop += 1
        
    if predictions[0] == 0 and np.mean(np.array(presort_pred[:12])) < 0.7:
        
        predictions = [0]*win_size + predictions
        
    else:
        
        predictions = [1]*win_size + predictions
        
    if predictions[-1] == 0 and np.mean(np.array(presort_pred[-12:])) < 0.7:
        
        predictions += [0]*win_size
        
    else:
        
        predictions += [1]*win_size
    
    return predictions

def fold_sequence(sequence):
    
    command = ["curl", "-X", "POST", "--data", 
           str(sequence), 
           "https://api.esmatlas.com/foldSequence/v1/pdb/"]

    result = subprocess.run(command, capture_output=True, text=True, check=True)
    
    return result.stdout

def display_model_with_predictions(model, predictions):
    
    
    colors = {}
    for i in range(len(predictions)):
        if predictions[i] == 0:
            colors[i+1] = "red"
        else:
            colors[i+1] = "black"
    
    p = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
    p.addModel(model, 'pdb')
    p.setStyle({'cartoon':{'colorscheme':{'prop':'resi','map':colors}}})
    p.setHoverable({}, True, '''function(atom,viewer,event,container) {
                            if(!atom.label) {
                            atom.label = viewer.addLabel(atom.resn+":"+atom.resi,{position: atom, 
                            backgroundColor: 'mintcream', fontColor:'black'});
                            }}''',
                        '''function(atom,viewer) { 
                            if(atom.label) {
                            viewer.removeLabel(atom.label);
                            delete atom.label;
                            }
                            }''')
    p.zoomTo()
    p.show()
    
    
def display_model(model):
    
    
    p = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
    p.addModel(model, 'pdb')
    p.setStyle({'cartoon': {'color': 'spectrum', 'colorscheme': 'roygb'}})
    p.setHoverable({}, True, '''function(atom,viewer,event,container) {
                            if(!atom.label) {
                            atom.label = viewer.addLabel(atom.resn+":"+atom.resi,{position: atom, 
                            backgroundColor: 'mintcream', fontColor:'black'});
                            }}''',
                        '''function(atom,viewer) { 
                            if(atom.label) {
                            viewer.removeLabel(atom.label);
                            delete atom.label;
                            }
                            }''')
    p.zoomTo()
    p.show()
    
def get_mean_b_factor(model):
    
    model_list = model.split("\n")
    b_factors = []
    for count,e in enumerate(model_list):
        if count > 21:
            b_factors.append(float(e[e.index("1.00  ")+6:e.index("1.00  ")+10]))
            
    return float(np.mean(np.array(b_factors))) 

def get_min_b_factor(model):
    
    model_list = model.split("\n")
    b_factors = []
    for count,e in enumerate(model_list):
        if count > 21:
            b_factors.append(float(e[e.index("1.00  ")+6:e.index("1.00  ")+10]))
            
    return float(np.min(np.array(b_factors)))
    
def remove_predictions(sequence, predictions):
    
    sequence = bytearray(sequence, encoding='utf8')
    for count, i in reversed(list(enumerate(predictions))):
        
        if i == 0:
            del sequence[count]
            
    return str(sequence.decode())

def disorder_probability(sequence: str) -> float:

    predictions = []
    # generate encodings for sequence
    mae_pred = predict_data(sequence, e3p_mae_model, 4096)
    plddt_pred = list(np.array(predict_data(sequence, e3p_plddt_model, 4096))/100.0)
    presort_pred = presort_data(sequence, presort_model, 2048)
    ordinal_list = encode_sequence(sequence)
    # window size of predictions
    win_size = 11

    start, label, stop = 0, int(win_size), int((win_size * 2) + 1)

    while stop < len(sequence) + 1:
        prediction = cat_model.predict_proba(
            mae_pred[start:stop] + plddt_pred[start:stop] + presort_pred[start:stop] +
        ordinal_list[start:stop])[0]
        predictions.append(prediction)

        start += 1
        label += 1
        stop += 1


    mean_proba = float(np.mean(np.array(predictions)))

    return mean_proba

def disorder_list(sequence: str) -> float:

    predictions = []
    # generate encodings for sequence
    mae_pred = predict_data(sequence, e3p_mae_model, 4096)
    plddt_pred = list(np.array(predict_data(sequence, e3p_plddt_model, 4096))/100.0)
    presort_pred = presort_data(sequence, presort_model, 2048)
    ordinal_list = encode_sequence(sequence)
    # window size of predictions
    win_size = 11

    start, label, stop = 0, int(win_size), int((win_size * 2) + 1)

    while stop < len(sequence) + 1:
        prediction = cat_model.predict_proba(
            mae_pred[start:stop] + plddt_pred[start:stop] + presort_pred[start:stop] +
        ordinal_list[start:stop])[0]
        predictions.append(prediction)

        start += 1
        label += 1
        stop += 1

    predictions = [0]*win_size + predictions
    predictions = predictions + [0]*win_size

    return predictions

def plot_disorder(sequence, name):
    """
    Function that plots the disorder redictions for each residue as a time-series
    """
    predictions = disorder_list(sequence)
    arr = np.array(predictions)
    fig, ax = plt.subplots()
    ax.set_ylim(0, 1)

    plt.title(f"Probability of Disorder per Residue for {name}",loc="center")
    plt.xlabel("Residue")
    plt.ylabel("Disorder Probability")
    ax.plot(arr)

    return fig

def normalize_list(input_list):

    input_list = np.array(input_list)
    min = np.min(input_list)+0.0001
    max = np.max(input_list)

    normalized_list = (input_list - min) / (max - min)
    normalized_list = normalized_list.tolist()

    return normalized_list

def get_evo_probabilities(sequence):

    evo_probabilities = []
    evo_model = "esm1v_t33_650M_UR90S_1.pt"
    model, alphabet = pretrained.load_model_and_alphabet(evo_model)

    data = [("sequence", sequence),]
    batch_converter = alphabet.get_batch_converter()
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1)

    for count, residue in enumerate(sequence):

        wt_encoded, mt_encoded = alphabet.get_idx(residue), alphabet.get_idx("R")
        score = token_probs[0, 1 + count, mt_encoded] - token_probs[0, 1 + count, wt_encoded]
        evo_probabilities.append(score.detach().numpy())

    for count, e in enumerate(evo_probabilities):
        evo_probabilities[count] = math.exp(e)
    
    return evo_probabilities

def get_sasa(sequence):

    model = fold_sequence(sequence)
    with open ('test.pdb', 'w') as file:  
        for line in model:  
            file.write(line)  
    array = strucio.load_structure("test.pdb")
    atom_sasa = struc.sasa(array, vdw_radii="Single")
    res_sasa = struc.apply_residue_wise(array, atom_sasa, np.sum)
    res_sasa = res_sasa.tolist()

    return res_sasa

In [3]:
e3p_mae_path = r"C:\Users\GRICHARDSON\OneDrive - Evotec\Desktop\crystallization_deletion_tool\e3p_mae"
e3p_plddt_path = r"C:\Users\GRICHARDSON\OneDrive - Evotec\Desktop\crystallization_deletion_tool\e3p_plddt"
presort_path = r"C:\Users\GRICHARDSON\OneDrive - Evotec\Desktop\crystallization_deletion_tool\presort"
catboost_path = r"C:\Users\GRICHARDSON\OneDrive - Evotec\Desktop\crystallization_deletion_tool\catboost_model_win11_all4_allknn_undersampled.pkl"
input_size = 4096
presort_input = 2048
e3p_mae_model = tf.keras.models.load_model(e3p_mae_path, custom_objects=None, compile=True, options=None)
print("mae loaded")
e3p_plddt_model = tf.keras.models.load_model(e3p_plddt_path, custom_objects=None, compile=True, options=None)
print("plddt loaded")
presort_model = tf.keras.models.load_model(presort_path, custom_objects=None, compile=True, options=None)
print("presort loaded")
cat_model = pickle.load(open(catboost_path, 'rb'))
print("catboost loaded")

mae loaded
plddt loaded
presort loaded
catboost loaded


In [4]:
import sys
from transformers import AutoModel, AutoTokenizer, AutoModelForTokenClassification
import torch
import numpy as np
import torch.nn.functional as F
import pickle
from itertools import groupby
import pandas as pd

In [5]:
checkpoint = r"C:\Users\GRICHARDSON\OneDrive - Evotec\Desktop\crystallization_deletion_tool\DR-BERT-final"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
bert_model = AutoModelForTokenClassification.from_pretrained(checkpoint)
bert_model = bert_model.eval()

  return self.fget.__get__(instance, owner)()


In [6]:
# >sp|P20711|DDC_HUMAN Aromatic-L-amino-acid decarboxylase OS=Homo sapiens OX=9606 GN=DDC PE=1 SV=2
aadc = "MNASEFRRRGKEMVDYMANYMEGIEGRQVYPDVEPGYLRPLIPAAAPQEPDTFEDIINDVEKIIMPGVTHWHSPYFFAYFPTASSYPAMLADMLCGAIGCIGFSWAASPACTELETVMMDWLGKMLELPKAFLNEKAGEGGGVIQGSASEATLVALLAARTKVIHRLQAASPELTQAAIMEKLVAYSSDQAHSSVERAGLIGGVKLKAIPSDGNFAMRASALQEALERDKAAGLIPFFMVATLGTTTCCSFDNLLEVGPICNKEDIWLHVDAAYAGSAFICPEFRHLLNGVEFADSFNFNPHKWLLVNFDCSAMWVKKRTDLTGAFRLDPTYLKHSHQDSGLITDYRHWQIPLGRRFRSLKMWFVFRMYGVKGLQAYIRKHVQLSHEFESLVRQDPRFEICVEVILGLVCFRLKGSNKVNEALLQRINSAKKIHLVPCHLRDKFVLRFAICSRTVESAHVQRAWEHIKELAADVLRAERE"

In [7]:
# >sp|P68104|EF1A1_HUMAN Elongation factor 1-alpha 1 OS=Homo sapiens OX=9606 GN=EEF1A1 PE=1 SV=1
ef1a = "MGKEKTHINIVVIGHVDSGKSTTTGHLIYKCGGIDKRTIEKFEKEAAEMGKGSFKYAWVLDKLKAERERGITIDISLWKFETSKYYVTIIDAPGHRDFIKNMITGTSQADCAVLIVAAGVGEFEAGISKNGQTREHALLAYTLGVKQLIVGVNKMDSTEPPYSQKRYEEIVKEVSTYIKKIGYNPDTVAFVPISGWNGDNMLEPSANMPWFKGWKVTRKDGNASGTTLLEALDCILPPTRPTDKPLRLPLQDVYKIGGIGTVPVGRVETGVLKPGMVVTFAPVNVTTEVKSVEMHHEALSEALPGDNVGFNVKNVSVKDVRRGNVAGDSKNDPPMEAAGFTAQVIILNHPGQISAGYAPVLDCHTAHIACKFAELKEKIDRRSGKKLEDGPKFLKSGDAAIVDMVPGKPMCVESFSDYPPLGRFAVRDMRQTVAVGVIKAVDKKAAGAGKVTKSAQKAQKAK"

In [8]:
sequences = [aadc, ef1a]
sequence_ids = ["aadc", "ef1a"] 

In [9]:
pirate_list = []
for sequence in sequences:
    pirate_list.append(disorder_list(sequence))

  prediction = [float(i) for i in prediction]


In [10]:
bert_list = []
for sequence in sequences:
    encoded = tokenizer.encode_plus(("something", str(sequence)), return_tensors="pt")
    with torch.no_grad():
        output = bert_model(**encoded)
    output = F.softmax(torch.squeeze(output['logits']))[2:-2,1].detach().numpy().tolist()
    bert_list.append(output)

  output = F.softmax(torch.squeeze(output['logits']))[2:-2,1].detach().numpy().tolist()


In [11]:
data = pd.DataFrame()
data["sequence_ids"] = sequence_ids
data["pirate_preds"] = pirate_list
data["dr_bert_preds"] = bert_list

data.to_csv("aadc_efa1_preds_pirate_bert.csv")