# Run inference on pre-trained model:

In [None]:
import matplotlib.pyplot as plt
import torch

In [None]:
from grace.styling import COLORMAPS
from grace.io.image_dataset import ImageGraphDataset
from grace.models.feature_extractor import FeatureExtractor

from grace.evaluation.inference import GraphLabelPredictor
from grace.visualisation.manifold import TSNEManifoldProjection
from grace.visualisation.plotting import (
    plot_simple_graph,
    read_patch_stack_by_label, 
    montage_from_image_patches,
    overlay_from_image_patches,
)

### Read some real grace-annotated data:

In [None]:
bbox_size = (224, 224)
extractor_filename = "/Users/kulicna/Desktop/classifier/extractor/resnet152.pt"
pre_trained_resnet = torch.load(extractor_filename)
feature_extractor = FeatureExtractor(model=pre_trained_resnet, bbox_size=bbox_size)


In [None]:
grace_path = "/Users/kulicna/Desktop/dataset/shape_stars/infer"

dataset = ImageGraphDataset(
    image_dir=grace_path, 
    grace_dir=grace_path, 
    transform=feature_extractor,
    keep_node_unknown_labels=True, 
    keep_edge_unknown_labels=True, 
)

In [None]:
image, graph_data = dataset[0]

G = graph_data["graph"]
image = image.numpy()
annot = graph_data["annotation"]
G.number_of_nodes(), G.number_of_edges(), annot.shape, image.shape 

### Visualise the data:

In [None]:
shape = 5
_, axes = plt.subplots(nrows=1, ncols=3, figsize=(shape*3, shape*1))

plot_simple_graph(G, title=f"Graph with {G.number_of_nodes()} nodes & {G.number_of_edges()} edges", ax=axes[0])
axes[0].imshow(image, cmap=COLORMAPS["mask"])
axes[1].imshow(annot, cmap=COLORMAPS["mask"])
axes[2].imshow(image, cmap=COLORMAPS["mask"])

plt.tight_layout()
plt.show()

In [None]:
crops = read_patch_stack_by_label(G, image=image, crop_shape=bbox_size)
montage_from_image_patches(crops)
overlay_from_image_patches(crops)

### Nominate a pre-trained GCN model:

In [None]:
# classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-08_15-11-58/classifier.pt"  # poor GCN + Linear classifier
# classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-07_17-30-51/classifier.pt"  # best Linear classifier

classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-26_18-50-07/classifier.pt"  # poor GCN + Linear classifier
# classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-26_19-00-06/classifier.pt"  # best Linear classifier
classifier_filename


### Perform TSNE before & after GCN:

In [None]:
dim_red = TSNEManifoldProjection(graph=G, model=classifier_filename)
dim_red.plot_TSNE_before_and_after_GCN()
plt.show()
plt.close()

### Show how well the classifier performs:

In [None]:
GLP = GraphLabelPredictor(model=classifier_filename)

# Update an individual graph if needed:
GLP.set_node_and_edge_probabilities(G=G)

graph_plots = GLP.visualise_prediction_probs_on_graph(G=G)
plt.show()
plt.close()


In [None]:
# Process entire inference dataset:
infer_target_list = [graph_data, ]

predicted_results = GLP.calculate_numerical_results_on_entire_batch(infer_target_list)
predicted_results

In [None]:
# path = "/Users/kulicna/Desktop/classifier/"
# GLP.visualise_model_performance_on_entire_batch(infer_target_list, save_figures=path)

GLP.visualise_model_performance_on_entire_batch(infer_target_list, show_figures=True)

##### Done!