In [None]:
import plumed
import mdtraj as md
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import json
import networkx as nx
matplotlib.rc('xtick', labelsize=20) 
matplotlib.rc('ytick', labelsize=20) 
matplotlib.rcParams['font.size'] = 20
matplotlib.rcParams['figure.figsize'] = (12, 8)
def slice(traj, selection):
    return traj.atom_slice(traj.top.select(selection))

In [None]:
colvar_path = "colvar_reweight.dat"
colvar = plumed.read_as_pandas(colvar_path) 
d1 = colvar['d1'].to_numpy()
rmsd = colvar['rmsd'].to_numpy()
lp = colvar['fps.lp'].to_numpy()
ld = colvar['fps.ld'].to_numpy()
m_traj = md.load('md_whole.xtc', top='md_whole.gro')
m_traj = slice(m_traj, 'not water and not name NA and not name CL')
m_traj.center_coordinates()
ref_lig = md.load('../a_insert/lig.pdb')
lig_nres = ref_lig.top.n_residues
total_nres = m_traj.top.n_residues
prot_nres = total_nres - lig_nres
dih_names = ['phi', 'psi', 'omega']
dihedrals = {'phi': [], 'psi': []}
protein = slice(m_traj, f"resid 0 to {prot_nres - 1}")
ligand = slice(m_traj, f"resid >= {prot_nres}")
try:
    colvar['omega1']
    dihedrals['omega'] = []
    n_dih = 3
except:
    n_dih = 2

for i in range(1, ligand.top.n_residues - 1):
    for j in range(n_dih):
        dihedrals[dih_names[j]].append(colvar[f'{dih_names[j]}{i}'].to_numpy())
        
bias = colvar['metad.bias']
kT = 2.479
bias_weights = np.exp(bias / kT) 
total = sum(bias_weights)
bias_weights /= total


In [None]:
ktkc = 0.5915
hist, bins = np.histogram(lp, bins=50, weights=bias_weights)
bins = bins[:-1]
params = json.load(open('params.json', 'r'))
rcyl = float(params['rcyl'])
alpha = float(params['alpha'])
cone_length = float(params['cone_length'])
radius = np.maximum(rcyl * np.ones(len(bins)), rcyl + alpha * (cone_length - bins))
potential = -ktkc * np.log(hist) + ktkc * np.log(np.pi * np.square(radius * 10) / 1660)
plt.plot(bins, potential)
plt.xlabel("Protein-Ligand Distance (nm)")
plt.ylabel("Free Energy (kcal/mol)")


In [None]:
ktkc = 0.5915
THRESHOLD = 0.6
fig, ax = plt.subplots(n_dih, ligand.top.n_residues - 2, figsize=[8*(ligand.top.n_residues - 2), 8*n_dih])
for i in range(n_dih):
    dname = dih_names[i]
    for j in range(ligand.top.n_residues - 2):
        cur_d = dihedrals[dname][j]
        hist, bins = np.histogram(cur_d[lp < THRESHOLD], bins=50, weights=bias_weights[lp < THRESHOLD])
        bins = bins[:-1]
        
        potential = -ktkc * np.log(hist) 
        potential -= np.min(potential)
        ax[i,j].plot(bins, potential, label='bound')
        hist, bins = np.histogram(cur_d[lp >= THRESHOLD], bins=50, weights=bias_weights[lp >= THRESHOLD])
        bins = bins[:-1]
        
        potential = -ktkc * np.log(hist) 
        potential -= np.min(potential)

        ax[i,j].plot(bins, potential, label='in solvent')
        ax[i,j].set_xlabel(f"$\\{dname}_{j+1}$")
        ax[i,j].set_ylabel("Free Energy (kcal/mol)")
        ax[i,j].legend()
plt.show()

In [None]:
from tqdm import tqdm
import itertools
lig_idxs = m_traj.top.select(f"resid >= {prot_nres}")
lig_contacts = np.zeros((protein.n_residues, m_traj.n_frames))
init_contacts = np.zeros(protein.n_residues)
docked = md.load('../a_insert/prot_lig.gro')
bw_contact_sums = np.zeros(protein.n_residues)
for i in tqdm(range(protein.n_residues)):
    res_idxs = m_traj.top.select(f'resid {i}')
    dists = md.compute_distances(m_traj, itertools.product(lig_idxs, res_idxs))
    init_dists = md.compute_distances(docked, itertools.product(lig_idxs, res_idxs))
    frame_contacts = np.any(dists < 0.35, axis=1)
    bw_contact_sums[i] = np.sum(bias_weights[::20][:m_traj.n_frames][frame_contacts])
    lig_contacts[i] = np.where(frame_contacts, 1, np.nan)
    init_contacts[i] = 1 if np.any(init_dists < 0.35) else np.nan
    plt.scatter(np.arange(m_traj.n_frames), i * lig_contacts[i])
plt.xlabel('Simulation time (ns)')
plt.ylabel('Residue Number')
plt.title('Residues in contact with ligand')
plt.show()
plt.bar(np.arange(protein.n_residues), bw_contact_sums)

plt.scatter(np.arange(protein.n_residues), 0.01 * init_contacts, marker='x', c='red', label='Contact after docking')
plt.xlabel('Residue Number')
plt.ylabel('Probability of Contact')
plt.legend()
plt.show()

In [None]:
import nglview as nv
view = nv.show_mdtraj(protein)
a = view.add_trajectory(ligand)
a.clear_representations()
a.add_representation("ball+stick")
view

In [None]:
plt.plot(lp)

In [None]:
def plot_fes(q1, q2, bw, xlabel='Q1', ylabel='Q2', title='FES'):
    probs, xedges, yedges = np.histogram2d(q1, q2, weights=bw)
    potential = -ktkc * np.log(probs)
    plt.contourf(potential.T, origin='lower', extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], levels=250, vmin=0, vmax=6, cmap='jet')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    c = plt.colorbar()
    c.set_label("Free energy (kcal/mol)")
    plt.show()
plot_fes(lp, ld, bias_weights, "Funnel axis projection", "Distance from funnel axis", "Funnel projection-distance FES")
plot_fes(d1, lp, bias_weights, "Protein-Ligand distance", "Funnel axis projection",  "FES")
plot_fes(d1, ld, bias_weights, "Protein-Ligand distance", "Distance from funnel axis", "Funnel projection-distance FES")
