# Magics

In [None]:
%load_ext autoreload

%autoreload 2

# Imports

In [None]:
import os
import gc
import configparser
import pathlib as p
import numpy as np
import seaborn as sns

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt 

In [None]:
import pandas as pd
pd.options.display.max_columns = 999

In [None]:
import torch

In [None]:
import clipper_python as clipper

In [None]:
import torch.nn as nn
import torch.optim as optim

In [None]:
from frag_nn.pytorch.network import GNINA_regressor, GNINA_regressor_v2, GNINA_regressor_v3, GNINA_regressor_v4, GNINA_regressor_v5, GNINA_regressor_v6, GNINA_regressor_v7, GNINA_regressor_v8
# from frag_nn.data import XChemData
from frag_nn.pytorch.network import ClassifierV3, ClassifierV4, ClassifierV5
from frag_nn.pytorch.dataset import EventDataset
from frag_nn.pytorch.dataset import OrthogonalGrid
from frag_nn.pytorch.dataset import GetRandomisedLocation, GetRandomisedRotation, SetRoot
from frag_nn.pytorch.dataset import GetAnnotationClassifier, GetDataRefMoveZ

from frag_nn.pytorch.dataset import XChemDataset
import frag_nn.constants as c


# Get Config

In [None]:
config_path = "/home/zoh22914/pandda_nn_2/frag_nn/params.ini"

In [None]:
conf = configparser.ConfigParser()

In [None]:
conf.read(config_path)

In [None]:
ds_conf = conf[c.x_chem_database]

In [None]:
grid_size = 48
grid_step = 0.5
filters = 64


In [None]:
network_type = "classifier"
network_version = 5
dataset_version = 3
train = "gpu"
transforms = "rottrans"

In [None]:
state_dict_dir = "/home/zoh22914/pandda_nn_2/"
state_dict_file = state_dict_dir + "model_params_{}_{}_{}_{}_{}_{}_{}_{}.pt".format(grid_size,
                                                                                  grid_step,
                                                                                  network_type,
                                                                                  network_version,
                                                                                  dataset_version,
                                                                                  train,
                                                                                  transforms,
                                                                                     filters)

# Get accessible events

In [None]:
events_test = pd.read_csv("new_events_test_no_cheat.csv")

In [None]:
events_test

# Create Dataset

In [None]:
grid = OrthogonalGrid(grid_size, 
                     grid_step)

In [None]:
test_dataset = EventDataset(events=events_test,
                             transforms_record=[GetRandomisedLocation(base_trans_max=4.0, secondary_trans_max=0.0),
                                                    GetRandomisedRotation(max_rot=0.0),
                                                    SetRoot("/data/data")
                                               ],
                             get_annotation=GetAnnotationClassifier(),
                             get_data=GetDataRefMoveZ(grid)
                            )

# Create Dataloaders

In [None]:
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                         batch_size=1, 
                                         shuffle=True,
                                         num_workers=48)

# Define Model

In [None]:
model = ClassifierV5(filters,
                        grid_dimension=grid_size)

In [None]:
model.load_state_dict(torch.load(state_dict_file))

In [None]:
model_c = model.to("cuda:1")

# Precision - Recall functions

In [None]:
def get_precision(y_hat, y, cutoff):
    
    positives_hat_mask = (y_hat > cutoff)
    negatives_hat_mask = (y_hat <= cutoff)
    
    positives_mask = (y == 1)
    negatives_mask = (y == 0)

    true_positives = np.count_nonzero(positives_hat_mask[positives_mask])
    false_positives = np.count_nonzero(positives_hat_mask[negatives_mask])
    
    total_predicted_positives = true_positives + false_positives
    
    if total_predicted_positives == 0:
        return 1
    
    precision = true_positives / total_predicted_positives
    
    return precision


In [None]:
def get_recall(y_hat, y, cutoff):
    positives_hat_mask = (y_hat > cutoff)
    negatives_hat_mask = (y_hat <= cutoff)

    positives_mask = (y == 1)
    negatives_mask = (y == 0)

    true_positives = np.count_nonzero(positives_hat_mask[positives_mask])
    false_negatives = np.count_nonzero(negatives_hat_mask[positives_mask])    

    total_positives = (true_positives + false_negatives)
    
    if total_positives == 0:
        return 0
    
    recall = true_positives / total_positives

    return recall

# Evaluate - Test Data

In [None]:
y_test = []
y_test_hat = []

In [None]:
for i, data in enumerate(test_dataloader):
    # get the inputs; data is a list of [inputs, labels]
    print("Iteration: {}".format(i))
    x = data["data"]
    y = data["annotation"]
#     x = x.unsqueeze(1)
    y = y.view(-1,2)
    
    x_c = x.to("cuda:1")
    y_c = y.to("cuda:1")
    
    outputs_c = model_c(x_c)
    
    outputs = outputs_c.detach().to("cpu")
    print(outputs)
    y_test.append(y.detach())
    y_test_hat.append(outputs)
#     optimizer.zero_grad()


In [None]:
y_test

In [None]:
y_test_hat

In [None]:
# float(outputs[0])

In [None]:
len(y_test)

In [None]:
len(y_test_hat)

In [None]:
y = torch.cat(y_test)[:,1]
y_hat = torch.cat(y_test_hat)[:,1]

# Define Rankings

In [None]:
def get_ranking(df):
    rankings = []
    for i in range(len(df)):
        truncated_df = df.iloc[:i] 
        recall_high = len(truncated_df[truncated_df["Ligand Confidence"] == "High"])
        recall_med = len(truncated_df[truncated_df["Ligand Confidence"] == "Medium"])
        recall= recall_high + recall_med
        record = {"length": i,
                 "recall": recall}
        rankings.append(record)
        
    return pd.DataFrame(rankings)

# Append to table

In [None]:
events_test["nn_score"] = y_hat
events_test

# Pull size sorted

In [None]:
size_df = events_test[["Ligand Confidence", "cluster_size"]]

In [None]:
size_df

In [None]:
sorted_size_df = size_df.sort_values("cluster_size", ascending=False)

In [None]:
sorted_size_df

In [None]:
size_rankings_df = get_ranking(sorted_size_df)

In [None]:
size_rankings_df

# Pull NN sorted

In [None]:
nn_score_df = events_test[["Ligand Confidence", "nn_score"]]

In [None]:
nn_score_df

In [None]:
sorted_nn_score_df = nn_score_df.sort_values("nn_score", ascending=False)

In [None]:
sorted_nn_score_df

In [None]:
nn_score_rankings_df = get_ranking(sorted_nn_score_df)

In [None]:
nn_score_rankings_df

In [None]:
perfect_df = events_test[["Ligand Confidence"]]

In [None]:
perfect_df["score"] = events_test["Ligand Confidence"]
perfect_df["score"][perfect_df["score"] == "High"] = 0
perfect_df["score"][perfect_df["score"] == "Medium"] = 1
perfect_df["score"][perfect_df["score"] == "Low"] = 2


In [None]:
perfect_df

In [None]:
sorted_perfect_df = perfect_df.sort_values("score")

In [None]:
perfect_rankings_df = get_ranking(sorted_perfect_df)

In [None]:
# perfect_rankings_df["recall"] = list(range(len(perfect_rankings_df)))

In [None]:
perfect_rankings_df

In [None]:
random_df = events_test[["Ligand Confidence"]]

In [None]:
sorted_random_df = random_df.sample(len(random_df))

In [None]:
random_rankings_df = get_ranking(sorted_random_df)

In [None]:
# recall_length_df = pd.DataFrame({"length": nn_score_rankings_df["length"],
#                                 "recall_nn_score": nn_score_rankings_df["recall"],
#                                 "recall_size_score": size_rankings_df["recall"]})
recall_length_df = pd.concat([nn_score_rankings_df[["recall"]], 
                              size_rankings_df[["recall"]],
                             random_rankings_df[["recall"]],
                             perfect_rankings_df[["recall"]]], 
                             keys=["nn_score", "cluster_size", "random", "perfect"], 
                             names=["score", "len"])

In [None]:
recall_length_df

In [None]:
recall_length_df = recall_length_df.reset_index()

In [None]:
recall_length_df

In [None]:
# recall_length_df.stack()

In [None]:
# recall_length_df.stack().reset_index()

# Plot

In [None]:
sns.lineplot(x="length",
               y="recall",
               data=size_rankings_df)

In [None]:
sns.lineplot(x="length",
               y="recall",
               data=nn_score_rankings_df)

In [None]:
fix, ax = plt.subplots(figsize=(20,20))

sns.lineplot(x="len",
               y="recall",
               data=recall_length_df,
             hue="score",
            ax=ax)

In [None]:
y.shape

In [None]:
y_hat.shape

In [None]:
# Get recall and precission for different cutoffs
points = []
for cutoff in np.linspace(0, 1, 100):
    precision = get_precision(y_hat, y, cutoff)
    recall = get_recall(y_hat, y, cutoff)
    points.append({"cutoff": cutoff,
                   "precision": precision, 
                   "recall":recall,
                  "num_predicted_positives": len(y_hat[y_hat > cutoff]),
                  "num_true_positives": len(y[y_hat > cutoff][y[y_hat > cutoff] == 1])})

In [None]:
stats = pd.DataFrame(points).set_index("cutoff")

In [None]:
stats

In [None]:
sns.lineplot(x="recall",
               y="precision",
               data=stats)

In [None]:
sns.scatterplot(x="recall",
               y="precision",
               data=stats,
            estimator=None)

In [None]:
stats.iloc[0].num_true_positives / stats.iloc[0].num_predicted_positives

In [None]:
base_precision = len(events_test[events_test["Ligand Confidence"] == "High"]) / len(events_train)
base_precision

In [None]:
raise Exception

# Evaluate - Train Data

In [None]:
y_test = []
y_test_hat = []

In [None]:
for i, data in enumerate(train_dataloader):
    # get the inputs; data is a list of [inputs, labels]
    print("Iteration: {}".format(i))
    x = data["x"]
    y = data["y"]
#     x = x.unsqueeze(1)
    y = y.view(-1,1)
    
    outputs = model(x)
    
    y_test.append(y.detach())
    y_test_hat.append(outputs.detach())
#     optimizer.zero_grad()
    gc.collect()

In [None]:
y = torch.cat(y_test)
y_hat = torch.cat(y_test_hat)

In [None]:
# Get recall and precission for different cutoffs
points = []
for cutoff in np.linspace(0, 1, 50):
    precision = get_precision(y_hat, y, cutoff)
    recall = get_recall(y_hat, y, cutoff)
    points.append({"cutoff": cutoff,
                   "precision": precision, 
                   "recall":recall})

In [None]:
stats = pd.DataFrame(points).set_index("cutoff")

In [None]:
stats

In [None]:
sns.lineplot(x="recall",
               y="precision",
               data=stats,
            estimator=None)

In [None]:
sns.scatterplot(x="recall",
               y="precision",
               data=stats,
            estimator=None)

In [None]:
base_precision = len(events_train[events_train["ligand_confidence_inspect"] == "High"]) / len(events_train)
base_precision

In [None]:
while True:
    time.sleep(60)

In [None]:
torch.save(model.state_dict(), "model_params.pt")

# Inspect Model

In [None]:
trace.history

In [None]:
iterator = dataset.make_one_shot_iterator()

In [None]:
x, y = iterator.next()

In [None]:
x

In [None]:
model.net(x)

In [None]:
print(y)

In [None]:
model.save_weights("model_32.h5")

In [None]:
from torch.ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator

In [None]:
trainer = create_supervised_trainer(model, optimizer, loss)

In [None]:
@trainer.on(Events.EPOCH_COMPLETED)
def print