In [3]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import dgl
from dgl.nn import GraphConv
import h5py
from HighLevelFeatures import HighLevelFeatures as HLF
from XMLHandler import XMLHandler

In [4]:
HLF_1_photons = HLF('photon', filename='../data/binning_dataset_1_photons.xml')
photon_file = h5py.File('../data/dataset_1_photons_1.hdf5', 'r')
xml = XMLHandler('photon', filename='../data/binning_dataset_1_photons.xml')

def get_orig_coord(bin_no, xml):
    for layer in range(len(xml.bin_edges)):
        if bin_no >= xml.bin_edges[layer]:
            continue
        else:
            layer = layer-1
            break
    bin_no = bin_no - xml.bin_edges[layer]
    r_bin = int(bin_no / xml.a_bins[layer])
    alpha_bin = bin_no % xml.a_bins[layer]
    return layer, r_bin, alpha_bin

def get_coord(bin_no, xml):
    layer, r_bin, alpha_bin = get_orig_coord(bin_no, xml)
    r = xml.r_midvalue[layer][r_bin]
    alpha = xml.alphaListPerLayer[layer][r_bin][alpha_bin]
    eta = r * math.sin(alpha)
    phi = r * math.cos(alpha)
    z = layer
    return eta, phi, z

def generate_graph_from_incident(data, incident_no):
    bin_num = len(data[incident_no])

    edges_start = torch.arange(bin_num).repeat(bin_num)
    edges_end = torch.repeat_interleave(torch.arange(bin_num),bin_num)
    g = dgl.graph((edges_start, edges_end))

    g.ndata['x'] = torch.empty(g.num_nodes(), 3)

    for bin_no in range(len(data[incident_no])):
        eta, phi, z = get_coord(bin_no, xml)
        g.ndata['x'][bin_no] = torch.tensor([eta, phi, z])

    return g

data = photon_file["showers"][:]

graph_list = []

# for incident_no in range(len(data)):
for incident_no in range(100):
    g = generate_graph_from_incident(data, incident_no)
    graph_list.append(g)

In [53]:
class GCN(nn.Module):

    def __init__(self, in_feats, h_feats):
        super(GCN, self).__init__()
        self.conv = GraphConv(in_feats, h_feats)

    def forward(self, g, in_feat):
        h = self.conv(g, in_feat)
        h = F.relu(h)
        return h

class GCNLayer(nn.Module):

    def __init__(self, c_in, c_out):
        super().__init__()
        self.projection = nn.Linear(c_in, c_out)

    def forward(self, node_feats, adj_matrix):
        num_neighbours = adj_matrix.sum(dim=-1, keepdims=True)
        node_feats = self.projection(node_feats)
        node_feats = torch.bmm(adj_matrix, node_feats)
        node_feats = node_feats / num_neighbours
        return node_feats


node_feats = graph_list[0].ndata['x'].view(1, 368, 3)
adj_matrix = torch.ones(1, 368, 368)
layer = GCNLayer(c_in=3, c_out=3)
layer.projection.weight.data = torch.Tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
layer.projection.bias.data = torch.Tensor([0., 0., 0.])

with torch.no_grad():
    out_feats = layer(node_feats, adj_matrix)

print(node_feats)
print(out_feats)

tensor([[[   0.0000,    2.5000,    0.0000],
         [   0.0000,    7.5000,    0.0000],
         [   0.0000,   20.0000,    0.0000],
         ...,
         [   0.0000,  300.0000,   12.0000],
         [   0.0000,  700.0000,   12.0000],
         [   0.0000, 1500.0000,   12.0000]]])
tensor([[[1.4512e-07, 1.3166e+01, 1.6712e+00],
         [1.4512e-07, 1.3166e+01, 1.6712e+00],
         [1.4512e-07, 1.3166e+01, 1.6712e+00],
         ...,
         [1.4512e-07, 1.3166e+01, 1.6712e+00],
         [1.4512e-07, 1.3166e+01, 1.6712e+00],
         [1.4512e-07, 1.3166e+01, 1.6712e+00]]])
