In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from dataset import KmClass

# Create figure with subplots
fig = plt.figure(figsize=(12, 8))
gs = fig.add_gridspec(2, 2, height_ratios=[0.6, 0.4], width_ratios=[0.3, 0.7], hspace=0.3, wspace=0.25)

# Load data for first subplot
df_train = pd.read_csv("../data/csv/train_dataset_hxkm_complex_conditioned_bs.csv")
db_train = KmClass(df_train)
df_train = db_train.dataframe

df_test = pd.read_csv("../data/csv/HXKm_dataset_final_new_conditioned_bs.csv")
db_test = KmClass(df_test)
df_test = db_test.dataframe

# First subplot (spanning both columns) - A
ax1 = fig.add_subplot(gs[0, :])
train_km = df_train.loc[(df_train.km_value < 100) & (df_train.km_value > 0)].km_value
test_km = df_test.loc[(df_test.km_value < 100) & (df_test.km_value > 0)].km_value

# Create histogram with log scale
ax1.hist(train_km, bins=50, alpha=0.7, label=f'Train (n={df_train.shape[0]:,})', 
         color='blue', edgecolor='black', linewidth=0.5)
ax1.hist(test_km, bins=50, alpha=0.7, label=f'Test (n={df_test.shape[0]:,})', 
         color='orange', edgecolor='black', linewidth=0.5)

ax1.set_yscale('log')
ax1.set_xlabel('KM value (in mM)', fontsize=18)
ax1.set_ylabel('Count (log scale)', fontsize=18)
ax1.legend(fontsize=14, loc='upper right')
ax1.grid(True, alpha=0.3)
ax1.tick_params(axis='both', which='major', labelsize=16)

# Add annotation A
ax1.text(-0.06, 0.90, 'A', transform=ax1.transAxes, fontsize=24, 
         fontweight='bold', va='bottom', ha='right')

# Second subplot - B (bottom left)
ax2 = fig.add_subplot(gs[1, 0])

# Load data for second subplot
df_test2 = pd.read_csv("../data/hxkm.csv")
categories = ["Wild type", "Mutant"]
wild_type = df_test2.loc[df_test2.protein_type == "wildtype"].shape[0]
mutant = df_test2.loc[df_test2.protein_type == "mutant"].shape[0]

bars = ax2.bar(categories, [wild_type, mutant], color=["green", "purple"], 
               edgecolor='black', linewidth=1)
ax2.set_ylabel('Enzyme Count', fontsize=18)
ax2.tick_params(axis='both', which='major', labelsize=16)
ax2.grid(True, alpha=0.3, axis='y')


# Add annotation B
ax2.text(-0.03, 1.02, 'B', transform=ax2.transAxes, fontsize=24, 
         fontweight='bold', va='bottom', ha='right')

# Third subplot - C (bottom right)
ax3 = fig.add_subplot(gs[1, 1])

# Load data for third subplot
df_train3 = pd.read_csv("../data/brenda_sabio_processed.csv")
df_train3 = df_train3.loc[df_train3.protein_type == "WT"]
df_train3["enzyme_class"] = df_train3.enzyme_commission.apply(lambda x: x.split(".")[0])
df_grouped = df_train3.groupby("enzyme_class")

ec_u_substrates = {}
for ec, group in df_grouped:
    not_group = df_train3.iloc[~df_train3.index.isin(group.index)]
    substrates = group.loc[~group.substrate.isin(not_group.substrate)].substrate.tolist()
    ec_u_substrates[ec] = {
        "unique_n": len(set(substrates)),
        "n": len(substrates)
    }

ec_classes = [f"EC_{k}" for k in ec_u_substrates.keys()]
ec_classes_n = [v["n"] for v in ec_u_substrates.values()]
ec_classes_u_n = [v["unique_n"] for v in ec_u_substrates.values()]

# Create grouped bar chart
x = np.arange(len(ec_classes))
width = 0.35

bars1 = ax3.bar(x - width/2, ec_classes_u_n, width, label='Unique substrates', 
                color='lightblue', edgecolor='black', linewidth=1)
bars2 = ax3.bar(x + width/2, ec_classes_n, width, label='Total substrates', 
                color='lightcoral', edgecolor='black', linewidth=1)

ax3.set_ylabel('Substrate count', fontsize=18)
ax3.set_xticks(x)
ax3.set_xticklabels(ec_classes, fontsize=16, rotation=45, ha='right')
ax3.tick_params(axis='y', labelsize=16)
ax3.legend(fontsize=14, loc='upper right')
ax3.grid(True, alpha=0.3, axis='y')

# Add annotation C
ax3.text(-0.05, 1.02, 'C', transform=ax3.transAxes, fontsize=24, 
         fontweight='bold', va='bottom', ha='right')

# Adjust layout and save
plt.tight_layout()
plt.savefig("../figures/figure_2.jpg", dpi=600, bbox_inches='tight')
plt.savefig("../figures/figure_2.tiff", dpi=600, bbox_inches='tight')
plt.show()