In [11]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import json
import py3Dmol
from statistics import mean, median,variance,stdev
import seaborn as sns
import pandas as pd
import copy
import numpy as np

def fasta_reader(fasta_file):
    seq_list = []
    f = open(fasta_file)
    for i, line in enumerate(f):
        if i%2 == 1:
            seq_list.append(line.strip())
    return seq_list

def learning_data_reader(fasta_file):
    aa_dic = {'A':0,'C':0,'D':0,'E':0,'F':0,'G':0,'H':0,'I':0,'K':0,'L':0,'M':0,'N':0,'P':0,'Q':0,'R':0,'S':0,'T':0,'V':0,'W':0,'Y':0}
    seq_list = []
    f = open(fasta_file)
    for i, line in enumerate(f):
        if i%2 == 1:
            non_natural_aa_count = 0
            for aa in line.strip():
                if aa!='b' and aa not in aa_dic:
                    non_natural_aa_count += 1
            if non_natural_aa_count == 0:
                seq_list.append(line.strip())
    
    return seq_list


def aa_counter(seq_list):
    aa_count = {'A':0,'C':0,'D':0,'E':0,'F':0,'G':0,'H':0,'I':0,'K':0,'L':0,'M':0,'N':0,'P':0,'Q':0,'R':0,'S':0,'T':0,'V':0,'W':0,'Y':0}
    for seq in seq_list:
        for aa in seq:
            aa_count[aa] += 1
    return aa_count


def aa_pair_counter(seq_list):
    aa_pair = []
    for seq in seq_list:
        for aa1, aa2 in zip(seq[0:-2], seq[1:-1]):
            aa_pair.append([aa1, aa2])
    return aa_pair



def aa_count_hist_maker(aa_count_dict):
    dict = aa_count_dict
    g = plt.bar([d for d in dict.keys()], dict.values(),color='gray',alpha=0.7,edgecolor='k')
    plt.show()

def second_st_rate_checker(AF_result_dir):
    second_st_rates = []
    rank1_pdbs = glob.glob(AF_result_dir+'/*rank_001*.pdb')
    for pdb in rank1_pdbs:
        !stride $pdb > temp.txt
        f = open('temp.txt')
        second_st_rate = 0
        for line in f:
            if 'ASG' in line:
                if line[24] in ['E', 'H']:
                    second_st_rate += 1
        second_st_rate = second_st_rate/100
        second_st_rates.append(second_st_rate)
    return second_st_rates

def strand_rate_checker(AF_result_dir):
    second_st_rates = []
    rank1_pdbs = glob.glob(AF_result_dir+'/*rank_001*.pdb')
    for pdb in rank1_pdbs:
        !stride $pdb > temp.txt
        f = open('temp.txt')
        second_st_rate = 0
        for line in f:
            if 'ASG' in line:
                if line[24] in ['E']:
                    second_st_rate += 1
        second_st_rate = second_st_rate/100
        second_st_rates.append(second_st_rate)
    return second_st_rates


def helix_rate_checker(AF_result_dir):
    second_st_rates = []
    rank1_pdbs = glob.glob(AF_result_dir+'/*rank_001*.pdb')
    for pdb in rank1_pdbs:
        !stride $pdb > temp.txt
        f = open('temp.txt')
        second_st_rate = 0
        for line in f:
            if 'ASG' in line:
                if line[24] in ['H']:
                    second_st_rate += 1
        second_st_rate = second_st_rate/100
        second_st_rates.append(second_st_rate)
    return second_st_rates

def confidence_score_checker(AF_result_dir):
    ptms = []
    plddts = []
    rank1_pdbs = glob.glob(AF_result_dir+'/*rank_001*.pdb')
    for pdb in rank1_pdbs:
        json_f = pdb.replace('.pdb', '.json').replace('unrelaxed', 'scores')
        scores = json.load(open(json_f))
        ptms.append(scores['ptm'])
        plddts.append(sum(scores['plddt'])/len(scores['plddt']))
    return ptms, plddts




def structure_viewer(pdb_path, scale=2):
    print(f'File Path: {pdb_path}')
    json_file = pdb_path.replace('.pdb', '.json').replace('unrelaxed','scores')
    scores = json.load(open(json_file))
    plddt = sum(scores['plddt'])/len(scores['plddt'])
    ptm = scores['ptm']
    print(f'pLDDT: {plddt}, pTM: {ptm}')
    view = py3Dmol.view(width=680*scale, height=250*scale, query=pdb_path, viewergrid=(1,3), linked=False)
    view.setStyle({'stick': {}}, viewer=(0,0))
    view.setStyle({'sphere': {}}, viewer=(0,1))
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}}, viewer=(0,2))
    view.setBackgroundColor('#ebf4fb', viewer=(0,0))
    view.setBackgroundColor('#f9f4fb', viewer=(0,1))
    view.setBackgroundColor('#e1e1e1', viewer=(0,2))
    view.show()
    view.png()


def mat_maker(seq_list, fig_file_name, return_mat=True):
    aa_dic = {'A':0,'C':0,'D':0,'E':0,'F':0,'G':0,'H':0,'I':0,'K':0,'L':0,'M':0,'N':0,'P':0,'Q':0,'R':0,'S':0,'T':0,'V':0,'W':0,'Y':0}
    aa_count_dic = aa_counter(seq_list)
    aa_pair_count_dic = {}
    for aa in aa_dic:
        aa_pair_count_dic[aa] = copy.copy(aa_dic)

    all_pair_num = 0
    for aa_pair in aa_pair_counter(seq_list):
        #print(aa_pair[0], aa_pair[1])
        all_pair_num += 1
        aa_pair_count_dic[aa_pair[0]][aa_pair[1]] += (200*100)**2/(aa_count_dic[aa_pair[0]]*aa_count_dic[aa_pair[1]])



    for aa1 in aa_dic:
        for aa2 in aa_dic:
            aa_pair_count_dic[aa1][aa2] = np.log(aa_pair_count_dic[aa1][aa2]/all_pair_num)



    df = pd.DataFrame(aa_pair_count_dic, columns=['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y'],index=['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y'])


    sns.heatmap(df,cmap="RdBu",square =True,vmin=-0.8,vmax=0.8,linewidths=.5)
    plt.xticks([0.5,1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5,10.5,11.5,12.5,13.5,14.5,15.5,16.5,17.5,18.5,19.5],['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y'],fontweight='bold',fontsize=8)
    plt.yticks([0.5,1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5,10.5,11.5,12.5,13.5,14.5,15.5,16.5,17.5,18.5,19.5],['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y'],fontweight='bold',fontsize=8,rotation=90)
    plt.tick_params(bottom=False, left=False, right=False, top=False)
    plt.savefig(fig_file_name,dpi=400)
    plt.show()

    if return_mat:
        return df

In [22]:

learning_data_randomly_picked_up = np.random.permutation(learning_data_reader('../fasta/train_data_from_AFDB.fasta'))#[0:200]
learning_data_randomly_picked_up_parsed = []


In [23]:
learning_data_randomly_picked_up
for seq in learning_data_randomly_picked_up:    
    seq_parsed = seq
    while 'b' in seq_parsed:
        seq_parsed=seq_parsed.replace('b','')

    learning_data_randomly_picked_up_parsed.append(seq_parsed)

In [28]:
train_data_from_AFDB_from100res_to200res = []
for seq in learning_data_randomly_picked_up_parsed:
    if len(seq)>100:
        train_data_from_AFDB_from100res_to200res.append(seq)

w = open('../fasta/train_data_from_AFDB_from100res_to200res.fasta','w')
for i, seq in enumerate(train_data_from_AFDB_from100res_to200res):
    while len(seq)<200:
        seq += 'b'
    w.write(f'>seq{i}\n{seq}\n')
w.close()