# Run inference on pre-trained model:

In [None]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

import torch
import numpy.typing as npt
from torch_geometric.data import Data
from tqdm.auto import tqdm


In [None]:
from grace.base import GraphAttrs
from grace.io.image_dataset import ImageGraphDataset
from grace.models.feature_extractor import FeatureExtractor
from grace.models.datasets import dataset_from_graph
from grace.evaluation.visualisation import plot_simple_graph
from grace.evaluation.inference import GraphLabelPredictor
from grace.evaluation.utils import plot_confusion_matrix_tiles

from grace.evaluation.metrics_classifier import (
    accuracy_metric, 
    confusion_matrix_metric, 
    areas_under_curves_metrics,
)
# from sklearn.metrics import ConfusionMatrixDisplay

### Read some real grace-annotated data:

In [None]:
extractor_filename = "/Users/kulicna/Desktop/classifier/extractor/resnet152.pt"
pre_trained_resnet = torch.load(extractor_filename)
feature_extractor = FeatureExtractor(model=pre_trained_resnet)


In [None]:
# grace_path = "/Users/kulicna/Desktop/dataset/shape_stars/train"
grace_path = "/Users/kulicna/Desktop/dataset/shape_stars/infer"
dataset = ImageGraphDataset(
    image_dir=grace_path, 
    grace_dir=grace_path, 
    transform=feature_extractor,
    keep_node_unknown_labels=False, 
    keep_edge_unknown_labels=False, 
    
)

In [None]:
# image, graph_data = dataset[1]
image, graph_data = dataset[0]

G = graph_data["graph"]
image = image.numpy()
annot = graph_data["annotation"]
G.number_of_nodes(), G.number_of_edges(), annot.shape, image.shape 

### Visualise the data:

In [None]:
shape = 5
_, axes = plt.subplots(nrows=1, ncols=3, figsize=(shape*3, shape*1))

plot_simple_graph(G, title=f"Graph with {G.number_of_nodes()} nodes & {G.number_of_edges()} edges", ax=axes[0])
axes[0].imshow(image, cmap="binary_r")
axes[1].imshow(annot, cmap="binary_r")
axes[2].imshow(image, cmap="binary_r")

plt.tight_layout()
plt.show()

### Nominate a pre-trained GCN model:

In [None]:
# classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-07_17-07-08/classifier.pt"
# classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-07_17-15-47/classifier.pt"
classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-07_17-30-51/classifier.pt"  # best Linear classifier
# classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-08_15-11-58/classifier.pt"  # bad GCN + Linear classifier

pre_trained_gcn = torch.load(classifier_filename)
pre_trained_gcn.eval()


### Features are now automatically appended to the image - predict:

In [None]:
GraphLabelPredictor(pre_trained_gcn).set_node_and_edge_probabilities(G)

In [None]:
G.nodes[0]

### Now compute metrics:

In [None]:
node_true = [node[GraphAttrs.NODE_GROUND_TRUTH] for _, node in G.nodes(data=True)]
node_pred = [node[GraphAttrs.NODE_PREDICTION][0] for _, node in G.nodes(data=True)]
node_probabs = np.array([node[GraphAttrs.NODE_PREDICTION][1] for _, node in G.nodes(data=True)])

edge_true = [edge[GraphAttrs.EDGE_GROUND_TRUTH] for _, _, edge in G.edges(data=True)]
edge_pred = [edge[GraphAttrs.EDGE_PREDICTION][0] for _, _, edge in G.edges(data=True)]
edge_probabs = np.array([edge[GraphAttrs.EDGE_PREDICTION][1] for _, _, edge in G.edges(data=True)])


In [None]:
node_acc, edge_acc = accuracy_metric(node_pred, edge_pred, node_true, edge_true)
node_acc, edge_acc

In [None]:
areas_fig = areas_under_curves_metrics(node_probabs, edge_probabs, node_true, edge_true, figsize=(10, 4))

In [None]:


# figsize = (12, 12)
# colormap = "copper"
# confusion_matrix_plotting_data = [
#     [node_pred, node_true, "nodes"], 
#     [edge_pred, edge_true, "edges"],
# ]

# _, axs = plt.subplots(2, 2, figsize=figsize)

# for d, matrix_data in enumerate(confusion_matrix_plotting_data):
#     for n, nrm in enumerate([None, "true"]):
#         ConfusionMatrixDisplay.from_predictions(
#             y_pred=matrix_data[0],
#             y_true=matrix_data[1],
#             normalize=nrm,
#             ax=axs[d, n],
#             cmap=colormap,
#             display_labels=["TN", "TP"],
#             text_kw={"fontsize": "large"},
#         )

#         flag = "Raw Counts" if nrm is None else "Normalised"
#         text = f"{matrix_data[2].capitalize()} | {flag} Values"
#         axs[d, n].set_title(text)

# plt.show()


In [None]:
data_whole_graph = dataset_from_graph(G, mode="whole")
data_whole_graph

### Chop off the last Linear layers:

In [None]:
def drop_linear_layers_from_model(model: torch.nn.Module) -> torch.nn.Sequential:
    """Chops off last 2 Linear layers from the classifier to 
    access node embeddings learnt by the GCN classifier."""

    modules = list(pre_trained_gcn.children())[:-2]
    node_emb_extractor = torch.nn.Sequential(*modules)
    for p in node_emb_extractor.parameters():
        p.requires_grad = False

    return node_emb_extractor


In [None]:
node_emb_extractor = drop_linear_layers_from_model(model=pre_trained_gcn)
node_emb_extractor

### Get the predictions:

In [None]:
def get_predictions_for_data_batches(
    model: torch.nn.Module, 
    data_batches: list[Data],
) -> tuple[torch.Tensor]:

    node_softmax_preds = []
    edge_softmax_preds = []
    node_argmax_preds = []
    edge_argmax_preds = []
    node_labels = []
    edge_labels = []

    # Predict labels from sub-graph:
    for data in tqdm(data_batches, desc="Predicting for the entire graph: "):

        # Get the ground truth labels:
        node_labels.extend(data.y)
        edge_labels.extend(data.edge_label)

        # Get the model predictions:
        node_x, edge_x = model.predict(x=data.x, edge_index=data.edge_index)
        print (node_x.shape, edge_x.shape)

        # Process node probs into classes predictions:
        node_soft = node_x.softmax(dim=1)
        node_softmax_preds.extend(node_soft)
        node_arg = node_soft.argmax(dim=1).long()
        node_argmax_preds.extend(node_arg)

        # Process edge probs into classes predictions:
        edge_soft = edge_x.softmax(dim=1)
        edge_softmax_preds.extend(edge_soft)
        edge_arg = edge_soft.argmax(dim=1).long()
        edge_argmax_preds.extend(edge_arg)

    # Stack the results:
    node_softmax_preds = torch.stack(node_softmax_preds, axis=0)
    edge_softmax_preds = torch.stack(edge_softmax_preds, axis=0)
    node_argmax_preds = torch.stack(node_argmax_preds, axis=0)
    edge_argmax_preds = torch.stack(edge_argmax_preds, axis=0)
    node_labels = torch.stack(node_labels, axis=0)
    edge_labels = torch.stack(edge_labels, axis=0)

    print(node_softmax_preds.shape, node_argmax_preds.shape, node_labels.shape)
    return node_softmax_preds, edge_softmax_preds, node_argmax_preds, edge_argmax_preds, node_labels, edge_labels


In [None]:
predicted_results = get_predictions_for_data_batches(model=pre_trained_gcn, data_batches=data_whole_graph)
node_probabs, edge_probabs, node_pred, edge_pred, node_true, edge_true = predicted_results


In [None]:
plt.plot(node_true)
plt.plot(node_probabs)
plt.show()

In [None]:
plt.plot(edge_true)
plt.plot(edge_probabs)
plt.show()

In [None]:
# # Investigate

# plt.scatter(x=node_pred, y=node_probabs[:, 0], color='firebrick', label='TN')
# plt.scatter(x=node_pred, y=node_probabs[:, 1], color='limegreen', label='TP')
# plt.title("Nodes")
# plt.legend()
# plt.show()

# plt.scatter(x=range(node_probabs.shape[0]), y=node_probabs[:, 0], color='firebrick', label='TN')
# plt.scatter(x=range(node_probabs.shape[0]), y=node_probabs[:, 1], color='limegreen', label='TP')
# plt.title("Nodes")
# plt.legend()
# plt.show()


# plt.scatter(x=edge_pred, y=edge_probabs[:, 0], color='firebrick', label='TN')
# plt.scatter(x=edge_pred, y=edge_probabs[:, 1], color='limegreen', label='TP')
# plt.title("Edges")
# plt.legend()
# plt.show()

# plt.scatter(x=range(edge_probabs.shape[0]), y=edge_probabs[:, 0], color='firebrick', label='TN')
# plt.scatter(x=range(edge_probabs.shape[0]), y=edge_probabs[:, 1], color='limegreen', label='TP')
# plt.title("Edges")
# plt.legend()
# plt.show()


## Evaluation:
### Simple metrics first:

In [None]:
from grace.evaluation.metrics_classifier import (
    accuracy_metric, 
    confusion_matrix_metric, 
    areas_under_curves_metrics,
)
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
node_acc, edge_acc = accuracy_metric(node_pred, edge_pred, node_true, edge_true)
node_acc, edge_acc

In [None]:
n_cm, e_cm = confusion_matrix_metric(node_pred, edge_pred, node_true, edge_true, normalize="true")

In [None]:
n_cm, e_cm = confusion_matrix_metric(node_pred, edge_pred, node_true, edge_true, normalize=None)

In [None]:
areas_fig = areas_under_curves_metrics(node_pred, edge_pred, node_true, edge_true)

### Possibly, display all 4 confusion matrices:

In [None]:
figsize = (12, 12)
colormap = "copper"
confusion_matrix_plotting_data = [
    [node_pred, node_true, "nodes"], 
    [edge_pred, edge_true, "edges"],
]

_, axs = plt.subplots(2, 2, figsize=figsize)

for d, matrix_data in enumerate(confusion_matrix_plotting_data):
    for n, nrm in enumerate([None, "true"]):
        ConfusionMatrixDisplay.from_predictions(
            y_pred=matrix_data[0],
            y_true=matrix_data[1],
            normalize=nrm,
            ax=axs[d, n],
            cmap=colormap,
            display_labels=["TN", "TP"],
            text_kw={"fontsize": "large"},
        )

        flag = "Raw Counts" if nrm is None else "Normalised"
        text = f"{matrix_data[2].capitalize()} | {flag} Values"
        axs[d, n].set_title(text)

plt.show()


##### Done!