In [None]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F


from iou_graph import IOUGraph
from dgl_reflacx_tools.tools import gridify, gridify_indices, gridify_by_indices, grid_readout

from dgl_reflacx_tools.dgl_reflacx_collection import GraphCollection

In [None]:
dataset_pth = 'datasets/reflacx_densnet225_iou'
collection = GraphCollection(dataset_pth, IOUGraph)

#### Getting sample batch

In [None]:
batch_size = 5
grid_size = 4

In [None]:
pairs = [collection.fetch_by_dgl_index(i) for i in range(batch_size)]

In [None]:
graphs = [pair.dgl_graph for pair in pairs]
labels = [pair.dgl_labels for pair in pairs]


In [None]:
batch = dgl.batch(graphs)
labels = torch.cat(labels).reshape((batch_size, len(labels[0])))

batch, labels.shape

#### initialization

In [None]:
batch.ndata['h'] = torch.cat([batch.ndata['norm_x'].unsqueeze(1),
                              batch.ndata['norm_y'].unsqueeze(1),
                              batch.ndata['duration'].unsqueeze(1),
                              batch.ndata['feats']],
                             dim=1)
batch.update_all(fn.copy_e('weight', 'm'), fn.sum('m', 'neigh_weight'))


In [None]:
batch.ndata['h'].shape

In [None]:
input_shape = batch.ndata['h'].shape[1]

In [None]:
batch.ndata.keys()

In [None]:
batch.edata.keys()

#### convolution module on a grid

In [None]:
def pass_messages(g, feat_nm, w_nm, sum_w_nm):
    g.update_all(fn.v_mul_e(feat_nm, w_nm, 'm'), fn.sum('m', feat_nm))
    g.ndata[feat_nm] = torch.divide(batch.ndata[feat_nm], batch.ndata[sum_w_nm].unsqueeze(1))

In [None]:
class GridConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 grid_indices,
                 pass_messages,
                 activation=F.relu):
        super(GridConv, self).__init__()
        self.grid_lin = [[nn.Linear(in_feats, out_feats) for j in range(len(grid_indices[0]))]
                         for i in range(len(grid_indices))]
        self.grid_indices = grid_indices
        self.pass_messages = pass_messages
        self.activation = activation

    def forward(self, graph, feat_nm, out_feat_nm=None):
        # pass messages (convolution) in whole graph
        self.pass_messages(graph, feat_nm)
        
        # activation on grid cell model
        grid = gridify_by_indices(graph, self.grid_indices)
        new_feats = None
        i_s = None
        for i, line in enumerate(grid):
            for j, sg in enumerate(line):
                conv_feats = self.activation(self.grid_lin[i][j](sg.ndata[feat_nm]))
                
                #concatenate new features to uptadate parent graph
                if new_feats is None:
                    new_feats = conv_feats
                    i_s = self.grid_indices[i][j]
                else:
                    new_feats = torch.cat([new_feats, conv_feats])
                    i_s = torch.cat([i_s, self.grid_indices[i][j]])
        
        #update parent graph with features calculated by grid
        i_s = torch.sort(i_s).indices
        new_feats = new_feats[i_s]
        graph.ndata[feat_nm if out_feat_nm is None else out_feat_nm] = new_feats
        
        return new_feats

In [None]:
f_message = lambda g, feat_nm: pass_messages(g, feat_nm, 'weight', 'neigh_weight')

In [None]:
conv = GridConv(input_shape, input_shape, gridify_indices(batch, grid_size), f_message)

In [None]:
with batch.local_scope():
   h = conv(batch, 'h')

In [None]:
h.shape

In [None]:
from dgl_reflacx_tools.tools import Readout

In [None]:
class ReflacxReadout(Readout):
    def __init__(self):
        feats_and_aggrs = [('duration', lambda x, y: dgl.sum_nodes(x, y).cpu()),
                           ('h', lambda x, y: dgl.mean_nodes(x, y).cpu())]
        super().__init__(feats_and_aggrs)

In [None]:
class Classifier(nn.Module):
    def __init__(self,
                 input_dim,
                 conv_dims,
                 class_dims,
                 readout_dim,
                 n_classes,
                 grid_indices,
                 pass_messages,
                 readout,
                 conv_activation=F.relu,
                 mlp_activation=F.relu): # TODO make possibel to have diff activations for conv and fc
        super(Classifier, self).__init__()
        self.grid_indices = grid_indices
        new_conv = lambda in_dims, out_dims: GridConv(in_dims,
                                                      out_dims,
                                                      grid_indices,
                                                      pass_messages,
                                                      conv_activation)
        self.convs = [new_conv(input_dim,
                               (conv_dims[0]
                                if len(conv_dims) > 0
                                else input_dim))]
        for i, dim in enumerate(conv_dims[1:], start=1):
            self.convs.append(new_conv(conv_dims[i - 1], dim))

        self.fcs = [nn.Linear(readout_dim if len(conv_dims) > 0 else input_dim,
                              class_dims[0] if len(class_dims) > 0 else n_classes)]
        for i, dim in enumerate(class_dims[1:], start=1):
            self.fcs.append(nn.Linear(class_dims[i - 1], dim))
        if len(self.fcs) > 1:
            self.fcs.append(nn.Linear(class_dims[-1], n_classes))

        self.readout = readout
        self.conv_activation = conv_activation
        self.mlp_activation = mlp_activation



    def forward(self, graph, conv_feat_nm):
        h = None
        with graph.local_scope():
            for conv_l in self.convs:
                h = conv_l(graph, conv_feat_nm)
                print('h', h.shape)
            grid = gridify_by_indices(graph, self.grid_indices)
            h = self.readout(grid)
        print('ro', h.shape)
        for fc_l in self.fcs[:-1]:
            h = self.mlp_activation(fc_l(h))
            print('h', h.shape)
        h = self.fcs[-1](h)
        print('preds', h.shape)

        return h

        

In [68]:
from collections import OrderedDict

In [65]:
def create_grid_conv_seq(shapes,
                         activation,
                         grid_indices,
                         pass_messages):
    result = []
    for i in range(len(shapes) - 1):
        in_dim = shapes[i]
        out_dim = shapes[i + 1]
        result.append(('conv{}'.format(i + 1),
                       GridConv(in_dim,
                                out_dim,
                                grid_indices,
                                pass_messages,
                                activation)))
    return OrderedDict(result)


In [69]:
nn.Sequential(create_grid_conv_seq([input_shape, 200, 100], F.relu, gridify_indices(batch, grid_size), pass_messages))

Sequential(
  (conv1): GridConv()
  (conv2): GridConv()
)

In [60]:
class Classifier2(nn.Module):
    def __init__(self,
                 input_dim,
                 conv_dims,
                 class_dims,
                 readout_dim,
                 n_classes,
                 grid_indices,
                 pass_messages,
                 readout,
                 conv_activation=F.relu,
                 mlp_activation=F.relu): # TODO make possibel to have diff activations for conv and fc
        super(Classifier2, self).__init__()
        self.grid_indices = grid_indices
        self.convs = nn.Sequential(create_grid_conv_seq([input_dim] + conv_dims,
                                                        conv_activation,
                                                        grid_indices,
                                                        pass_messages))
        
        fsshapes = [readout_dim] + class_dims
        fclist = []
        for i in range(len(fsshapes) - 1):
            fclist.append(('fc{}'.format(i + 1),
                           nn.Linear(fsshapes[i], fsshapes[i + 1])))
            fclist.append(('fc_activ{}'.format(i + 1),
                           mlp_activation()))
        self.fcs = nn.Sequential(OrderedDict(fclist))
        
        self.readout = readout
        self.conv_activation = conv_activation
        self.mlp_activation = mlp_activation



    def forward(self, graph, conv_feat_nm):
        h = None
        with graph.local_scope():
            h = self.convs(batch, 'h')
            grid = gridify_by_indices(graph, self.grid_indices)
            h = self.readout(grid)
        h = self.fcs(h)
        for fc_l in self.fcs[:-1]:
        print('preds', h.shape)

        return h

        

In [61]:
clf = Classifier(input_shape,
                 [100, 200],
                 [50, 40],
                 3216,
                 6,
                 gridify_indices(batch, grid_size),
                 f_message,
                 ReflacxReadout())

In [62]:
clf2 = Classifier2(input_shape,
                 [100, 200],
                 [50, 40],
                 3216,
                 6,
                 gridify_indices(batch, grid_size),
                 f_message,
                 ReflacxReadout())

In [64]:
list(clf.parameters())

[]

In [None]:
h = clf(batch, 'h')

In [None]:
h.shape

In [None]:
h

### Train

one training step

In [56]:
list(clf.parameters())

[]

In [55]:

opt = torch.optim.Adam(clf.parameters())
h = clf(batch, 'h')
loss = F.cross_entropy(h, labels) # TODO check if this is the correct loss for regression
opt.zero_grad()
loss.backward()
opt.step()

ValueError: optimizer got an empty parameter list