### Message Passing for REFLACX graph datapoints

In [None]:
import torch
import dgl.function as fn
import dgl
import matplotlib.pyplot as plt
import numpy as np


from iou_graph import IOUGraph
from dgl_reflacx_tools.tools import get_node, get_edge, draw, get_connected_component
from dgl_reflacx_tools.dgl_reflacx_collection import GraphCollection

In [None]:
dataset_pth = 'datasets/reflacx_densnet225_iou'

In [None]:
collection = GraphCollection(dataset_pth, IOUGraph)

In [None]:
pair = collection.fetch_by_reflacx('1bdf3180-0209f001-967acab6-0b811ea2-3c2e13eb', 'P300R510107')

In [None]:
g = pair.dgl_graph
labels = pair.dgl_labels
g, labels

In [None]:
if torch.cuda.device_count() > 0:
    g = g.to(torch.device('cuda:0'))
    labels = labels.to(torch.device('cuda:0'))
g, labels, g.nodes().device, labels.device

In [None]:
draw(g)

In [None]:
def node_edge_fns(g):
    node = lambda i: get_node(g, i)
    edge = lambda i, j: get_edge(g, i, j)
    return node, edge

In [None]:
adj = g.adjacency_matrix().to_dense().cpu().detach().numpy()

In [None]:
cc = get_connected_component(np.copy(adj))

In [None]:
sg = g.subgraph(list(cc))

In [None]:
dgl.mean_nodes(sg, 'feats').shape

In [None]:
draw(sg)

In [None]:
def print_keys(keys):
    for k in keys:
        print(k)

In [None]:
print_keys(sg.edata.keys())

In [None]:
print_keys(sg.ndata.keys())

Setting neighborhood weight factors for duration and IOU

In [None]:
def init_gnn(g):
    g.apply_edges(fn.u_mul_e('duration', 'weight', 'factor'))
    g.update_all(fn.copy_e('factor', 'm'), fn.sum('m', 'neigh_factor'))

init_gnn(sg)

convolution step 1

In [None]:
def conv_1(g):
    g.apply_edges(fn.u_mul_e('feats', 'factor', 'w_feats'))
    g.update_all(fn.e_div_v('w_feats', 'neigh_factor', 'norm_feats'),
                  fn.sum('norm_feats', 'conv_feats'))

conv_1(sg)

convolution steps 2 and on

In [None]:
def conv_2n(g):
    g.apply_edges(fn.u_mul_e('conv_feats', 'factor', 'w_feats'))
    g.update_all(fn.e_div_v('w_feats', 'neigh_factor', 'norm_feats'),
                 fn.sum('norm_feats', 'conv_feats'))

conv_2n(sg)

testing execution time for largest graph

In [None]:
metadata = collection.reflacx

In [None]:
max_pt = None
max_len = 0
for d in metadata.list_dicom_ids():
    for r in metadata.list_reflacx_ids(d):
        pt = metadata.get_sample(d, r)
        l = len(pt.get_fixations())
        if max_pt is None or l > max_len:
            max_len = l
            max_pt = pt

In [None]:
max_len

In [None]:
d, r = max_pt.dicom_id, max_pt.reflacx_id
d, r

In [None]:
d = '5d93b668-2ecb804a-0b026b1d-08c7dd4d-0bd8202c'
r = 'P300R591856'

In [None]:
pair = collection.fetch_by_reflacx(d, r)
g = pair.dgl_graph
labels = pair.dgl_labels
g, labels

In [None]:
print_keys(g.edata.keys())

In [None]:
print_keys(g.ndata.keys())

In [None]:
from time import time

In [None]:
clock = time()
init_gnn(g)
init_t = time() - clock
clock = time()
conv_1(g)
conv_1_t = time() - clock
clock = time()
conv_2n(g)
conv_2n_t = time() - clock

In [None]:
init_t, conv_1_t, conv_2n_t

### Aggregating graph into grid
making a subgraph for each grid cell, limited by x and y

In [None]:
node, edge = node_edge_fns(g)

In [None]:
node(0)

In [None]:
from dgl_reflacx_tools.tools import gridify, grid_readout

In [None]:
sg_grid = gridify(g, 4)
sg_grid[0]

In [None]:
duration_ro = grid_readout(sg_grid, 'duration', lambda x, y: dgl.sum_nodes(x, y).cpu())
duration_ro.shape

In [None]:
feats_ro = grid_readout(sg_grid, 'feats', lambda x, y: dgl.mean_nodes(x, y).cpu())
feats_ro.shape

In [None]:
fs = torch.dstack((duration_ro, feats_ro))
fs.shape