# Imports

In [None]:
import plotly.express as ex
import pandas as pd
import numpy as np
import math, sys, os
from collections import Counter
sys.path.append('../')
from asapdiscovery.docking import plotting as pl
# from asapdiscovery.docking.analysis import DockingDataset

In [None]:
fn = "../scripts/docking_results.csv"

In [None]:
df = pd.read_csv(fn).fillna("-1")

In [None]:
df.index = df.Complex_ID

In [None]:
df["POSIT_R"] = -df.POSIT+1

In [None]:
df.head()

# filter df

In [None]:
redocked_df = df[df.MCSS_Rank == -1]

In [None]:
redocked_df

In [None]:
test_df = df[df.MCSS_Rank > -1]

In [None]:
test_df

In [None]:
cmplx_list = df.index

In [None]:
len(cmplx_list)

In [None]:
len(redocked_df.index)

In [None]:
len(test_df.index)

## make sure the numbers match up

In [None]:
len(np.append(redocked_df.Complex_ID, test_df.Complex_ID))

### they do!

## Filter out wacky results from docked dataset

### plot histogram of RMSDs

In [None]:
ex.histogram(df, x="RMSD")

### examine RMSDs lower than 0

In [None]:
RMSDs_less_than_0 = df[df.RMSD < 0]

In [None]:
len(RMSDs_less_than_0)

In [None]:
set(RMSDs_less_than_0.Compound_ID)

In [None]:
MAT_POS_090737b9_1_df = df[df.Compound_ID == 'MAT-POS-090737b9-1']

In [None]:
set(MAT_POS_090737b9_1_df.RMSD)

In [None]:
set(MAT_POS_090737b9_1_df.Reference_SDF)

#### Conclusion: all the RMSDs calculated from this ligand "MAT-POS-090737b9-1" are -1

### Examine RMSDs higher than the rest

In [None]:
RMSDs_large = df[df.RMSD >15]

In [None]:
len(RMSDs_large)

In [None]:
len(RMSDs_large[(RMSDs_large.Compound_ID == "MAT-POS-5d65ec79-1") | (RMSDs_large.Crystal_ID == "Mpro-P0097")])

#### Conclusion: All large RMSDs have to do with the fragalysis structure "Mpro-P0097" which had "MAT-POS-5d65ec79-1" as its ligand

In [None]:
set(RMSDs_large.Compound_ID)

In [None]:
set(RMSDs_large.Crystal_ID)

### plot histogram of chemgauss4 score

In [None]:
ex.histogram(df, x="Chemgauss4")

In [None]:
large_chemgauss4 = df[df.Chemgauss4 > 100]

In [None]:
len(large_chemgauss4)

In [None]:
set(large_chemgauss4.Chemgauss4)

In [None]:
set(large_chemgauss4.Crystal_ID)

In [None]:
set(large_chemgauss4.Reference_SDF)

In [None]:
set(large_chemgauss4.Compound_ID)

In [None]:
ex.scatter(large_chemgauss4, x="RMSD", y="POSIT_R", color="Compound_ID", symbol="Crystal_ID")

#### Conclusion: no idea

## Filter out bad data

In [None]:
bad_data_df = pd.concat([large_chemgauss4, RMSDs_large, RMSDs_less_than_0])

### are there any redocked structures in the bad data?

In [None]:
redocked_bad_data_df = bad_data_df[bad_data_df.MCSS_Rank == -1]

In [None]:
test_bad_data_df = bad_data_df[bad_data_df.MCSS_Rank != -1]

In [None]:
redocked_bad_data_df

#### Conclusion: the only bad redocked structures where from the RMSDs = -1 one

In [None]:
len(bad_data_df)

In [None]:
redocked_filtered_df = redocked_df.drop(redocked_bad_data_df.Complex_ID, axis='index')

In [None]:
test_filtered_df = test_df.drop(test_bad_data_df.Complex_ID, axis='index')

In [None]:
bad_data_df = df[(df.Chemgauss4 > 0) | (df.RMSD < 0) | (df.RMSD > 15)]

In [None]:
len(redocked_df) - len(redocked_filtered_df)

In [None]:
len(test_df) - len(test_filtered_df)

#### Conclusion: 2 redocked and 62 test complexes were removed

# Make a compound results Dataframe

## get all the info

In [None]:
test_filtered_df.groupby("Compound_ID")["RMSD"].agg(["count", "min"])

In [None]:
total_poses = test_filtered_df.groupby('Compound_ID')["RMSD"].count()

In [None]:
total_poses

In [None]:
RMSDs = test_filtered_df.groupby('Compound_ID')[['RMSD']].apply(lambda x: x[x <= 2].agg(["count", "min"]))

In [None]:
RMSDs

In [None]:
n_good_poses = RMSDs.xs("count", level=1)["RMSD"]

In [None]:
min_RMSD = RMSDs.xs("min", level=1)["RMSD"]

In [None]:
perc_good_poses = n_good_poses / total_poses

In [None]:
perc_good_poses

In [None]:
min_posit_R = test_filtered_df.groupby('Compound_ID')['POSIT_R'].min()

In [None]:
min_posit_R

In [None]:
cmpd_df = pd.DataFrame({
#     "Compound_ID": all_cmpd_counter.keys(),
    "N_Poses": total_poses,
    "N_Good_Poses": n_good_poses,
    "Perc_Good_Poses": perc_good_poses, 
    "Min_RMSD": min_RMSD,
    "Min_POSIT_R": min_posit_R,
                      })

In [None]:
cmpd_df

In [None]:
cmpd_df.sort_values("Perc_Good_Poses")

## Plot the dataframe

In [None]:
fig = ex.bar(cmpd_df.sort_values("Perc_Good_Poses"), 
       x=cmpd_df.sort_values("Perc_Good_Poses").index, 
       y="Perc_Good_Poses",
      )
fig.update_traces(width=1)
fig.update_layout(height=600, 
                  width=3600, 
                  font_size=12,
#                  xaxis={'categoryorder':'category ascending'}
                 )
fig.show()

In [None]:
too_good_df = cmpd_df[cmpd_df.Perc_Good_Poses == 1]

In [None]:
too_good_df

In [None]:
_df = test_filtered_df[test_filtered_df.Compound_ID.isin(too_good_df.index)]

In [None]:
ex.scatter(_df, 
           x="RMSD", 
           y="POSIT_R", 
           color="Compound_ID", 
           hover_data=["Crystal_ID", 
                       "Reference_SDF", 
                       "Chain_ID"], 
          )

## what percentage of compounds have at least 1 structure with RMSD < 2?

In [None]:
sum(cmpd_df.Perc_Good_Poses > 0)

In [None]:
sum(cmpd_df.Perc_Good_Poses == 0)

In [None]:
perc_cmpds = sum(cmpd_df.Perc_Good_Poses > 0) / len(cmpd_df)

In [None]:
perc_cmpds

### Conclusion: 87.6% of compounds with at least 1 structure

# Plotting Test Data

In [None]:
ex.scatter(test_filtered_df, 
           x="MCSS_Rank", 
           y="RMSD", 
           color="Compound_ID", 
           hover_data=["Crystal_ID", 
                       "Reference_SDF", 
                       "Chain_ID"]
          )

## General

In [None]:
ex.scatter(test_filtered_df, 
           x="RMSD", 
           y="Chemgauss4", 
           color="POSIT_R", 
           hover_data=["Crystal_ID", 
                       "Reference_SDF", 
                       "Chain_ID"], 
          )

## How well does POSIT recapitulate RMSD?

In [None]:
ex.scatter(test_filtered_df, 
           x="POSIT_R", 
           y="RMSD", 
           color="Chemgauss4", 
           hover_data=["Crystal_ID", 
                       "Reference_SDF", 
                       "Chain_ID"], 
          )

In [None]:
ex.density_heatmap(test_filtered_df, 
           x="POSIT_R", 
           y="RMSD", 
#            color="Chemgauss4", 
           hover_data=["Crystal_ID", 
                       "Reference_SDF", 
                       "Chain_ID"], 
                   marginal_x="histogram", 
                   marginal_y="histogram"
          )

## How well does chemgauss4 recapitulate RMSD?

In [None]:
ex.density_heatmap(test_filtered_df, 
           x="RMSD", 
           y="Chemgauss4", 
#            color="POSIT_R", 
           hover_data=["Crystal_ID", 
                       "Reference_SDF", 
                       "Chain_ID"], 
                   marginal_x="histogram", 
                   marginal_y="histogram"
          )

In [None]:
ex.scatter(test_filtered_df, 
           x="MCSS_Rank", 
           y="Chemgauss4", 
           color="RMSD", 
           hover_data=["Crystal_ID", 
                       "Reference_SDF", 
                       "Chain_ID"], 
#            facet_row="MCSS_Rank",
          )

In [None]:
ex.density_heatmap(test_filtered_df, 
           x="MCSS_Rank", 
           y="Chemgauss4", 
          )

In [None]:
ex.density_heatmap(test_filtered_df, 
           x="MCSS_Rank", 
           y="POSIT_R", 
                   marginal_y='histogram'
          )

In [None]:
ex.density_heatmap(test_filtered_df, 
                   x="MCSS_Rank", 
                   y="RMSD",
                   marginal_y='histogram'
          )

In [None]:
ex.density_heatmap(test_filtered_df, 
           x="MCSS_Rank", 
           y="POSIT",
                   marginal_y='histogram'
          )

# Plotting Redocked Data

In [None]:
redocked_filtered_df

In [None]:
ex.scatter(redocked_filtered_df, 
           x="POSIT_R", 
           y="RMSD", 
           color="Chemgauss4", 
           hover_data=["Crystal_ID", 
                       "Reference_SDF", 
                       "Chain_ID"], 
          )

In [None]:
ex.density_heatmap(test_filtered_df, 
                   x="POSIT_R", 
                   y="RMSD",
                   marginal_y='histogram',
                   marginal_x='histogram'
          )

In [None]:
total_good_cmpds = len(set(df[df["RMSD"] <= 2].Compound_ID))

In [None]:
total_good_cmpds

In [None]:
Counter(test_df[test_df["RMSD"] <= 2].Compound_ID)

In [None]:
test_df[test_df.Compound_ID == 'ALP-POS-64a710fa-1']

# AUC

## write AUC calculation functions

In [None]:
class Rock():
    def __init__(self, 
                 df,
                 score_name,
                 n_samples, 
                 ):
        self.df = df
        self.score_name = score_name
        self.n_samples = n_samples
        self.get_score_range()
        
        self.total_poses, self.total_good_poses, self.total_bad_poses, self.total_cmpds, self.total_good_cmpds, self.total_bad_cmpds = self.calc_data(self.df)

        self.auc_poses = []
        self.auc_cmpds = []
    
    def calc_data(self, df):
        n_poses = len(df)
        n_good_poses = sum(df["RMSD"] <= 2)
        n_bad_poses = n_poses - n_good_poses
        
        n_cmpds = len(set(df.Compound_ID))
        set_of_good_cmpds = set(df[df["RMSD"] <= 2].Compound_ID)
        n_good_cmpds = len(set_of_good_cmpds)
        n_bad_cmpds = n_cmpds - n_good_cmpds
        
        return n_poses, n_good_poses, n_bad_poses, n_cmpds, n_good_cmpds, n_bad_cmpds
    
    def calc_auc_from_fpr_tpr(self, fpr, tpr):
        return np.trapz(x=fpr, y=tpr)
    
    def get_score_range(self):
        self.score_range = np.linspace(self.df[self.score_name].min() - 1,
                                  self.df[self.score_name].max(),
                                  self.n_samples,
                                 endpoint=True)
    
    def weird_division(self, n, d):
        return n / d if d else 0
    
    def get_auc_from_df(self, df=None, bootstrap=False):
        if df is None:
            df = self.df
#             print("Using self.df")
        else:
#             print(f"using {df}")
            self.total_poses, self.total_good_poses, self.total_bad_poses, self.total_cmpds, self.total_good_cmpds, self.total_bad_cmpds = self.calc_data(df)
            
#         print(self.score_range)
        self.tpr_poses = [] ## same thing as recall
        self.fpr_poses = []
        self.precision_poses = []
        
        self.tpr_cmpds = [] ## same thing as recall
        self.fpr_cmpds = []
        self.precision_cmpds = []
            
        data = [self.calc_data(df[df[self.score_name] <= cutoff]) for cutoff in self.score_range]
        n_poses_list, n_good_poses_list, n_bad_poses_list, n_cmpds_list, n_good_cmpds_list, n_bad_cmpds_list = zip(*data)
        
        for idx in range(len(n_poses_list)):
            n_poses = n_poses_list[idx]
            n_good_poses = n_good_poses_list[idx]
            n_bad_poses = n_bad_poses_list[idx]
            n_cmpds = n_cmpds_list[idx]
            n_good_cmpds = n_good_cmpds_list[idx]
            n_bad_cmpds = n_bad_cmpds_list[idx]

            self.tpr_poses.append(n_good_poses / self.total_good_poses)
            self.fpr_poses.append(n_bad_poses / self.total_bad_poses)
            self.precision_poses.append(self.weird_division(n_good_poses, n_poses))
            
            self.tpr_cmpds.append(n_good_cmpds / self.total_good_cmpds)
#             self.fpr_cmpds.append(n_bad_cmpds / self.total_bad_cmpds) ## this doesn't really make sense mathematically
            self.precision_cmpds.append(self.weird_division(n_good_cmpds, n_cmpds))
            
#             print(n_good_cmpds, n_bad_cmpds, self.total_cmpds)
                
        
        self.auc_poses.append(self.calc_auc_from_fpr_tpr(self.fpr_poses, self.tpr_poses))
#         self.auc_cmpds.append(self.calc_auc_from_fpr_tpr(self.fpr_cmpds, self.tpr_cmpds))
    
    def get_bootstrapped_error_bars(self, n_bootstraps):
        
        _ = [self.get_auc_from_df(self.df.sample(frac=1, replace=True)) for n in range(n_bootstraps)]

        auc_poses_array = np.array(self.auc_poses)
#         auc_cmpds_array = np.array(self.auc_cmpds)
        
        auc_poses_array.sort()
#         auc_cmpds_array.sort()
        
        auc_poses_bounds = math.floor(len(auc_poses_array) * 0.025)
#         auc_cmpds_bounds = math.floor(len(auc_cmpds_array) * 0.025)
        self.poses_ci = (auc_poses_array.mean() - auc_poses_array[auc_poses_bounds], auc_poses_array[-auc_poses_bounds] - auc_poses_array.mean())
#         self.cmpds_ci = (auc_cmpds_array.mean() - auc_cmpds_array[auc_cmpds_bounds], auc_cmpds_array[-auc_cmpds_bounds] - auc_cmpds_array.mean())
    
    def get_df(self):
        self.auc_poses_df = pd.DataFrame({"True_Positive": self.tpr_poses,
                      "False_Positive": self.fpr_poses,
                      "Value": self.score_range,
                      "Score_Type": self.score_name,
                                          "Precision": self.precision_poses
                              })
        self.auc_cmpds_df = pd.DataFrame({"True_Positive": self.tpr_cmpds,
#                       "False_Positive": self.fpr_cmpds,
                      "Value": self.score_range,
                      "Score_Type": self.score_name,
                                          "Precision": self.precision_cmpds
                              })
        

class Rocks():
    def __init__(self, 
                 df, 
                 score_list, 
                 n_samples, 
                 n_bootstraps=None,
                ):
        self.df = df
        self.score_list = score_list
        self.n_samples = n_samples
        self.n_bootstraps = n_bootstraps
        self.rock_dict = {}
        
        self.build_rocks()

    def build_rocks(self):
        for score_name in self.score_list:
            assert score_name in self.df.columns
            self.rock_dict[score_name] = Rock(self.df, score_name, self.n_samples)
    
    def get_aucs(self):
        for score_name, rock in self.rock_dict.items():
            rock.get_auc_from_df()
            rock.get_df()
            self.rock_dict[score_name] = rock
            
    def combine_dfs(self):
        _ = [rock.get_df() for rock in self.rock_dict.values()]
        poses_dfs = [rock.auc_poses_df for rock in self.rock_dict.values()]
        cmpds_dfs = [rock.auc_cmpds_df for rock in self.rock_dict.values()]
        self.poses_df = pd.concat(poses_dfs)
        self.cmpds_df = pd.concat(cmpds_dfs)
            
        
    def get_auc_cis(self):
        
        lower_bound_list = []
        upper_bound_list = []
        auc_list = []
        for score_name, rock in self.rock_dict.items():
            print(score_name)
            rock.get_bootstrapped_error_bars(self.n_bootstraps)
            lower_bound_list.append(rock.poses_ci[0])
            upper_bound_list.append(rock.poses_ci[1])
            auc_list.append(rock.auc_poses[0])
        self.model_df = pd.DataFrame({
            "Score_Type": self.score_list,
            "Lower_Bound": lower_bound_list,
            "AUC": auc_list,
            "Upper_Bound": upper_bound_list
        })
        
    
    def plot_auc(self, df_type='poses'):
        if df_type=='poses':
            df = self.poses_df
        elif df_type=='cmpds':
            df = self.cmpds_df
        fig = ex.line(df, 
            x="False_Positive", 
            y="True_Positive", 
            color="Score_Type",
            hover_data=["Value"],
                     )
        fig.update_layout(height=600, width=600, title=df_type)
        fig.update_yaxes(
            scaleanchor="x",
            scaleratio=1,
        )
        
        
        return fig

In [None]:
rocks = Rocks(test_filtered_df,
             ["POSIT_R", "Chemgauss4", "MCSS_Rank"],
             n_samples=100,
             n_bootstraps=100)

In [None]:
rocks.get_aucs()

In [None]:
rocks.combine_dfs()

In [None]:
rocks.plot_auc()

In [None]:
rocks.get_auc_cis()

In [None]:
rocks.model_df

In [None]:
def get_auc_from_df(df, score_name, n_samples):
    score_range = np.linspace(df[score_name].min() - 1,
                              df[score_name].max(),
                              n_samples,
                             endpoint=True)
    total_good_poses = sum(df["RMSD"] <= 2)
    total_bad_poses = sum(df["RMSD"] > 2)
    total_n_cmpds = len(set(df.Compound_ID))
    total_good_cmpds = len(set(df[df["RMSD"] <= 2].Compound_ID))
    
    tps = []
    fps = []
    perc_cmpds = []
    perc_good_cmpds = []
    for cutoff in score_range:
        new_df = df[df[score_name] <= cutoff]
        good_poses = new_df[new_df["RMSD"] <= 2]
        
        n_good_poses = len(good_poses)
        n_good_cmpds = len(set(good_poses.Compound_ID))
        n_bad_poses = sum(new_df["RMSD"] > 2)
        n_cmpds = len(set(new_df.Compound_ID))
        
        tps.append(n_good_poses / total_good_poses)
        fps.append(n_bad_poses / total_bad_poses)
        perc_cmpds.append(n_cmpds / total_n_cmpds)
        perc_good_cmpds.append(n_good_cmpds / total_good_cmpds)
    
    
    auc = np.trapz(x=fps, y=tps)
    
    return tps, fps, auc, perc_cmpds, perc_good_cmpds, score_range

In [None]:
def get_auc_bootstrap_error_bars(df, score_name, n_samples, n_bootstraps=1):
    auc_list = []
    for n in range(n_bootstraps):
        bootstrap_df = df.sample(frac=1, replace=True)
        tps, fps, auc, perc_cmpds, perc_good_cmpds, score_range = get_auc_from_df(bootstrap_df, score_name, n_samples)
        auc_list.append(auc)
    auc_array = np.array(auc_list)
    auc_array.sort()
#     print(auc_array)
    bounds = math.floor(len(auc_array) * 0.025)
    ci = (auc_array.mean() - auc_array[bounds], auc_array[-bounds] - auc_array.mean())
    return ci

In [None]:
get_auc_bootstrap_error_bars(df=test_filtered_df, 
                             score_name="Chemgauss4", 
                             n_samples=10, 
                             n_bootstraps=10)

In [None]:
def get_all_aucs(df, score_list, n_samples, n_bootstraps):
    auc_data_df_list = []
    
    auc_list = []
    lower_bound_list = []
    upper_bound_list = []
    score_name_list = []
    
    for score_name in score_list:
        print(score_name)
        tps, fps, auc, perc_cmpds, perc_good_cmpds, pscore_range = get_auc_from_df(df, score_name, n_samples)
        auc_df = pd.DataFrame({"True_Positive": tps,
                      "False_Positive": fps,
                      "Value": score_range,
                      "Score_Type": score_name,
                              "Perc_CMPDs": perc_cmpds,
                               "Perc_Good_CMPDs": perc_good_cmpds
                              })
        auc_data_df_list.append(auc_df)
        
        lower_bound, upper_bound = get_auc_bootstrap_error_bars(df, score_name, n_samples, n_bootstraps)
        
        auc_list.append(auc)
        lower_bound_list.append(lower_bound)
        upper_bound_list.append(upper_bound)
        score_name_list.append(score_name)
    
    ## Add random
    random_auc_df = pd.DataFrame({"True_Positive": np.linspace(0,1, len(tps)),
                      "False_Positive": np.linspace(0,1, len(tps)),
                      "Value": np.linspace(0,1, len(tps)),
                      "Score_Type": "RANDOM"})
    
    model_df = pd.DataFrame({
        "Score_Type": score_name_list,
        "Lower_Bound": lower_bound_list,
        "AUC": auc_list,
        "Upper_Bound": upper_bound_list
    })
        
        
    return pd.concat(auc_data_df_list + [random_auc_df]), model_df

## calculate all AUCs for test data

In [None]:
auc_data_df, model_df = get_all_aucs(test_filtered_df, 
                                     ["Chemgauss4", 
                                      "POSIT_R", 
                                      "MCSS_Rank"], 
                                     n_samples=100, 
                                     n_bootstraps=100)

## which score function does the best job at predicting good poses (RMSD <=2)?

In [None]:
fig = ex.line(auc_data_df, 
        x="False_Positive", 
        y="True_Positive", 
        color="Score_Type",
       hover_data=["Value"],
       )
fig.update_layout(height=600, width=600)
fig.update_yaxes(
    scaleanchor="x",
    scaleratio=1,
)

### Conclusion: POSIT

### i don't think this graph is very helpful:

In [None]:
fig = ex.line(auc_data_df, 
        x="False_Positive", 
        y="Perc_CMPDs", 
        color="Score_Type",
       hover_data=["Value"],
       )
fig.update_layout(height=600, width=600)
fig.update_yaxes(
    scaleanchor="x",
    scaleratio=1,
)

## Which score function does the best at returning at least 1 good pose per compound?

In [None]:
fig = ex.line(auc_data_df, 
        x="False_Positive", 
        y="Perc_Good_CMPDs", 
        color="Score_Type",
       hover_data=["Value"],
       )
fig.update_layout(height=600, width=600)
fig.update_yaxes(
    scaleanchor="x",
    scaleratio=1,
)

### Conclusion: still POSIT

## How many are we actually returning?

### POSIT > 0.25

In [None]:
good_POSIT_R = test_df[test_df.POSIT_R < 0.75]

In [None]:
sum(good_POSIT_R.RMSD <=2)

In [None]:
sum(good_POSIT_R.RMSD > 2)

In [None]:
len(good_POSIT_R)

In [None]:
sum(test_df.RMSD <= 2)

In [None]:
sum(good_POSIT_R.RMSD <=2) / sum(test_df.RMSD <= 2)

In [None]:
len(set(good_POSIT_R.Compound_ID)) / len(set(test_df.Compound_ID))

### Conclusion: At a POSIT score larger than 0.25, we return 97.8% of good RMSD

## Plot Precision vs Recall

In [None]:
auc_data_df

In [None]:
model_df

In [None]:
ex.bar(model_df, x="Score_Type", y="AUC", error_y="Upper_Bound", error_y_minus="Lower_Bound")

## calculate all AUCs for redocked data

In [None]:
auc_data_df, model_df = get_all_aucs(redocked_filtered_df, ["Chemgauss4", "POSIT_R"], 100, 100)

In [None]:
fig = ex.line(auc_data_df, 
        x="False_Positive", 
        y="True_Positive", 
        color="Score_Type",
       hover_data=["Value"],
       )
fig.update_layout(height=600, width=600)
fig.update_yaxes(
    scaleanchor="x",
    scaleratio=1,
)

In [None]:
model_df

In [None]:
ex.bar(model_df, 
       x="Score_Type", 
       y="AUC", 
       error_y="Upper_Bound", 
       error_y_minus="Lower_Bound")

# Testing analyzing mers data using the module code

In [None]:
import os, sys
from importlib import reload
sys.path.append('../')

In [None]:
from asapdiscovery.docking import plotting as pl

In [None]:
reload(pl)

In [None]:
fn = "/Users/alexpayne/lilac-mount-point/mers_hallucination/all_results.csv"

In [None]:
rocks = pl.Rocks(fn, ["POSIT_prob", "chemgauss4_score"], "docked_RMSD", 10)

In [None]:
import pandas as pd
df = pd.read_csv(fn)

In [None]:
df.docked_file[0]

In [None]:
df

In [None]:
df[df.docked_RMSD == -1]

# Try on sars data instead

In [None]:
import os, sys
from importlib import reload
sys.path.append('../')
from asapdiscovery.docking import plotting as pl
reload(pl)

In [None]:
reload(pl)

In [None]:
fn = "../scripts/20220818-sars-docking_v2.csv"

In [None]:
rocks = pl.Rocks(fn,
        ["POSIT_R", "Chemgauss4", "MCSS_Rank"],
        "RMSD",
        100,
        100)

In [None]:
rocks.clean_dataframe()

In [None]:
rocks.df

# initial plotting

In [None]:
ex.histogram(rocks.df, x="RMSD")

In [None]:
ex.scatter(rocks.df, 
           x="MCSS_Rank", 
           y="RMSD", 
           color="Compound_ID", 
           hover_data=["Crystal_ID", 
                       "Reference", 
                       "Chain_ID"]
          )

In [None]:
ex.scatter(rocks.df, 
           x="RMSD", 
           y="Chemgauss4", 
           color="POSIT", 
           hover_data=["Crystal_ID", 
                       "Reference", 
                       "Chain_ID"], 
          )

In [None]:
ex.scatter(rocks.df, 
           x="POSIT_R", 
           y="RMSD", 
           color="Chemgauss4", 
           hover_data=["Compound_ID",
                       "Crystal_ID", 
                       "Reference", 
                       "Chain_ID"], 
          )

In [None]:
rocks.get_aucs()

In [None]:
rocks.combine_dfs()

In [None]:
rocks.plot_poses_auc()

In [None]:
rocks.poses_df

In [None]:
rocks.combine_dfs()

In [None]:
rocks.get_auc_cis()

In [None]:
rocks.plot_precision_recall()

In [None]:
rocks.model_df

In [None]:
ex.bar(rocks.model_df, 
       x="Score_Type", 
       y="AUC", 
       error_y="Upper_Bound", 
       error_y_minus="Lower_Bound")

In [None]:
fig = rocks.plot_poses_auc()

In [None]:
fig.add_shape(x0=0, x1=1, y0=0, y1=1)