In [None]:
from argparse import Namespace
import sys
import numpy as np
import os
import json
import torch.nn as nn
from IGEP.evaluation.evaluate import EPMPEvaluator, EGNNEvaluator
from IGEP.preprocessing.data_utils import epipredDataset
sys.path.append('../')

models = {"EpiEPMP": EPMPEvaluator, "egnn": EGNNEvaluator, }

In [None]:
# Select the run
run_dir = "logs/GAT64/r4"
model_file = "best_state_dict.pth"
params_file = "params.json"
save_dir = "logs/"

split = "test"

# Process the Data    

In [None]:
# # This is a list of slected good runs # # 
# Make sure to up to change the constant file

args = json.load(open(os.path.join(run_dir,params_file)))
args = Namespace(**args)

evaluator = models[args.model](args, args.num_cdr_feats, args.num_ag_feats)
evaluator.load_state_dict(os.path.join(run_dir, model_file), key= "model_state_dict")
model = evaluator.model

In [None]:
## LOAD DATA ##
protein_list_test = epipredDataset(args.test_path, feats=args.feats, random_rotation=False, centered=args.centered)
pdb_names = [protein_list_test[i]["name"] for i in range(len(protein_list_test))]
print(pdb_names)

# Analyse the test data

In [None]:
test_metrics, df_test, curves_vals = evaluator.test_epitope(protein_list_test)

import pickle
with open(os.path.join(run_dir,"curves.p"), "wb") as f:
    pickle.dump(curves_vals, f, protocol=2)

for i_p, mol in enumerate(["ab","ag"]):
    idx_sorted = np.argsort(curves_vals[mol]["PrecRecall"][1])
    test_metrics[mol]["aucpr"] = np.trapz(curves_vals[mol]["PrecRecall"][0][idx_sorted], curves_vals[mol]["PrecRecall"][1][idx_sorted])

import pandas as pd
df_dict = dict()
for mol in ["ab", "ag"]:
    df_dict.update(
        {mol + "_" + metric: [test_metrics[mol][metric].item()] for metric in test_metrics[mol]})
df = pd.DataFrame(df_dict, index=[0])
file_path = os.path.join(f"{save_dir}final_stats_{'_'.join(run_dir.split('/')[-2:])}.csv")
print("Saving csv file to: ", file_path)
df.to_csv(file_path, index=True, float_format='%f')
print(df.to_string(float_format='{:.2f}'.format))

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=1, cols=2, subplot_titles=["ab", "ag"])
for i_p, mol in enumerate(["ab", "ag"]):
    fig.add_traces(
        [go.Scatter(x=curves_vals[mol]["PrecRecall"][1], y=curves_vals[mol]["PrecRecall"][0], name="Model",
                    line=go.scatter.Line(color="blue")), ], rows=1, cols=i_p + 1
    )

fig.update_layout(height=400, width=800, title_text=f"Precision Recall curves ")
fig.show()
fig.write_image(os.path.join(run_dir,"PRcurve.png"))

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=1, cols=2, subplot_titles=["ab","ag"])
for i_p, mol in enumerate(["ab","ag"]):
    fig.add_traces(
        [go.Scatter(x=curves_vals[mol]["ROC"][0], y=curves_vals[mol]["ROC"][1], name="Model", line=go.scatter.Line(color="blue")),
         go.Scatter(x=[0 ,1], y=[0,1], name="No Skill", line=go.scatter.Line(dash="dash", color="red"))], rows=1, cols=i_p+1
    )

fig.update_layout(height=400, width=800, title_text=f"ROC curves")
fig.show()
fig.write_image(os.path.join(run_dir,"ROCcurve.png"))

In [None]:
# plot molecules
from Utils.plot_utils import plot_abag_pointclouds, plot_abag_3dgraphs
protein_list_test.centered = False
pdb = {"name": '4jr9'}
pdb["idx"] = pdb_names.index(pdb["name"])

sigmoid = nn.Sigmoid()
pdb["pred_cdr"], pdb["pred_ag"] = evaluator.run_model(protein_list_test[pdb["idx"]])
pdb["pred_cdr"] = sigmoid(pdb["pred_cdr"]).cpu().detach().numpy().squeeze(-1)
pdb["pred_ag"] = sigmoid(pdb["pred_ag"]).cpu().detach().numpy().squeeze(-1)
pdb["pred"] = np.concatenate((pdb["pred_cdr"], pdb["pred_ag"]), 0)

pdb["coords_ag"] = protein_list_test[pdb["idx"]]["coords_ag"].cpu().numpy()
pdb["coords_cdr"] = protein_list_test[pdb["idx"]]["coords_cdr"].cpu().numpy()
pdb["lbls_cdr"] = np.array(protein_list_test[pdb["idx"]]["cdr_lbls"].cpu())
pdb["lbls_ag"] = np.array(protein_list_test[pdb["idx"]]["ag_lbls"].cpu())
pdb["lbls"] = np.concatenate((pdb["lbls_cdr"], pdb["lbls_ag"],), 0)
pdb['edge_cdr'] = np.array(protein_list_test[pdb["idx"]]["edge_index_cdr"].cpu())
pdb['edge_cdr'] = np.unique(pdb['edge_cdr'][:,::2], axis=1)
pdb['edge_ag'] = np.array(protein_list_test[pdb["idx"]]["edge_index_ag"].cpu())
pdb['edge_ag'] = np.unique(pdb['edge_ag'][:,::2], axis=1)

### CDR + AG

In [None]:
# graph
coords = [{"ab": pdb["coords_cdr"], "ag": pdb["coords_ag"]} for i in range(3)]
edges = [{"ab": pdb["edge_cdr"], "ag": pdb["edge_ag"]} for i in range(3)]
colors = [{"ab": pdb["lbls_cdr"], "ag": pdb["lbls_ag"]}, {"ab": pdb["pred_cdr"], "ag": pdb["pred_ag"]},
          {"ab": np.greater(pdb["pred_cdr"], 0.5).astype(float), "ag": np.greater(pdb["pred_ag"], 0.5).astype(float)}, ]
fig = plot_abag_3dgraphs(coords, edges, colors, ("ground thruth", "Prediction", "Discrete Prediction"),
                         size=5, show=False)
fig.update_layout(title_text=f"ABAG", title_font_family="Times New Roman",
                  font_family="Times New Roman", plot_bgcolor='white')
fig

In [None]:
# point clouds
fig = plot_abag_pointclouds(coords, colors , ("ground thruth", "Prediction", "Discrete Prediction"),  size=5, show=False)
fig.update_layout(title_text=f"ABAG", title_font_family="Times New Roman",
                  font_family="Times New Roman", plot_bgcolor='white')
fig

### CDR

In [None]:
# graph
coords = [{"ab": pdb["coords_cdr"]} for i in range(3)]
edges = [{"ab": pdb["edge_cdr"],} for i in range(3)]
colors = [{"ab": pdb["lbls_cdr"],}, {"ab": pdb["pred_cdr"],},
          {"ab": np.greater(pdb["pred_cdr"], 0.5).astype(float),}, ]
fig = plot_abag_3dgraphs(coords, edges, colors , ("ground thruth", "Prediction", "Discrete Prediction"),
                         size=5, show=False)
fig.update_layout(title_text=f"AB", title_font_family="Times New Roman",
                  font_family="Times New Roman", plot_bgcolor='white')

In [None]:
# point clouds
fig = plot_abag_pointclouds(coords, colors , ("ground thruth", "Prediction", "Discrete Prediction"),
                            size=5, show=False)
fig.update_layout(title_text=f"AB", title_font_family="Times New Roman",
                  font_family="Times New Roman", plot_bgcolor='white')

### AG

In [None]:
# graph
coords = [{"ag": pdb["coords_ag"]} for i in range(3)]
edges = [{"ag": pdb["edge_ag"]} for i in range(3)]
colors = [{"ag": pdb["lbls_ag"]}, {"ag": pdb["pred_ag"]},
          {"ag": np.greater(pdb["pred_ag"], 0.5).astype(float)}, ]
fig = plot_abag_3dgraphs(coords, edges, colors , ("ground thruth", "Prediction", "Discrete Prediction"),
                         size=5, show=False)
fig.update_layout(title_text=f"AG", title_font_family="Times New Roman",
                  font_family="Times New Roman", plot_bgcolor='white')

In [None]:
# point clouds
fig = plot_abag_pointclouds(coords, colors , ("ground thruth", "Prediction", "Discrete Prediction"),
                            size=5, show=False)
fig.update_layout(title_text=f"AG", title_font_family="Times New Roman",
                  font_family="Times New Roman", plot_bgcolor='white')