In [3]:
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress
import os
from typing import List, Dict, Literal

# --- 1. Global Configurations ---
# Base path to the data files
# Ensure this path is correct; all CSV files are stored here
BASE_DATA_PATH = "../../Data/"

# Mapping from model names to their display colors
MODEL_COLOR_MAP = {
    'DLKcat': '#022061',
    'Boost-KM': '#2D5F89',
    'DLTKcat': '#BF0001',
    'DeepEnzyme': '#71B4E6',
    'UniKP': '#EB700C',
    'EITLEM-Kinetics': '#A2CCA4',
    'CataPro': '#DB97E4',
    'MolGen&ProtT5': '#022061'
}

# Matplotlib global font and style settings
plt.rcParams.update({
    'font.size': 7,
    'font.family': 'Arial',
    'mathtext.fontset': 'custom',
    'mathtext.rm': 'Arial',
    'mathtext.it': 'Arial:italic',
    'mathtext.bf': 'Arial:bold',
    'pdf.fonttype': 42,
    'ps.fonttype': 42
})

# --- 2. Core Calculation Function ---

def calculate_r2_in_bins(file_path: str, bin_col: str, bins: List[float]) -> tuple:
    """
    Read data from a CSV file, apply log10 transformation to 'Label' and 'P_Label',
    and compute R² and sample counts within specified bins.

    Args:
        file_path (str): Full path to the CSV data file.
        bin_col (str): Column used for binning ('identity' or 'similarity').
        bins (List[float]): Boundaries defining the bins (e.g., [0, 0.6, 0.8, 0.99, 1.0]).

    Returns:
        tuple: A tuple containing (r2_scores: List[float], counts: List[int]).
               R² values are NaN if computation is not possible.
    """
    try:
        df = pd.read_csv(file_path)
    except FileNotFoundError:
        print(f"Error: File not found: {file_path}")
        return [], []

    # Check for required columns
    required_cols = ['P_Label', 'Label', bin_col]
    if not all(col in df.columns for col in required_cols):
        print(f"Error: File {file_path} is missing required columns. "
              f"Required: {required_cols}, Found: {list(df.columns)}")
        return [], []

    r2_scores = []
    counts = []

    for i in range(len(bins) - 1):
        lower_bound = bins[i]
        upper_bound = bins[i + 1]

        # Use closed interval for the last bin
        if i == len(bins) - 2:
            subset = df[(df[bin_col] >= lower_bound) & (df[bin_col] <= upper_bound)].copy()
        else:
            subset = df[(df[bin_col] >= lower_bound) & (df[bin_col] < upper_bound)].copy()

        count = len(subset)
        counts.append(count)

        if count > 1:
            # Filter out rows where Label or P_Label <= 0 (log10 undefined)
            # Also remove NaN values
            subset = subset[['Label', 'P_Label']].dropna()
            subset = subset[(subset['Label'] > 0) & (subset['P_Label'] > 0)]

            if len(subset) > 1:
                log_label = np.log10(subset['Label'])
                log_p_label = np.log10(subset['P_Label'])

                # Perform linear regression on log-transformed data
                slope, intercept, r_value, p_value, std_err = linregress(log_label, log_p_label)
                r2_scores.append(r_value ** 2)
            else:
                r2_scores.append(np.nan)
        else:
            r2_scores.append(np.nan)

    return r2_scores, counts

# --- 3. Core Plotting Function ---

def plot_comparison_chart(
    model_configs: List[Dict],
    plot_type: Literal['identity', 'similarity'],
    title: str,
    figname: str,
    y_lim: tuple = (0.0, 0.9)
):
    """
    Plot a line chart comparing the performance of multiple models.

    Args:
        model_configs (List[Dict]): List of model configurations, each containing 'name', 'file'.
        plot_type (Literal['identity', 'similarity']): Determines the x-axis (sequence identity or substrate similarity).
        title (str): Chart title (typically in LaTeX format for kinetics constants).
        figname (str): Output filename for the saved figure.
        y_lim (tuple): Y-axis limits.
    """
    fig, ax = plt.subplots(figsize=(2, 2), dpi=600)

    # Define bin edges and labels
    bins = [0, 0.6, 0.8, 0.99, 1.0]
    bin_labels = ['0-60', '60-80', '80-99', '99-100']
    ticks = np.arange(len(bin_labels))

    # Vertical offset for count annotations to avoid overlap
    annotation_offsets = np.linspace(8, 8 + 5 * (len(model_configs) - 1), len(model_configs))

    for i, config in enumerate(model_configs):
        model_name = config['name']
        file_path = os.path.join(BASE_DATA_PATH, config['file'])

        r2_values, counts = calculate_r2_in_bins(
            file_path=file_path,
            bin_col=plot_type,
            bins=bins
        )

        if not r2_values:
            continue

        color = MODEL_COLOR_MAP.get(model_name, '#000000')  # Default to black if color not found

        # Plot line and markers
        ax.plot(ticks, r2_values, color=color, linewidth=1, label=model_name, zorder=2)
        ax.scatter(ticks, r2_values, color=color, marker="o", s=17, zorder=3)

        # Annotate sample counts above data points
        if model_name == 'UniKP' or len(model_configs) == 1:
            for j, count in enumerate(counts):
                if not np.isnan(r2_values[j]) and count > 0:
                    ax.annotate(
                        int(count),
                        (ticks[j], r2_values[j]),
                        textcoords="offset points",
                        xytext=(0, annotation_offsets[i]),
                        ha='center',
                        fontsize=6,
                        color=color,
                        zorder=4
                    )

    # Configure axis labels
    if plot_type == 'identity':
        ax.set_xlabel('Enzyme sequence identity (%)', fontsize=7)
    else:
        ax.set_xlabel('Substrate similarity (%)', fontsize=7)

    ax.set_ylabel('R²', fontsize=7)
    # ax.set_title(title, fontsize=7)

    ax.set_xticks(ticks)
    ax.set_xticklabels(bin_labels, fontsize=7)
    ax.tick_params(axis='x', length=0, rotation=0)  # Rotate labels to prevent overlap

    ax.set_ylim(y_lim)
    ax.set_xlim(-0.5, len(ticks) - 0.5)

    # Adjust label positions
    ax.yaxis.set_label_coords(-0.18, 0.5)
    ax.xaxis.set_label_coords(0.5, -0.18)

    # Legend is omitted for single-model plots
    # if len(model_configs) > 1:
    #     ax.legend(loc="lower right", fontsize=6, frameon=False)

    plt.tight_layout()
    plt.savefig(figname, dpi=600, bbox_inches='tight')
    print(f"Figure generated: {figname}")
    plt.close(fig)

# --- 4. Main Execution Logic ---
if __name__ == "__main__":

    # Model configuration: keys are internal identifiers, 'name' is the display name
    ALL_MODELS_CONFIG = {
        'UniKP_kcat': {'name': 'UniKP', 'file': 'UniKP_KCAT_prediction_new.csv'},
        'UniKP_km': {'name': 'UniKP', 'file': 'UniKP_KM_prediction_new.csv'},
        'UniKP_kkm': {'name': 'UniKP', 'file': 'UniKP_KKM_prediction_new.csv'},
        'EITLEM_kcat': {'name': 'EITLEM-Kinetics', 'file': 'EITLEM_KCAT_prediction_new.csv'},
        'EITLEM_km': {'name': 'EITLEM-Kinetics', 'file': 'EITLEM_KM_prediction_new.csv'},
        'EITLEM_kkm': {'name': 'EITLEM-Kinetics', 'file': 'EITLEM_KKM_prediction_new.csv'},
        'CataPro_kcat': {'name': 'CataPro', 'file': 'CataPro_KCAT_prediction_new.csv'},
        'CataPro_km': {'name': 'CataPro', 'file': 'CataPro_KM_prediction_new.csv'},
        'CataPro_kkm': {'name': 'CataPro', 'file': 'CataPro_KKM_prediction_new.csv'},
        'DLKcat': {'name': 'DLKcat', 'file': 'DLKcat_prediction_new.csv'},
        'DeepEnzyme': {'name': 'DeepEnzyme', 'file': 'DeepEnzyme_prediction_new.csv'},
        'Boost_KM': {'name': 'Boost-KM', 'file': 'Boost_KM_prediction_new.csv'},
        'DLTKcat': {'name': 'DLTKcat', 'file': 'DLTKcat_prediction_new.csv'},
        'TurNuP': {'name': 'TurNuP', 'file': 'TurNuP_prediction_new1.csv'},
        'molgen_prott5_kcat': {'name': 'MolGen&ProtT5', 'file': 'molgen_prott5_KCAT.csv'},
        'molgen_prott5_km': {'name': 'MolGen&ProtT5', 'file': 'molgen_prott5_KM.csv'},
        'molgen_prott5_kkm': {'name': 'MolGen&ProtT5', 'file': 'molgen_prott5_KKM.csv'}
    }

    # Define model groups for plotting
    KCAT_MODELS = ['UniKP_kcat', 'EITLEM_kcat', 'CataPro_kcat', 'DLKcat', 'DeepEnzyme']
    KM_MODELS = ['UniKP_km', 'EITLEM_km', 'CataPro_km', 'Boost_KM']
    KKM_MODELS = ['UniKP_kkm', 'EITLEM_kkm', 'CataPro_kkm']
    DLTKCAT_MODEL = ['DLTKcat']
    TurNuP_MODEL = ['TurNuP']
    MolGen_MODELS_KCAT = ['molgen_prott5_kcat', 'UniKP_kcat']
    MolGen_MODELS_KM = ['molgen_prott5_km', 'UniKP_km']
    MolGen_MODELS_KKM = ['molgen_prott5_kkm', 'UniKP_kkm']

    dltkcat_config = [ALL_MODELS_CONFIG[m] for m in DLTKCAT_MODEL]
    plot_comparison_chart(dltkcat_config, 'identity', r'Performance of DLTKcat', '../../Figure/figS2_DLTKcat_identity.pdf')
    plot_comparison_chart(dltkcat_config, 'similarity', r'Performance of DLTKcat', '../../Figure/figS2_DLTKcat_similarity.pdf')

Figure generated: ../../Figure/figS2_DLTKcat_identity.pdf
Figure generated: ../../Figure/figS2_DLTKcat_similarity.pdf


In [4]:
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress
import os
from typing import List, Dict, Literal
from matplotlib.lines import Line2D

# --- 1. (Global Configurations) ---

BASE_DATA_PATH = "../../Data/"

MODEL_COLOR_MAP = {
    'DLKcat': '#022061',
    'Boost-KM': '#2D5F89',
    'DLTKcat': '#BF0001',
    'DeepEnzyme': '#71B4E6',
    'UniKP': '#EB700C',
    'EITLEM-Kinetics': '#A2CCA4',
    'CataPro': '#DB97E4',
    'MolGen&ProtT5': '#022061'
}

plt.rcParams.update({
    'font.size': 7,
    'font.family': 'Arial',
    'mathtext.fontset': 'custom',
    'mathtext.rm': 'Arial',
    'mathtext.it': 'Arial:italic',
    'mathtext.bf': 'Arial:bold',
    'pdf.fonttype': 42,
    'ps.fonttype': 42
})

# --- 2. Core Calculation Function ---

def calculate_r2_in_bins(file_path: str, bin_col: str, bins: List[float]):
    try:
        df = pd.read_csv(file_path)
    except FileNotFoundError:
        print(f"Error: File not found {file_path}")
        return [], []

    required_cols = ['P_Label', 'Label', bin_col]
    if not all(col in df.columns for col in required_cols):
        print(f"Error: File {file_path} is missing required columns. Required {required_cols}, but only found {list(df.columns)}")
        return [], []

    r2_scores = []
    counts = []

    for i in range(len(bins) - 1):
        lower_bound = bins[i]
        upper_bound = bins[i+1]
        
        if i == len(bins) - 2:
            subset = df[(df[bin_col] >= lower_bound) & (df[bin_col] <= upper_bound)].copy()
        else:
            subset = df[(df[bin_col] >= lower_bound) & (df[bin_col] < upper_bound)].copy()
            
        count = len(subset)
        counts.append(count)

        if count > 1:
            subset = subset[['Label', 'P_Label']].dropna()
            subset = subset[(subset['Label'] > 0) & (subset['P_Label'] > 0)]
            
            if len(subset) > 1:
                log_label = np.log10(subset['Label'])
                log_p_label = np.log10(subset['P_Label'])
                slope, intercept, r_value, p_value, std_err = linregress(log_label, log_p_label)
                r2_scores.append(r_value**2)
            else:
                r2_scores.append(np.nan)
        else:
            r2_scores.append(np.nan)
            
    return r2_scores, counts

# --- 3. Core Plotting Function ---
from matplotlib.lines import Line2D

def plot_combined_figure(
    plot_type: Literal['identity', 'similarity'],
    kcat_configs: List[Dict],
    km_configs: List[Dict],
    kkm_configs: List[Dict],
    figname: str,
    y_lim: tuple = (0.0, 0.95)
):
    """
    Generates a combined plot with three subplots (kcat, km, kcat/km).
    Each subplot compares the performance of multiple models on a specified data split (identity or similarity).

    Args:
        plot_type (Literal['identity', 'similarity']): The type of plot, which determines the x-axis.
        kcat_configs (List[Dict]): A list of configurations for the k_cat models.
        km_configs (List[Dict]): A list of configurations for the K_m models.
        kkm_configs (List[Dict]): A list of configurations for the k_cat/K_m models.
        figname (str): The filename for the output image.
        y_lim (tuple): The range for the Y-axis.
    """
    # Create the figure and subplots
    fig, axes = plt.subplots(1, 3, figsize=(7, 2.5), dpi=600, sharey=True)

    # Define subplot information
    subplot_info = [
        {'title': r'$\mathbf{\mathit{k}_\mathrm{cat}}$', 'configs': kcat_configs},
        {'title': r'$\mathbf{\mathit{K}_\mathrm{m}}$', 'configs': km_configs},
        {'title': r'$\mathbf{\mathit{k}_\mathrm{cat}/\mathit{K}_\mathrm{m}}$', 'configs': kkm_configs}
    ]

    # Define X-axis bins and labels
    bins = [0, 0.6, 0.8, 0.99, 1.0]
    bin_labels = ['0-60', '60-80', '80-99', '99-100']
    ticks = np.arange(len(bin_labels))
    
    # To store legend handles and avoid duplicates
    legend_handles = {}

    # Loop through each subplot
    for i, info in enumerate(subplot_info):
        ax = axes[i]
        
        # Loop through each model in the subplot
        for config in info['configs']:
            model_name = config['name']
            file_path = os.path.join(BASE_DATA_PATH, config['file'])
            
            # --- MODIFICATION 1: Capture the 'counts' variable ---
            r2_values, counts = calculate_r2_in_bins(
                file_path=file_path,
                bin_col=plot_type,
                bins=bins
            )

            if not r2_values:
                continue

            color = MODEL_COLOR_MAP.get(model_name, '#000000')
            
            # Plot lines and data points
            ax.plot(ticks, r2_values, color=color, linewidth=1, label=model_name, zorder=2)
            ax.scatter(ticks, r2_values, color=color, marker="o", s=17, zorder=3)
            
            # --- MODIFICATION 2: Add conditional annotation for 'UniKP' ---
            if model_name == 'UniKP':
                for j, count in enumerate(counts):
                    if not np.isnan(r2_values[j]) and count > 0:
                        ax.annotate(
                            int(count),
                            (ticks[j], r2_values[j]),
                            textcoords="offset points",
                            xytext=(0, 8),  # A fixed vertical offset
                            ha='center',
                            fontsize=6,
                            color='#000000', # Use the color of the UniKP line
                            zorder=4
                        )
            # --- END OF MODIFICATION ---

            # Create a legend handle if it doesn't exist yet
            if model_name not in legend_handles:
                 legend_handles[model_name] = Line2D([0], [0], color=color, lw=1.5, label=model_name)

        # --- Configure subplot styling ---
        ax.set_title(info['title'], fontsize=9)
        ax.set_xticks(ticks)
        ax.set_xticklabels(bin_labels, fontsize=7, rotation=0)
        ax.tick_params(axis='x', length=0)

        # Set styles
        ax.spines['top'].set_linewidth(0.5)
        ax.spines['bottom'].set_linewidth(0.5)
        ax.spines['left'].set_linewidth(0.5)
        ax.spines['right'].set_linewidth(0.5)
        ax.tick_params(axis='both', direction='out', width=0.5, which='both', length=2, pad=2)
        ax.tick_params(axis='y', labelsize=7)

        ax.set_ylim(y_lim)
        ax.set_xlim(-0.5, len(ticks) - 0.5)

    # --- Configure shared labels for the entire figure ---
    axes[0].set_ylabel('R²', fontsize=7)
    
    if plot_type == 'identity':
        xlabel = 'Enzyme sequence identity (%)'
    else:
        xlabel = 'Substrate similarity (%)'
    # Place the x-axis label under the middle subplot
    axes[1].set_xlabel(xlabel, fontsize=7)

    # --- Create and place the shared legend ---
    # Sort the legend according to the order in MODEL_COLOR_MAP for consistency
    all_appearing_models = [name for name in MODEL_COLOR_MAP if name in legend_handles]
    # You might want to add a legend back if you need it. Example:
    # sorted_handles = [legend_handles[name] for name in all_appearing_models]
    # fig.legend(handles=sorted_handles, loc='upper center', bbox_to_anchor=(0.5, 0.98), ncol=len(sorted_handles), frameon=False, fontsize=7)
       
    # --- Final adjustments and saving ---
    # Use tight_layout and provide a rectangle to make space for the legend and title
    fig.tight_layout(rect=[0, 0.1, 1, 0.95]) # rect=[left, bottom, right, top]

    # Saving with bbox_inches='tight' can sometimes interfere with tight_layout,
    # but it's often useful. Keep as is.
    plt.savefig(figname, dpi=600, bbox_inches='tight')
    print(f"Combined plot generated: {figname}")
    plt.close(fig)


# --- 4. Main Execution Logic - REVISED ---
if __name__ == "__main__":

    # Define configurations for all models
    ALL_MODELS_CONFIG = {
        'UniKP_kcat': {'name': 'UniKP', 'file': 'UniKP_KCAT_prediction_new.csv'},
        'UniKP_km': {'name': 'UniKP', 'file': 'UniKP_KM_prediction_new.csv'},
        'UniKP_kkm': {'name': 'UniKP', 'file': 'UniKP_KKM_prediction_new.csv'},
        'EITLEM_kcat': {'name': 'EITLEM-Kinetics', 'file': 'EITLEM_KCAT_prediction_new.csv'},
        'EITLEM_km': {'name': 'EITLEM-Kinetics', 'file': 'EITLEM_KM_prediction_new.csv'},
        'EITLEM_kkm': {'name': 'EITLEM-Kinetics', 'file': 'EITLEM_KKM_prediction_new.csv'},
        'CataPro_kcat': {'name': 'CataPro', 'file': 'CataPro_KCAT_prediction_new.csv'},
        'CataPro_km': {'name': 'CataPro', 'file': 'CataPro_KM_prediction_new.csv'},
        'CataPro_kkm': {'name': 'CataPro', 'file': 'CataPro_KKM_prediction_new.csv'},
        'DLKcat': {'name': 'DLKcat', 'file': 'DLKcat_prediction_new.csv'},
        'DeepEnzyme': {'name': 'DeepEnzyme', 'file': 'DeepEnzyme_prediction_new.csv'},
        'Boost_KM': {'name': 'Boost-KM', 'file': 'Boost_KM_prediction_new.csv'},
        'DLTKcat': {'name': 'DLTKcat', 'file': 'DLTKcat_prediction_new.csv'},
        'TurNuP': {'name': 'TurNuP', 'file': 'TurNuP_prediction_new1.csv'},
        'molgen_prott5_kcat': {'name': 'MolGen&ProtT5', 'file': 'molgen_prott5_KCAT.csv'},
        'molgen_prott5_km': {'name': 'MolGen&ProtT5', 'file': 'molgen_prott5_KM.csv'},
        'molgen_prott5_kkm': {'name': 'MolGen&ProtT5', 'file': 'molgen_prott5_KKM.csv'}
    }

    # Define model groups
    KCAT_MODELS = ['UniKP_kcat', 'EITLEM_kcat', 'CataPro_kcat', 'DLKcat', 'DeepEnzyme']
    KM_MODELS = ['UniKP_km', 'EITLEM_km', 'CataPro_km', 'Boost_KM']
    KKM_MODELS = ['UniKP_kkm', 'EITLEM_kkm', 'CataPro_kkm']

    # Extract configurations for each group
    kcat_configs = [ALL_MODELS_CONFIG[m] for m in KCAT_MODELS]
    km_configs = [ALL_MODELS_CONFIG[m] for m in KM_MODELS]
    kkm_configs = [ALL_MODELS_CONFIG[m] for m in KKM_MODELS]

    # --- Generate Identity Combined Plot ---
    plot_combined_figure(
        plot_type='identity',
        kcat_configs=kcat_configs,
        km_configs=km_configs,
        kkm_configs=kkm_configs,
        figname='../../Figure/fig2a_combined_identity.pdf'
    )

    # --- Generate Similarity Combined Plot ---
    plot_combined_figure(
        plot_type='similarity',
        kcat_configs=kcat_configs,
        km_configs=km_configs,
        kkm_configs=kkm_configs,
        figname='../../Figure/fig2b_combined_similarity.pdf'
    )

  X -= avg[:, None]


Combined plot generated: ../../Figure/fig2a_combined_identity.pdf


  X -= avg[:, None]


Combined plot generated: ../../Figure/fig2b_combined_similarity.pdf
