# Visualise GRACE data:

In [None]:
import matplotlib.pyplot as plt
import torch

from grace.models.datasets import dataset_from_graph
from grace.models.feature_extractor import FeatureExtractor

from grace.io.image_dataset import ImageGraphDataset
from grace.evaluation.process import generate_ground_truth_graph
from grace.visualisation.subgraph import (
    plot_subgraph_geometry, 
    plot_local_node_geometry
)
from grace.visualisation.plotting import (
    display_image_and_grace_annotation, 
    read_patch_stack_by_label, 
    montage_from_image_patches, 
    overlay_from_image_patches,
    plot_simple_graph,
    plot_connected_components
)

### Visualise the overlay of annotated graph on image annotation mask:

In [None]:
bbox_size = (224, 224)
extractor_filename = "/Users/kulicna/Desktop/classifier/extractor/resnet152.pt"
pre_trained_resnet = torch.load(extractor_filename)
feature_extractor = FeatureExtractor(model=pre_trained_resnet, bbox_size=bbox_size)


In [None]:
grace_path = "/Users/kulicna/Desktop/dataset/shape_stars/infer"
dataset = ImageGraphDataset(
    image_dir=grace_path, 
    grace_dir=grace_path, 
    transform=feature_extractor,
)
image, graph_data = dataset[0]
graph = graph_data["graph"]
graph.number_of_nodes(), graph.number_of_edges()


### Generate ground truth graph:

In [None]:
gt_graph = generate_ground_truth_graph(graph)
gt_graph.number_of_nodes(), gt_graph.number_of_edges()


### Display the annotation:

In [None]:
display_image_and_grace_annotation(image=image, target=graph_data)

### Visualise the crops montages & overlay:

In [None]:
crops = read_patch_stack_by_label(G=graph_data["graph"], image=image, crop_shape=bbox_size)
[len(c) for c in crops]

In [None]:
montage_from_image_patches(crops)

In [None]:
overlay_from_image_patches(crops)

### Draw the plain graphs:

In [None]:
shape = 5
_, axes = plt.subplots(1, 3, figsize=(shape*3, shape*1))

plot_simple_graph(graph, title=f"Simple Random Graph", ax=axes[0])
plot_simple_graph(gt_graph, title=f"Simple Ground Truth Graph", ax=axes[1])
plot_connected_components(gt_graph, title=f"Individual Connected Components", ax=axes[2])

plt.tight_layout()
plt.show()


### Visualise few subgraphs:

In [None]:
dataset = dataset_from_graph(graph, mode='sub', in_train_mode=True)
len(dataset), type(dataset[0])

In [None]:
ncols = 3
_, axes = plt.subplots(nrows=1, ncols=ncols, figsize=(18, 5))
    
for sub in range(ncols):
    sub_graph = dataset[(sub+ncols)*10]
    plot_subgraph_geometry(sub_graph, title=f"Node index = {sub*40}", ax=axes[sub])
    

In [None]:
ncols = 3
_, axes = plt.subplots(nrows=1, ncols=ncols, figsize=(18, 5))
    
for sub in range(ncols):
    node_idx = sub*ncols*5
    plot_local_node_geometry(graph, node_idx=node_idx, title=f"Node index = {sub*ncols*5}", ax=axes[sub])
    

##### Done!