In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas
from dataset import KmClass
from copy import copy

df_train = pandas.read_csv("../data/csv/train_dataset_hxkm_complex_conditioned_bs.csv")
db_train = KmClass(df_train)
df_train = db_train.dataframe

df_test = pandas.read_csv("../data/csv/HXKm_dataset_final_new_conditioned_bs.csv")
db_test = KmClass(df_test)
df_test = db_test.dataframe

fig = go.Figure(
    data = [
        go.Histogram(
            x = df_train.loc[(df_train.km_value < 100) & (df_train.km_value > 0)].km_value,
            name="train"
        ),
        go.Histogram(
            x = df_test.loc[(df_test.km_value < 100) & (df_test.km_value > 0)].km_value,
            name="test"
        )
    ]
)
fig.update_yaxes(
    type="log",
)
fig.update_layout(
    title=dict(
        text=f"Km values for the BRENDA SABIO-RK (n={df_train.shape[0]:,}) and HXKm databases (n={df_test.shape[0]:,}).",
        x=0.5,
        font=dict(
            size=23,
        )
    ),
    xaxis_title="KM value (in mM)",
    yaxis_title="Count",
    font=dict(
        size=18
    ),
    width=900,
    height=500
)
fig.show()
fig.write_image("../figures/brenda_data_distribution.jpg", width=900, height=500)
fig1 = copy(fig)

Before drop: 19414
After drop: 18541
Loaded Km class. Size of the database: 18,541
Number of descriptor features when fitting: 196
Before drop: 431
After drop: 420
Loaded Km class. Size of the database: 420
Number of descriptor features when fitting: 196


In [4]:
df_test = pandas.read_csv("../data/hxkm.csv")
categories = [
    "Wild type",
    "Mutant"
]
#Protein
wild_type = df_test.loc[df_test.protein_type=="wildtype"].shape[0]
mutant = df_test.loc[df_test.protein_type=="mutant"].shape[0]

fig = go.Figure(data=[
    go.Bar(x=categories, y=[wild_type, mutant], marker_color=["green", "purple"]),
])
fig.update_layout(
    # barmode="stack",
    title=dict(
        text="HXKm database enzyme type.",
        x=0.5,
        font=dict(
            size=23,
        )
    ),
    yaxis_title="Count",
    font=dict(
        size=18
    ),
    width=700,
    height=500
)
fig.write_image("../figures/brenda_enzyme_categories.jpg", width=700, height=400)
fig.show()
fig2 = copy(fig)

In [5]:
df_train = pandas.read_csv("../data/brenda_sabio_processed.csv")
df_train = df_train.loc[df_train.protein_type == "WT"]
df_train["enzyme_class"] = df_train.enzyme_commission.apply(lambda x: x.split(".")[0])
df_grouped = df_train.groupby("enzyme_class")

ec_u_substrates = {}
for ec, group in df_grouped:
    not_group = df_train.iloc[~df_train.index.isin(group.index)]
    substrates = group.loc[~group.substrate.isin(not_group.substrate)].substrate.tolist()
    ec_u_substrates[ec] = {
        "unique_n": len(set(substrates)),
        "n": len(substrates)
    }
ec_classes = [f"EC_{k}" for k in ec_u_substrates.keys()]
ec_classes_n = [v["n"] for v in ec_u_substrates.values()]
ec_classes_u_n = [v["unique_n"] for v in ec_u_substrates.values()]
fig = go.Figure(data=[
    go.Bar(
        name="Unique substrates", x=ec_classes, y=ec_classes_u_n, 
        # text = [f"{n:,}" for n in ec_classes_u_n],
    ),
    go.Bar(
        name="Total substrates", x=ec_classes, y=ec_classes_n, 
        # text = [f"{n:,}" for n in ec_classes_n],
    ),
])
fig.update_layout(
    width=800, height=500, 
    title={
        "text": "Number of substrate distribution per Enzyme Class.",
        "font": {
            "size": 23
        },
    },
    uniformtext_minsize=18, uniformtext_mode="hide",
    yaxis={
        "title": {
            "text": "Count",
        },
        "title_font": {
            "size": 18
        },
    },
    xaxis={
        "title": {
            "text": "Enzyme Class",
        },
        "title_font": {
            "size": 18
        },
    },
    legend={
        "font": {
            "size": 18
        }
    }
)
fig.update_xaxes(tickfont=dict(size=16))
fig.update_yaxes(tickfont=dict(size=16))
fig.update_traces(textposition="inside", insidetextanchor="middle")
fig.write_image(
    "../figures/n_substrates_per_classes.jpg",
    width=1000, height=500
)
fig.show()
fig3 = copy(fig)

In [6]:
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=(
        # f'<span style="color:red">A</span>', #: Km values for train (n={df_train.shape[0]:,}) and test (n={df_test.shape[0]:,}).', 
        # '<span style="color:red">B</span>', # Train dataset substrate diversity.',
        # '<span style="color:red">C</span>', #HXKm dataset enzyme type.', 
    ),
    # specs=[[{"secondary_y": False}, {"secondary_y": False}],
    #        [{"secondary_y": False}, {"secondary_y": False}]]
    specs=[
        [{"colspan": 2}, None],# Second row: one subplot spanning both columns
        [{}, {}],  # First row: two regular subplots
    ],
    row_heights=[0.55,0.45],  
    vertical_spacing=0.15,
    horizontal_spacing=.15
)

# Add traces from individual figures
for trace in fig1.data:
    trace.legendgroup="group1"
    trace.legendgrouptitle = {"text": "Panel A"}
    fig.add_trace(trace, row=1, col=1)
    
for trace in fig2.data:
    trace.showlegend = False
    fig.add_trace(trace, row=2, col=2)
    
for trace in fig3.data:
    trace.legendgroup="group2"
    trace.legendgrouptitle = {"text": "Panel B"}
    fig.add_trace(trace, row=2, col=1)

fig.update_yaxes(type="log", row=1, col=1, title={"text": "count (log scale)"})
fig.update_yaxes(row=2, col=1, title={"text": "count"})
fig.update_xaxes(row=1, col=1, title={"text": "Km value (mM)"})
fig.update_xaxes(row=1, col=1, title={"text": "Km value (mM)"})
fig.update_yaxes(row=2, col=1, title={"text": "Substrate count"})
fig.update_yaxes(row=2, col=2, title={"text": "Enzyme count"})
fig.update_layout(
    font={"size": 19},
    legend={"font": {"size":14}},
    uniformtext_minsize=18, uniformtext_mode="hide",
    height=600, 
    width=1000, 
    # title_text="Multiple Subplots",
    margin=dict(t=0, b=0, l=100, r=30),  # Tighter margins
    # legend_tracegroupgap=260
)
fig.add_annotation(
    x=-0.13,
    y=1,
    showarrow=False,
    yref="paper",
    xref="paper",
    text="A",
    align="center",
    font={"weight": 800}
)
fig.add_annotation(
    x=-0.13,
    y=.38,
    showarrow=False,
    yref="paper",
    xref="paper",
    text="B",
    align="center",
    font={"weight": 800}
)
fig.add_annotation(
    x=0.5,
    y=0.38,
    showarrow=False,
    yref="paper",
    xref="paper",
    text="C",
    align="center",
    font={"weight": 800}
)
fig.update_annotations(font=dict(size=21))
for annotation in fig.layout.annotations:
    annotation.y += 0.01  # Increase by 0.02 (you can adjust this value)
# Add figure letters
fig.show()
fig.write_image("../figures/figure_2.jpg", width=1000, height=600)