In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import pathlib
import json
import torch
from scipy.stats import spearmanr
import catboost
from transformers import AutoModel, AutoTokenizer, AutoModelForTokenClassification
import torch.nn.functional as F
import pickle

In [None]:
def afsm12_encode_data(data, input_size):
    """
    Takes in fasta sequence and returns encoded/padded data
    """
    residue_dictionary = {"A": 1, "E": 2, "L": 3, "M": 4, "C": 5, "D": 6, "F": 7, "G": 8,
                          "H": 9, "K":10, "N": 11, "P": 12, "Q": 13, "R": 14, "S": 15,
                          "W": 16, "Y": 17, "T": 18, "V": 19, "I": 20}
    
    fasta = list(str(data))
    # Encode data
    for index, value in enumerate(fasta):
        fasta[index] = residue_dictionary[value]
    # Pad data

    # Invert FASTA and make list 200 times the length to avoid edge cases where FASTA is small
    padding = fasta[::-1]*2000
    
    split = int((input_size-len(fasta))/2)
    last_padding_len = input_size - len(fasta) - split

    stop_pos = int(split+len(fasta))
    padding_1 = padding[-split:]
    padding_2 = padding[:last_padding_len]
    fasta = padding_1 + fasta + padding_2
    
    # Reshape data for input
    fasta = np.array(fasta).reshape(-1, input_size, 1)
    # Normalize data by subtracting training mean and dividing by training std. deviation
    fasta = (fasta - 10.108613363425793)/6.034641898334733
    return fasta, split, stop_pos

def afsm3_encode_data(data, input_size):
    """
    Takes in fasta sequence and returns encoded/padded data
    """
    residue_dictionary = {"A": 1, "E": 2, "L": 3, "M": 4, "C": 5, "D": 6, "F": 7, "G": 8,
                          "H": 9, "K":10, "N": 11, "P": 12, "Q": 13, "R": 14, "S": 15,
                          "W": 16, "Y": 17, "T": 18, "V": 19, "I": 20}
    
    fasta = list(str(data))
    # Encode data
    for index, value in enumerate(fasta):
        fasta[index] = residue_dictionary[value]
    # Pad data

    # Invert FASTA and make list 200 times the length to avoid edge cases where FASTA is small
    padding = fasta[::-1]*2000
    
    split = int((input_size-len(fasta))/2)
    last_padding_len = input_size - len(fasta) - split

    stop_pos = int(split+len(fasta))
    padding_1 = padding[-split:]
    padding_2 = padding[:last_padding_len]
    fasta = padding_1 + fasta + padding_2
    
    # Reshape data for input
    fasta = np.array(fasta).reshape(-1, input_size, 1)
    # Normalize data by subtracting training mean and dividing by training std. deviation
    fasta = (fasta - 10.15)/5.98
    return fasta, split, stop_pos


def afsm12_predict_data(fasta, model, input_size):
    """
    Generate prediction for data point. Will return either predicted pae or plddt.
    """

    data, start_pos, stop_pos = afsm12_encode_data(fasta, input_size)
    prediction = model.predict(data).reshape(input_size, 1)
    prediction = prediction[start_pos:stop_pos]
    prediction = [float(i) for i in prediction]

    return prediction


def afsm3_predict_data(fasta, model, input_size):
    """
    Generate prediction for data point. Will return either probability of 
    crystallization.
    """

    data, start_pos, stop_pos = afsm3_encode_data(fasta, input_size)
    prediction = model.predict(data)[0]
    prediction = list(prediction[:,1])
    prediction = prediction[start_pos:stop_pos]
    prediction = [float(i) for i in prediction]

    return prediction

def encode_sequence(fasta):
    
    residue_dictionary = {"A": 1, "E": 2, "L": 3, "M": 4, "C": 5, "D": 6, "F": 7, "G": 8,
                          "H": 9, "K":10, "N": 11, "P": 12, "Q": 13, "R": 14, "S": 15,
                          "W": 16, "Y": 17, "T": 18, "V": 19, "I": 20}
    
    fasta = list(str(fasta))
    # Encode data
    for index, value in enumerate(fasta):
        fasta[index] = int(residue_dictionary[value])
        
    return fasta

def process_protein(sequence, mae_pred, plddt_pred, presort_pred, ordinal_list, model):
    
    predictions = []
    
    win_size = 11
    
    start, label, stop = 0, int(win_size), int((win_size * 2) + 1)
    

    while stop < len(sequence)+1:
        
        prediction = model.predict(mae_pred[start:stop] + plddt_pred[start:stop] + presort_pred[start:stop] + ordinal_list[start:stop])
        predictions.append(prediction)
        
        start += 1
        label += 1
        stop += 1
        
    if predictions[0] == 0 and np.mean(np.array(presort_pred[:12])) < 0.7:
        
        predictions = [0]*win_size + predictions
        
    else:
        
        predictions = [1]*win_size + predictions
        
    if predictions[-1] == 0 and np.mean(np.array(presort_pred[-12:])) < 0.7:
        
        predictions += [0]*win_size
        
    else:
        
        predictions += [1]*win_size
    
    return predictions




def disorder_list(sequence: str) -> float:

    predictions = []
    # generate encodings for sequence
    afsm1_pred = afsm12_predict_data(sequence, afsm1_model, 4096)
    afsm2_pred = list(np.array(afsm12_predict_data(sequence, afsm2_model, 4096))/100.0)
    afsm3_pred = afsm3_predict_data(sequence, afsm3_model, 2048)
    ordinal_list = encode_sequence(sequence)
    # window size of predictions
    win_size = 11

    start, label, stop = 0, int(win_size), int((win_size * 2) + 1)

    while stop < len(sequence) + 1:
        prediction = pirate_model.predict_proba(
            afsm1_pred[start:stop] + afsm2_pred[start:stop] + afsm3_pred[start:stop] +
        ordinal_list[start:stop])[0]
        predictions.append(prediction)

        start += 1
        label += 1
        stop += 1

    predictions = [0]*win_size + predictions + [0]*win_size

    return predictions

In [None]:
local_path = pathlib.Path().absolute()
model_path = str(local_path.parents[0])+"/models/"
afsm1_path = model_path+"afsm1"
afsm2_path = model_path+"afsm2"
afsm3_path = model_path+"afsm3"
pirate_path = model_path+"pirate.pkl"
input_size = 4096
presort_input = 2048
afsm1_model = tf.keras.models.load_model(afsm1_path, custom_objects=None, compile=True, options=None)
print("afsm1 loaded")
afsm2_model = tf.keras.models.load_model(afsm2_path, custom_objects=None, compile=True, options=None)
print("afsm2 loaded")
afsm3_model = tf.keras.models.load_model(afsm3_path, custom_objects=None, compile=True, options=None)
print("afsm3 loaded")
pirate_model = pickle.load(open(pirate_path, 'rb'))
print("pirate loaded")

In [None]:
# Please change this path to the location of your local DR-BERT checkpoint file
checkpoint = r"C:\Users\GRICHARDSON\OneDrive - Evotec\Desktop\crystallization_deletion_tool\DR-BERT-final"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
bert_model = AutoModelForTokenClassification.from_pretrained(checkpoint)
bert_model = bert_model.eval()

In [None]:
# >sp|P20711|DDC_HUMAN Aromatic-L-amino-acid decarboxylase OS=Homo sapiens OX=9606 GN=DDC PE=1 SV=2
aadc = "MNASEFRRRGKEMVDYMANYMEGIEGRQVYPDVEPGYLRPLIPAAAPQEPDTFEDIINDVEKIIMPGVTHWHSPYFFAYFPTASSYPAMLADMLCGAIGCIGFSWAASPACTELETVMMDWLGKMLELPKAFLNEKAGEGGGVIQGSASEATLVALLAARTKVIHRLQAASPELTQAAIMEKLVAYSSDQAHSSVERAGLIGGVKLKAIPSDGNFAMRASALQEALERDKAAGLIPFFMVATLGTTTCCSFDNLLEVGPICNKEDIWLHVDAAYAGSAFICPEFRHLLNGVEFADSFNFNPHKWLLVNFDCSAMWVKKRTDLTGAFRLDPTYLKHSHQDSGLITDYRHWQIPLGRRFRSLKMWFVFRMYGVKGLQAYIRKHVQLSHEFESLVRQDPRFEICVEVILGLVCFRLKGSNKVNEALLQRINSAKKIHLVPCHLRDKFVLRFAICSRTVESAHVQRAWEHIKELAADVLRAERE"

In [None]:
# >sp|P68104|EF1A1_HUMAN Elongation factor 1-alpha 1 OS=Homo sapiens OX=9606 GN=EEF1A1 PE=1 SV=1
ef1a = "MGKEKTHINIVVIGHVDSGKSTTTGHLIYKCGGIDKRTIEKFEKEAAEMGKGSFKYAWVLDKLKAERERGITIDISLWKFETSKYYVTIIDAPGHRDFIKNMITGTSQADCAVLIVAAGVGEFEAGISKNGQTREHALLAYTLGVKQLIVGVNKMDSTEPPYSQKRYEEIVKEVSTYIKKIGYNPDTVAFVPISGWNGDNMLEPSANMPWFKGWKVTRKDGNASGTTLLEALDCILPPTRPTDKPLRLPLQDVYKIGGIGTVPVGRVETGVLKPGMVVTFAPVNVTTEVKSVEMHHEALSEALPGDNVGFNVKNVSVKDVRRGNVAGDSKNDPPMEAAGFTAQVIILNHPGQISAGYAPVLDCHTAHIACKFAELKEKIDRRSGKKLEDGPKFLKSGDAAIVDMVPGKPMCVESFSDYPPLGRFAVRDMRQTVAVGVIKAVDKKAAGAGKVTKSAQKAQKAK"

In [None]:
sequences = [aadc, ef1a]
sequence_ids = ["aadc", "ef1a"] 

In [None]:
pirate_list = []
for sequence in sequences:
    pirate_list.append(disorder_list(sequence))

In [None]:
bert_list = []
for sequence in sequences:
    encoded = tokenizer.encode_plus(("something", str(sequence)), return_tensors="pt")
    with torch.no_grad():
        output = bert_model(**encoded)
    output = F.softmax(torch.squeeze(output['logits']))[2:-2,1].detach().numpy().tolist()
    bert_list.append(output)

In [None]:
data = pd.DataFrame()
data["sequence_ids"] = sequence_ids
data["pirate_preds"] = pirate_list
data["dr_bert_preds"] = bert_list

data.to_csv("aadc_efa1_preds_pirate_bert.csv")