In [None]:
'''
This script is to extract data from MMPBSA calculations of complexes and compare them with experimental values.
The results can also be combined with entropy calculations. Plots are generated to show the linear fit of the data.
'''
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress
plt.rcParams['font.family'] = 'Times New Roman'
root_dir = "/Directory"
forcefields = ["mmpbsa_data_gaff","mmpbsa_data_gaff2"]
molecules = ["4ntj_azj", "4ntj_psb", "4ntj_tiq", "4pxz", "4py0"]
'''Used for antagonist systems'''
# molecules = ["4ntj_azj", "4ntj_psb", "4ntj_tiq"]
suffixes = ["", "_muta"]
methods = ["STM", "MTM"]

def extract_delta_total(filepath):
    with open(filepath, 'r') as f:
        for line in f:
            if "DELTA TOTAL" in line:
                return float(line.split()[2]) 
    return None

results = {}

for forcefield in forcefields:
    for molecule in molecules:
        for suffix in suffixes:
            group = f"{molecule}{suffix}"
            for method in methods:
               
                data_dir = os.path.join(root_dir, forcefield, group, method)

                if not os.path.exists(data_dir):
                    print(f"Warning: {data_dir} does not exist.")
                    continue

                delta_totals = []
                for filename in os.listdir(data_dir):
                    if filename.endswith(".dat"):
                        filepath = os.path.join(data_dir, filename)
                        delta = extract_delta_total(filepath)
                        if delta is not None:
                            delta_totals.append(delta)

                if delta_totals:
                    avg_delta = np.mean(delta_totals)
                else:
                    avg_delta = None

                key = f"{forcefield}_{group}_{method}"
                if molecule == "4pxz" and forcefield == "mmpbsa_data_gaff2":
                    avg_delta = avg_delta + 9.47088 # This correction comes from energy penalty of Mg2+ unbinding
                if molecule == "4py0" and forcefield == "mmpbsa_data_gaff2":
                    avg_delta = avg_delta + 10.320759999999998 # This correction comes from energy penalty of Mg2+ unbinding
                if molecule == "4pxz" and forcefield == "mmpbsa_data_gaff":
                    avg_delta = avg_delta + 7.1838 # This correction comes from energy penalty of Mg2+ unbinding
                if molecule == "4py0" and forcefield == "mmpbsa_data_gaff":
                    avg_delta = avg_delta + 8.538059999999998 # This correction comes from energy penalty of Mg2+ unbinding
                results[key] = avg_delta

for key, value in results.items():
    print(f"{key}: {value:.2f}")

output_file = os.path.join(root_dir, "binding_energy_results.txt")
with open(output_file, 'w') as f:
    for key, value in results.items():
        f.write(f"{key}: {value}\n")

print(f"Results saved to {output_file}")

'''Complete array of experimental values'''
y_values = [-10.03189878083093, -10.097444252847287, -9.277297563761575, -9.284889393049388, -9.189800281871705, -8.895832909220593, -9.398224730584465, -9.598301607755392, -9.58610503903495, -9.772775856543346]

'''Array of experimental values of antagonist systems'''
# y_values = [-10.03189878083093, -10.097444252847287, -9.277297563761575, -9.284889393049388, -9.189800281871705, -8.895832909220593]

'''Complete Results of Truncated NMA''' 
# entropy_gaff2_STM = [-19.706475633070603, -15.509997954049975, -15.794099949689755, -16.705910984403822, -20.47708274358545, -18.73464480965957, -25.379030320308573, -25.139216032198558, -27.304346201576386, -25.49291299681369]
# entropy_gaff2_MTM = [-20.004925406674495, -14.409341975515682, -14.975811839678013, -16.402403756498405, -18.873190072111353, -18.067850444407178, -38.30870209626026, -38.616792017440886, -40.414961898373306, -36.63002206942815]
# entropy_gaff_STM = [-13.643612409860808, -15.900626329029015, -15.794099949689755, -18.101658426966292, -19.38846570518196, -19.365702163340604, -25.717205098104984, -25.54585134999162, -28.033948951869867, -27.875628643300352]
# entropy_gaff_MTM = [-15.738323025322824, -16.479360020124098, -14.975811839678013, -17.055415059533793, -19.090100888814355, -20.16550457823243, -39.814629079322486, -40.727549555592816, -40.8463547543183, -42.81409428140198]

'''Results of Truncated NMA of antagonist systems'''
# entropy_gaff2_STM = [-19.706475633070603, -15.509997954049975, -15.794099949689755, -16.705910984403822, -20.47708274358545, -18.73464480965957]
# entropy_gaff2_MTM = [-20.004925406674495, -14.409341975515682, -14.975811839678013, -16.402403756498405, -18.873190072111353, -18.067850444407178]
# entropy_gaff_STM = [-13.643612409860808, -15.900626329029015, -15.794099949689755, -18.101658426966292, -19.38846570518196, -19.365702163340604]
# entropy_gaff_MTM = [-15.738323025322824, -16.479360020124098, -14.975811839678013, -17.055415059533793, -19.090100888814355, -20.16550457823243]

conditions = ["gaff_STM", "gaff_MTM", "gaff2_STM", "gaff2_MTM"]

for condition in conditions:
    '''Help distribute the entropy values to the corresponding conditions'''
    # if condition == "gaff_STM":
    #     entropy_condition = entropy_gaff_STM
    # elif condition == "gaff_MTM":
    #     entropy_condition = entropy_gaff_MTM
    # elif condition == "gaff2_STM":
    #     entropy_condition = entropy_gaff2_STM
    # elif condition == "gaff2_MTM":
    #     entropy_condition = entropy_gaff2_MTM
    # else:
    #     print(f"Unknown condition: {condition}. Skipping.")
    #     continue

    x_values = []
    for molecule in molecules:
        for suffix in suffixes:
            key = f"mmpbsa_data_{condition.split('_')[0]}_{molecule}{suffix}_{condition.split('_')[1]}"
            if key in results and results[key] is not None:
                x_values.append(results[key])

    '''If need to include entropy in results, add entropy to x_values'''
    # if len(x_values) != len(entropy_condition):
    #     print(f"Warning: Mismatch in x_values and entropy_condition for {condition}. Skipping plot.")
    #     continue

    # x_values = [x - entropy for x, entropy in zip(x_values, entropy_condition)]
    # if len(x_values) != len(y_values):
    #     print(f"Warning: Mismatch in data points for {condition}. Skipping plot.")
    #     continue

    slope, intercept, r_value, p_value, std_err = linregress(x_values, y_values)
    line = [slope * x + intercept for x in x_values]
    r_squared = r_value**2

    plt.figure(figsize=(8, 6))
    default_color = '#1a80bb'
    highlight_color = '#36b700'
    plt.plot(x_values, line, color='#a00000', linewidth=2, label=f'R = {r_value:.2f}')
    plt.scatter(x_values[:-4], y_values[:-4], color=default_color, s=40, label="Antagonist Systems")
    plt.scatter(x_values[-4:], y_values[-4:], color=highlight_color, s=40, label="Agonist Systems")
    plt.xlabel("∆G from MMPBSA calculation (kcal/mol)", fontsize=14)
    plt.ylabel("∆G from experiments (kcal/mol)", fontsize=14)
    plt.xlim(min(x_values) - 0.5, max(x_values) + 0.5)  
    plt.ylim(min(y_values) - 0.5, max(y_values) + 0.5)  
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(fontsize=12)

    plot_file = os.path.join(root_dir, f"{condition}_linear_fit.png")
    plt.savefig(f'***.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()
    print(f"Plot saved to {plot_file}, R^2: {r_squared:.2f}")