In [1]:
import os
from pathlib import Path

from refine import *
from utils import get_random_direction

import time
import datetime as dt
import re

import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

import pickle

import project
project.setup()

In [2]:
source_dir = project.data_path / 'benchmark' / 'modeled_pp5'
output_dir = project.output_path / 'benchmark_pp5_test'
fluctuations_dir = output_dir / 'fluctuations'
plots_dir = output_dir / 'plots'
sessions_dir = output_dir / 'sessions'
trajectories_dir = output_dir / 'trajectories'

pp_benchmark_dir = Path('/home/semyon/mipt/GPCR-TEAM/pp_benchmark_v5')
pp_benchmark_structures = pp_benchmark_dir / 'benchmark5' / 'structures'
pp_benchmark_table_path = pp_benchmark_dir / 'Table_BM5.xlsx'

amber = 'amber14/protein.ff14SB.xml'
charmm = 'charmm36.xml'
forcefield_name = amber

In [3]:
def read_benchmark_table(table_path):
    table_bm5 = pd.read_excel(table_path)
    regex_difficulty = re.compile(r'(?P<name>.+) \((?P<count>\d+)\)')
    regex_complex = re.compile(r'(?P<name>[0-9A-Z]+)_(?P<chainsR>[A-Z]+):(?P<chainsL>[A-Z]+)')
    table_bm5.columns = table_bm5.iloc[1]
    table_bm5.drop([0, 1], axis=0, inplace=True)
    table_bm5['Difficulty'] = None
    table_bm5['Complex ID'] = None
    table_bm5['Chains R'] = None
    table_bm5['Chains L'] = None
    table_bm5.reset_index(inplace=True, drop=True)
    difficulty = None
    drop_ids = []
    for row in table_bm5.iterrows():
        idx, value = row
        m_difficulty = regex_difficulty.match(value.Complex)
        m_complex = regex_complex.match(value.Complex)
        table_bm5.loc[idx, 'Difficulty'] = difficulty
        if m_difficulty:
            difficulty = m_difficulty.group('name')
            drop_ids.append(idx)
        if m_complex:
            table_bm5.loc[idx, 'Complex ID'] = m_complex.group('name')
            table_bm5.loc[idx, 'Chains R'] = m_complex.group('chainsR')
            table_bm5.loc[idx, 'Chains L'] = m_complex.group('chainsL')
    table_bm5.drop(drop_ids, axis=0, inplace=True)
    table_bm5.set_index('Complex', inplace=True, drop=True)
    return table_bm5

def get_structure_paths(idx, table, structure_path):
    structure_path = Path(structure_path)
    value = table.loc[idx]
    paths = []
    for a in ['r', 'l']:
        for b in ['u', 'b']:
            paths.append(str(structure_path / f'{value["Complex ID"]}_{a}_{b}.pdb'))
    return paths

In [4]:
table_bm5 = read_benchmark_table(pp_benchmark_table_path)
table_bm5.head()

1,Cat.,PDB ID 1,Protein 1,PDB ID 2,Protein 2,I-RMSD (Å),ΔASA(Å2),BM version introduced,Difficulty,Complex ID,Chains R,Chains L
Complex,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
1AHW_AB:C,A,1FGN_LH,Fab 5g9,1TFH_A,Tissue factor,0.69,1899,2,Rigid-body,1AHW,AB,C
1BVK_DE:F,A,1BVL_BA,Fv Hulys11,3LZT_,HEW lysozyme,1.24,1321,2,Rigid-body,1BVK,DE,F
1DQJ_AB:C,A,1DQQ_CD,Fab Hyhel63,3LZT_,HEW lysozyme,0.75,1765,2,Rigid-body,1DQJ,AB,C
1E6J_HL:P,A,1E6O_HL,Fab,1A43_,HIV-1 capsid protein p24,1.05,1245,2,Rigid-body,1E6J,HL,P
1JPS_HL:T,A,1JPT_HL,Fab D3H44,1TFH_B,Tissue factor,0.51,1852,2,Rigid-body,1JPS,HL,T


In [None]:
get_structure_paths('1DQJ_AB:C', table_bm5, pp_benchmark_dir / 'structures')

In [None]:
pdb_file = str(source_dir / '20T3_u_modeled.pdb')
n_modes = 10
cutoff = 6.5
max_rmsd = 4.0

In [None]:
# create protein complex data structure
omm_structure = app.PDBFile(pdb_file)
chains = list(omm_structure.topology.chains())
selections = [f"chain {chain.id}" for chain in chains]
pc = ProteinComplex(pdb_file, forcefield_name, selections, cid='1ay7_complex')

# create restriction
mode_params = [
    {'nmodes': n_modes, 'cutoff': cutoff},
    {'nmodes': n_modes, 'cutoff': cutoff}
]
rw = RMRestrictionWrapper(pc, mode_params)
zero_pos = [rw.get_position(i) for i in range(len(rw))]

In [None]:
# plot normal modes
fig, ax = plt.subplots(nrows=n_modes, ncols=len(rw), figsize=(12, 30))
y_lim = 0

for j in range(len(rw)):
    mode_magnitude = np.sum(rw._anms[0][0].getArrayNx3() ** 2, axis=1) ** 0.5
    mode_density = np.zeros_like(mode_magnitude) + 1 / len(mode_magnitude)
    best_entropy = -np.sum(np.log(mode_density) * mode_density)
    for i in range(n_modes):
        mode_magnitude = np.sum(rw._anms[j][i].getArrayNx3() ** 2, axis=1) ** 0.5
        mode_density = mode_magnitude / np.sum(mode_magnitude)
        entropy = -np.sum(np.log(mode_density) * mode_density)
        max_val = np.max(mode_magnitude) * 1.1
        if max_val > y_lim:
            y_lim = max_val
        ax[i, j].set_title(f'Mode: {i + 1} Molecule: {j + 1} Entropy Score = {best_entropy - entropy:.3f}')
        ax[i, j].set_xlabel('Atom Number')
        ax[i, j].set_ylabel('Mode Magnitude')
        ax[i, j].plot(mode_magnitude, color='blue')
        ax[i, j].grid()
for i in range(n_modes):
    for j in range(len(rw)):
        ax[i, j].set_ylim((0, y_lim))
plt.tight_layout()

In [None]:
bond_lengths_0 = pc.get_bond_lengths()

In [None]:
def multimolecule_confined_gradient_descent(
        rw, decrement=0.9, relative_bounds_r=(0.01, 3), relative_bounds_s=(0.01, 0.5),
        max_iter=100, save_history=False, extended_result=False, log=False, mode=CGDMode.BOTH):
    """
    Performs gradient descent with respect to a special confinement


    @param rw: system to optimize.
    @type rw: RMRestrictionWrapper
    @param decrement: fold step when choosing optimal step
    @type decrement: float
    @param relative_bounds_r: minimum and maximum rmsd between actual intermediate state and the next one (rigid)
    @type relative_bounds_r: tuple
    @param relative_bounds_s: minimum and maximum rmsd between actual intermediate state and the next one (modes)
    @type relative_bounds_s: tuple
    @param max_iter: maximum number of iterations
    @type max_iter: int
    @param save_history: if true all intermediate states, energies and forces are returned.
        Otherwise, the function returns only final record
    @type save_history: bool
    @param extended_result: if true additional information is returned
    @type extended_result: bool
    @param log: if true a log will be printed to the console
    @type log: bool
    @param mode: there are three modes RIGID, FLEXIBLE and BOTH
    @type mode: CGDMode
    @return: dictionary containing all the results.
        'states' - list of all states along optimization path
        'energies' - list of all energies along optimization path
        'forces' - list of all forces along optimization path
        If return_traj is false returns only last record
    @rtype: dict
    """
    if log:
        logger.setLevel(logging.INFO)
    else:
        logger.setLevel(logging.ERROR)

    logger.info('INITIALIZATION SATGE')
    optimization_result = MCGOptimizationResult(extended=extended_result)

    optimization_result.update_main(rw)

    k = 0
    M = len(rw)
    while k < max_iter:
        logger.info(f'ITERATION {k} START')

        tau_list = []
        mtau_list = []
        iinv_t_list = []
        w_list = []
        logger.info('Advancement region computation'.upper())
        for i in range(M):
            logger.info(f'SELECTION:{i}')
            record = optimization_result.get_main(i)
            position = record['position']
            force = record['force']
            energy = record['energy']
            coord = record['coords']

            f_trans = force[0]
            torque = force[1]
            inertia_inv = np.linalg.inv(force[2])
            f_modes = force[3]
            w = np.sum(rw._weights[0])
            w_list.append(w)
            iinv_t = inertia_inv.dot(torque)
            iinv_t_list.append(iinv_t)
            iinv_t24 = iinv_t.dot(iinv_t) / 4
            ft24w = f_trans.dot(f_trans) / 4 / w
            tit = torque.dot(iinv_t)
            wrmsd20 = w * relative_bounds_r[0] ** 2
            wrmsd21 = w * relative_bounds_r[1] ** 2
            a = ft24w * iinv_t24
            b0 = ft24w + tit - wrmsd20 * iinv_t24
            c0 = -wrmsd20
            b1 = ft24w + tit - wrmsd21 * iinv_t24
            c1 = -wrmsd21
            roots0 = np.roots([a, b0, c0])
            roots1 = np.roots([a, b1, c1])
            tau0 = np.max(roots0) ** 0.25
            tau1 = np.max(roots1) ** 0.25
            tau_list.append((tau0, tau1))
            logger.info(f'tau: {tau0}, {tau1}')

            mcoeff = (4 * w / f_modes.dot(f_modes)) ** 0.25
            mtau0 = relative_bounds_s[0] ** 0.5 * mcoeff
            mtau1 = relative_bounds_s[1] ** 0.5 * mcoeff
            mtau_list.append((mtau0, mtau1))
            logger.info(f'mtau: {mtau0}, {mtau1}')
        # get minimal interval
        logger.info('MINIMAL ADVANCEMENT REGION')
        min_tau0 = np.min([t0 for t0, t1 in tau_list])
        min_tau1 = np.min([t1 for t0, t1 in tau_list])
        min_mtau0 = np.min([mt0 for mt0, mt1 in mtau_list])
        min_mtau1 = np.min([mt1 for mt0, mt1 in mtau_list])
        logger.info(f'TAU: {min_tau0}, {min_mtau1}')
        logger.info(f'MTAU: {min_mtau0}, {min_mtau1}')
        # for all selections
        logger.info('LINEAR SEARCH')

        # extended results
        tdiff_list = []
        rrmsd_list = []
        frmsd_list = []
        qdiff_list = []
        mdiff_list = []



        logger.info('SYSTEM EVOLUTION')
        for i in range(M):
            logger.info(f'SELECTION:{i}')
            record = optimization_result.get_main(i)
            iinv_t = iinv_t_list[i]
            w = w_list[i]
            position = record['position']
            force = record['force']
            energy = record['energy']
            coords = record['coords']

            f_trans = force[0]
            torque = force[1]
            f_modes = force[3]

            tau0 = min_tau0
            tau1 = min_tau1
            mtau0 = min_mtau0
            mtau1 = min_mtau1

            while (tau1 > tau0 or mode == CGDMode.FLEXIBLE) and (mtau1 > mtau0 or mode == CGDMode.RIGID):
                tdiff = tau1 ** 2 / w / 2 * f_trans
                qdiff = np.quaternion(1, *(tau1 ** 2 / 2 * iinv_t))
                qdiff /= qdiff.norm()
                mdiff = mtau1 ** 2 / 2 * f_modes
                trans = position[0]
                quat = position[1]
                modes = position[2]
                if mode == CGDMode.BOTH:
                    trans = tdiff + trans
                    quat = qdiff * quat
                    modes = mdiff + modes
                elif mode == CGDMode.RIGID:
                    trans = tdiff + trans
                    quat = qdiff * quat
                elif mode == CGDMode.FLEXIBLE:
                    modes = mdiff + modes
                new_pos = [trans, quat, modes]

                # rmsd
                if extended_result:
                    tdiff_list.append(tdiff)
                    qdiff_list.append(qdiff)
                    mdiff_list.append(mdiff)
                    logger.info(f'EXT::tdiff: {tdiff}')
                    logger.info(f'EXT::qdiff: {qdiff}')
                    logger.info(f'EXT::mdiff: {mdiff}')
                    coords1 = coords
                    test_pos = [trans, quat, position[2]]
                    rw.set_position(i, test_pos)
                    coords2 = rw._protein_complex.get_coords(i)
                    weights = rw._weights[i]
                    rrmsd = rmsd(coords1, coords2, weights)
                    rrmsd_list.append(rrmsd)
                    logger.info(f'EXT::RRMSD: {rrmsd}')
                    frmsd = ((mdiff.dot(mdiff)) / np.sum(weights)) ** 0.5
                    frmsd_list.append(frmsd)
                    logger.info(f'EXT::FRMSD: {frmsd}')

                rw.set_position(0, new_pos)
                new_energy = rw.get_energy()
                tau1 *= decrement * decrement
                mtau1 *= decrement * decrement
                logger.info(f'NEW ENERGY: {new_energy}')
                if new_energy < energy:
                    break

        logger.info(f'ITERATION {k} END')
        current_energy = rw.get_energy()
        previous_energy = optimization_result.get_energy()
        if previous_energy < current_energy:
            optimization_result.set_status(False)
            break
        optimization_result.update_main(rw)
        if extended_result:
            optimization_result.update_translation_diff(tdiff_list)
            optimization_result.update_rotation_diff(qdiff_list)
            optimization_result.update_mode_diff(mdiff_list)
            optimization_result.update_rigid_rmsd(rrmsd_list)
            optimization_result.update_flexible_rmsd(frmsd_list)
        k += 1
    return optimization_result

rw._protein_complex.get_force(0)

iterator = rw._protein_complex._compartments[1].iterAtoms()
atoms = list(iterator)
print(len(atoms) * 3)

In [None]:
for i in range(len(rw)):
    rw.set_position(i, zero_pos[i])

In [None]:
result = multimolecule_confined_gradient_descent(rw, save_history=True, extended_result=True, log=True)

plt.plot(result._optimization_result['energies'])
plt.grid()

In [None]:
with open(output_dir / 'opt.pdb', 'w') as handle:
    pc.to_pdb(handle)

# introduce random permutations
for i in range(len(rw)):
    rw.set_position(i, zero_pos[i])
bond_lengths_0 = pc.get_bond_lengths()
angle = 10 * np.pi / 180
for i in range(len(rw)):
    time_start = time.time()
    count_attempts = 0
    native_coords = pc.get_coords(i)
    # weight for rmsd
    weights = np.ones(len(native_coords))
    weights /= np.sum(weights)
    # translation vector - random vector pointing to one of the molecule centers
    k = np.random.choice([j for j in range(len(rw)) if j != i])
    translation = (rw._c_tensors[k] - rw._c_tensors[i])
    while True:
        count_attempts += 1
        translation = ((translation / np.linalg.norm(translation)) + np.random.uniform(-0.1, 0.1, 3)) * 1.5
        translation = np.zeros_like(rw._c_tensors[i])
        direction = get_random_direction()
        rotation = np.quaternion(np.cos(angle / 2), *(direction * np.sin(angle / 2)))
        rotation = np.quaternion(1, 0, 0, 0)
        modes = np.random.uniform(-1.4, 1.4, n_modes)
        modes /= np.linalg.norm(modes)
        modes *= 2 * np.sum(rw._weights[i]) ** 0.5
        # modes = np.zeros(n_modes)
        rw.set_position(i, [translation, rotation, modes])
        bond_lengths_1 = pc.get_bond_lengths()
        min_0 = np.min(bond_lengths_0)
        min_1 = np.min(bond_lengths_1)
        max_0 = np.max(bond_lengths_0)
        max_1 = np.max(bond_lengths_1)
        bond_deviation = bond_lengths_1 - bond_lengths_0
        bond_relative_deviation = np.abs(bond_deviation) / bond_lengths_0
        max_d = np.max(np.abs(bond_deviation))
        rmsbd = np.mean((bond_deviation) ** 2) ** 0.5
        if rmsd(native_coords, pc.get_coords(i), weights) <= max_rmsd and\
            np.max(bond_relative_deviation) < 0.2:
            print(f'Final rmsd: {rmsd(native_coords, pc.get_coords(i), weights)}')
            print(f'Attempt count: {count_attempts}')
            print(f'Elapsed time: {dt.timedelta(seconds=time.time() - time_start)}')
            break
#         else:
#             angle *= 0.9
        if np.max(bond_relative_deviation) >= 0.2:
            print(f'The maximum relative deviation: {np.max(bond_relative_deviation)}')
            print(f'RMSD: {rmsd(native_coords, pc.get_coords(i), weights)}')

bond_lengths_1 = pc.get_bond_lengths()
min_0 = np.min(bond_lengths_0)
min_1 = np.min(bond_lengths_1)
max_0 = np.max(bond_lengths_0)
max_1 = np.max(bond_lengths_1)
bond_deviation = bond_lengths_1 - bond_lengths_0
bond_relative_deviation = np.abs(bond_deviation) / bond_lengths_0
max_d = np.max(np.abs(bond_deviation))
rmsbd = np.mean((bond_deviation) ** 2) ** 0.5
print(f'The shortest bond 0: {min_0}')
print(f'The shortest bond 1: {min_1}')
print(f'The longest bond 0: {max_0}')
print(f'The longest bond 1: {max_1}')
print(f'The maximum deviation: {max_d}')
print(f'Root mean square bond deviation: {rmsbd}')
with open(output_dir / 'a_multi_test_pc.pdb', 'w') as handle:
    pc.to_pdb(handle)

plt.plot(bond_relative_deviation)
plt.grid()
plt.show()

# introduce random permutations
for i in range(len(rw)):
    rw.set_position(i, zero_pos[i])
bond_lengths_0 = pc.get_bond_lengths()
max_deviation_list = []
for i in range(len(rw) - 1):
    time_start = time.time()
    count_attempts = 0
    
    native_coords = pc.get_coords(i)
    # weight for rmsd
    weights = np.ones(len(native_coords))
    weights /= np.sum(weights)
    while True:
        count_attempts += 1
        translation = np.zeros_like(rw._c_tensors[i])
        rotation = np.quaternion(1, 0, 0, 0)
        
        modes = np.random.uniform(-1.4, 1.4, n_modes)
        modes /= np.linalg.norm(modes)
        modes *= 2 * np.sum(rw._weights[i]) ** 0.5
        
        rw.set_position(i, [translation, rotation, modes])
        
        bond_lengths_1 = pc.get_bond_lengths()
        bond_deviation = bond_lengths_1 - bond_lengths_0
        bond_relative_deviation = np.abs(bond_deviation) / bond_lengths_0
        max_d = np.max(np.abs(bond_deviation))
        max_deviation_list.append(max_d)
        if count_attempts >= 200:
            break
        else:
            print(f'Iteration: {count_attempts} Elapsed: {dt.timedelta(seconds=time.time() - time_start)}', end='\r')

print('Mean:', np.mean(max_deviation_list))
print('STD:', np.std(max_deviation_list))
print('Min:', np.min(max_deviation_list))

plt.figure(figsize=(10, 10))
plt.hist(max_deviation_list, bins=50)
plt.title('Histogram of maximum relative bond deviation')
plt.ylabel('Count')
plt.xlabel('max_rbd')
plt.grid()
plt.savefig(project.output_path / 'max_rbd_hist/hist_10_0_3rij3_2.png', fmt='png', dpi=300)

from scipy.optimize import minimize

def cost_function(x, rw, init_bond_lengths, a, rmsd_bound, ind):
    rw.set_position(ind, [np.zeros(3), np.quaternion(1, 0, 0, 0), x])
    pc = rw._protein_complex
    bond_lengths = pc.get_bond_lengths()
    rel_bond_dev = (bond_lengths - init_bond_lengths) / init_bond_lengths
    t0 = (rmsd_bound - np.sum(x ** 2) / np.sum(rw._weights[ind])) ** 2
    t1 = a * np.mean(rel_bond_dev ** 2)
    return t0 + t1

def callback(xk):
    if callback.iter == 0:
        callback.time_start = time.time()
    print(f'Iteration: {callback.iter + 1} Elapsed: {dt.timedelta(seconds=time.time() - callback.time_start)}', end='\r')
    callback.iter += 1
callback.time_start = 0
callback.iter = 0

a = 200
rmsd_bound = 1
ind = 0

for i in range(len(rw)):
    rw.set_position(i, zero_pos[i])
x0 = np.random.uniform(-1, 1, n_modes)
x0 /= np.linalg.norm(x0)
x0 *= np.sum(rw._weights[ind]) ** 0.5
init_bond_lengths = pc.get_bond_lengths()

results = []
for k in range(20):
    a = 200
    rmsd_bound = 1
    ind = 0
    
    for i in range(len(rw)):
        rw.set_position(i, zero_pos[i])
    x0 = np.random.uniform(-1, 1, n_modes)
    x0 /= np.linalg.norm(x0)
    x0 *= np.sum(rw._weights[ind]) ** 0.5
    init_bond_lengths = pc.get_bond_lengths()
    
    res = minimize(cost_function, x0, args=(rw, init_bond_lengths, a, rmsd_bound, ind),
                   method=None,
                   options={'maxiter': 40}
                  )
    results.append(res)
    print(f'k = {k}')

for r in results:
#     print('res!!!')
#     print(r.fun)
    cost_function(r.x, rw, init_bond_lengths, a, rmsd_bound, ind)
#     print(rmsd(native_coords, pc.get_coords(0), weights))
    rel_bond_dev = (pc.get_bond_lengths() - init_bond_lengths) / init_bond_lengths
#     print(np.mean(rel_bond_dev ** 2) * a)
    print(np.max(np.abs(rel_bond_dev)))

cost_function(x0, rw, init_bond_lengths, a, rmsd_bound, ind)

rmsd(native_coords, pc.get_coords(0), weights)

rel_bond_dev = (pc.get_bond_lengths() - init_bond_lengths) / init_bond_lengths

np.mean(rel_bond_dev ** 2) * 1000

plt.plot(np.abs(rel_bond_dev))
plt.grid()