In [None]:
import sys

sys.path.append("../../")
from syntemp.SynUtils.utils import load_from_pickle

## 1. Split data

In [None]:
# import pandas as pd
# from syntemp.SynUtils.utils import train_val_test_split_df, save_database

# original_data = load_database("../../Data/Temp/data_aam.json.gz")
# original_data = pd.DataFrame(original_data)

# train, test, valid = train_val_test_split_df(original_data, target="class")
# train, test, valid = (
#     train.to_dict("records"),
#     test.to_dict("records"),
#     valid.to_dict("records"),
# )

# save_database(train, "../../Data/Temp/Benchmark/train.json.gz")
# save_database(test, "../../Data/Temp/Benchmark/test.json.gz")
# save_database(valid, "../../Data/Temp/Benchmark/valid.json.gz")

# Count number of templates

In [None]:
raw = load_from_pickle("../../Data/Temp/Benchmark/Raw/templates.pkl.gz")
complete = load_from_pickle("../../Data/Temp/Benchmark/Complete/templates.pkl.gz")


def calculate(data):
    number = []
    for i in range(len(data)):
        number.append(len(data[i]))
    return number


raw_result = calculate(raw)
complete_result = calculate(complete)

print(raw_result)
print(complete_result)

# Analyze descriptors

In [None]:
complete = load_from_pickle("../../Data/Temp/Benchmark/Complete/templates.pkl.gz")
data_cluster = load_from_pickle(
    "../../Data/Temp/Benchmark/Complete/data_cluster.pkl.gz"
)
temp_0 = complete[0]

## Template percentage and DPO rule

In [None]:
import matplotlib.pyplot as plt

plt.rc("text", usetex=True)
plt.rc("text.latex", preamble=r"\usepackage{amsmath}")


from _analysis._plot_analysis import plot_top_rules_with_seaborn

fig, ax = plt.subplots(figsize=(16, 10))  # Correctly create a figure and an axes object

plot_top_rules_with_seaborn(temp_0, top_n=20, ax=ax)  # Use the ax object correctly
plt.tight_layout(pad=0.5)

plt.savefig(
    "../../Docs/Analysis/fig/FigS2A_rule_distribution.pdf",
    dpi=600,
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()

## Descriptors Analysis

In [None]:
from _analysis._plot_analysis import calculate_value_percentage

In [None]:
print(calculate_value_percentage(temp_0, "Reaction Type"))
print(calculate_value_percentage(data_cluster, "Reaction Type"))

In [None]:
print(calculate_value_percentage(temp_0, "Topo Type"))
print(calculate_value_percentage(data_cluster, "Topo Type"))

In [None]:
print(calculate_value_percentage(temp_0, "Reaction Step"))
print(calculate_value_percentage(data_cluster, "Reaction Step"))

In [None]:
acyl = [value for value in temp_0 if value["Topo Type"] == "Acyclic"]
single = [value for value in temp_0 if value["Topo Type"] == "Single Cyclic"]
combo = [value for value in temp_0 if value["Topo Type"] == "Combinatorial Cyclic"]
comp = [value for value in temp_0 if value["Topo Type"] == "Complex Cyclic"]
print("Acyclic", calculate_value_percentage(acyl, "Rings"))
print("Single Cyclic", calculate_value_percentage(single, "Rings"))
print("Combinatorial Cyclic", calculate_value_percentage(combo, "Rings"))
print("Complex Cyclic", calculate_value_percentage(comp, "Rings"))

In [None]:
acyl = [value for value in data_cluster if value["Topo Type"] == "Acyclic"]
single = [value for value in data_cluster if value["Topo Type"] == "Single Cyclic"]
combo = [
    value for value in data_cluster if value["Topo Type"] == "Combinatorial Cyclic"
]
comp = [value for value in data_cluster if value["Topo Type"] == "Complex Cyclic"]
print("Acyclic", calculate_value_percentage(acyl, "Rings"))
print("Single Cyclic", calculate_value_percentage(single, "Rings"))
print("Combinatorial Cyclic", calculate_value_percentage(combo, "Rings"))
print("Complex Cyclic", calculate_value_percentage(comp, "Rings"))

## Descriptors Visualization

### Pie chart

In [None]:
for key, value in enumerate(temp_0):
    if value["Topo Type"] == "Acyclic":
        temp_0[key]["Topo Type"] = "Acyclic Graph"
    elif value["Topo Type"] == "Complex":
        temp_0[key]["Topo Type"] = "Hybrid Graph"

In [None]:
for key, value in enumerate(data_cluster):
    if value["Topo Type"] == "Acyclic":
        data_cluster[key]["Topo Type"] = "Acyclic Graph"
    elif value["Topo Type"] == "Complex":
        data_cluster[key]["Topo Type"] = "Hybrid Graph"

In [None]:
# from _analysis._plot_analysis import create_pie_chart
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


def create_pie_chart(data, column, ax=None, title=None, color_pallet="pastel"):
    """
    Generates a pie chart for the specified column from a list of dictionaries.
    Displays percentage labels inside the slices only and category names in an external
    legend without percentages. Allows customization of the plot title, supporting LaTeX
    formatted strings.

    Parameters:
    - data (list of dict): Data to plot.
    - column (str): Column name to plot percentages for.
    - ax (matplotlib.axes.Axes, optional): Matplotlib axis object to plot on.
    - title (str, optional): Title for the pie chart, supports LaTeX formatted strings.

    Returns:
    - matplotlib.axes.Axes: The axis with the pie chart.
    """
    # Enable LaTeX formatting for better quality text rendering
    plt.rc("text", usetex=True)
    plt.rc("font", family="serif")

    # Convert list of dictionaries to DataFrame
    df = pd.DataFrame(data)

    # Calculate percentage
    percentage = df[column].value_counts(normalize=True) * 100

    # Define a color palette using Seaborn
    colors = sns.color_palette(color_pallet, len(percentage))

    # Create pie plot
    if ax is None:
        fig, ax = plt.subplots()

    wedges, texts, autotexts = ax.pie(
        percentage,
        startangle=90,
        colors=colors,
        autopct="%1.1f%%",
        pctdistance=0.85,
        explode=[0.05] * len(percentage),
    )

    # Draw a circle at the center of pie to make it look like a donut
    centre_circle = plt.Circle((0, 0), 0.70, fc="white")
    ax.add_artist(centre_circle)

    # Equal aspect ratio ensures that pie is drawn as a circle.
    ax.axis("equal")

    # Add legend with category names only
    ax.legend(
        wedges,
        [rf"{label}" for label in percentage.index],
        title=column,
        loc="lower right",
        bbox_to_anchor=(0.6, 0.1, 0.68, 1),
        prop={"size": 20},
        title_fontsize=24,
    )  # Set label font size

    # Set title using LaTeX if provided, else default to a generic title
    if title:
        ax.set_title(title, fontsize=32)
    else:
        ax.set_title(f"Pie Chart of {column}", fontsize=32)

    # Enhance the font size and color of the autotexts
    for autotext in autotexts:
        autotext.set_color("black")
        autotext.set_fontsize(20)

    return ax


fig, axs = plt.subplots(2, 2, figsize=(18, 10))
create_pie_chart(
    temp_0,
    "Reaction Type",
    ax=axs[0, 1],
    title=r"B. Template library",
    color_pallet="pastel",
)

create_pie_chart(
    data_cluster,
    "Reaction Type",
    ax=axs[0, 0],
    title=r"A. Database",
    color_pallet="pastel",
)

create_pie_chart(
    temp_0,
    "Topo Type",
    ax=axs[1, 1],
    title=r"D. Template library",
    color_pallet="coolwarm",
)

create_pie_chart(
    data_cluster,
    "Topo Type",
    ax=axs[1, 0],
    title=r"C. Database",
    color_pallet="coolwarm",
)


plt.tight_layout()
plt.savefig(
    "../../Docs/Analysis/fig/Fig8_Analysis_rtype_topo.pdf",
    dpi=600,
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()

### Distribution

In [None]:
from _analysis._plot_analysis import count_column_values

In [None]:
element = [value for value in temp_0 if value["Reaction Type"] in ["Elementary"]]
complex = [value for value in temp_0 if value["Reaction Type"] in ["Complicated"]]
element_count = count_column_values(element, "Rings")
complex_count = count_column_values(complex, "Rings")

In [None]:
element_all = [
    value for value in data_cluster if value["Reaction Type"] in ["Elementary"]
]
complex_all = [
    value for value in data_cluster if value["Reaction Type"] in ["Complicated"]
]
element_count_all = count_column_values(element_all, "Rings")
complex_count_all = count_column_values(complex_all, "Rings")

In [None]:
from typing import *
from matplotlib.axes import Axes


def plot_rules_distribution(
    rules: Dict[str, int],
    rule_type: str = "single",
    ax: Optional[Axes] = None,
    title: Optional[str] = None,
    refinement: bool = False,
    threshold: float = 1,
    remove: bool = True,
    color_pallet: str = "pastel",
) -> None:
    """
    Plots the distribution of rules in a bar chart, optionally combining all entries under
    the threshold into a single category 'Under 1%' if `refinement` is True.

    Parameters:
    - rules (Dict[str, int]): Dictionary with rule counts keyed by rule name,
    where the values are counts.
    - rule_type (str, optional): Specifies the type of rules to plot
    ('single' or 'complex'). Default is 'single'.
    - ax (matplotlib.axes.Axes, optional): Matplotlib axis object to plot on.
    If None, a new figure is created.
    - title (str, optional): Optional title for the chart. If None,
    a default title based on `rule_type` is used.
    - refinement (bool, optional): If True, combines all percentages under
    the threshold into one category 'Under 1%'. Default is False.
    - threshold (float, optional): The percentage threshold under which all
    categories are combined into 'Under 1%' if `refinement` is True. Default is 1.
    - remove (bool, optional): If True, removes the last category from the plot.
    Default is True.
    - color_pallet (str, optional): Color palette to use for the plot.
    Default is 'pastel'.

    Returns:
    - None: The function directly modifies the `ax` object or creates a new plot.
    """
    # Calculate total counts for the rules
    total_rules = sum(rules.values())

    # Convert counts to percentages and optionally combine small values
    if refinement:
        refined_rules = {}
        small_value_aggregate = 0
        for key, value in rules.items():
            percentage = value / total_rules * 100
            if percentage < threshold:
                small_value_aggregate += percentage
            else:
                refined_rules[key] = percentage
        if small_value_aggregate > 0:
            refined_rules["Under 1%"] = small_value_aggregate
        percentages = list(refined_rules.values())
        types_of_rules = list(refined_rules.keys())
        if remove:
            percentages = percentages[:-1]
            types_of_rules = types_of_rules[:-1]
    else:
        percentages = [value / total_rules * 100 for value in rules.values()]
        types_of_rules = list(rules.keys())

    # Set style
    sns.set(style="whitegrid")

    # Enable LaTeX rendering in matplotlib
    plt.rc("text", usetex=True)
    plt.rc("text.latex", preamble=r"\usepackage{amsmath}")  # Ensure amsmath is loaded

    # Create figure and axis if not provided
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6), dpi=120)

    # Plot the data
    sns.barplot(ax=ax, x=types_of_rules, y=percentages, palette=color_pallet)
    if title:
        ax.set_title(rf"{title}", fontsize=24)
    else:
        ax.set_title(f"Distribution of {rule_type.capitalize()} Rules", fontsize=16)
    ax.set_xlabel("Cycle length", fontsize=18)
    ax.set_ylabel(r"Percentage (\%)", fontsize=18)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

    # Set font size for x-tick and y-tick labels
    ax.tick_params(axis="x", labelsize=16)
    ax.tick_params(axis="y", labelsize=16)

    # Add text labels above the bars
    for index, value in enumerate(percentages):
        ax.text(
            index, value + 0.5, f"{value:.1f}%", ha="center", va="bottom", fontsize=18
        )

    # Only show plot if ax is not provided (i.e., we created the figure here)
    if ax is None:
        plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="whitegrid")


# Enable LaTeX rendering in matplotlib
plt.rc("text", usetex=True)
plt.rc("text.latex", preamble=r"\usepackage{amsmath}")  # Ensure amsmath is loaded
fig, axs = plt.subplots(2, 2, figsize=(16, 12))


plot_rules_distribution(
    element_count,
    ax=axs[0, 1],
    title=r"B. Elementary reactions in template library",
    color_pallet="pastel",
)
plot_rules_distribution(
    element_count_all,
    ax=axs[0, 0],
    title=r"A. Elementary reactions in Database",
    color_pallet="pastel",
)
plot_rules_distribution(
    complex_count,
    ax=axs[1, 1],
    title=r"D. Complicated reactions in template library",
    refinement=True,
    color_pallet="coolwarm",
)
plot_rules_distribution(
    complex_count_all,
    ax=axs[1, 0],
    title=r"C. Complicated reactions in database",
    refinement=True,
    color_pallet="coolwarm",
    threshold=0.3,
)


plt.tight_layout()
plt.savefig(
    "../../Docs/Analysis/fig/Fig9_rings_type.pdf",
    dpi=600,
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()

## Time Benchmark for Rule Clustering

In [None]:
data = [
    {
        "Type": "Hierarchical",
        "R0": f"{58.9 / 60:.2f}",
        "R1": f"{47.21 / 60:.2f}",
        "R2": f"{90.92 / 60:.2f}",
        "R3": f"{101.66 / 60:.2f}",
    },
    {
        "Type": "Empirical",
        "R0": f"{57.82 / 60:.2f}",
        "R1": f"{275.02 / 60:.2f}",
        "R2": f"{1807.58 / 60:.2f}",
        "R3": f"{5675.54 / 60:.2f}",
    },
]

In [None]:
import matplotlib.pyplot as plt

plt.rc("text", usetex=True)  # Enable LaTeX rendering
plt.rc("font", family="serif")  # Optional: use serif font
from _analysis._plot_analysis import plot_bar_compare, plot_cumulative_line

fig, axes = plt.subplots(1, 2, figsize=(14, 6))
radius = [0, 1, 2, 3]
hier = [58.9, 47.21, 90.92, 101.66]
hier = [round(value / 60, 2) for value in hier]
emp = [57.82, 275.02, 1807.58, 5675.54]
emp = [round(value / 60, 2) for value in emp]
# Plot on the first subplot for demonstration
plot_bar_compare(data, axes[0])
plot_cumulative_line(axes[1], radius, hier, emp)

# Adjust layout and show the plot
plt.tight_layout()
plt.savefig(
    "../../Docs/Analysis/fig/FigA3_time_cluster.pdf",
    dpi=600,
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()

## Rule Composition

In [None]:
import sys

sys.path.append("../../")
from syntemp.pipeline import extract_its, rule_extract, write_gml

In [None]:
data = [
    {"R-id": "Alkyne-reduction", "rsmi": "[CH:1]#[CH:2].[H:3][H:4]>>[CH2:1]=[CH2:2]"},
    {"R-id": "Alkene-reduction", "rsmi": "[CH2:1]=[CH2:2].[H:3][H:4]>>[CH3:1]-[CH3:2]"},
]

In [None]:
its_correct, its_incorrect, all_uncertain_hydrogen = extract_its(
    data, mapper_types=["rsmi"], n_jobs=1
)

reaction_dicts, templates, hier_templates = rule_extract(
    its_correct,
)


gml_rules = write_gml(templates, None, "Cluster_id", "RC", True)

In [None]:
from syntemp.SynComp.rule_compose import RuleCompose
from mod import *

rule_0 = ruleGMLString(gml_rules[0][0])
rule_1 = ruleGMLString(gml_rules[0][1])

combo = RuleCompose._compose(rule_0, rule_1)

In [None]:
print(combo[0].getGMLString())

In [None]:
from syntemp.SynUtils.utils import load_from_pickle

complete = load_from_pickle("../../Data/Temp/Benchmark/Complete/templates.pkl.gz")
data_cluster = load_from_pickle(
    "../../Data/Temp/Benchmark/Complete/data_cluster.pkl.gz"
)

In [None]:
temp_0 = complete[0]

In [None]:
single = [value for value in temp_0 if value["Reaction Step"] == 1]
double = [value for value in temp_0 if value["Reaction Step"] == 2]
triple = [value for value in temp_0 if value["Reaction Step"] == 3]

In [None]:
len(single)

In [None]:
single_path = "../../Data/Temp/RuleComp/Single"
double_path = "../../Data/Temp/RuleComp/Double"

In [None]:
from syntemp.pipeline import write_gml

In [None]:
write_gml([double], double_path)

In [None]:
write_gml([single], single_path)

In [None]:
import glob

compose = []
for i in glob.glob("../../Data/Temp/RuleComp/Compose/*gml"):
    compose.append(i)

In [None]:
len(compose)