In [1]:
#@title Import libraries
# setup the model
import sys
import os 
sys.path.insert(0, os.path.abspath("/stornext/HPCScratch/home/iskander.j/git_projects/alphafold"))
import pickle
from alphafold.data import pipeline
from alphafold.data import parsers
from alphafold.data import templates
from alphafold.common import protein
from alphafold.relax import relax

RELAX_MAX_ITERATIONS = 0
RELAX_ENERGY_TOLERANCE = 2.39
RELAX_STIFFNESS = 10.0
RELAX_EXCLUDE_RESIDUES = []
RELAX_MAX_OUTER_ITERATIONS = 20

amber_relaxer = relax.AmberRelaxation(
      max_iterations=RELAX_MAX_ITERATIONS,
      tolerance=RELAX_ENERGY_TOLERANCE,
      stiffness=RELAX_STIFFNESS,
      exclude_residues=RELAX_EXCLUDE_RESIDUES,
      max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS)

In [6]:
unrelaxed_pdb_path="/home/users/allstaff/iskander.j/scratchhome/git_projects/alphafold/input/smchd1hs2_3p/smchd1_dimer-chainB-40A.pdb"
with open(unrelaxed_pdb_path, 'r') as f:
      pdb=f.read()

In [7]:
unrelaxed_protein=protein.from_pdb_string(pdb,"A")
#t_0 = time.time()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
#timings[f'relax_{model_name}'] = time.time() - t_0

#relaxed_pdbs[model_name] = relaxed_pdb_str

# Save the relaxed PDB.
relaxed_output_path = 'relaxed_smchd1_dimer-chainB-40A_chainA.pdb'
with open(relaxed_output_path, 'w') as f:
      f.write(relaxed_pdb_str)



In [None]:
unrelaxed_protein=protein.from_pdb_string(pdb,"B")
#t_0 = time.time()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
#timings[f'relax_{model_name}'] = time.time() - t_0

#relaxed_pdbs[model_name] = relaxed_pdb_str

# Save the relaxed PDB.
relaxed_output_path = 'relaxed_smchd1_dimer-chainB-40A_chainB.pdb'
with open(relaxed_output_path, 'w') as f:
      f.write(relaxed_pdb_str)

In [24]:
feature_dict = pickle.load(open('features.pkl', 'rb'))

In [20]:
input_fasta_path="../../example/smchd1hs.fasta"
with open(input_fasta_path) as f:
    input_fasta_str = f.read()
    input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
    if len(input_seqs) != 1:
        raise ValueError(f'More than one input sequence found in {input_fasta_path}.')
    input_sequence = input_seqs[0]
    input_description = input_descs[0]
    num_res = len(input_sequence)

In [5]:
a3m_lines = "".join(open("msas/bfd_uniclust_hits.a3m","r").readlines())
bfd_msa, bfd_deletion_matrix = pipeline.parsers.parse_a3m(a3m_lines)
jackhmmer_uniref90_result = "".join(open("msas/uniref90_hits.sto","r").readlines())
uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm(
        jackhmmer_uniref90_result)

jackhmmer_mgnify_result = "".join(open("msas/mgnify_hits.sto","r").readlines())
mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm(
        jackhmmer_mgnify_result)

mgnify_max_hits= 501
uniref_max_hits= 10000
hhsearch_result = "".join(open("msas/pdb70_hits.hhr","r").readlines())
hhsearch_hits = parsers.parse_hhr(hhsearch_result)
mgnify_msa = mgnify_msa[:mgnify_max_hits]
mgnify_deletion_matrix = mgnify_deletion_matrix[:mgnify_max_hits]
msa=(uniref90_msa, bfd_msa, mgnify_msa)
deletion_matrix=(uniref90_deletion_matrix,
                           bfd_deletion_matrix,
                           mgnify_deletion_matrix)


msa = [i for sub in msa for i in sub]
deletion_matrix = [i for sub in deletion_matrix for i in sub]

In [6]:

len(msa), len(deletion_matrix)


(14656, 14656)

In [7]:
homooligomer=2
msas = []
deletion_matrices = []
Ln = len(feature_dict['sequence'][0])

for o in range(homooligomer):
    L = Ln * o
    R = Ln * (homooligomer-(o+1))
    msas.append(["-"*L+seq+"-"*R for seq in msa])
    deletion_matrices.append([[0]*L+mtx+[0]*R for mtx in deletion_matrix])

##create MSA features
msa_features=pipeline.make_msa_features(msas=msas,deletion_matrices=deletion_matrices)

In [8]:
msa_features

{'deletion_matrix_int': array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=int32),
 'msa': array([[10,  0,  0, ..., 21, 21, 21],
        [10,  0,  0, ..., 21, 21, 21],
        [21, 21, 21, ..., 21, 21, 21],
        ...,
        [21, 21, 21, ..., 21, 21, 21],
        [21, 21, 21, ..., 21, 21, 21],
        [21, 21, 21, ..., 21, 21, 21]], dtype=int32),
 'num_alignments': array([29284, 29284, 29284, ..., 29284, 29284, 29284], dtype=int32)}

In [26]:
feature_dict.keys()


dict_keys(['aatype', 'between_segment_residues', 'domain_name', 'residue_index', 'seq_length', 'sequence', 'deletion_matrix_int', 'msa', 'num_alignments', 'template_aatype', 'template_all_atom_masks', 'template_all_atom_positions', 'template_domain_names', 'template_sequence', 'template_sum_probs'])

In [10]:
feature_dict['deletion_matrix_int']=msa_features['deletion_matrix_int']
feature_dict['msa']=msa_features['msa']

feature_dict['num_alignments']=msa_features['num_alignments']


In [21]:
feature_dict

{'aatype': array([[0, 0, 0, ..., 0, 0, 0],
        [1, 0, 0, ..., 0, 0, 0],
        [1, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 1, 0]], dtype=int32),
 'between_segment_residues': array([0, 0, 0, ..., 0, 0, 0], dtype=int32),
 'domain_name': array([b'sp|A6NHR9|SMHD1_HUMAN Structural maintenance of chromosomes flexible hinge domain-containing protein 1 OS=Homo sapiens OX=9606 GN=SMCHD1 PE=1 SV=2'],
       dtype=object),
 'residue_index': array([   0,    1,    2, ..., 2002, 2003, 2004], dtype=int32),
 'seq_length': array([2005, 2005, 2005, ..., 2005, 2005, 2005], dtype=int32),
 'sequence': array([b'MAAADGGGPGGASVGTEEDGGGVGHRTVYLFDRREKESELGDRPLQVGERSDYAGFRACVCQTLGISPEEKFVITTTSRKEITCDNFDETVKDGVTLYLLQSVNQLLLTATKERIDFLPHYDTLVKSGMYEYYASEGQNPLPFALAELIDNSLSATSRNIGVRRIQIKLLFDETQGKPAVAVIDNGRGMTSKQLNNWAVYRLSKFTRQGDFESDHSGYVRPVPVPRSLNSDISYFGVGGKQAVFFVGQSARMISKPADSQDVHELVLSKEDFEKKEKNKEAIYSGYIRNRKPSDSVHITNDDERFLHHLIIE

In [17]:
template_featurizer = templates.TemplateHitFeaturizer(
      mmcif_dir=FLAGS.template_mmcif_dir,
      max_template_date=FLAGS.max_template_date,
      max_hits=MAX_TEMPLATE_HITS,
      kalign_binary_path=FLAGS.kalign_binary_path,
      release_dates_path=None,
      obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)

TypeError: get_templates() missing 1 required positional argument: 'self'

In [None]:
feature_dict = {
    **pipeline.make_sequence_features(sequence=input_sequence*homooligomer,
                                      description="none",
                                      num_res=len(input_sequence)*homooligomer),
    **pipeline.make_msa_features(msas=msas,deletion_matrices=deletion_matrices),
    **templates_result.features
}

In [7]:
f = pickle.load(open('/stornext/HPCScratch/home/iskander.j/AF/output/query_colab/features.pkl', 'rb'))

In [9]:
f["residue_index"]

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58], dtype=int32)

In [2]:
f2 = pickle.load(open('/stornext/HPCScratch/home/iskander.j/AF/output/query_colab/features_h2.pkl', 'rb'))

In [3]:
len(f2["residue_index"])

118

In [15]:
import numpy as np
from absl import logging

def write_chains(chains_path,chains):
    with open(chains_path, 'w') as filehandle:
        filehandle.writelines("%s\n" % chain for chain in chains)
    
def read_chains(chains_path):
    chains = []
    # open file and read the content in a list
    with open(chains_path, 'r') as filehandle:
        filecontents = filehandle.readlines()
        for line in filecontents:
            current_place = line[:-1]
            chains.append(current_place)
    return chains

def set_bfactor(ip_path,op_path, bfac):
    I = open(ip_path,"r").readlines()
    O = open(op_path,"w")
    for line in I:
        if line[0:6] == "ATOM  ":
            seq_id = int(line[22:26].strip()) - 1
            O.write(f"{line[:60]}{bfac[seq_id]:6.2f}{line[66:]}")
    O.close()
    
    
def set_chain_bfactor(ip_path,op_path, bfac,  chains,idx_res=None, is_relaxed=False):

    #logging.info("Chains len : %d",len(chains))
    I = open(ip_path,"r").readlines()
    O = open(op_path,"w")
    for line in I:
        if line[0:6] == "ATOM  ":
          seq_id = int(line[22:26].strip()) - 1
          #logging.info("Seq_id : %d",seq_id)
          if not is_relaxed:
            seq_id = np.where(idx_res == seq_id)[0][0]
          O.write(f"{line[:21]}{chains[seq_id]}{line[22:60]}{bfac[seq_id]:6.2f}{line[66:]}")
    O.close()

In [16]:
ip_path='/stornext/HPCScratch/home/iskander.j/AF/output/query_colab/unrelaxed_model_1.pdb'
op_path='/stornext/HPCScratch/home/iskander.j/AF/output/query_colab/unrelaxed_model_1_n.pdb'
r=pickle.load(open('/stornext/HPCScratch/home/iskander.j/AF/output/query_colab/result_model_1.pkl', 'rb'))
chains=read_chains('/stornext/HPCScratch/home/iskander.j/AF/output/query_colab/chains_h2.txt')


In [17]:
set_chain_bfactor(ip_path,op_path, r['plddt'],  chains,idx_res=f2["residue_index"], is_relaxed=False)

In [18]:
ip_path='/stornext/HPCScratch/home/iskander.j/AF/output/query_colab/relaxed_model_1.pdb'
op_path='/stornext/HPCScratch/home/iskander.j/AF/output/query_colab/relaxed_model_1_n.pdb'
set_chain_bfactor(ip_path,op_path, r['plddt'],  chains,idx_res=f2["residue_index"], is_relaxed=True)

In [29]:
import os, json
plddts={}
models=['model_1', 'mode1_2','model_3','model_4','model_5']
for model_name in models:
    m_path=f'/stornext/HPCScratch/home/iskander.j/AF/output/query_colab/result_{model_name}.pkl'
    if os.path.exists(m_path):
        result=pickle.load(open(m_path, 'rb'))
        plddts[model_name] = np.mean(result['plddt'])
ranked_order = []
for idx, (model_name, _) in enumerate(sorted(plddts.items(), key=lambda x: x[1], reverse=True)):
    ranked_order.append(model_name)
output_dir="./"
ranking_output_path = os.path.join(output_dir, 'ranking_debug.json')
with open(ranking_output_path, 'w') as f:
    f.write(json.dumps({'plddts': plddts, 'order': ranked_order}, indent=4))

In [24]:
plddts

{'model_1': 96.81348613588663}

In [27]:
ranked_order

['model_1']