In [None]:
import os
import sys

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

import matplotlib.patches as patches

from grace.io.image_dataset import mrc_reader, ImageGraphDataset
from grace.napari.utils import EdgeColor
from grace.base import GraphAttrs
from grace.models.feature_extractor import resnet, FeatureExtractor
from grace.utils.augment_image import RandomImageGraphRotate
from grace.utils.augment_graph import RandomEdgeAdditionAndRemoval
from grace.models.datasets import dataset_from_graph
from grace.models.classifier import GCN
from grace.training.train import train_model

from grace.utils.augment_image import RandomEdgeCrop
from torchvision.transforms import (
    Resize,
    Lambda,
    Normalize,
    RandomApply,
    RandomAffine,
)


In [None]:
IMAGE_SAVE_DIR = '/Users/mfamili/work/exp_grace/'

### Helper Code

In [None]:
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target
    
def normalize8(I):
    I = np.array(I)
    mn = I.min()
    mx = I.max()

    mx -= mn

    I = ((I - mn)/mx) * 255
    return I.astype(np.uint8)

def draw_graph(graph, ax, edge_color:str='cyan', node_color='teal'):

    # node positions
    pos = {
        idx: (node[GraphAttrs.NODE_X], node[GraphAttrs.NODE_Y]) 
        for idx, node in graph.nodes(data=True)
    }
    #edge_gt = [data[GraphAttrs.EDGE_GROUND_TRUTH] for _,_,data in graph_data['graph'].edges(data=True)]
    edge_colors = [edge_color]
    node_colors = [node_color]

    nx.draw(
        graph, 
        ax=ax, 
        pos=pos, 
        #pos=pos,
        with_labels=False, 
        # node_color="w", 
        #node_size=32,
        node_size=15,
        edge_color=edge_colors,
        node_color=node_colors,
    )

def show_image_and_graph(image, graph_data):

    fig, axes = plt.subplots(1,3, figsize=(30, 10))

    # node positions
    pos = {
        idx: (node[GraphAttrs.NODE_X], node[GraphAttrs.NODE_Y]) 
        for idx, node in graph_data['graph'].nodes(data=True)
    }
    #pos = {k: (pos[k][1], pos[k][0]) for k in pos}
    #pos_flipped = {k: (pos[k][1],image.size()[0]-pos[k][0]) for k in pos}
    pos_flipped = {k: (pos[k][0],image.size()[1]-pos[k][1]) for k in pos}
    pos_ = [pos_flipped, pos]

    # edge annotations
    edge_gt = [data[GraphAttrs.EDGE_GROUND_TRUTH] for _,_,data in graph_data['graph'].edges(data=True)]
    edge_colors = [EdgeColor[gt.name].value for gt in edge_gt]

    node_colors = [
        EdgeColor[node_attrs[GraphAttrs.NODE_GROUND_TRUTH].name].value 
        for _, node_attrs in graph_data['graph'].nodes(data=True)
    ]

    axes[1].set_aspect('equal')

    for n,ax in enumerate(axes[1:]):
        nx.draw(
            graph_data['graph'], 
            ax=ax, 
            pos=pos_[n], 
            #pos=pos,
            with_labels=False, 
            # node_color="w", 
            node_size=10,
            #node_size=2,
            edge_color=edge_colors,
            node_color=node_colors,
        )

    for ax in [axes[0], axes[2]]:
        image = normalize8(image)
        ax.imshow(image, cmap='gray')
        ax.set_xticks([])
        ax.set_yticks([])

    fig.suptitle(graph_data['metadata']["image_filename"], y=0.95, fontsize=25)

    return fig, axes

### Load Image and Graph (Grace File)

In [None]:
IMAGEPATH = "/Users/mfamili/work/datasets/dataset_synthetic_grace/shape_stars/train"
GRACEPATH = "/Users/mfamili/work/datasets/dataset_synthetic_grace/shape_stars/train"

In [None]:
image_graph_dataset = ImageGraphDataset(
    image_dir=IMAGEPATH,
    grace_dir=GRACEPATH,
    image_filetype="mrc",
    transform=lambda x,y: (x,y),
)

### Show Image and Graph (No Augmentations)

In [None]:
for image, graph_data in image_graph_dataset:

    show_image_and_graph(image, graph_data)
    #plt.savefig(os.path.join(IMAGE_SAVE_DIR, 'full_image_raw'), bbox_inches='tight')

### Show Image and Graph (Rotation Augmentation)

In [None]:
image_graph_dataset = ImageGraphDataset(
    image_dir=IMAGEPATH,
    grace_dir=GRACEPATH,
    image_filetype="mrc",
    transform=RandomImageGraphRotate(),
)

In [None]:
for image, graph_data in image_graph_dataset:

    show_image_and_graph(image, graph_data)
    plt.savefig(os.path.join(IMAGE_SAVE_DIR, 'full_image_rotated'), bbox_inches='tight')

### Show Image and Graph (Graph Augmentation)

In [None]:
image_graph_dataset = ImageGraphDataset(
    image_dir=IMAGEPATH,
    grace_dir=GRACEPATH,
    image_filetype="mrc",
    transform=RandomEdgeAdditionAndRemoval(annotation_mode='unknown', p_add=0.02, p_remove=0.02),
)

In [None]:
for image, graph_data in image_graph_dataset:

    show_image_and_graph(image, graph_data)
    plt.savefig(os.path.join(IMAGE_SAVE_DIR, 'full_image_graph_aug'), bbox_inches='tight')

### Extract Bounding Boxes (No Augmentation)

In [None]:
feature_extractor = FeatureExtractor(model=lambda x: x,
                                     augmentations=lambda x: x)
image_e, graph_data_e = feature_extractor(image, graph_data)

In [None]:
image_graph_dataset = ImageGraphDataset(
    image_dir=IMAGEPATH,
    grace_dir=GRACEPATH,
    image_filetype="mrc",
    transform=Compose([
        feature_extractor,
    ]),
)

In [None]:
NODES = [150, 151, 153]

fig, axes = plt.subplots(1, len(NODES), figsize=(30,10))

for node, ax in enumerate(axes):

    patch = graph_data_e['graph'].nodes(data=True)[NODES[node]][GraphAttrs.NODE_FEATURES]
    ax.imshow(patch[0], cmap='gray')
    ax.set_title(f"Patch {NODES[node]}", fontsize=30)
    ax.set_xticks([])
    ax.set_yticks([])

fig.suptitle('Patches, No Augmentation', fontsize=40, x=.51)
plt.savefig(os.path.join(IMAGE_SAVE_DIR, 'patches_raw'), bbox_inches='tight')
plt.show()

### Extract Bounding Boxes (Rotated)

In [None]:
feature_extractor = FeatureExtractor(model=lambda x: x,
                                     augmentations=lambda x: x)


In [None]:
image_graph_dataset = ImageGraphDataset(
    image_dir=IMAGEPATH,
    grace_dir=GRACEPATH,
    image_filetype="mrc",
    transform=Compose([
        RandomImageGraphRotate(),
        feature_extractor,
    ]),
)

In [None]:
for image_e, graph_data_e in image_graph_dataset:

    show_image_and_graph(image_e, graph_data_e)
    plt.savefig(os.path.join(IMAGE_SAVE_DIR, 'full_image_rotated_'), bbox_inches='tight')

    NUM_NODES = 3

    fig, axes = plt.subplots(1, NUM_NODES, figsize=(30,10))
    ax_n = 0
    node = 150
    node_chosen = 0

    for node in [150,151,153]:
    #while node_chosen < NUM_NODES:

        patch = graph_data_e['graph'].nodes(data=True)[node][GraphAttrs.NODE_FEATURES]
        '''if patch is None:
            node +=1
            continue'''

        ax = axes[ax_n]
        ax.imshow(patch[0], cmap='gray')
        ax.set_title(f"Patch {node}", fontsize=30)
        ax.set_xticks([])
        ax.set_yticks([])
        ax_n += 1
        node_chosen += 1
        node +=1

    fig.suptitle('Patches, Rotated', fontsize=40, x=.51)
    plt.savefig(os.path.join(IMAGE_SAVE_DIR, 'patches_rotated'), bbox_inches='tight')
    plt.show()

### Apply transforms one by one 

In [None]:
'''import torch
for image, graph_data in image_graph_dataset:

    image_a, graph_data_a = RandomImageGraphRotate()(image, graph_data)
    #print(torch.equal(image, image_a))
    #print(nx.utils.misc.graphs_equal(graph_data['graph'], graph_data_a['graph']))
    image_a, graph_data_a = feature_extractor(image_a, graph_data_a)'''

### Extract Bounding Boxes (Translate Augmentation)

In [None]:
augmentations = RandomApply(
    [
        #RandomEdgeCrop(max_fraction=0.1),
        RandomAffine(
            degrees=0,
            translate=(0.2, 0.2),
        ),
    ],
    p=1.,
)

feature_extractor = FeatureExtractor(model=lambda x: x,
                                    augmentations=augmentations)

image_graph_dataset = ImageGraphDataset(
    image_dir=IMAGEPATH,
    grace_dir=GRACEPATH,
    image_filetype="mrc",
    transform=Compose([
        feature_extractor,
    ]),
)


In [None]:
NODES = [150, 151, 153]

for image_e, graph_data_e in image_graph_dataset:

    fig, axes = plt.subplots(1, len(NODES), figsize=(30,10))

    for node, ax in enumerate(axes):

        patch = graph_data_e['graph'].nodes(data=True)[NODES[node]][GraphAttrs.NODE_FEATURES]
        ax.imshow(normalize8(patch[0]), cmap='gray')
        ax.set_title(f"Patch {NODES[node]}", fontsize=30)
        ax.set_xticks([])
        ax.set_yticks([])

    fig.suptitle('Patches, Translation Augmentation', fontsize=40, x=.51)
    plt.savefig(os.path.join(IMAGE_SAVE_DIR, 'patches_translated'), bbox_inches='tight')
    plt.show()

### Extract Bounding Boxes (Translate & Rotation Augmentation)

In [None]:
augmentations = RandomApply(
    [
        #RandomEdgeCrop(max_fraction=0.1),
        RandomAffine(
            degrees=0,
            translate=(0.2, 0.2),
        ),
    ],
    p=1.,
)

feature_extractor = FeatureExtractor(model=lambda x: x,
                                    augmentations=augmentations)

image_graph_dataset = ImageGraphDataset(
    image_dir=IMAGEPATH,
    grace_dir=GRACEPATH,
    image_filetype="mrc",
    transform=Compose([
        RandomImageGraphRotate(),
        feature_extractor,
    ]),
)


In [None]:
NODES = [150, 151, 153]

for image_e, graph_data_e in image_graph_dataset:

    fig, axes = plt.subplots(1, len(NODES), figsize=(30,10))

    for node, ax in enumerate(axes):

        patch = graph_data_e['graph'].nodes(data=True)[NODES[node]][GraphAttrs.NODE_FEATURES]
        ax.imshow(normalize8(patch[0]), cmap='gray')
        ax.set_title(f"Patch {NODES[node]}", fontsize=30)
        ax.set_xticks([])
        ax.set_yticks([])

    fig.suptitle('Patches, Translation & Rotation Augmentation', fontsize=40, x=.51)
    plt.savefig(os.path.join(IMAGE_SAVE_DIR, 'patches_translated_rotated'), bbox_inches='tight')
    plt.show()

### Dataset 

In [None]:
feature_extractor = FeatureExtractor(model=lambda x: x,
                                     bbox_size=(150,150),
                                     transforms=lambda x: x,
                                     augmentations=lambda x: x,)
image_graph_dataset = ImageGraphDataset(
    image_dir=IMAGEPATH,
    grace_dir=GRACEPATH,
    image_filetype="mrc",
    transform=Compose([
        feature_extractor,
    ]),
)

SUBGRAPHS = [30]

for img, gph in image_graph_dataset:
    dataset = dataset_from_graph(gph['graph'],mode = "sub")
    for s in SUBGRAPHS:

        data = dataset[s]
        #x = data.x[:,0,...] # (N, 224, 224)
        x = data.x # (N, 224, 224)
        img_size = x.size()[-2:]
        box_coords = data.edge_attr
        edges = data.edge_index

        coords = np.array([[cor[1], cor[0]] for cor in box_coords])

        FACTOR = 1
        PADDING = 7

        fig, ax = plt.subplots(figsize=(20,20))
        max_width = max(coords[:,0]) - min(coords[:,0]) + img_size[0] + PADDING*2
        max_height = max(coords[:,1]) - min(coords[:,1]) + img_size[1]
        w = int(max_width*FACTOR - img_size[1]*(FACTOR-1)) + PADDING*2
        h = int(max_height*FACTOR - img_size[0]*(FACTOR-1)) + PADDING*2
        img = np.full((w,h), 255, dtype='uint8')

        for node in range(x.size(0)):

            pad = np.zeros((img_size[-2]+PADDING*2, img_size[-1]+PADDING*2))

            coord = coords[node]
            x_, y_ =  coord[0] - min(coords[:,0]), coord[1] - min(coords[:,1])
            x_, y_ = x_*FACTOR + PADDING, y_*FACTOR + PADDING
            x_box = slice(int(x_), int(x_ + img_size[0]))
            y_box = slice(int(y_), int(y_ + img_size[1]))
            img[int(x_-PADDING):int(x_+PADDING+img_size[0]), int(y_-PADDING):int(y_+PADDING+img_size[1])] = pad
            img[x_box, y_box] = normalize8(x[node])

            circle = patches.Circle(xy=[y_+img_size[1]/2, x_+img_size[0]/2], radius=7, facecolor='darkorange')
            c = circle.get_facecolor()
            ax.add_patch(circle)

        for edge in range(edges.size(1)):

            src_node, dst_node = edges[0,edge], edges[1,edge]
            x_vals = [coords[src_node][0], coords[dst_node][0]]
            y_vals = [coords[src_node][1], coords[dst_node][1]]
            
            x_vals = [i - min(coords[:,0]) for i in x_vals]
            x_vals = [i*FACTOR + PADDING + img_size[0]/2 for i in x_vals]
            
            y_vals = [i - min(coords[:,1]) for i in y_vals]
            y_vals = [i*FACTOR + PADDING + img_size[1]/2 for i in y_vals]

            ax.plot(y_vals, x_vals, linewidth=3, color=c)

        ax.imshow(img, cmap='gray')

        ax.spines[['right', 'top', 'bottom', 'left']].set_visible(False)
        ax.set_xticks([])
        ax.set_yticks([])

        plt.savefig(os.path.join(IMAGE_SAVE_DIR, f'subgraph_{s}_orange_150px_to_scale'), bbox_inches='tight')
        plt.show()
    


### Draw Subgraph on Big Graph

In [None]:
for image, graph_data in image_graph_dataset:
    1

In [None]:
subgraph = 30
central_coords = data.pos[np.where(np.isclose(data.edge_attr, 0))[0][0]]

for node, values in graph_data['graph'].nodes(data=True):
    node_coords = np.array([values[GraphAttrs.NODE_X], values[GraphAttrs.NODE_Y]])
    if np.allclose(central_coords, node_coords):
        central_node = node
        break

In [None]:
fig, ax = plt.subplots(1,1, figsize=(15, 15))

draw_graph(graph_data['graph'], ax)
draw_graph(nx.ego_graph(graph_data['graph'], central_node), ax, 'darkorange', 'darkorange')

image = normalize8(image)
ax.imshow(image, cmap='gray')
ax.set_xticks([])
ax.set_yticks([])

plt.savefig(os.path.join(IMAGE_SAVE_DIR, f'location_subgraph_{s}'), bbox_inches='tight')

### Classifier

In [None]:
# Set up dataset
IMAGEPATH = "/Users/mfamili/work/datasets/dataset_synthetic_grace/shape_stars/train"
GRACEPATH =  "/Users/mfamili/work/datasets/dataset_synthetic_grace/shape_stars/train"

feature_extractor = FeatureExtractor(model=lambda x: np.random.normal(size=x.size()[:-3]+(2,)),
                                     bbox_size=(224,224),
                                     augmentations=lambda x: x,)
image_graph_dataset = ImageGraphDataset(
    image_dir=IMAGEPATH,
    grace_dir=GRACEPATH,
    image_filetype="mrc",
    transform=Compose([
        feature_extractor,
    ]),
)

In [None]:
for img, gph in image_graph_dataset:
    training_dataset = dataset_from_graph(gph['graph'], mode='sub')

In [None]:
gcn = GCN(input_channels=2,
          hidden_channels=[2],
          node_output_classes=2,
          edge_output_classes=2,)

In [None]:
'''for data in training_dataset:
    print(data)'''

In [None]:
#node_x, edge_x = gcn(data.x, data.edge_index)

In [None]:
#print(node_x.size(), edge_x.size(), data.edge_label.size(), data.edge_index.size())

#### Test DataLoader

In [None]:
from torch_geometric.loader import DataLoader

In [None]:
loader = DataLoader(training_dataset, batch_size = 1)

In [None]:
for batch in loader:

    print(batch)
    print(batch.x.size())

### Train

In [None]:
train_model(gcn, training_dataset, epochs=10, 
            log_dir=os.path.join(IMAGE_SAVE_DIR, "run_3"), metrics=['accuracy', 'confusion_matrix'])

In [None]:
%load_ext tensorboard