# Imports

In [None]:
from pathlib import Path
import pandas as pd, numpy as np
import plotly.express as px
from datetime import datetime
from tqdm import tqdm

In [None]:
main_path = Path("/Users/alexpayne/lilac-mount-point/asap-datasets/retro_docking/sars_fragalysis_retrospective/20230411")

In [None]:
csv = main_path/"all_results_cleaned.csv"

In [None]:
csv.exists()

## load as pandas df

In [None]:
df = pd.read_csv(csv, index_col=0)

In [None]:
df.head()

# Filter out only P structures

In [None]:
p_df = df[(df.Complex_ID.apply(lambda x: "Mpro-P" in x)) & (df.Compound_Source.apply(lambda x: "Mpro-P" in x))]

In [None]:
len(p_df.Compound_ID.unique())

In [None]:
p_df["Structure_Name"] = p_df.Structure_Source.apply(lambda x: x.split("_")[0])

# Load Mpro_Soaks.csv

In [None]:
mpro_soaks = Path("/Users/alexpayne/lilac-mount-point/asap-datasets/current/sars_00_fragalysis/extra_files/Mpro_soaks.csv")

In [None]:
date_df = pd.read_csv(mpro_soaks)

In [None]:
ddf = date_df.loc[:, ["Sample Name", "Data Collection Date"]]

In [None]:
ddf.head()

In [None]:
def date_processor(date_string):
    if type(date_string) == str and not date_string == 'None':
        try:
            return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
        except ValueError:
            return datetime.strptime(date_string, "%d/%m/%Y %H:%M")
    else:
        return None

In [None]:
ddf['Sanitized_Date'] = ddf["Data Collection Date"].apply(date_processor)

In [None]:
ddf.columns = ["Structure_Name", "Data_Collection_Date", "Structure_Date"]

In [None]:
merged = pd.merge(p_df, ddf, on="Structure_Name")

In [None]:
merged["Structure_Date"] = merged.Structure_Date.apply(lambda x: x.date())

In [None]:
merged.sort_values("Structure_Date")

In [None]:
merged.to_csv(csv.parent / "results_with_structure_dates.csv")

# Implement a split

In [None]:
dates = merged.Structure_Date.unique()

In [None]:
dates.sort()

In [None]:
dates

In [None]:
def calculate_perc_good(df, cutoffs:list):
    sorted_df = df.sort_values(["POSIT"], ascending=[False])
    perc_good = []
    n_selected = []
    cutoff_list = []
    date_list = []
    for cutoff in tqdm(cutoffs):
        for date in dates:
            selected = sorted_df[sorted_df.Structure_Date <= date].groupby("Compound_ID").head(1)
            n_selected.append(len(selected))
            perc_good.append(selected.RMSD.apply(lambda x: x <=cutoff).sum() / len(selected))
            cutoff_list.append(cutoff)
            date_list.append(date)
    df = pd.DataFrame({"Date": date_list, "Cutoff": cutoff_list, "Percentage": perc_good})
    return df

In [None]:
perc_good_df = calculate_perc_good(merged, cutoffs=[0.5, 1, 1.5, 2, 3])

In [None]:
perc_good_df

In [None]:
fig = px.scatter(perc_good_df, x="Date", 
                 y="Percentage", 
                 color="Cutoff", 
                 height=800, 
                 width=800, 
                 color_continuous_scale="Portland")

In [None]:
fig.update_yaxes(title="Percentage of Molecules with Selected Pose RMSD from True Pose < Cutoff ")

In [None]:
fig.write_image("../../../../figures/20230518_sars_retrospective_temporal_split_RMSD.png")