In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt

import glob
from scipy.spatial import distance_matrix

import os
import sys
import urllib.request

import Bio
import Bio.PDB
import Bio.SeqRecord
from sklearn.metrics.pairwise import euclidean_distances

from sklearn.decomposition import PCA

import pandas as pd

import MDAnalysis as mda
from MDAnalysis.analysis import pca, align, rms

In [None]:

filein = open('binder_receptor_paired.a3m').readlines()[:10233]
binder_seq = ''
receptor_seq = ''

for x in glob.glob('clusters_folder/*.a3m'):
    if 'EX' in x:
        file_test = open(x).readlines()
        name = '>102\n'
        initial_seq = '-'*len(binder_seq)
        new_msa = []
        new_msa.append(name)
        for i in file_test[1:]:
            if i[:1]!='>':
                new_msa.append(initial_seq+i)
            else:
                new_msa.append(i)
        name=x.split('/')[-1]
        file=open(f'result_folder/{name}', 'w')
        file.write(f'#{len(binder_seq)},{len(receptor_seq)}\t1,1\n')
        file.write('>101\t102\n')
        file.write(f'{binder_seq}{receptor_seq}\n')
        for i in filein[1:]:
            file.write(i)
        for i in new_msa:
            file.write(i)
        file.close()

In [None]:
import time

start = time.time()
print(start)
for i in glob.glob('result_folder/EX*.a3m'):
    if 'U' not in i:
        print(i)
        os.system(f'colabfold_batch {i} result_folder/pdbs/ --num-recycle 3 --num-models 1')
end = time.time()
print(end)
print((end-start)/60)

In [None]:
def read_pdb(pdbcode, pdbfilenm):
    """
    Read a PDB structure from a file.
    :param pdbcode: A PDB ID string
    :param pdbfilenm: The PDB file
    :return: a Bio.PDB.Structure object or None if something went wrong
    """
    try:
        pdbparser = Bio.PDB.PDBParser(QUIET=True)   # suppress PDBConstructionWarning
        struct = pdbparser.get_structure(pdbcode, pdbfilenm)
        return struct
    except Exception as err:
        print(str(err), file=sys.stderr)
        return None

In [None]:
import warnings
warnings.filterwarnings("ignore")
pdbs = []
contacts = []
names = []
plddts = []
plddts_full = []
iptms = []
ipaes = []

ref_inactive = mda.Universe('inactive.pdb')
ref_inactive = ref_inactive.select_atoms('protein and name CA and (resid 297-330 or resid 360-380)')
ref_active = mda.Universe('active.pdb') 
ref_active = ref_active.select_atoms('protein and name CA and (resid 297-330 or resid 360-380)')

for i in glob.glob('result_folder/pdbs/*.pdb'):

    u = mda.Universe(i) 
    u = u.select_atoms("chainID A and name CA and (resid 297-330 or resid 360-380 or resid 145-160)")
    R_inactive = mda.analysis.rms.RMSD(u, ref_inactive,)
    R_inactive.run()
    
    R_active = mda.analysis.rms.RMSD(u, ref_active,)
    R_active.run()
    
    name = i.split('/')[-1].split('_')[1]
    
    data = open(f'result_folder/pdbs/EX_{name}_scores_rank_001_alphafold2_ptm_model_1_seed_000.json')
    #conf_prediction/a2a/a2a_gi/pdbs/EX_000_scores_rank_001_alphafold2_multimer_v3_model_1_seed_000.json
    #conf_prediction/a2a/a2a_msas/pdbs/EX_000_scores_rank_001_alphafold2_ptm_model_1_seed_000.json
    #conf_prediction/a2a/a2a_gi/pdbs/EX_000_predicted_aligned_error_v1.json
    data = json.load(data)
    plddts.append(np.mean(data['plddt']))
    plddts_full.append(data['plddt'])
    #iptms.append(data['iptm'])
    #ipaes.append(np.mean(np.array([np.min(np.array(data['pae'])[chain_a_hotspots[0]:chain_a_hotspots[1],i]) for i in chain_b_hotspots])))
    pdbs.append(np.array([R_inactive.rmsd[0][-1], R_active.rmsd[0][-1]]))
    names.append(i)
    contacts.append(euclidean_distances(u.atoms.positions,u.atoms.positions).reshape((1,-1))[0])
pdbs = np.array(pdbs).T
contacts.append(euclidean_distances(ref_inactive.atoms.positions,ref_inactive.atoms.positions).reshape((1,-1))[0])
contacts.append(euclidean_distances(ref_active.atoms.positions,ref_active.atoms.positions).reshape((1,-1))[0])

In [None]:
plt.rcParams["figure.figsize"] = (7,6)
data = pd.DataFrame({'rmsd1':pdbs[0],
                     'rmsd2':pdbs[1],
                     'plddts1':plddts,
                     'name':names})
data = data.sort_values(by='plddts1')

plt.scatter(data['rmsd1'],data['rmsd2'], c=data['plddts1'], cmap='rainbow_r', alpha=0.95, s=200, vmin=40, vmax=90)
plt.ylim(0,20)
plt.xlim(0,20)
plt.xlabel('RMSD to inactive', fontsize=14)
plt.ylabel('RMSD to active', fontsize=14)
#plt.axhline(y = 0.8, xmin = 0, xmax = 1.2)
#plt.axvline(x = 0.8, ymin = 0, ymax = 1.2)
plt.colorbar()
plt.plot([0, 20], [0,20], ls="--", c='black')
plt.savefig('results/image.png', dpi=600)