# Pipeline:

In [None]:
import json
import torch
import matplotlib.pyplot as plt


In [None]:
from scripts.reader import read_data_from_url, extract_relevant_data, shuffle_lists_together, split_graphs_dataset
from scripts.plotter import plot_distributions, plot_training_curves
from scripts.dataset import prepare_graph_dataset, load_into_dataloader
from scripts.predictor import compute_metrics_and_matrix_classification
from scripts.trainer import train_model, build_model


### Read the data:

In [None]:
data_frame = read_data_from_url()
data_frame


### Extract the data & shuffle the lists:

In [None]:
names_data, smile_data, pIC50_data = extract_relevant_data(data_frame=data_frame)
names_data, smile_data, pIC50_data = shuffle_lists_together(names_data, smile_data, pIC50_data)


### Split the datasets into train, valid & infer:

In [None]:
split_dataset = split_graphs_dataset(
    total_smiles_data=smile_data, 
    total_pIC50_value=pIC50_data, 
    )

train_smile_data, train_pIC50_data = split_dataset["train"]
valid_smile_data, valid_pIC50_data = split_dataset["valid"]
infer_smile_data, infer_pIC50_data = split_dataset["infer"]


### Visualise:

In [None]:
hist_fig = plot_distributions([train_pIC50_data, valid_pIC50_data, infer_pIC50_data], figsize=(10, 3))
plt.show()
plt.close(hist_fig)

In [None]:
### Normalise (and later de-normalise) the y-label value:

# y_min, y_max = np.min(train_pIC50_data), np.max(train_pIC50_data)
# y_min, y_max

## Create pytorch geometric dataset:

In [None]:
config = json.load(open("./results/config.json"))
config

In [None]:
train_dataset = prepare_graph_dataset(train_smile_data, train_pIC50_data, problem_type=config["problem_type"])
valid_dataset = prepare_graph_dataset(valid_smile_data, valid_pIC50_data, problem_type=config["problem_type"])
infer_dataset = prepare_graph_dataset(infer_smile_data, infer_pIC50_data, problem_type=config["problem_type"])
len(train_dataset), len(valid_dataset), len(infer_dataset)


In [None]:
train_dataset[0], train_dataset[0].x.dtype, train_dataset[0].y.dtype

### Ensure label distributions are cca. i.i.d.:

In [None]:
hist_fig = plot_distributions([
    [g.y.item() for g in train_dataset], 
    [g.y.item() for g in valid_dataset], 
    [g.y.item() for g in infer_dataset], 
], figsize=(10, 3))
plt.show()
plt.close(hist_fig)

### Wrap into the DataLoader:

In [None]:
train_loader = load_into_dataloader(train_dataset, batch_size=config["batch_size"], shuffle=True)
valid_loader = load_into_dataloader(valid_dataset, batch_size=config["batch_size"], shuffle=False)
infer_loader = load_into_dataloader(infer_dataset, batch_size=config["batch_size"], shuffle=False)


## Build the GNN model & train:

### Define loss criterion for model training:

In [None]:
weight = torch.Tensor([1.0, config["pos_class_weight"]])  # Adjust the value as needed
loss_fn = torch.nn.CrossEntropyLoss(weight=weight)
weight, loss_fn

In [None]:
gat_model, gat_optimizer, gat_scheduler = build_model(
    num_features=config["num_features"], 
    embedding_size=config["embedding_size"], 
    num_heads=config["num_attn_heads"], 
    dropout_prob=config["dropout_prob"],
    use_batch_norm=config["use_batch_norm"],
    learning_rate=config["learning_rate"],
    weight_decay=config["weight_decay"],
    scheduler_gamma=config["scheduler_gamma"],
)

In [None]:
# Train the model:

results = train_model(
    gat_model, 
    loss_fn,
    gat_optimizer, 
    gat_scheduler,    
    train_loader = train_loader, 
    valid_loader = valid_loader, 
    infer_loader = infer_loader, 
    num_epochs = config["num_epochs"], 
    logger_freq = config["logging_frequency"], 
)


## Evaluate performance:

In [None]:
pre_trained_model, train_losses, train_metric, valid_losses, valid_metric, infer_loss, infer_metr = results
pre_trained_model.eval()


In [None]:
loss_fig = plot_training_curves(train_losses, valid_losses, infer_loss, figsize=(5, 3))
plt.savefig("./results/training_loss.png")
plt.show()
plt.close(hist_fig)

In [None]:
loss_fig = plot_training_curves(train_metric, valid_metric, infer_metr, figsize=(5, 3))
plt.ylabel("Metric [accuracy]")
plt.savefig("./results/training_metric.png")
plt.show()
plt.close(hist_fig)

### Evaluate metrics:

In [None]:
# def compute_metrics_and_matrix_regression(data_loader, pre_trained_model, cutoff_limit: float = 8.0):
#     y_pred_results = []
#     y_true_results = []

#     for batch in data_loader:
    
#         # Passing the node features and the connection info
#         predictions, _ = pre_trained_model(batch.x, batch.edge_index, batch.batch)
#         y_pred = predictions.squeeze()
#         y_true = batch.y.squeeze()
        
#         # Process into integer predictions:
#         y_pred = [value > cutoff_limit for value in y_pred]
#         y_true = [value > cutoff_limit for value in y_true]
        
#         # Store the result:
#         y_pred_results.extend(y_pred)
#         y_true_results.extend(y_true)

#     y_pred_results = np.stack(y_pred_results)
#     y_true_results = np.stack(y_true_results)

#     accuracy = accuracy_score(y_true=y_true_results, y_pred=y_pred_results)
#     precision, recall, f1score, support = precision_recall_fscore_support(y_true=y_true_results, y_pred=y_pred_results, pos_label=1)
#     metrics = {
#         "accuracy" : accuracy, 
#         "precision" : precision, 
#         "recall" : recall, 
#         "f1score" : f1score,
# }
#     conf_matrix = ConfusionMatrixDisplay.from_predictions(
#         y_true=y_true_results, 
#         y_pred=y_pred_results,
#         normalize="true",
#         cmap="copper"
#     )
    
#     return metrics, conf_matrix


In [None]:
metrics, conf_matrix, auroc_curve, avg_pred_curve = compute_metrics_and_matrix_classification(
    data_loader=infer_loader, pre_trained_model=pre_trained_model,
)

print (f"Inference metrics: {metrics}")

conf_matrix.plot(cmap="copper")
plt.title("Confusion Matrix")
plt.savefig("./results/infer_confusion_matrix.png")
plt.close()

auroc_curve.plot()
plt.title("Area Under ROC Curve")
plt.savefig("./results/infer_area_under_curve.png")
plt.close()

avg_pred_curve.plot()
plt.title("Precision Recall Curve")
plt.savefig("./results/infer_precision_recall_curve.png")
plt.close()
    

### Save the model & resulting metrics as appropriate:

In [None]:
torch.save(pre_trained_model, "./results/classifier.pt")

with open("./results/metrics.json", "w") as outfile:
    json.dump(metrics, outfile, indent=4)


### Done!