In [None]:
import glob
import pandas
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pickle
from scipy import stats
from dataset import KmClass
from torcheval.metrics.functional import r2_score
from torchmetrics import PearsonCorrCoef
from copy import copy
from explore_repo import plot_ablation, learning_curves
from explore_repo import plot_all_ablation, results_plot, plot_gating_weights


In [None]:
experiment_files = [
    "../data/csv/inferences_full_features.csv",
    "../data/csv/inferences_no_gates.csv",
    "../data/csv/inferences_conditioned_bs_no_gates.csv",
    "../data/csv/inferences_no_seqid.csv",
    "../data/csv/inferences_conditioned_bs_no_seqid.csv",
]
experiment_conditions = [
    "All features",
    "No gates",
    "All features on no gates",
    "No positional encoding",
    "All features on no positional"
]
plot_all_ablation(experiment_files, experiment_conditions)

# plot gating weigths:
plot_gating_weights()

# plot learning curves:
experiments = [
    "conditioned_bs_full_features", 
    "descriptor_free_full_features", 
    "unconditioned_bs_full_features", 
    "bs_free_full_features", 
    "fingerprint_free_full_features",
    "aa_id_free_full_features",
    "protein_free_full_features",
    "molecule_free_full_features",
]
experiment_titles = [
    "C AS",
    "D free",
    "U AS",
    "AS free",
    "FP free",
    "PE free",
    "P free",
    "M free"
]
big_fig = learning_curves(experiments, experiment_titles)
big_fig.update_layout(
    margin=dict(t=0,b=0,l=0,r=0),
    font=dict(size=26),
)
big_fig.write_image("../figures/supp_fig_2.jpg", height=2300, width=1750)

# full features
grouped = plot_ablation(
    "../data/csv/inferences_conditioned_bs_full_features.csv",
    "R2 and Pearson correlation score for HXKm dataset computed with conditioned BS model with all features.",
    "../figures/ablation_full_features_on_conditioned_bs.svg",
    with_table=True
)

# plot main results:
fig, r2, p = results_plot(
    model_name="conditioned_bs_full_features",
    test_name="unconditioned_bs_test",
    grouped_data=grouped,
    title="HXKm database: true vs predicted",
    save_path="../figures/figure_3.svg"
)

In [None]:
# combine all results into one:

# all features:
full_features = pandas.read_csv("../data/csv/inferences_full_features.csv")
full_features.name = full_features.name.apply(lambda x: x + "_full_features")

full_features_conditioned_bs = pandas.read_csv("../data/csv/inferences_conditioned_bs_full_features.csv")
full_features_conditioned_bs.name = full_features.name.apply(lambda x: x + "_conditioned_bs_full_features")

# no gates
no_gates = pandas.read_csv("../data/csv/inferences_no_gates.csv")
no_gates.name = no_gates.name.apply(lambda x: x + "_no_gates")

no_gates_conditioned_bs = pandas.read_csv("../data/csv/inferences_conditioned_bs_no_gates.csv")
no_gates_conditioned_bs.name = no_gates.name.apply(lambda x: x + "_conditioned_bs_no_gates")

# no seqid
no_seqid = pandas.read_csv("../data/csv/inferences_no_seqid.csv")
no_seqid.name = no_seqid.name.apply(lambda x: x + "_no_seqid")

no_seqid_conditioned_bs = pandas.read_csv("../data/csv/inferences_conditioned_bs_no_seqid.csv")
no_seqid_conditioned_bs.name = no_seqid.name.apply(lambda x: x + "_conditioned_bs_no_seqid")

combined_inferences = pandas.concat([
    full_features, full_features_conditioned_bs,
    no_gates, no_gates_conditioned_bs,
    no_seqid, no_seqid_conditioned_bs
])
test_metrics = combined_inferences.loc[
    (combined_inferences.name.str.contains("conditioned_bs_test")) &
    (~combined_inferences.name.str.contains("_train"))
    ]
grouped = test_metrics.groupby(["name"], as_index=False).agg(
    r2_mean=pandas.NamedAgg(column="r2", aggfunc="mean"),
    r2_std=pandas.NamedAgg(column="r2", aggfunc="std"),
    p_mean=pandas.NamedAgg(column="pearson", aggfunc="mean"), 
    p_std=pandas.NamedAgg(column="pearson", aggfunc="std"),
    mse_mean=pandas.NamedAgg(column="mse", aggfunc="mean"), 
    mse_std=pandas.NamedAgg(column="mse", aggfunc="std")
)
# compute statistics:
plot_order = [
    grouped[grouped.name == "conditioned_bs_test_full_features"].index[0],
    grouped[grouped.name == "conditioned_bs_test_no_gates"].index[0],
    grouped[grouped.name == "conditioned_bs_test_no_seqid"].index[0],
]
# plot
labels = {
    "conditioned_bs_test_full_features": "All features",
    "conditioned_bs_test_no_gates": "No gates",
    "conditioned_bs_test_no_seqid": "No positional encoding",
}
grouped = grouped.reindex(plot_order)
model_r2 = {group: state.r2.values for group, state in test_metrics.groupby("name") if group in labels.keys()}
ref_sample = model_r2.pop("conditioned_bs_test_full_features")
p_values = {}
significances = {}
for model, r2 in model_r2.items():
    _, p_value = stats.ttest_rel(ref_sample, r2)
    p_values[model] = p_value
    if p_value < 0.001:
        significances[model] = "***"
    elif p_value < 0.01:
        significances[model] = "**"
    elif p_value < 0.05:
        significances[model] = "*"


# plot conditioned_bs in different feature sets:
fig = go.Figure(data=[
    go.Bar(
        name="R2", x=grouped.name.values, y=grouped.r2_mean.values, 
        error_y={
            "type":"data",
            "array": grouped.r2_std.values,
            "visible": True,
            "color": "red"
        },
        text = [f"{t:.3f}" for t in grouped.r2_mean.values],
        textfont=dict(weight=900),
        marker_color="orange"
    ),
    go.Bar(
        name="Pearson", x=grouped.name.values, y=grouped.p_mean.values, 
        text = [f"{t:.3f}" for t in grouped.p_mean.values],
        error_y={
            "type":"data",
            "array": grouped.p_std.values,
            "visible": True,
            "color": "red"
        },
        textfont=dict(weight=900),
        marker_color="purple"
    )
])
for model, significance in significances.items():
    fig.add_annotation(
        x=model,
        y=0.6,
        text=significance,
        showarrow=False,
        font={"size":25, "color": "red"},
        yref="y"
    )
fig.add_annotation(
    x=-0.06,
    y=.98,
    showarrow=False,
    yref="paper",
    xref="paper",
    text="A",
    align="center",
    font={"weight": 800, "size":18}
)
fig.update_layout(
    width=1000, height=300, 
    margin=dict(t=0, l=50, r=0, b=0),
    legend={
        "font": {"size": 18}
    },
    uniformtext_minsize=18, uniformtext_mode="hide",
)
fig.update_xaxes(labelalias=labels, tickfont=dict(size=18))
fig.update_yaxes(tickfont=dict(size=18))
fig.update_traces(textposition="inside", insidetextanchor="middle")
fig.write_image("../figures/conditioned_bs_comparison.png", width=1000, height=300)
fig.show()