In [1]:
import numpy as np
import seaborn as sns
import mdtraj as md
import matplotlib.pyplot as plt
import nglview as nv

# https://biopython.org/docs/1.74/api/Bio.SVDSuperimposer.html
from Bio.SVDSuperimposer import SVDSuperimposer

from numpy import array, dot, set_printoptions



In [2]:
import sys
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, '/Users/thor/surfdrive/Scripts/notebooks/HNS-sequence/WorkingDir/4_dimer/')
#sys.path.insert(1, '/Users/heesch/surfdrive/Scripts/notebooks/HNS-sequence/WorkingDir')
from gen_filament import *

%load_ext autoreload
%autoreload 2

In [2]:
# link to openmm manual for how to set up a minimizer
# http://docs.openmm.org/7.2.0/userguide/application.html

In [4]:
# Load H-NS s1s1 dimers
loc_dimers = './data/0_s1s1/drytrajs/'
short_trajs = [md.load(loc_dimers+f'dry_{i}.xtc',top=loc_dimers+f'dry_{i}.pdb').remove_solvent() for i in range(0,16)]
#start_open = md.load(loc_dimers+f'dry_open.xtc',top=loc_dimers+f'dry_open.pdb').remove_solvent()
#start_closed = md.load(loc_dimers+f'dry_closed.xtc',top=loc_dimers+f'dry_closed.pdb').remove_solvent()

#s1s1 = md.join([start_open,start_closed,md.join(short_trajs)])
s1s1 = md.join(short_trajs)

# Load H-NS s2s2 dimers
loc_dimers = './data/1_s2s2/drytrajs/'
short_trajs = [md.load(loc_dimers+f'dry_{i}.xtc',top=loc_dimers+f'dry_{i}.pdb').remove_solvent() for i in range(0,10)]

s2s2 = md.join(short_trajs)

In [5]:
def check_selection(top,selection):
    if selection == 'CA':
        indices = top.select('name CA')
    elif selection == 'backbone':
        indices = top.select('backbone')
    elif selection == 'sidechain':
        indices = top.select('sidechain')
    else:
        indices = top.select('all')   
    return indices 

def get_monomer_domain_indices(top,domain,chain=0,selection=None):
    residues = np.array(top._chains[chain]._residues)
    indices = check_selection(top,selection)
    return [at.index for res in residues[domain] for at in res.atoms if at.index in indices]
    

def show_domain(system,domain):
    # shows first frame
    top = system.top
    view = nv.show_mdtraj(system[0])
    view.clear()
    indices = get_monomer_domain_indices(top,domains[domain],chain=0)
    view.add_representation('cartoon',selection=[i for i in  top.select('all') if i not in indices],color='cornflowerblue')
    top = system.topology
    chain_id = 0
    indices = get_monomer_domain_indices(top,domains[domain],chain=chain_id)
    view.add_representation('cartoon',selection=indices,color='gold')
    top = system.topology
    chain_id = 1
    indices = get_monomer_domain_indices(top,domains[domain],chain=chain_id)
    view.add_representation('cartoon',selection=indices,color='red')
    return view


def get_segment_structures(traj,segments,site='dbd'):
    chain_a = get_monomer_domain_indices(top=traj.top, domain=segments[site], chain=0, selection=None)
    chain_b = get_monomer_domain_indices(top=traj.top, domain=segments[site], chain=1, selection=None)
    A = traj.atom_slice(chain_a)
    B = traj.atom_slice(chain_b)
    return md.join([A,B])

def get_site_structures(traj,segements,site='s1'):
    chain_a = get_monomer_domain_indices(top=traj.top, domain=segments[site], chain=0, selection=None)
    chain_b = get_monomer_domain_indices(top=traj.top, domain=segments[site], chain=1, selection=None)
    return traj.atom_slice(np.sort(chain_a+chain_b))

pairs = [['s1','h3'],
         ['h3','s2'],
         ['s2','l2'],
         ['l2','dbd']]
n = 2

segments = {'s1':np.arange(0,41+n),
            'h3':np.arange(41-n,53+n),
            's2':np.arange(53-n,82+n),
            'l2':np.arange(82-n,95+n),
            'dbd':np.arange(95-n,137)}
k = 100
s1 = get_site_structures(s1s1,segments,site='s1')[::k]
s2 = get_site_structures(s2s2,segments,site='s2')[::k]

h3_s1s1 = get_segment_structures(s1s1,segments,site='h3')
h3_s2s2 = get_segment_structures(s2s2,segments,site='h3')
h3 = md.join([h3_s1s1,h3_s2s2])[::k]

l2_s1s1 = get_segment_structures(s1s1,segments,site='l2')
l2_s2s2 = get_segment_structures(s2s2,segments,site='l2')
l2 = md.join([l2_s1s1,l2_s2s2])[::k]

dbd_s1s1 = get_segment_structures(s1s1,segments,site='dbd')
dbd_s2s2 = get_segment_structures(s2s2,segments,site='dbd')
dbd = md.join([dbd_s1s1,dbd_s2s2])[::k]

In [9]:
def add_pair(traj,pair,site_map,leading_chain=0,adding_chain=0,verbose=False,reverse=False,segment='fixed'):
    keep_resSeq = False
    A,B,C=None,None,None
    site_a, site_b = pair
    if verbose:
        print(site_a,site_b)
        
    if not traj:
        if segment == 'fixed':
            x,y = 40,90
        elif segment == 'random':
            k = len(site_map[site_a])
            l = len(site_map[site_b])
            x,y = np.random.randint(0,k,1)[0],np.random.randint(0,l,1)[0]
        A = site_map[site_a][x]
        B = site_map[site_b][y]
    else:
        if segment == 'fixed':
            z = 20
        elif segment == 'random':
            k = len(site_map[site_b])
            z = np.random.randint(0,k,1)[0]
        A = traj
        B = site_map[site_b][z]
        
    # check if site had dimerization site
    dimer_a = check_if_dimerization(site_a)
    dimer_b = check_if_dimerization(site_b)
    if verbose:
        print(dimer_a, dimer_b)
        
    # get_termini of site a and b
    terminus_a, terminus_b = get_termini(site_a,site_b)
    
    # determine growth direction (forward, or reverse)
    if terminus_a == 'C_terminus':
        reverse = True
    else:
        reverse = False
        
    # get atom indices of overlapping segements
    overlap_A = get_overlap_indices(A.top,n,chain=leading_chain,terminus=terminus_a)
    overlap_B = get_overlap_indices(B.top,n,chain=adding_chain,terminus=terminus_b)
    
    # make sure overlapping indices are consistent
    check = check_overlaps(overlap_A,overlap_B)
    if check:
        return check
    
    # obtain superimposition of B on A
    new_B = fit_B_on_A(A,B,overlap_A,overlap_B)
    
    # remove overlapping selection from A used for fit
    new_A = remove_overlap(A,overlap_A)
    
    # splits topology in leading chain and remainder (not leading chain(s))
    A_active, A_passive = split_chain_topology(new_A,leading_chain)
    
    if dimer_b:
        # splits topology in leading chain and remainder (not adding chain(s))
        B_active, B_passive = split_chain_topology(new_B,adding_chain)
            
        # add B to active part of A (and make sure they are in same chain)
        if reverse:
            temp = merge_chain_topology(B_active,A_active,keep_resSeq=keep_resSeq)
        else:
            temp = merge_chain_topology(A_active,B_active,keep_resSeq=keep_resSeq)
            
        C_temp = temp.stack(A_passive,keep_resSeq=keep_resSeq)
        C =  C_temp.stack(B_passive,keep_resSeq=keep_resSeq)
    else:
        # add B to active part of A (and make sure they are in same chain)
        if reverse:
            temp = merge_chain_topology(new_B,A_active)
        else:
            temp = merge_chain_topology(A_active,new_B)
        # combine passive part with new structure (active part of A and B)
        C = temp.stack(A_passive,keep_resSeq=keep_resSeq)
        
    return C

def add_dimer(traj, chainid = 0, verbose=False, segment='random'):
    
    s1_pairs = [['s1', 'h3'],['h3','s2'],['s2','l2'],['l2','dbd']]
    s2_pairs = [['s2','h3'],['h3','s1'],['s2', 'l2'],['l2','dbd']]
    
    for idx,pair in enumerate(s2_pairs):
        if idx > 0:
            leading_chain = 0
        else:
            leading_chain = chainid
        traj = add_pair(traj,pair,site_map,leading_chain=leading_chain,verbose=verbose,segment=segment)

    for idx,pair in enumerate(s1_pairs):
        if idx > 0:
            leading_chain = 0
        else:
            leading_chain =  chainid + 2
        traj = add_pair(traj,pair,site_map,leading_chain=leading_chain,verbose=verbose,segment=segment)
    return traj

In [8]:
site_map = {'s1':s1,
           'h3':h3,
           's2':s2,
           'l2':l2,
           'dbd':dbd}

In [12]:
traj = None
dimers = 8

i = 0
for idx in range(dimers):
    print(idx)
    traj = add_dimer(traj,i,segment='fixed')
    i+=2

print([c.n_residues for c in traj.top.chains])

0
1
2
3
4
5
6
7
[137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 137, 33, 33]


In [13]:
view = nv.show_mdtraj(traj.atom_slice(traj.top.select(f'chainid 0 to {(dimers*2)-1}')))
view

NGLWidget()

In [14]:
traj.atom_slice(traj.top.select(f'chainid 0 to {(dimers*2)-1}')).save('./octa_dimer.pdb')

In [92]:
view = nv.show_mdtraj(traj.atom_slice(traj.top.select(f'chainid 0 to {(dimers*2)-1}')))
view

NGLWidget()

In [69]:
view = nv.show_mdtraj(traj.atom_slice(traj.top.select(f'chainid 0 to {(dimers*2)-1}')))
view

NGLWidget()

In [15]:
view = nv.show_mdtraj(s1s1)
view

NGLWidget(max_frame=20015)

In [18]:
s1s1[850].save('./dimer_open.pdb')