In [None]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F


from gaze_graphs.iou_graph import IOUGraph
from dgl_reflacx_tools.grid_tools import gridify, gridify_indices, gridify_by_indices, grid_readout

from dgl_reflacx_tools.dgl_reflacx_collection import GraphCollection

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
dataset_pth = 'datasets/reflacx_densnet225_iou'
collection = GraphCollection(dataset_pth, IOUGraph)

#### Getting sample batch

In [None]:
batch_size = 5
grid_size = 4

In [None]:
pairs = [collection.fetch_by_dgl_index(i) for i in range(batch_size)]

In [None]:
graphs = [pair.dgl_graph for pair in pairs]
labels = [pair.dgl_labels for pair in pairs]


In [None]:
batch = dgl.batch(graphs).to(device)
labels = torch.cat(labels).reshape((batch_size, len(labels[0]))).to(device)

batch, labels.shape, batch.device, labels.device

#### initialization

Setting node features to be convolved. Concatenating (X, Y) position, duration, and extracted features from gaze crop

In [None]:
def init_graph(g):
    g.ndata['h'] = torch.cat([g.ndata['norm_x'].unsqueeze(1),
                              g.ndata['norm_y'].unsqueeze(1),
                              g.ndata['duration'].unsqueeze(1),
                              g.ndata['feats']],
                              dim=1)
    g.update_all(fn.copy_e('weight', 'm'), fn.sum('m', 'neigh_weight'))


In [None]:
fmin, fmax = torch.max(batch.ndata['feats'], dim=0).values, torch.min(batch.ndata['feats'], dim=0).values
finterval = fmax - fmin
fmin.shape, fmax.shape, finterval.shape

In [None]:
((batch.ndata['feats'][123] - fmin) / finterval).shape

In [None]:
init_graph(batch)
batch.ndata['h'].shape

#### convolution module on a grid

In [None]:
def pass_messages(g, feat_nm, w_nm, sum_w_nm):
    g.update_all(fn.v_mul_e(feat_nm, w_nm, 'm'), fn.sum('m', feat_nm))
    g.ndata[feat_nm] = torch.divide(g.ndata[feat_nm], g.ndata[sum_w_nm].unsqueeze(1))

In [None]:
class GridConv(nn.Module):
    def __init__(self,
                 device,
                 in_feats,
                 out_feats,
                 grid_size,
                 pass_messages,
                 activation=F.relu):
        super(GridConv, self).__init__()
        self.grid_lin = [[nn.Linear(in_feats, out_feats).to(device)
                          for j in range(grid_size)]
                         for i in range(grid_size)]
        self.grid_size = grid_size
        self.pass_messages = pass_messages
        self.activation = activation

    def forward(self, graph, feat_nm, grid_indices, out_feat_nm=None):
        # pass messages (convolution) in whole graph
        self.pass_messages(graph, feat_nm)
        
        # activation on grid cell model
        grid = gridify_by_indices(graph, grid_indices)
        new_feats = None
        i_s = None
        for i, line in enumerate(grid):
            for j, sg in enumerate(line):
                conv_feats = self.activation(self.grid_lin[i][j](sg.ndata[feat_nm]))
                
                #concatenate new features to uptadate parent graph
                if new_feats is None:
                    new_feats = conv_feats
                    i_s = grid_indices[i][j]
                else:
                    new_feats = torch.cat([new_feats, conv_feats])
                    i_s = torch.cat([i_s, grid_indices[i][j]])
        
        #update parent graph with features calculated by grid
        i_s = torch.sort(i_s).indices
        new_feats = new_feats[i_s]
        graph.ndata[feat_nm if out_feat_nm is None else out_feat_nm] = new_feats
        
        return new_feats

In [None]:
f_message = lambda g, feat_nm: pass_messages(g, feat_nm, 'weight', 'neigh_weight')

In [None]:
conv = GridConv(device, 1027, 1027, grid_size, f_message).to(device)

In [None]:
g_i = gridify_indices(batch, grid_size)
with batch.local_scope():
   h = conv(batch, 'h', g_i)
   h2 = conv(batch, 'h', g_i)

with batch.local_scope():
   h3 = conv(batch, 'h', g_i)

In [None]:
torch.all(h == h2), torch.all(h == h3)

In [None]:
h.shape

Defining readout to be, for each grid cell, the concatenation of the sum of the cell's nodes duration with the convolved features

In [None]:
from dgl_reflacx_tools.grid_tools import ReadoutPipeline

In [None]:
class ReflacxReadout(ReadoutPipeline):
    def __init__(self):
        feats_and_aggrs = [('duration', dgl.sum_nodes),
                           ('h', dgl.mean_nodes)]
        super().__init__(feats_and_aggrs)

In [None]:
class ReflacxClassifier(nn.Module):
    def __init__(self,
                 device,
                 input_dim,
                 readout_dim,
                 n_classes,
                 grid_size,
                 pass_messages,
                 readout,
                 conv_activation=F.relu,
                 mlp_activation=F.relu):
        super(ReflacxClassifier, self).__init__()
        self.grid_size = grid_size
        new_conv = lambda in_feats, out_feats: GridConv(device,
                                                        in_feats,
                                                        out_feats,
                                                        self.grid_size,
                                                        pass_messages,
                                                        conv_activation)
        self.conv1 = new_conv(input_dim, 512)
        self.conv2 = new_conv(512, 256)

        self.fc1 = nn.Linear(readout_dim, 256).to(device)
        self.fc2 = nn.Linear(256, 64).to(device)

        self.fcf = nn.Linear(64, n_classes).to(device)
        
        self.readout = readout
        self.conv_activation = conv_activation
        self.mlp_activation = mlp_activation



    def forward(self, graph, conv_feat_nm, grid_indices):
        with graph.local_scope():
            h = self.conv1(graph, conv_feat_nm, grid_indices)
            h = self.conv2(graph, conv_feat_nm, grid_indices)
            ro = self.readout(gridify_by_indices(graph, grid_indices))
        h = self.mlp_activation(self.fc1(ro))
        h = self.mlp_activation(self.fc2(h))
        return self.fcf(h)

        

In [None]:
clf = ReflacxClassifier(device,
                        1027,
                        4112,
                        6,
                        grid_size,
                        f_message,
                        ReflacxReadout())

In [None]:
h = clf(batch, 'h', gridify_indices(batch, grid_size))

In [None]:
h.shape

In [None]:
h

### Train

one training step

In [None]:
opt = torch.optim.Adam(clf.parameters())
h = clf(batch, 'h', gridify_indices(batch, grid_size))
loss = F.mse_loss(h, labels)
opt.zero_grad()
loss.backward()
opt.step()

loading from DGL dataset

In [None]:
from dgl.data.utils import split_dataset
from dgl.dataloading import GraphDataLoader

In [None]:
data_split = [0.8, 0.1, 0.1] # train, val, test

In [None]:
dataset = dgl.data.CSVDataset(dataset_pth)

Regularize duration and features

In [None]:
from regularization.regularization_pipeline import RegularizationPipeline

In [None]:
regppl = RegularizationPipeline(dataset, device, ['feats', 'duration'])

In [None]:
train, val, test = split_dataset(dataset, data_split, shuffle=True)
len(train), len(val), len(test), train.__class__

In [None]:
loader = GraphDataLoader(train, batch_size=10, shuffle=False, drop_last=False)

In [None]:
for epoch in range(1):
    for b, l in loader:
        b = b.to(device)
        l = l.to(device)
        init_graph(b)
        graph_indices = gridify_indices(b, grid_size)
        h = clf(b, 'h', graph_indices)
        loss = F.mse_loss(h, l)
        opt.zero_grad()
        loss.backward()
        opt.step()

In [None]:
torch.save(clf.state_dict(), 'test.pt')