In [144]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn import GraphConv
import h5py
from HighLevelFeatures import HighLevelFeatures as HLF
from XMLHandler import XMLHandler

In [145]:
HLF_1_photons = HLF('photon', filename='../data/binning_dataset_1_photons.xml')
photon_file = h5py.File('../data/dataset_1_photons_1.hdf5', 'r')
xml = XMLHandler('photon', filename='../data/binning_dataset_1_photons.xml')

In [146]:
def get_orig_coord(bin_no, xml):
    for layer in range(len(xml.bin_edges)):
        if bin_no >= xml.bin_edges[layer]:
            continue
        else:
            layer = layer-1
            break
    bin_no = bin_no - xml.bin_edges[layer]
    r_bin = int(bin_no / xml.a_bins[layer])
    alpha_bin = bin_no % xml.a_bins[layer]
    return layer, r_bin, alpha_bin

def get_coord(bin_no, xml):
    layer, r_bin, alpha_bin = get_orig_coord(bin_no, xml)
    r = xml.r_midvalue[layer][r_bin]
    alpha = xml.alphaListPerLayer[layer][r_bin][alpha_bin]
    eta = r * math.sin(alpha)
    phi = r * math.cos(alpha)
    z = layer
    return eta, phi, z

def generate_graph_from_incident(incident_data):
    bin_num = len(incident_data)

    edges_start = torch.arange(bin_num).repeat(bin_num)
    edges_end = torch.repeat_interleave(torch.arange(bin_num),bin_num)
    g = dgl.graph((edges_start, edges_end))

    g.ndata['x'] = torch.empty(g.num_nodes(), 4)

    for bin_no in range(bin_num):
        eta, phi, z = get_coord(bin_no, xml)
        energy = incident_data[bin_no]
        g.ndata['x'][bin_no] = torch.tensor([eta, phi, z, energy])

    return g

def generate_knn_graph_from_incident(incident_data, k):
    bin_num = len(incident_data)

    x = torch.empty(bin_num, 3)
    energy = torch.empty(bin_num)

    for bin_no in range(bin_num):
        eta, phi, z = get_coord(bin_no, xml)
        energy[bin_no] = incident_data[bin_no]
        x[bin_no] = torch.tensor([eta, phi, z])

    knn_g = dgl.knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean')

    energy = energy.view(bin_num,1)
    knn_g.ndata['x'] = torch.cat((x,energy),1)

    return knn_g

In [152]:
data = photon_file["showers"][:]

graph_list = []
knn_graph_list = []

# for incident_no in range(len(data)):
for incident_no in range(1):
    g = generate_graph_from_incident(data[incident_no])
    graph_list.append(g)
    knn_g = generate_knn_graph_from_incident(data[incident_no], 5)
    knn_graph_list.append(knn_g)

In [148]:
# class GCNLayer(nn.Module):
#
#     def __init__(self, c_in, c_out):
#         super().__init__()
#         self.projection = nn.Linear(c_in, c_out)
#
#     def forward(self, node_feats, adj_matrix):
#         num_neighbours = adj_matrix.sum(dim=-1, keepdims=True)
#         node_feats = self.projection(node_feats)
#         node_feats = torch.bmm(adj_matrix, node_feats)
#         node_feats = node_feats / num_neighbours
#         return node_feats
#
# test_graph = graph_list[0]
# node_feats = test_graph.ndata['x'].view(1, 368, 4)
# adj_matrix = torch.ones(1, 368, 368)
# layer = GCNLayer(c_in=4, c_out=4)
# layer.projection.weight.data = torch.Tensor([[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]])
# layer.projection.bias.data = torch.Tensor([0., 0., 0., 0.])
#
# with torch.no_grad():
#     out_feats = layer(node_feats, adj_matrix)
#
# print("input: ", node_feats)
# print("output: ", out_feats)

In [149]:
class GCN(nn.Module):

    def __init__(self, in_feats, h_feats):
        super(GCN, self).__init__()
        self.conv = GraphConv(in_feats, h_feats)

    def forward(self, g, in_feat):
        h = self.conv(g, in_feat)
        h = F.relu(h)
        return h

In [153]:
# Pass a graph through GCN
test_graph = graph_list[0]
node_feats = test_graph.ndata['x']
model = GCN(in_feats=4, h_feats=4)

with torch.no_grad():
    out_feats = model(test_graph, node_feats)

print("input: ", node_feats)
print("output: ", out_feats)

input:  tensor([[   0.0000,    2.5000,    0.0000,  300.0140],
        [   0.0000,    7.5000,    0.0000,   44.5615],
        [   0.0000,   20.0000,    0.0000,    0.0000],
        ...,
        [   0.0000,  300.0000,   12.0000,    0.0000],
        [   0.0000,  700.0000,   12.0000,    0.0000],
        [   0.0000, 1500.0000,   12.0000,    0.0000]])
output:  tensor([[ 4.9052, 15.1358,  0.0000, 26.0411],
        [ 4.9052, 15.1358,  0.0000, 26.0411],
        [ 4.9052, 15.1358,  0.0000, 26.0411],
        ...,
        [ 4.9052, 15.1358,  0.0000, 26.0411],
        [ 4.9052, 15.1358,  0.0000, 26.0411],
        [ 4.9052, 15.1358,  0.0000, 26.0411]])


In [157]:
# Pass a k-nn graph through GCN
test_graph = knn_graph_list[0]
node_feats = test_graph.ndata['x']
model = GCN(in_feats=4, h_feats=4)

with torch.no_grad():
    out_feats = model(test_graph, node_feats)

print("input: ", node_feats)
print("output: ", out_feats)

input:  tensor([[   0.0000,    2.5000,    0.0000,  300.0140],
        [   0.0000,    7.5000,    0.0000,   44.5615],
        [   0.0000,   20.0000,    0.0000,    0.0000],
        ...,
        [   0.0000,  300.0000,   12.0000,    0.0000],
        [   0.0000,  700.0000,   12.0000,    0.0000],
        [   0.0000, 1500.0000,   12.0000,    0.0000]])
output:  tensor([[   0.0000,    0.0000,   70.6326,   11.7041],
        [   0.0000,    0.0000,   37.3929,    3.4382],
        [   0.0000,   17.0260,   13.2687,    0.0000],
        ...,
        [  31.1571,  216.6079,   47.7082,    0.0000],
        [  88.1828,  500.0758,   95.4867,    0.0000],
        [ 170.2868, 1004.7780,  175.4178,    0.0000]])
