In [1]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F


from iou_graph import IOUGraph
from dgl_reflacx_tools.tools import gridify, gridify_indices, gridify_by_indices, grid_readout

from dgl_reflacx_tools.dgl_reflacx_collection import GraphCollection

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
dataset_pth = 'datasets/reflacx_densnet225_iou'
collection = GraphCollection(dataset_pth, IOUGraph)

loading metadata
metadata loaded from file
Done loading data from cached files.


#### Getting sample batch

In [4]:
batch_size = 5
grid_size = 4

In [5]:
pairs = [collection.fetch_by_dgl_index(i) for i in range(batch_size)]

34cedb74-d0996b40-6d218312-a9174bea-d48dc033 -- P102R108387 
  Fixation 3 out of chest bounding box

34cedb74-d0996b40-6d218312-a9174bea-d48dc033 -- P102R108387 
  Fixation 4 out of chest bounding box

34cedb74-d0996b40-6d218312-a9174bea-d48dc033 -- P102R108387 
  Fixation 5 out of chest bounding box

34cedb74-d0996b40-6d218312-a9174bea-d48dc033 -- P102R108387 
  Fixation 18 out of chest bounding box

34cedb74-d0996b40-6d218312-a9174bea-d48dc033 -- P102R108387 
  Fixation 20 out of chest bounding box

34cedb74-d0996b40-6d218312-a9174bea-d48dc033 -- P102R108387 
  Fixation 28 out of chest bounding box

34cedb74-d0996b40-6d218312-a9174bea-d48dc033 -- P102R108387 
  Fixation 30 out of chest bounding box

34cedb74-d0996b40-6d218312-a9174bea-d48dc033 -- P102R108387 
  Fixation 31 out of chest bounding box

34cedb74-d0996b40-6d218312-a9174bea-d48dc033 -- P102R108387 
  Fixation 53 out of chest bounding box

34cedb74-d0996b40-6d218312-a9174bea-d48dc033 -- P102R108387 
  Fixation 55 out of che

In [6]:
graphs = [pair.dgl_graph for pair in pairs]
labels = [pair.dgl_labels for pair in pairs]


In [7]:
batch = dgl.batch(graphs).to(device)
labels = torch.cat(labels).reshape((batch_size, len(labels[0]))).to(device)

batch, labels.shape, batch.device, labels.device

(Graph(num_nodes=721, num_edges=12067,
       ndata_schemes={'feats': Scheme(shape=(1024,), dtype=torch.float32), 'duration': Scheme(shape=(), dtype=torch.float32), 'norm_bottom_right': Scheme(shape=(2,), dtype=torch.float32), 'norm_top_left': Scheme(shape=(2,), dtype=torch.float32), 'norm_y': Scheme(shape=(), dtype=torch.float32), 'norm_x': Scheme(shape=(), dtype=torch.float32)}
       edata_schemes={'weight': Scheme(shape=(), dtype=torch.float32)}),
 torch.Size([5, 6]),
 device(type='cuda', index=0),
 device(type='cuda', index=0))

#### initialization

Setting node features to be convolved. Concatenating (X, Y) position, duration, and extracted features from gaze crop

In [8]:
def init_graph(g):
    g.ndata['h'] = torch.cat([g.ndata['norm_x'].unsqueeze(1),
                              g.ndata['norm_y'].unsqueeze(1),
                              g.ndata['duration'].unsqueeze(1),
                              g.ndata['feats']],
                              dim=1)
    g.update_all(fn.copy_e('weight', 'm'), fn.sum('m', 'neigh_weight'))


In [9]:
init_graph(batch)
batch.ndata['h'].shape

torch.Size([721, 1027])

#### convolution module on a grid

In [10]:
def pass_messages(g, feat_nm, w_nm, sum_w_nm):
    g.update_all(fn.v_mul_e(feat_nm, w_nm, 'm'), fn.sum('m', feat_nm))
    g.ndata[feat_nm] = torch.divide(g.ndata[feat_nm], g.ndata[sum_w_nm].unsqueeze(1))

In [11]:
class GridConv(nn.Module):
    def __init__(self,
                 device,
                 in_feats,
                 out_feats,
                 grid_size,
                 pass_messages,
                 activation=F.relu):
        super(GridConv, self).__init__()
        self.grid_lin = [[nn.Linear(in_feats, out_feats).to(device)
                          for j in range(grid_size)]
                         for i in range(grid_size)]
        self.grid_size = grid_size
        self.pass_messages = pass_messages
        self.activation = activation

    def forward(self, graph, feat_nm, grid_indices, out_feat_nm=None):
        # pass messages (convolution) in whole graph
        self.pass_messages(graph, feat_nm)
        
        # activation on grid cell model
        grid = gridify_by_indices(graph, grid_indices)
        new_feats = None
        i_s = None
        for i, line in enumerate(grid):
            for j, sg in enumerate(line):
                conv_feats = self.activation(self.grid_lin[i][j](sg.ndata[feat_nm]))
                
                #concatenate new features to uptadate parent graph
                if new_feats is None:
                    new_feats = conv_feats
                    i_s = grid_indices[i][j]
                else:
                    new_feats = torch.cat([new_feats, conv_feats])
                    i_s = torch.cat([i_s, grid_indices[i][j]])
        
        #update parent graph with features calculated by grid
        i_s = torch.sort(i_s).indices
        new_feats = new_feats[i_s]
        graph.ndata[feat_nm if out_feat_nm is None else out_feat_nm] = new_feats
        
        return new_feats

In [12]:
f_message = lambda g, feat_nm: pass_messages(g, feat_nm, 'weight', 'neigh_weight')

In [13]:
conv = GridConv(device, 1027, 1027, grid_size, f_message).to(device)

In [14]:
g_i = gridify_indices(batch, grid_size)
with batch.local_scope():
   h = conv(batch, 'h', g_i)
   h2 = conv(batch, 'h', g_i)

with batch.local_scope():
   h3 = conv(batch, 'h', g_i)

In [15]:
torch.all(h == h2), torch.all(h == h3)

(tensor(False, device='cuda:0'), tensor(True, device='cuda:0'))

In [16]:
h.shape

torch.Size([721, 1027])

Defining readout to be, for each grid cell, the concatenation of the sum of the cell's nodes duration with the convolved features

In [17]:
from dgl_reflacx_tools.tools import Readout

In [18]:
class ReflacxReadout(Readout):
    def __init__(self):
        feats_and_aggrs = [('duration', lambda x, y: dgl.sum_nodes(x, y)),
                           ('h', lambda x, y: dgl.mean_nodes(x, y))]
        super().__init__(feats_and_aggrs)

In [19]:
class ReflacxClassifier(nn.Module):
    def __init__(self,
                 device,
                 input_dim,
                 readout_dim,
                 n_classes,
                 grid_size,
                 pass_messages,
                 readout,
                 conv_activation=F.relu,
                 mlp_activation=F.relu): # TODO make possibel to have diff activations for conv and fc
        super(ReflacxClassifier, self).__init__()
        self.grid_size = grid_size
        new_conv = lambda in_feats, out_feats: GridConv(device,
                                                        in_feats,
                                                        out_feats,
                                                        self.grid_size,
                                                        pass_messages,
                                                        conv_activation)
        self.conv1 = new_conv(input_dim, 512)
        self.conv2 = new_conv(512, 256)

        self.fc1 = nn.Linear(readout_dim, 256).to(device)
        self.fc2 = nn.Linear(256, 64).to(device)

        self.fcf = nn.Linear(64, n_classes).to(device)
        
        self.readout = readout
        self.conv_activation = conv_activation
        self.mlp_activation = mlp_activation



    def forward(self, graph, conv_feat_nm, grid_indices):
        with graph.local_scope():
            h = self.conv1(graph, conv_feat_nm, grid_indices)
            h = self.conv2(graph, conv_feat_nm, grid_indices)
            ro = self.readout(gridify_by_indices(graph, grid_indices))
        h = self.mlp_activation(self.fc1(ro))
        h = self.mlp_activation(self.fc2(h))
        return self.fcf(h)

        

In [20]:
clf = ReflacxClassifier(device,
                        1027,
                        4112,
                        6,
                        grid_size,
                        f_message,
                        ReflacxReadout())

In [21]:
h = clf(batch, 'h', gridify_indices(batch, grid_size))

In [22]:
h.shape

torch.Size([5, 6])

In [23]:
h

tensor([[-0.1350,  0.0447, -0.1107, -0.0893, -0.0383,  0.1142],
        [-0.1392,  0.0296, -0.1065, -0.0864, -0.0450,  0.1048],
        [-0.1418, -0.0012, -0.1608, -0.0764, -0.0730,  0.0861],
        [-0.1807, -0.0118, -0.1457, -0.1331, -0.0824,  0.0853],
        [-0.1235,  0.0427, -0.1037, -0.0954, -0.0436,  0.1128]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

### Train

one training step

In [24]:
opt = torch.optim.Adam(clf.parameters())
h = clf(batch, 'h', gridify_indices(batch, grid_size))
loss = F.mse_loss(h, labels)
opt.zero_grad()
loss.backward()
opt.step()

loading from DGL dataset

In [25]:
from dgl.data.utils import split_dataset
from dgl.dataloading import GraphDataLoader
from dgl.data import DGLDataset

In [26]:
data_split = [0.8, 0.1, 0.1] # train, val, test

In [27]:
dataset = dgl.data.CSVDataset(dataset_pth)

Done loading data from cached files.


In [28]:
dataset[1]

(Graph(num_nodes=88, num_edges=898,
       ndata_schemes={'feats': Scheme(shape=(1024,), dtype=torch.float32), 'duration': Scheme(shape=(), dtype=torch.float32), 'norm_bottom_right': Scheme(shape=(2,), dtype=torch.float32), 'norm_top_left': Scheme(shape=(2,), dtype=torch.float32), 'norm_y': Scheme(shape=(), dtype=torch.float32), 'norm_x': Scheme(shape=(), dtype=torch.float32)}
       edata_schemes={'weight': Scheme(shape=(), dtype=torch.float32)}),
 tensor([0., 0., 1., 0., 0., 0.]))

In [33]:
train, val, test = split_dataset(dataset, data_split, shuffle=True)
len(train), len(val), len(test), train.__class__

(2440, 305, 306, dgl.data.utils.Subset)

In [91]:
loader = GraphDataLoader(train, batch_size=100, shuffle=False, drop_last=False)

In [34]:
outnow = False
errg = {}
for epoch in range(1):
    i = 0
    for b, l in train:
        #try:
            b = b.to(device)
            l = l.to(device)
            init_graph(b)
            graph_indices = gridify_indices(b, grid_size)
            h = clf(b, 'h', graph_indices)
            loss = F.mse_loss(h, l)
            opt.zero_grad()
            loss.backward()
            opt.step()
            i += 1
            
        #except DGLErr:
        #    errg[i] = b
        #    raise(DGLErr)

KeyError: 'norm_x'

In [100]:
torch.save(clf.state_dict(), 'test.pt')