In [None]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn


from iou_graph import IOUGraph
from dgl_reflacx_tools.tools import gridify, grid_readout

from dgl_reflacx_tools.dgl_reflacx_collection import GraphCollection

In [None]:
dataset_pth = 'datasets/reflacx_densnet225_iou'
collection = GraphCollection(dataset_pth, IOUGraph)

#### Getting sample batch

In [None]:
batch_size = 5
grid_size = 4

In [None]:
pairs = [collection.fetch_by_dgl_index(i) for i in range(batch_size)]

In [None]:
graphs = [pair.dgl_graph for pair in pairs]
labels = [pair.dgl_labels for pair in pairs]


In [None]:
batch = dgl.batch(graphs)
labels = torch.cat(labels).reshape((batch_size, len(labels[0])))

batch, labels.shape

#### initialization

In [None]:
batch.ndata['conv_feats'] = batch.ndata['feats'].clone()
batch.apply_edges(fn.copy_u('duration', 'duration'))
batch.update_all(fn.copy_e('weight', 'm'), fn.sum('m', 'edge_factor'))
batch.update_all(fn.copy_u('duration', 'm'), fn.sum('m', 'neighbors_factor'))


In [None]:
batch.ndata.keys()

In [None]:
batch.edata.keys()

In [None]:
grid = gridify(batch, grid_size)

### convolution layer

In [None]:
lin = nn.Linear(1025, 1025)

In [None]:
def conv(g):
    g.update_all(fn.v_mul_e('feats', 'weight', 'm'), fn.sum('m', 'conv_feats'))
    g.ndata['conv_feats'] = torch.multiply(g.ndata['conv_feats'],
                                              g.ndata['edge_factor'].unsqueeze(1))
    return torch.cat([g.ndata['duration'].unsqueeze(1), g.ndata['feats']], dim=1)

In [None]:
x = conv(batch)

In [None]:
y = lin(x)

In [None]:
y.shape

In [None]:
a, b = y.split_with_sizes((1, 1024), dim=1)
a.shape, b.shape

In [None]:
def post_conv(h, g):
    _, conv_feats = h.split_with_sizes((1, 1024), dim=1)
    g.ndata['conv_feats'] = conv_feats

In [None]:
old = batch.ndata['conv_feats'].clone()
post_conv(y, batch)
old == batch.ndata['conv_feats']

In [None]:
lin.weight.shape

#### now on grid

In [None]:
#grid_lin = [[nn.Linear(1025, 1025) for j in range(grid_size)] for i in range(grid_size)]
grid_lin = [[lin for j in range(grid_size)] for i in range(grid_size)]

In [None]:
grid_y = []
for i, line in enumerate(grid):
    y_line = []
    grid_y.append(y_line)
    for j, sg in enumerate(line):
        y_line.append(grid_lin[i][j](conv(sg)))

In [None]:
grid_y[0][3].shape

In [None]:
new_feats = None
i_s = None
for i, line in enumerate(grid):
    for j, sg in enumerate(line):
        _, conv_feats = grid_lin[i][j](conv(sg)).split_with_sizes((1, 1024), dim=1)
        if new_feats is None:
            new_feats = conv_feats
            i_s = sg.ndata['_ID']
        else:
            new_feats = torch.cat([new_feats, conv_feats])
            i_s = torch.cat([i_s, grid[i][j].ndata['_ID']])

# TODO activation

i_s = torch.sort(i_s).indices
batch.ndata['conv_feats'] = new_feats[i_s]
