In [1]:
import glob
import pandas
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pickle
from scipy import stats
from dataset import KmClass
from torcheval.metrics.functional import r2_score
from torchmetrics import PearsonCorrCoef
from copy import copy
from explore_repo import plot_scatter, plot_ablation, learning_curves
from explore_repo import plot_all_ablation

In [2]:
experiment_files = [
    "../data/csv/inferences_full_features.csv",
    "../data/csv/inferences_no_gates.csv",
    "../data/csv/inferences_conditioned_bs_no_gates.csv",
    "../data/csv/inferences_no_seqid.csv",
    "../data/csv/inferences_conditioned_bs_no_seqid.csv",
]
experiment_conditions = [
    "All features",
    "No gates",
    "All features on no gates",
    "No positional encoding",
    "All features on no positional"
]
plot_all_ablation(experiment_files, experiment_conditions)

In [3]:
# plot learning curves:
experiments = [
    "conditioned_bs_full_features", 
    "descriptor_free_full_features", 
    "unconditioned_bs_full_features", 
    "bs_free_full_features", 
    "fingerprint_free_full_features",
    "aa_id_free_full_features",
    "protein_free_full_features",
    "molecule_free_full_features",
]
experiment_titles = [
    "C AS",
    "D free",
    "U AS",
    "AS free",
    "FP free",
    "PE free",
    "P free",
    "M free"
]
big_fig = learning_curves(experiments, experiment_titles)

# full features
grouped = plot_ablation(
    "../data/csv/inferences_conditioned_bs_full_features.csv",
    "R2 and Pearson correlation score for HXKm dataset computed with conditioned BS model with all features.",
    "../figures/ablation_full_features_on_conditioned_bs.jpg",
    with_table=True
)

scatter_fig, scatter_r2, scatter_p, hist_fig = plot_scatter(
    "conditioned_bs_full_features", # model name
    "unconditioned_bs_test", # set of features
    grouped, # data to plot mean and std R2
    "HXKm database: true vs predicted.", # title of the plot
    "../figures/scatter_on_conditioned_bs_full_features.jpg" # where to save
)

16
\hline 
Protein free& 0.253 $\pm$ 0.019& 0.519 $\pm$ 0.018& 0.022 $\pm$ 0.001& 431& 9.477e-02\\
\hline 
AA identity free& 0.321 $\pm$ 0.042& 0.571 $\pm$ 0.037& 0.020 $\pm$ 0.001& 420& 1.882e-03\\
\hline 
Molecule free& -0.117 $\pm$ 0.103& 0.270 $\pm$ 0.054& 0.033 $\pm$ 0.003& 420& 1.712e-08\\
\hline 
AS free& 0.282 $\pm$ 0.033& 0.548 $\pm$ 0.024& 0.021 $\pm$ 0.001& 431& 8.881e-01\\
\hline 
Conditioned AS& 0.284 $\pm$ 0.048& 0.560 $\pm$ 0.035& 0.021 $\pm$ 0.001& 420& ref\\
\hline 
Descriptor free& 0.182 $\pm$ 0.081& 0.497 $\pm$ 0.043& 0.024 $\pm$ 0.002& 420& 1.761e-05\\
\hline 
Fingerprint free& 0.136 $\pm$ 0.054& 0.443 $\pm$ 0.045& 0.026 $\pm$ 0.002& 420& 1.040e-08\\
\hline 
Unconditioned AS& 0.325 $\pm$ 0.052& 0.585 $\pm$ 0.038& 0.020 $\pm$ 0.002& 420& 4.683e-04\\

Before drop: 431
After drop: 420
Loaded Km class. Size of the database: 420
Number of descriptor features when fitting: 196



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [4]:
big_fig.update_layout(
    margin=dict(t=0,b=0,l=0,r=0),
    font=dict(size=26),
)
big_fig.write_image("../figures/supp_fig_2.jpg", height=2300, width=1750)

In [5]:
# combine all results into one:

# all features:
full_features = pandas.read_csv("../data/csv/inferences_full_features.csv")
full_features.name = full_features.name.apply(lambda x: x + "_full_features")

full_features_conditioned_bs = pandas.read_csv("../data/csv/inferences_conditioned_bs_full_features.csv")
full_features_conditioned_bs.name = full_features.name.apply(lambda x: x + "_conditioned_bs_full_features")

# no gates
no_gates = pandas.read_csv("../data/csv/inferences_no_gates.csv")
no_gates.name = no_gates.name.apply(lambda x: x + "_no_gates")

no_gates_conditioned_bs = pandas.read_csv("../data/csv/inferences_conditioned_bs_no_gates.csv")
no_gates_conditioned_bs.name = no_gates.name.apply(lambda x: x + "_conditioned_bs_no_gates")

# no seqid
no_seqid = pandas.read_csv("../data/csv/inferences_no_seqid.csv")
no_seqid.name = no_seqid.name.apply(lambda x: x + "_no_seqid")

no_seqid_conditioned_bs = pandas.read_csv("../data/csv/inferences_conditioned_bs_no_seqid.csv")
no_seqid_conditioned_bs.name = no_seqid.name.apply(lambda x: x + "_conditioned_bs_no_seqid")

combined_inferences = pandas.concat([
    full_features, full_features_conditioned_bs,
    no_gates, no_gates_conditioned_bs,
    no_seqid, no_seqid_conditioned_bs
])
test_metrics = combined_inferences.loc[
    (combined_inferences.name.str.contains("conditioned_bs_test")) &
    (~combined_inferences.name.str.contains("_train"))
    ]
grouped = test_metrics.groupby(["name"], as_index=False).agg(
    r2_mean=pandas.NamedAgg(column="r2", aggfunc="mean"),
    r2_std=pandas.NamedAgg(column="r2", aggfunc="std"),
    p_mean=pandas.NamedAgg(column="pearson", aggfunc="mean"), 
    p_std=pandas.NamedAgg(column="pearson", aggfunc="std"),
    mse_mean=pandas.NamedAgg(column="mse", aggfunc="mean"), 
    mse_std=pandas.NamedAgg(column="mse", aggfunc="std")
)
# compute statistics:
plot_order = [
    grouped[grouped.name == "conditioned_bs_test_full_features"].index[0],
    grouped[grouped.name == "conditioned_bs_test_no_gates"].index[0],
    grouped[grouped.name == "conditioned_bs_test_no_seqid"].index[0],
]
# plot
labels = {
    "conditioned_bs_test_full_features": "All features",
    "conditioned_bs_test_no_gates": "No gates",
    "conditioned_bs_test_no_seqid": "No positional encoding",
}
grouped = grouped.reindex(plot_order)
model_r2 = {group: state.r2.values for group, state in test_metrics.groupby("name") if group in labels.keys()}
ref_sample = model_r2.pop("conditioned_bs_test_full_features")
p_values = {}
significances = {}
for model, r2 in model_r2.items():
    _, p_value = stats.ttest_rel(ref_sample, r2)
    p_values[model] = p_value
    if p_value < 0.001:
        significances[model] = "***"
    elif p_value < 0.01:
        significances[model] = "**"
    elif p_value < 0.05:
        significances[model] = "*"


# plot conditioned_bs in different feature sets:
fig = go.Figure(data=[
    go.Bar(
        name="R2", x=grouped.name.values, y=grouped.r2_mean.values, 
        error_y={
            "type":"data",
            "array": grouped.r2_std.values,
            "visible": True,
            "color": "red"
        },
        text = [f"{t:.3f}" for t in grouped.r2_mean.values],
        textfont=dict(weight=900),
        marker_color="orange"
    ),
    go.Bar(
        name="Pearson", x=grouped.name.values, y=grouped.p_mean.values, 
        text = [f"{t:.3f}" for t in grouped.p_mean.values],
        error_y={
            "type":"data",
            "array": grouped.p_std.values,
            "visible": True,
            "color": "red"
        },
        textfont=dict(weight=900),
        marker_color="purple"
    )
])
for model, significance in significances.items():
    fig.add_annotation(
        x=model,
        y=0.6,
        text=significance,
        showarrow=False,
        font={"size":25, "color": "red"},
        yref="y"
    )
fig.add_annotation(
    x=-0.06,
    y=.98,
    showarrow=False,
    yref="paper",
    xref="paper",
    text="A",
    align="center",
    font={"weight": 800, "size":18}
)
fig.update_layout(
    width=1000, height=300, 
    margin=dict(t=0, l=50, r=0, b=0),
    legend={
        "font": {"size": 18}
    },
    uniformtext_minsize=18, uniformtext_mode="hide",
)
fig.update_xaxes(labelalias=labels, tickfont=dict(size=18))
fig.update_yaxes(tickfont=dict(size=18))
fig.update_traces(textposition="inside", insidetextanchor="middle")
fig.write_image("../figures/conditioned_bs_comparison.png", width=1000, height=300)
fig.show()

In [6]:
# plot wildtype vs mutants r2 performances:
# get model outputs:
outputs_files = glob.glob(f"../data/models/conditioned_bs_full_features/*_fold_unconditioned_bs_test_outputs.pkl")
r2_mutants = []
r2_wildtypes = []
pcc = PearsonCorrCoef()
pcc_mutants = []
pcc_wildtypes = []

hxkm = pandas.read_csv("../data/hxkm.csv")
df_test = pandas.read_csv("../data/csv/HXKm_dataset_final_new_conditioned_bs.csv")
test_db = KmClass(df_test).dataframe
for file in outputs_files:
    outputs = pickle.load(open(file, "rb"))

    y = torch.tensor(outputs["y_unscaled"])
    preds = torch.tensor(outputs["preds_unscaled"])
    indices = outputs["all_idx"].flatten()

    # get protein type info:
    test_entries = test_db.iloc[indices]
    enzyme_types = test_entries.apply(
        lambda x: hxkm.loc[
            hxkm.sequence == x.sequence
        ].protein_type.values[0]
    , axis=1)
    enzyme_types.reset_index(drop=True, inplace=True) # now in the same order as preds
    mutant_indices = enzyme_types[enzyme_types=="mutant"].index.tolist()
    wildtype_indices = enzyme_types[enzyme_types=="wildtype"].index.tolist()

    # compute r2 for wildtype and mutants:
    mutant_targets = y[mutant_indices]
    mutant_preds = preds[mutant_indices]
    r2_mutants.append(r2_score(mutant_preds, mutant_targets))
    pcc_mutants.append(pcc(mutant_preds, mutant_targets))
    
    wildtype_targets = y[wildtype_indices]
    wildtype_preds = preds[wildtype_indices]
    r2_wildtypes.append(r2_score(wildtype_preds, wildtype_targets))
    pcc_wildtypes.append(pcc(wildtype_preds, wildtype_targets))

_, p_value = stats.ttest_rel(r2_mutants, r2_wildtypes)
if p_value < 0.001:
    significance = "***"
elif p_value < 0.01:
    significance = "**"
elif p_value < 0.05:
    significance = "*"

r2_mutants = torch.tensor(r2_mutants)
pcc_mutants = torch.tensor(pcc_mutants)
r2_wildtypes = torch.tensor(r2_wildtypes)
pcc_wildtypes = torch.tensor(pcc_wildtypes)

r2_mean = [r2_wildtypes.mean().item(), r2_mutants.mean().item()]
r2_std = [r2_wildtypes.std().item(), r2_mutants.std().item()]
pcc_mean = [pcc_wildtypes.mean().item(), pcc_mutants.mean().item()]
pcc_std = [pcc_wildtypes.std().item(), pcc_mutants.std().item()]

x = ["Wild Type", "Mutant"]

# compute significance scores:
fig = go.Figure(data=[
    go.Bar(
        name="R2", x=x, y=r2_mean, 
        error_y={
            "type":"data",
            "array": r2_std,
            "visible": True,
            "color": "red"
        },
        text = [f"{t:.3f}" for t in r2_mean],
        textfont=dict(weight=900),
        marker_color="orange"
    ),
    go.Bar(
        name="Pearson", x=x, y=pcc_mean, 
        error_y={
            "type":"data",
            "array": pcc_std,
            "visible": True,
            "color": "red"
        },
        text = [f"{t:.3f}" for t in pcc_mean],
        textfont=dict(weight=900),
        marker_color="purple"
    )
])
fig.add_annotation(
    x=x[1],
    y=0.6,
    text=significance,
    showarrow=False,
    font={"size":25, "color": "red"},
    yref="y"
)
fig.update_layout(
    width=600, height=500, 
    title={
        "text":"Mutant VS Wild Type.",
        "font": {"size": 23}
    },
    legend={
        "font": {"size": 18}
    },
    uniformtext_minsize=18, uniformtext_mode="hide",
    xaxis={
        "title": {
            "text": "Genetic Construct",
        },
        "title_font": {"size": 18}
    },
    yaxis={
        "title": {
            "text": "Value",
        },
        "title_font": {"size": 18}
    }
)
fig.update_xaxes(labelalias=labels, tickfont=dict(size=18))
fig.update_yaxes(tickfont=dict(size=18))
fig.update_traces(textposition="inside", insidetextanchor="middle")
fig.write_image("../figures/mutant_vs_wildtype.jpg", width=600, height=500)
fig.show()
mutant_vs_wt_fig = copy(fig)

Before drop: 431
After drop: 420
Loaded Km class. Size of the database: 420
Number of descriptor features when fitting: 196


In [7]:
# plot figure 3
fig = make_subplots(
    rows=2, cols=3,
    subplot_titles=(
        # '<span style="color:red">A</span>: True vs Predicted.',
        # '<span style="color:red">B</span>: Wild Type vs Mutants', 
    ),
    horizontal_spacing=.12,
    vertical_spacing=0.15,
    specs=[
        [{"rowspan": 2,"colspan": 2}, None, {}],# Second row: one subplot spanning both columns
        [None, None, {}],# Second row: one subplot spanning both columns
    ],
)
for trace in scatter_fig.data:
    trace.legendgroup="group1"
    trace.legendgrouptitle = {"text": "Panel A"}
    fig.add_trace(trace, row=1, col=1)
    
for trace in hist_fig.data:
    trace.legendgroup="group2"
    trace.legendgrouptitle = {"text": "Panel B"}
    fig.add_trace(trace, row=1, col=3)

for trace in mutant_vs_wt_fig.data:
    trace.legendgroup="group3"
    trace.legendgrouptitle = {"text": "Panel C"}
    fig.add_trace(trace, row=2, col=3)

fig.update_yaxes(row=1, col=1, title={"text": "Predicted Km normalized value"})
fig.update_xaxes(row=1, col=1, title={"text": "Experimental Km normalized value"})
fig.update_xaxes(row=1, col=3, title={"text": "Km normalized value"})
fig.update_yaxes(row=1, col=3, title={"text": "Count"})
fig.update_yaxes(row=2, col=3, title={"text": "R2 and Pearson"})
fig.update_layout(
    font={"size": 17},
    height=550, 
    width=1200, 
    margin=dict(t=0, b=0, l=50, r=30),  # Tighter margins
)
fig.add_trace(go.Scatter(
    x=[None], y=[None],  # No actual data points
    mode='markers',
    marker=dict(color='purple', size=10, symbol='square'),
    name='Mutant',
    legendgroup="group1",
    showlegend=True
))

fig.add_trace(go.Scatter(
    x=[None], y=[None],
    mode='markers', 
    marker=dict(color='green', size=10, symbol='square'),
    name='Wild Type',
    legendgroup="group1",
    showlegend=True
))
fig.add_annotation(
    x=x[1],
    y=0.6,
    text=significance,
    showarrow=False,
    font={"size":25, "color": "red"},
    yref="y",
    col=3, row=2
)
fig.add_annotation(
    x=-0.05,
    y=.95,
    showarrow=False,
    yref="paper",
    xref="paper",
    text="A",
    align="center",
    font={"weight": 800}
)
fig.add_annotation(
    x=0.68,
    y=.95,
    showarrow=False,
    yref="paper",
    xref="paper",
    text="B",
    align="center",
    font={"weight": 800}
)
fig.add_annotation(
    x=0.68,
    y=.40,
    showarrow=False,
    yref="paper",
    xref="paper",
    text="C",
    align="center",
    font={"weight": 800}
)
# for annotation in fig.layout.annotations:
    # annotation.y += 0.01  # Increase by 0.02 (you can adjust this value)
fig.show()
fig.write_image("../figures/figure_3.jpg", width=1100, height=600)

In [8]:
from model import Network
from hyperparameters import hyperparameters
device = ("cpu", "cuda")[torch.cuda.is_available()]
net = Network(
    hidden_dim1=hyperparameters["hidden_dim1"], 
    hidden_dim2=hyperparameters["hidden_dim2"], 
    hidden_dim3=hyperparameters["hidden_dim3"], 
    dropout1=hyperparameters["dropout1"], 
    dropout2=hyperparameters["dropout2"]
).to(device)
params = torch.load(f"../data/models/conditioned_bs_full_features/1_fold_model.pth")["model_state_dict"]
net.load_state_dict(params)
gated_layer = net.net[0].gate.weight.data.sum(dim=0).softmax(dim=0).detach().cpu()

# contribution of each feature group:
res_feats = 1024*3 # each residue described by 3 features

fig = go.Figure(data=[
    go.Scatter(
        x=list(range(res_feats)), 
        y=gated_layer[:res_feats],
        line={
            "color": "red"
        },
        name="Enzyme"
    ),
    go.Scatter(
        x=list(range(res_feats, gated_layer.shape[0])), 
        y=gated_layer[res_feats:],
        line={
            "color": "blue"
        },
        name="Substrate"
    ),
])
fig.write_image("../figures/weight_contributions_scatter.jpg", width=1200, height=500)
hist_fig = copy(fig)

bs_weights = []
aa_weights = []
seqid_weights = []
protein_weights = []

descriptors_weights = []
fingerprint_weights = []
molecule_weights = []

model_files = glob.glob("../data/models/conditioned_bs_full_features/*_fold_model.pth")
for model_file in model_files:
    params = torch.load(model_file)["model_state_dict"]
    net.load_state_dict(params)
    gated_layer = net.net[0].gate.weight.data.sum(dim=0).softmax(dim=0).detach().cpu()

    bs_weight = gated_layer[1:res_feats:3].sum()
    aa_weight = gated_layer[0:res_feats:3].sum()
    seqid_weight = gated_layer[2:res_feats:3].sum()
    descriptors_weight = gated_layer[res_feats:res_feats+196].sum()
    fingerprint_weight = gated_layer[-2048:].sum()

    bs_weights.append(bs_weight)
    aa_weights.append(aa_weight)
    seqid_weights.append(seqid_weight)
    protein_weights.append(bs_weight + aa_weight + seqid_weight)

    descriptors_weights.append(descriptors_weight)
    fingerprint_weights.append(fingerprint_weight)
    molecule_weights.append(descriptors_weight + fingerprint_weight)

bs_weights = torch.tensor(bs_weights)
aa_weights = torch.tensor(aa_weights)
seqid_weights = torch.tensor(seqid_weights)
protein_weights = torch.tensor(protein_weights)
descriptors_weights = torch.tensor(descriptors_weights)
fingerprint_weights = torch.tensor(fingerprint_weights)
molecule_weights = torch.tensor(molecule_weights)
weights_mean_main = [
    protein_weights.mean(),
    molecule_weights.mean(),
]
weights_std_main = [
    protein_weights.std(),
    molecule_weights.std(),
]
weights_mean_supp = [
    protein_weights.mean(),
    molecule_weights.mean(),
    bs_weights.mean(),
    aa_weights.mean(),
    seqid_weights.mean(),
    descriptors_weights.mean(),
    fingerprint_weights.mean(),
]
weights_std_supp = [
    bs_weights.std(),
    aa_weights.std(),
    seqid_weights.std(),
    descriptors_weights.std(),
    fingerprint_weights.std(),
]

fig = go.Figure(data=[
    go.Bar(
        x=[
            "Enzyme",
            "Substrate",
        ],
        y=weights_mean_main,
        error_y={
            "type":"data",
            "array": weights_std_main,
            "visible": True
        },
        marker_color = ["red", "blue"]
    )
])
fig.write_image("../figures/weight_contributions_bar.jpg", width=1200, height=500)
bar_fig = copy(fig)
print("Sanity check:", protein_weights + molecule_weights)

# make supp figure:
fig = go.Figure(data=[
    go.Bar(
        x=[
            "Active Site",
            "AA Identity",
            "Positional Encoding",
            "Descriptors",
            "Fingerprints"
        ],
        y=weights_mean_supp,
        error_y={
            "type":"data",
            "array": weights_std_supp,
            "visible": True
        },
        marker_color = ["red", "red", "red", "blue", "blue"]
    )
])
fig.update_layout(
    margin=dict(t=0,b=0,r=0,l=0),
    font=dict(size=19),
    height=250,
    width=800
)
fig.update_yaxes(title={"text": "Importance Score"})
fig.show()
fig.write_image("../figures/supp_fig_4.jpg", width=800, height=250)
fig = make_subplots(
    rows=1, cols=3,
    horizontal_spacing=.1,
    specs=[
        [{"colspan": 2}, None, {}]
    ]
)
for trace in hist_fig.data:
    trace.legendgroup="group1"
    fig.add_trace(trace, row=1, col=1)
    
for trace in bar_fig.data:
    trace.showlegend=False
    fig.add_trace(trace, row=1, col=3)
fig.update_yaxes(title={"text": "Importance Score","font": {"size": 18}},)
fig.add_annotation(
    x=-.16,
    y=.95,
    showarrow=False,
    yref="paper",
    xref="paper",
    text="A",
    align="center",
    font={"weight": 800}
)
fig.add_annotation(
    x=0.68,
    y=.95,
    showarrow=False,
    yref="paper",
    xref="paper",
    text="B",
    align="center",
    font={"weight": 800}
)
fig.update_layout(
    margin=dict(t=0,b=0,r=0,l=130),
    font=dict(size=19),
    height=250,
    width=1000
)
fig.show()
fig.write_image("../figures/figure_6.jpg", width=1000, height=250)


Sanity check: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])
