In [None]:
'''This file is to plot the convergence when increasing the number of replicates incrementally'''
import os
import numpy as np
import matplotlib.pyplot as plt
import re
plt.rcParams['font.family'] = 'Times New Roman'
root_dir = "/Dircetory"
forcefields = ["mmpbsa_data_gaff","mmpbsa_data_gaff2"]
molecules = ["4ntj_azj", "4ntj_psb", "4ntj_tiq", "4pxz", "4py0"]
suffixes = ["_muta", ""]
methods = ["STM","MTM"]

def extract_delta_total(filepath):
    with open(filepath, 'r') as f:
        for line in f:
            if "DELTA TOTAL" in line:
                return float(line.split()[2]) 
    return None

results = {}

for forcefield in forcefields:
    for molecule in molecules:
        for suffix in suffixes:
            group = f"{molecule}{suffix}"
            for method in methods:
                data_dir = os.path.join(root_dir, forcefield, group, method)

                if not os.path.exists(data_dir):
                    print(f"Warning: {data_dir} does not exist.")
                    continue

                delta_totals = []
                filenames = [filename for filename in os.listdir(data_dir) if re.match(r".*[a-zA-Z](\d+)\.dat$", filename)]
                filenames.sort(key=lambda x: int(re.match(r".*[a-zA-Z](\d+)\.dat$", x).group(1)))
                for filename in filenames:
                    num = int(re.match(r".*[a-zA-Z](\d+)\.dat$", filename).group(1))
                    if 1 <= num <= 20:  
                        filepath = os.path.join(data_dir, filename)
                        delta = extract_delta_total(filepath)
                        if delta is not None:
                            delta_totals.append(delta)

                avg_deltas = []
                for i in range(1, len(delta_totals) + 1):
                    group_size = i * 2
                    if group_size <= len(delta_totals):
                        selected_group = delta_totals[:group_size]
                if avg_deltas: 
                    key = f"{group}"
                    results[key] = avg_deltas

plt.figure(figsize=(10, 6))
color = ["#e60049", "#0bb4ff", "#50e991", "#e6d800", "#9b19f5", "#ffa300", "#dc0ab4", "#b3d4ff", "#00bfa0", "#1c314e"]
'''This step is to swap the results of the mutated and non-mutated systems, since the wild-type P2Y12R is actually the mutated one in simulations'''
swapped_results = {}
for key, avg_deltas in results.items():
    if key.endswith("_muta"):
        base_key = key.replace("_muta", "")
        if base_key in results:
            swapped_results[base_key] = avg_deltas
            swapped_results[key] = results[base_key]
        else:
            swapped_results[key] = avg_deltas
results = swapped_results

for key, avg_deltas in results.items():
    print(f'{key}: {avg_deltas[-1] - avg_deltas[-2]}')

for idx, (key, avg_deltas) in enumerate(results.items()):
    if key == "4ntj_azj":
        key = "4ntj_azd"
    if key == "4ntj_azj_muta":
        key = "4ntj_azd_muta"
    key = key.upper()
    plt.plot(range(2, 2 * (len(avg_deltas) + 1), 2), avg_deltas, label=key, linewidth=3.5, color=color[idx % len(color)])
x_ticks = list(range(2, 21, 2))
y_ticks = np.arange(-120, 1, 20)
plt.xticks(x_ticks, fontsize=14) 
plt.yticks(y_ticks, fontsize=14)  

plt.xlabel('Number of Replicates', fontsize=16)  
plt.ylabel('Average ∆G (kcal/mol)', fontsize=16)  
plt.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=12)
plt.tight_layout()
plt.savefig('***.png', dpi=300, bbox_inches='tight')
plt.show()