# Run inference on pre-trained model:

In [None]:
import matplotlib.pyplot as plt

import torch

In [None]:
from grace.io.image_dataset import ImageGraphDataset
from grace.models.feature_extractor import FeatureExtractor
from grace.evaluation.visualisation import plot_simple_graph
from grace.evaluation.inference import GraphLabelPredictor


### Read some real grace-annotated data:

In [None]:
extractor_filename = "/Users/kulicna/Desktop/classifier/extractor/resnet152.pt"
pre_trained_resnet = torch.load(extractor_filename)
feature_extractor = FeatureExtractor(model=pre_trained_resnet)


In [None]:
grace_path = "/Users/kulicna/Desktop/dataset/shape_stars/train"
# grace_path = "/Users/kulicna/Desktop/dataset/shape_stars/infer"
dataset = ImageGraphDataset(
    image_dir=grace_path, 
    grace_dir=grace_path, 
    transform=feature_extractor,
    keep_node_unknown_labels=False, 
    keep_edge_unknown_labels=False, 
    
)

In [None]:
image, graph_data = dataset[1]
# image, graph_data = dataset[0]

G = graph_data["graph"]
image = image.numpy()
annot = graph_data["annotation"]
G.number_of_nodes(), G.number_of_edges(), annot.shape, image.shape 

### Visualise the data:

In [None]:
shape = 5
_, axes = plt.subplots(nrows=1, ncols=3, figsize=(shape*3, shape*1))

plot_simple_graph(G, title=f"Graph with {G.number_of_nodes()} nodes & {G.number_of_edges()} edges", ax=axes[0])
axes[0].imshow(image, cmap="binary_r")
axes[1].imshow(annot, cmap="binary_r")
axes[2].imshow(image, cmap="binary_r")

plt.tight_layout()
plt.show()

### Nominate a pre-trained GCN model:

In [None]:
# classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-07_17-07-08/classifier.pt"
# classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-07_17-15-47/classifier.pt"
classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-07_17-30-51/classifier.pt"  # best Linear classifier
# classifier_filename = "/Users/kulicna/Desktop/classifier/runs/2023-09-08_15-11-58/classifier.pt"  # bad GCN + Linear classifier

pre_trained_gcn = torch.load(classifier_filename)
pre_trained_gcn.eval()


### Features are now automatically appended to the image - predict:

In [None]:
predictor = GraphLabelPredictor(pre_trained_gcn)
predictor.set_node_and_edge_probabilities(G)
node_acc, edge_acc = predictor.visualise_performance(G)
print(f"Node accuracy = {node_acc:.4f} | Edge accuracy = {edge_acc:.4f}")

##### Done!