### Message Passing for REFLACX graph datapoints

In [None]:
import torch
import dgl.function as fn
import dgl
import matplotlib.pyplot as plt
import numpy as np

from iou_graph import IOUGraph
from dgl_reflacx_tools.tools import get_node, get_edge, draw, get_connected_component
from dgl_reflacx_tools.dgl_reflacx_collection import GraphCollection

In [None]:
dataset_pth = 'datasets/reflacx_densnet225_iou'

In [None]:
collection = GraphCollection(dataset_pth, IOUGraph)

In [None]:
pair = collection.fetch_by_reflacx('1bdf3180-0209f001-967acab6-0b811ea2-3c2e13eb', 'P300R510107')

In [None]:
g = pair.dgl_graph
labels = pair.dgl_labels
g, labels

In [None]:
if torch.cuda.device_count() > 0:
    g = g.to(torch.device('cuda:0'))
    labels = labels.to(torch.device('cuda:0'))
g, labels, g.nodes().device, labels.device

In [None]:
draw(g)

In [None]:
def node_edge_fns(g):
    node = lambda i: get_node(g, i)
    edge = lambda i, j: get_edge(g, i, j)
    return node, edge

In [None]:
adj = g.adjacency_matrix().to_dense().cpu().detach().numpy()

In [None]:
cc = get_connected_component(np.copy(adj))

In [None]:
sg = g.subgraph(list(cc))

In [None]:
dgl.mean_nodes(sg, 'feats').shape

In [None]:
draw(sg)

In [None]:
def print_keys(keys):
    for k in keys:
        print(k)

In [None]:
print_keys(sg.edata.keys())

In [None]:
print_keys(sg.ndata.keys())

Setting neighborhood weight factors for duration and IOU

In [None]:
def init_gnn(g):
    g.apply_edges(fn.u_mul_e('duration', 'weight', 'factor'))
    g.update_all(fn.copy_e('factor', 'm'), fn.sum('m', 'neigh_factor'))

init_gnn(sg)

convolution step 1

In [None]:
def conv_1(g):
    g.apply_edges(fn.u_mul_e('feats', 'factor', 'w_feats'))
    g.update_all(fn.e_div_v('w_feats', 'neigh_factor', 'norm_feats'),
                  fn.sum('norm_feats', 'conv_feats'))

conv_1(sg)

convolution steps 2 and on

In [None]:
def conv_2n(g):
    g.apply_edges(fn.u_mul_e('conv_feats', 'factor', 'w_feats'))
    g.update_all(fn.e_div_v('w_feats', 'neigh_factor', 'norm_feats'),
                 fn.sum('norm_feats', 'conv_feats'))

conv_2n(sg)

### Aggregating graph into grid
making a subgraph for each grid cell, limited by x and y

In [None]:
node, edge = node_edge_fns(g)

In [None]:
from dgl_reflacx_tools.tools import gridify, grid_readout

In [None]:
_, sg_grid = gridify(g, 4)
sg_grid[0]

In [None]:
duration_ro = grid_readout(sg_grid, 'duration', lambda x, y: dgl.sum_nodes(x, y).cpu())
duration_ro.shape

In [None]:
duration_ro

In [None]:
feats_ro = grid_readout(sg_grid, 'feats', lambda x, y: dgl.mean_nodes(x, y).cpu())
feats_ro.shape

#### Tests with batched graphs

In [None]:
g2 = collection.fetch_by_dgl_index(3).dgl_graph.to('cuda:0')
g3 = collection.fetch_by_dgl_index(5).dgl_graph.to('cuda:0')
b = dgl.batch([g, g2, g3])

In [None]:
gridsize = 4

In [None]:
_, bb = gridify(b, gridsize)
_, gg1 = gridify(g, gridsize)
_, gg2 = gridify(g2, gridsize)
_, gg3 = gridify(g3, gridsize)


In [None]:
X, Y = 0, 1

In [None]:
cell_bb = bb[X][Y]
cell_gg1 = gg1[X][Y]
cell_gg2 = gg2[X][Y]
cell_gg3 = gg3[X][Y]

In [None]:
cell_gg1.batch_num_nodes(), cell_gg2.batch_num_nodes(), cell_gg3.batch_num_nodes()

In [None]:
cell_bb.batch_num_nodes()

In [None]:
cell_gg1.batch_num_edges(), cell_gg2.batch_num_edges(), cell_gg3.batch_num_edges()

In [None]:
cell_bb.batch_num_edges()

In [None]:
torch.all(dgl.sum_nodes(cell_bb, 'duration') == torch.cat([dgl.sum_nodes(cell_gg1, 'duration'), dgl.sum_nodes(cell_gg2, 'duration'), dgl.sum_nodes(cell_gg3, 'duration')]))

In [None]:
torch.all(dgl.sum_edges(cell_bb, 'weight') == torch.cat([dgl.sum_edges(cell_gg1, 'weight'), dgl.sum_edges(cell_gg2, 'weight'), dgl.sum_edges(cell_gg3, 'weight')]))

Testing readouts for batches

In [None]:
bb_ro = grid_readout(bb, 'feats', lambda x, y: dgl.mean_nodes(x, y).cpu())
bb_gg1 = grid_readout(gg1, 'feats', lambda x, y: dgl.mean_nodes(x, y).cpu())
bb_gg2 = grid_readout(gg2, 'feats', lambda x, y: dgl.mean_nodes(x, y).cpu())
bb_gg3 = grid_readout(gg3, 'feats', lambda x, y: dgl.mean_nodes(x, y).cpu())

bb_ro.shape, bb_gg1.shape, bb_gg2.shape, bb_gg3.shape

In [None]:
(torch.all(bb_ro[0] == bb_gg1),
 torch.all(bb_ro[1] == bb_gg2),
 torch.all(bb_ro[2] == bb_gg3))

In [None]:
bbd_ro = grid_readout(bb, 'duration', lambda x, y: dgl.sum_nodes(x, y).cpu())

In [None]:
bb_ro.shape, bbd_ro.shape

In [None]:
torch.cat([bbd_ro.unsqueeze(-1), bb_ro], dim=-1).shape

In [None]:
from dgl_reflacx_tools.tools import Readout

In [None]:
class Readout:
    def __init__(self, feats_and_aggrs, replace_nan=True):
        self.replace_nan = replace_nan
        self.readouts = []
        for feat_nm, aggr in feats_and_aggrs:
            self.readouts.append(lambda grid, f=feat_nm, a=aggr, r=replace_nan: grid_readout(grid, f, a, r))
            
    def __call__(self, grid, flatten=True):
        result = None
        for readout in self.readouts:
            ro = readout(grid)
            if ro.dim() < 4:
                ro = ro.unsqueeze(-1)
            if result is None:
                result = ro
            else:
                result = torch.cat([result, ro], dim=-1)
        return result if not flatten else result.flatten(1, -1)

In [None]:
refro = Readout([('feats', lambda x, y: dgl.mean_nodes(x, y).cpu()),
                 ('duration', lambda x, y: dgl.sum_nodes(x, y).cpu())
                ])

In [None]:
cons_ro = refro(bb, flatten=False)

In [None]:
cons_ro.shape

In [None]:
cons_ro.flatten(1, -1).shape

In [None]:
4* 4* 1025

In [None]:
torch.all(cons_ro[:, :, :, :-1] == bb_ro), torch.all(cons_ro[:, :, :, -1] == bbd_ro)