# Imports

In [None]:
from pathlib import Path
import pandas as pd, numpy as np
import plotly.express as px
from datetime import datetime
from tqdm import tqdm

In [None]:
main_path = Path("/Users/alexpayne/lilac-mount-point/asap-datasets/retro_docking/sars_fragalysis_retrospective/20230411")

In [None]:
csv = main_path/"all_results_cleaned.csv"

In [None]:
csv.exists()

## load as pandas df

In [None]:
df = pd.read_csv(csv, index_col=0)

In [None]:
df.head()

# Filter out only P structures

In [None]:
p_df = df[(df.Complex_ID.apply(lambda x: "Mpro-P" in x)) & (df.Compound_Source.apply(lambda x: "Mpro-P" in x))]

In [None]:
len(p_df.Compound_ID.unique())

In [None]:
p_df["Structure_Name"] = p_df.Structure_Source.apply(lambda x: x.split("_")[0])

# Load Mpro_Soaks.csv

In [None]:
mpro_soaks = Path("/Users/alexpayne/lilac-mount-point/asap-datasets/current/sars_00_fragalysis/extra_files/Mpro_soaks.csv")

In [None]:
date_df = pd.read_csv(mpro_soaks)

In [None]:
ddf = date_df.loc[:, ["Sample Name", "Data Collection Date"]]

In [None]:
ddf.head()

In [None]:
def date_processor(date_string):
    if type(date_string) == str and not date_string == 'None':
        try:
            return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
        except ValueError:
            return datetime.strptime(date_string, "%d/%m/%Y %H:%M")
    else:
        return None

In [None]:
ddf['Sanitized_Date'] = ddf["Data Collection Date"].apply(date_processor)

In [None]:
ddf.columns = ["Structure_Name", "Data_Collection_Date", "Structure_Date"]

In [None]:
merged = pd.merge(p_df, ddf, on="Structure_Name")

In [None]:
merged["Structure_Date"] = merged.Structure_Date.apply(lambda x: x.date())

In [None]:
merged.sort_values("Structure_Date")

In [None]:
merged.to_csv(csv.parent / "results_with_structure_dates.csv")

# Implement a split

In [None]:
merged_csv = main_path / "results_with_structure_dates.csv"
merged_csv.exists()

In [None]:
merged = pd.read_csv(merged_csv, index_col=0)

In [None]:
merged.head()

In [None]:
dates = merged.Structure_Date.unique()

In [None]:
dates.sort()

In [None]:
dates

In [None]:
def calculate_perc_good(df, cutoffs:list):
    sorted_df = df.sort_values(["POSIT"], ascending=[False])
    perc_good = []
    n_selected = []
    cutoff_list = []
    date_list = []
    n_structures = []
    n_final_reference=[]
    for cutoff in tqdm(cutoffs):
        for date in dates:
            selected = sorted_df[sorted_df.Structure_Date <= date]
            top_posit_score = selected.groupby("Compound_ID").head(1)
            n_selected.append(len(top_posit_score))
            perc_good.append(top_posit_score.RMSD.apply(lambda x: x <=cutoff).sum() / len(top_posit_score))
            cutoff_list.append(cutoff)
            date_list.append(date)
            n_structures.append(len(selected.Structure_Source.unique()))
            n_final_reference.append(len(top_posit_score.Structure_Source.unique()))
    df = pd.DataFrame({"Date": date_list, "Cutoff (Å)": cutoff_list, "Percentage": perc_good, "Number of Reference Structures": n_structures, "Number of Structures Used in Best Pose": n_final_reference})
    return df

In [None]:
perc_good_df = calculate_perc_good(merged, cutoffs=[0.5, 1, 1.5, 2, 3])

In [None]:
# perc_good_df["Cutoff (Å)"] = perc_good_df["Cutoff (Å)"].astype(str)

In [None]:
fig = px.scatter(perc_good_df, x="Date", 
                 y="Percentage", 
                 color="Cutoff (Å)", 
                 height=800, 
                 width=800, 
                 color_continuous_scale="Portland")

In [None]:
fig.update_yaxes(title="Percentage of Molecules with Selected Pose RMSD from True Pose < Cutoff ", range=[0,1])
fig.update_xaxes(title="Cutoff Date for Inclusion of Reference Structures")

In [None]:
fig.write_image("../../../../figures/20230518_sars_retrospective_temporal_split_RMSD.png")

# Plot as a function of the number of structures

In [None]:
fig = px.scatter(perc_good_df, x="Number of Reference Structures", 
                 y="Percentage", 
                 color="Cutoff (Å)", 
                 height=800, 
                 width=800, 
                 color_continuous_scale="Portland")

In [None]:
fig.update_yaxes(title="Percentage of Molecules with Selected Pose RMSD from True Pose < Cutoff ", range=[0,1])
fig.update_xaxes(title="Number of Reference Structures within Cutoff Date")

In [None]:
fig.write_image("../../../../figures/20230518_sars_retrospective_temporal_split_RMSD_nreferences.png")

# Try to get them in the same plot

In [None]:
perc_good_df

In [None]:
from plotly.subplots import make_subplots

In [None]:
fig2 = px.line(perc_good_df[perc_good_df["Cutoff (Å)"] == 0.5], x="Date", y="Number of Reference Structures")

In [None]:
fig2.update_traces(yaxis="y2")

In [None]:
fig2.show()

In [None]:
fig = px.scatter(perc_good_df, x="Date", 
                 y="Percentage", 
                 color="Cutoff (Å)", 
                 height=800, 
                 width=800, 
                 color_continuous_scale="Portland")

In [None]:
fig.update_yaxes(title="Percentage of Molecules with Selected Pose RMSD from True Pose < Cutoff ", range=[0,1])
fig.update_xaxes(title="Cutoff Date for Inclusion of Reference Structures")

In [None]:
subfig = make_subplots(specs=[[{"secondary_y": True}]])
subfig.add_traces(fig.data + fig2.data)
subfig.update_xaxes(title="Cutoff Date for Inclusion of Reference Structures")
subfig.layout.yaxis.title="Percentage of Molecules with Chosen Pose RMSD to Reference < Cutoff"
subfig.layout.yaxis2.title="Number of Reference Structures Included"
subfig.update_layout(coloraxis_colorbar=dict(title="Cutoff (Å)"), colorscale={"sequential":"Portland"})
subfig.layout.height=800
subfig.layout.width=800
subfig.show()

In [None]:
subfig.write_image("../../../../figures/20230518_sars_retrospective_temporal_split_combined.png")

# Also add the number of references actually used

In [None]:
fig3 = px.line(perc_good_df[perc_good_df["Cutoff (Å)"] == 0.5], x="Date", y="Number of Structures Used in Best Pose")
fig3.update_traces(yaxis="y2")

In [None]:
subfig = make_subplots(specs=[[{"secondary_y": True}]])
subfig.add_traces(fig.data + fig2.data + fig3.data)
subfig.update_xaxes(title="Cutoff Date for Inclusion of Reference Structures")
subfig.layout.yaxis.title="Percentage of Molecules with Chosen Pose RMSD to Reference < Cutoff"
subfig.layout.yaxis2.title="Number of Reference Structures"
subfig.update_layout(coloraxis_colorbar=dict(title="Cutoff (Å)"), 
                     colorscale={"sequential":"Portland"},)
subfig.layout.height=800
subfig.layout.width=800
subfig.show()

In [None]:
subfig.write_image("../../../../figures/20230518_sars_retrospective_temporal_split_combined_v2.png")

# Make 2 level plot john suggests

In [None]:
stacked = make_subplots(rows=2, cols=1)

In [None]:
date_title = "Cutoff Date for Inclusion of Reference Structures"
y_axis = "Percentage of Molecules with Pose RMSD to Reference < 2.0Å"
y_axis2 = "Number of Reference Structures"

In [None]:
fig1 = px.scatter(perc_good_df[perc_good_df["Cutoff (Å)"] == 2.0], x="Date", 
                 y="Percentage", 
                 height=800, 
                 width=800, )


In [None]:
fig1.update_yaxes(title=y_axis, range=[0,1])
fig1.update_xaxes(title=date_title)

In [None]:
figures = [
            fig1, fig2
    ]

fig = make_subplots(rows=len(figures), cols=1) 

for i, figure in enumerate(figures):
    for trace in range(len(figure["data"])):
        fig.append_trace(figure["data"][trace], row=i+1, col=1)

In [None]:
fig.layout.yaxis.range=(0,1)
fig.layout.title=y_axis
fig.layout.yaxis.title="Percentage"
fig.layout.xaxis.title=date_title
fig.layout.xaxis2.title=date_title
fig.layout.yaxis2.title=y_axis2
fig.layout.height=800
fig.layout.width=800

In [None]:
fig.show()

In [None]:
fig.write_image("../../../../figures/20230518_sars_retrospective_temporal_split_combined_stacked.png")