
## All necessary imports

In [1]:
# !cd tools/ && python setup_opera_distance_metric.py build_ext --inplace

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm_notebook as tqdm
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import Sequential
from torch.distributions import Bernoulli


from tools.opera_distance_metric import generate_k_nearest_graph, \
                                        opera_distance_metric_py, \
                                        generate_radius_graph

from graph_rnn import bfs_seq, encode_adj, decode_adj

from torch.nn.utils.rnn import pack_padded_sequence, pack_sequence, pad_sequence, pad_packed_sequence
import random

sns.set(font_scale=2)

In [3]:
device = torch.device('cuda:0')

In [4]:
df = pd.read_pickle('./data/showers.pkl')

In [5]:
def bfs_handmade(G, start):
    visited, queue = set(), [start]
    while queue:
        vertex = queue.pop(0)
        if vertex not in visited:
            visited.add(vertex)
            edges = sorted(G.out_edges(vertex, data=True), key=lambda x: x[2]['weight'])
            queue.extend(set([x[1] for x in edges]) - visited)
    return np.array(list(visited))[np.argsort(np.array(list(G.nodes())))]


def encode_adj(adj, max_prev_node=10, is_full = False):
    '''
    :param adj: n*n, rows means time step, while columns are input dimension
    :param max_degree: we want to keep row number, but truncate column numbers
    :return:
    '''
    if is_full:
        max_prev_node = adj.shape[0] - 1
    
    # successors only
    adj = adj
    
    # pick up lower tri
    adj = np.tril(adj, k=-1)
    n = adj.shape[0]
    adj = adj[1:n, 0:n-1]

    # use max_prev_node to truncate
    # note: now adj is a (n-1) * (n-1) matrix
    adj_output = np.zeros((adj.shape[0], max_prev_node))
    for i in range(adj.shape[0]):
        input_start = max(0, i - max_prev_node + 1)
        input_end = i + 1
        output_start = max_prev_node + input_start - input_end
        output_end = max_prev_node
        adj_output[i, output_start:output_end] = adj[i, input_start:input_end]
        adj_output[i,:] = adj_output[i,:][::-1] # reverse order

    return adj_output

## Model parameters

In [6]:
max_prev_node = 10
graph_state_size = 64
embedding_size = 256
edge_rnn_embedding_size = 64

In [7]:
batch_size = 50

In [8]:
from collections import namedtuple

In [9]:
graphrnn_shower = namedtuple('graphrnn_shower', field_names=['x', 
                                                             'adj', 
                                                             'adj_out', 
                                                             'adj_squared', 
                                                             'ele_p',
                                                             'distances'])

In [22]:
max_prev_node=50
def preprocess_shower_for_graphrnn(shower, device, k=4, symmetric=False):
    X = np.vstack([
        np.arange(len(shower.SX)),
        shower.SX,
        shower.SY,  
        shower.SZ, 
        shower.TX,
        shower.TY,
        shower.ele_P]
    ).T
    print(len(X))
    edges_from, edges_to, distances = generate_k_nearest_graph(X, k=k, symmetric=symmetric)
    G = nx.Graph()
    edges = []
    for i in range(len(distances)):
        edges.append((edges_from[i], edges_to[i], {'weight': distances[i]}))
        
    G.add_edges_from(edges)
    G = nx.DiGraph(G)

    adj = np.asarray(nx.to_numpy_matrix(G))

    start_idx = 0
    x_idx = np.array(bfs_handmade(G, start_idx))
    adj = adj[np.ix_(x_idx, x_idx)]
    adj[adj!=0] = 1
    # actual data
    adj_output = encode_adj(adj, max_prev_node=max_prev_node)
    adj_recover = decode_adj(adj_output)
    print('error\n', np.sum(np.abs(adj_recover-adj)))
    
    X = X[x_idx, 1:]
    X = X / np.array([1e3, 1e3, 1e3, 1, 1, 1])
    distances = np.log(1. + np.array(distances))
    
    # for now forget about distances
    # TODO: what to do with distances?
    adj_output[adj_output!=0] = 1.
    
    adj_output_t = torch.tensor(np.append(np.ones((1, max_prev_node)), 
                                          adj_output, axis=0), 
                                dtype=torch.float32).to(device).view(1, -1, max_prev_node)
    
    X_t = torch.tensor(X[:, :-1], dtype=torch.float32).to(device).view(1, -1, 5)

    adj_out_t = torch.LongTensor(np.array(list(nx.from_numpy_matrix(decode_adj(adj_output), 
                                                                    create_using=nx.DiGraph).edges())).T).to(device)
    
    adj_squared_t = torch.tensor(adj, dtype=torch.float32).to(device)
    
    return graphrnn_shower(adj=adj_output_t, 
                           x=X_t, 
                           adj_out=adj_out_t,
                           adj_squared=adj_squared_t,
                           distances=distances,
                           ele_p=torch.tensor(X[-1, -1], dtype=torch.float32).to(device))

In [23]:
%%time
showers_train = []
for i, shower in list(df.iterrows())[:3]:
    showers_train.append(preprocess_shower_for_graphrnn(shower, device=device, k=3))

318
error
 0.0
153
error
 0.0
521
error
 8.0
CPU times: user 1.11 s, sys: 12 ms, total: 1.12 s
Wall time: 1.11 s


In [24]:
len(df)

8033

#### GraphRNN 

Generates embeddings for nodes.

In [25]:
class GraphRNN(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, 
                 num_layers, has_input=True, has_output=False, output_size=None):
        super(GraphRNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.has_input = has_input
        self.has_output = has_output

        if has_input:
            self.input = nn.Linear(input_size, embedding_size)
            self.rnn = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, 
                              num_layers=num_layers, batch_first=True)
        else:
            self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, 
                              num_layers=num_layers, batch_first=True)
        if has_output:
            self.output = nn.Sequential(
                nn.Linear(hidden_size, embedding_size),
                nn.ReLU(),
                nn.Linear(embedding_size, output_size)
            )

        self.relu = nn.ReLU()
        # initialize
        self.hidden_emb = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, self.hidden_size)
        )
        self.hidden = None  # need initialize before forward run

    def init_hidden(self, input, batch_size):
        hidden_emb = torch.cat([self.hidden_emb(input).view(1, batch_size, self.hidden_size), 
                                torch.zeros(self.num_layers - 1, batch_size, self.hidden_size).cuda()])
        return hidden_emb

    def forward(self, input_raw, pack=False, input_len=None):
        output_raw_emb, output_raw, output_len = None, None, None
        
        if self.has_input:
            input = self.input(input_raw)
            input = self.relu(input)
        else:
            input = input_raw
        if pack:
            
            pass # input = pack_sequence(input)
        
        output_raw_emb, self.hidden = self.rnn(input, self.hidden)
        if pack:
            output_raw_emb, output_len = pad_packed_sequence(output_raw_emb, batch_first=True)
        
        if self.has_output:
            output_raw = self.output(output_raw_emb)
            
        if pack:
            output_raw_packed = pack_padded_sequence(output_raw, lengths=output_len, batch_first=True)
            return output_raw_emb, output_raw, output_len
        
        # return hidden state at each time step
        return output_raw_emb, output_raw, output_len

In [26]:
model = GraphRNN(input_size=max_prev_node, 
                 embedding_size=max_prev_node, 
                 output_size=edge_rnn_embedding_size, 
                 has_output=True, 
                 hidden_size=embedding_size, 
                 num_layers=4, 
                 has_input=False).to(device)

### Edge network

In [27]:
edge_nn = GraphRNN(input_size=1, 
                   embedding_size=edge_rnn_embedding_size,
                   hidden_size=edge_rnn_embedding_size, 
                   num_layers=4, has_input=True, has_output=True, 
                   output_size=1).to(device)

### FeaturesGCN

In [28]:
import torch_geometric.transforms as T
import torch_cluster
import torch_geometric

from torch_geometric.nn import NNConv, GCNConv, GraphConv
from torch_geometric.nn import PointConv, EdgeConv, SplineConv


class FeaturesGCN(torch.nn.Module):
    def __init__(self, dim_in, embedding_size=128, num_layers=4, dim_out=6):
        super().__init__()
        
        self.wconv_in = EdgeConv(Sequential(nn.Linear(dim_in * 2, embedding_size)), 'max')
        
        self.layers = nn.ModuleList(modules=[EdgeConv(Sequential(nn.Linear(embedding_size * 2, embedding_size)), 'max')
                                   for i in range(num_layers)])

        self.wconv_out = EdgeConv(Sequential(nn.Linear(embedding_size * 2, dim_out)), 'max')

        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.wconv_in(x=x, edge_index=edge_index)
        
        for l in self.layers:
            x = l(x=x, edge_index=edge_index)
        
        x = self.wconv_out(x=x, edge_index=edge_index)
        
        return x

In [29]:
features_nn = FeaturesGCN(dim_in=edge_rnn_embedding_size * max_prev_node, 
                          embedding_size=128, num_layers=4,
                          dim_out=5).to(device=device)

#### Losses

In [30]:
sigmoid = nn.Sigmoid().to(device)
loss_bce = nn.BCELoss().to(device)
loss_mse = torch.nn.MSELoss().to(device)

def loss_mse_edges(shower, features):
    return loss_mse((shower.x[0][shower.adj_out[0]] - shower.x[0][shower.adj_out[1]]), 
                    (features[shower.adj_out[0]] - features[shower.adj_out[1]]))

### process_train_graphrnn

In [31]:
from torch.nn.utils.rnn import PackedSequence

def process_train_graphrnn(showers_batch):
    batch_size = len(showers_batch)
    
    model.hidden = model.init_hidden(input=torch.stack([x.ele_p for x in showers_batch]).view(-1, 1), 
                                     batch_size=batch_size)
    
    packed_adj_batch = pack_sequence([x.adj[0] for x in showers_batch])
    _, embedding_batch, output_len = model(packed_adj_batch, pack=True)

    packed_embedding_batch = pack_padded_sequence(embedding_batch, output_len, batch_first=True).data
    
    hidden_null = torch.zeros(4 - 1, packed_embedding_batch.shape[0], packed_embedding_batch.shape[1]).to(device)
    edge_nn.hidden = torch.cat((packed_embedding_batch.view(1, 
                                                            packed_embedding_batch.size(0), 
                                                            packed_embedding_batch.size(1)), hidden_null), dim=0)
    packed_adj_batch_data = packed_adj_batch.data
    packed_adj_batch_data = packed_adj_batch_data.view(packed_adj_batch_data.shape[0], 
                                                       packed_adj_batch_data.shape[1], 1)
    
    packed_adj_batch = torch.cat((torch.ones(packed_adj_batch_data.shape[0], 1, 1).to(device), 
                                  packed_adj_batch_data[:, 0:-1, 0:1]), dim=1)
    
    edges_emb, edges, _ = edge_nn(packed_adj_batch)
    return embedding_batch, output_len, pad_packed_sequence(PackedSequence(edges_emb.contiguous().view(edges_emb.size(0), -1), output_len))[0], loss_bce(torch.sigmoid(edges), packed_adj_batch_data)

In [32]:
showers_batch = showers_train
showers_batch = sorted(showers_batch, key=lambda x: -x.adj.shape[1])

In [37]:
embedding_batch, output_len, edges, ll_bce = process_train_graphrnn(showers_batch)

In [45]:
edges.shape

torch.Size([3, 521, 3200])

In [50]:
edges[2, :10]

tensor([[-0.0724,  0.0352, -0.0781,  ..., -0.9870, -0.9693,  0.9737],
        [-0.1319,  0.1038, -0.1659,  ..., -0.9845, -0.9720,  0.9692],
        [-0.1390,  0.1131, -0.1752,  ..., -0.9844, -0.9719,  0.9688],
        ...,
        [-0.1062,  0.0731, -0.1306,  ..., -0.9856, -0.9729,  0.9717],
        [-0.0825,  0.0469, -0.0950,  ..., -0.9875, -0.9739,  0.9764],
        [-0.0780,  0.0419, -0.0880,  ..., -0.9876, -0.9736,  0.9766]],
       device='cuda:0', grad_fn=<SliceBackward>)

### Optimization of edge predictions

In [34]:
from itertools import chain

learning_rate = 1e-5
optimizer_bce = torch.optim.Adam(list(model.parameters()) + 
                                 list(edge_nn.parameters()), 
                                 lr=learning_rate)

In [35]:
model.train()
edge_nn.train()

for i in tqdm(range(5000)):
    optimizer_bce.zero_grad()
    
    showers_batch = showers_train # random.sample(showers_train, batch_size)
    showers_batch = sorted(showers_batch, key=lambda x: -x.adj.shape[1])
    
    # iterate over showers in batch
    # and calc losses
    embedding_batch, output_len, edges_emb, ll_bce = process_train_graphrnn(showers_batch)

    ll_bce.backward()
    
    optimizer_bce.step()
    
    print(ll_bce.item())
    
    del embedding_batch, output_len

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

0.6991034746170044
0.6988672018051147
0.698631763458252
0.6983963847160339
0.6981610655784607
0.6979257464408875
0.6976904273033142
0.697455108165741
0.6972197890281677
0.6969844698905945
0.6967491507530212
0.696513831615448
0.6962785720825195
0.6960433721542358
0.6958080530166626
0.6955728530883789
0.69533771276474
0.6951026320457458
0.6948676705360413
0.6946330666542053
0.694399893283844
0.6941666007041931
0.6939331889152527
0.693699836730957
0.693466305732727
0.6932325959205627
0.6929988265037537
0.692764937877655
0.6925308704376221
0.6922966837882996
0.692062258720398
0.691827654838562
0.6915930509567261
0.6913580298423767
0.6911228895187378
0.6908875703811646
0.6906520128250122
0.690416157245636
0.6901800632476807
0.6899436116218567
0.6897069811820984
0.6894699931144714
0.6892327666282654
0.6889950633049011
0.688757061958313
0.688518762588501
0.6882799863815308
0.6880409121513367
0.6878013014793396
0.6875613331794739
0.68732088804245
0.6870800852775574
0.6868386268615723
0.6865968

0.4546852111816406
0.45371460914611816
0.45274531841278076
0.45177680253982544
0.450808584690094
0.44984108209609985
0.44887399673461914
0.4479074776172638
0.44694218039512634
0.4459781348705292
0.44501492381095886
0.44405242800712585
0.44309067726135254
0.4421297609806061
0.44117021560668945
0.4402120113372803
0.4392547309398651
0.4382983446121216
0.4373430013656616
0.4363892674446106
0.43543678522109985
0.43448537588119507
0.43353527784347534
0.4325869381427765
0.43163979053497314
0.43069374561309814
0.42974936962127686
0.42880651354789734
0.42786505818367004
0.4269254803657532
0.42598748207092285
0.4250507950782776
0.4241161346435547
0.4231833219528198
0.4222523868083954
0.42132309079170227
0.4203964173793793
0.4194740056991577
0.41855430603027344
0.4176364839076996
0.4167204797267914
0.41580644249916077
0.4148942232131958
0.4139837920665741
0.41307559609413147
0.4121691584587097
0.41126495599746704
0.41036275029182434
0.40946266055107117
0.4085646867752075
0.4076687693595886
0.4067

0.23227815330028534
0.23208987712860107
0.23190250992774963
0.23171566426753998
0.23152963817119598
0.2313443422317505
0.23115982115268707
0.23097598552703857
0.23079293966293335
0.23061062395572662
0.2304290235042572
0.2302481234073639
0.2300678938627243
0.22988836467266083
0.2297096699476242
0.2295314520597458
0.22935403883457184
0.22917728126049042
0.2290012538433075
0.2288258671760559
0.22865116596221924
0.22847707569599152
0.2283036708831787
0.22813095152378082
0.22795887291431427
0.22778744995594025
0.22761669754981995
0.2274465411901474
0.22727707028388977
0.22710825502872467
0.22694005072116852
0.2267724573612213
0.22660553455352783
0.22643925249576569
0.2262735664844513
0.22610852122306824
0.22594407200813293
0.22578021883964539
0.2256169617176056
0.22545433044433594
0.22529233992099762
0.22513088583946228
0.22497005760669708
0.22480982542037964
0.22465020418167114
0.22449111938476562
0.22433267533779144
0.22417478263378143
0.22401747107505798
0.2238607257604599
0.223704561591

0.1891508847475052
0.18909351527690887
0.1890362799167633
0.1889791339635849
0.18892218172550201
0.1888653188943863
0.1888086199760437
0.18875201046466827
0.18869556486606598
0.18863923847675323
0.18858306109905243
0.18852700293064117
0.18847106397151947
0.1884152591228485
0.1883595883846283
0.18830405175685883
0.1882486492395401
0.18819338083267212
0.1881382167339325
0.1880832016468048
0.18802830576896667
0.18797355890274048
0.18791890144348145
0.18786439299583435
0.187810018658638
0.1877557635307312
0.18770165741443634
0.18764764070510864
0.1875937581062317
0.18754000961780548
0.18748638033866882
0.1874329000711441
0.18737952411174774
0.18732628226280212
0.18727314472198486
0.18722018599510193
0.18716730177402496
0.18711456656455994
0.18706196546554565
0.1870095133781433
0.18695713579654694
0.1869049221277237
0.1868528127670288
0.18680085241794586
0.18674899637699127
0.18669724464416504
0.18664567172527313
0.18659420311450958
0.18654285371303558
0.18649159371852875
0.1864404678344726

0.1739446371793747
0.1739223748445511
0.17390014231204987
0.17387795448303223
0.17385581135749817
0.1738336980342865
0.17381158471107483
0.17378957569599152
0.1737675666809082
0.17374560236930847
0.17372368276119232
0.17370180785655975
0.17367994785308838
0.17365813255310059
0.17363636195659637
0.17361462116241455
0.1735929250717163
0.17357125878334045
0.17354965209960938
0.1735280454158783
0.1735064834356308
0.17348496615886688
0.17346349358558655
0.1734420657157898
0.17342065274715424
0.17339926958084106
0.17337794601917267
0.17335663735866547
0.17333537340164185
0.1733141541481018
0.17329294979572296
0.1732717752456665
0.17325066030025482
0.17322956025600433
0.17320851981639862
0.1731874942779541
0.17316649854183197
0.17314556241035461
0.17312464118003845
0.17310374975204468
0.17308290302753448
0.17306207120418549
0.17304129898548126
0.17302055656909943
0.17299984395503998
0.17297914624214172
0.17295850813388824
0.17293789982795715
0.17291732132434845
0.17289674282073975
0.172876238

0.16644613444805145
0.16642609238624573
0.16640609502792358
0.16638612747192383
0.16636620461940765
0.16634635627269745
0.16632655262947083
0.16630686819553375
0.16628722846508026
0.16626764833927155
0.1662481278181076
0.16622863709926605
0.16620922088623047
0.16618989408016205
0.1661706119775772
0.16615141928195953
0.16613230109214783
0.1661132574081421
0.1660943478345871
0.16607548296451569
0.16605672240257263
0.16603800654411316
0.16601936519145966
0.16600075364112854
0.16598227620124817
0.16596384346485138
0.16594547033309937
0.16592715680599213
0.16590891778469086
0.16589069366455078
0.16587257385253906
0.16585451364517212
0.16583649814128876
0.16581854224205017
0.16580064594745636
0.16578282415866852
0.16576504707336426
0.16574729979038239
0.16572964191436768
0.16571202874183655
0.1656944751739502
0.16567695140838623
0.16565953195095062
0.1656421720981598
0.16562488675117493
0.16560766100883484
0.16559050977230072
0.165573388338089
0.16555634140968323
0.16553936898708344
0.165522

0.15943309664726257
0.15940311551094055
0.15937280654907227
0.15934225916862488
0.15931154787540436
0.1592806875705719
0.15924954414367676
0.15921802818775177
0.15918628871440887
0.1591542363166809
0.15912199020385742
0.1590893566608429
0.15905652940273285
0.1590234935283661
0.15899024903774261
0.1589565873146057
0.15892262756824493
0.15888835489749908
0.15885381400585175
0.15881888568401337
0.15878373384475708
0.1587483137845993
0.15871262550354004
0.15867646038532257
0.15864001214504242
0.15860320627689362
0.15856604278087616
0.15852853655815125
0.15849077701568604
0.15845274925231934
0.15841417014598846
0.15837526321411133
0.15833599865436554
0.15829627215862274
0.15825624763965607
0.1582159548997879
0.15817521512508392
0.15813395380973816
0.15809231996536255
0.1580502837896347
0.15800778567790985
0.1579650342464447
0.15792177617549896
0.15787798166275024
0.1578337401151657
0.15778903663158417
0.157743901014328
0.15769848227500916
0.1576523780822754
0.15760578215122223
0.15755875408

0.14912354946136475
0.14911475777626038
0.149105966091156
0.14909720420837402
0.14908845722675323
0.14907972514629364
0.14907100796699524
0.14906232059001923
0.1490536332130432
0.1490449756383896
0.14903631806373596
0.14902767539024353
0.1490190625190735
0.14901046454906464
0.1490018516778946
0.14899328351020813
0.14898473024368286
0.1489761918783188
0.1489676535129547
0.14895914494991302
0.14895063638687134
0.14894214272499084
0.14893367886543274
0.14892521500587463
0.1489167958498001
0.14890837669372559
0.14889994263648987
0.14889155328273773
0.1488831490278244
0.14887478947639465
0.1488664299249649
0.14885808527469635
0.148849755525589
0.14884145557880402
0.14883315563201904
0.14882487058639526
0.14881660044193268
0.1488083451986313
0.1488001048564911
0.14879187941551208
0.14878365397453308
0.14877545833587646
0.14876726269721985
0.14875909686088562
0.1487509161233902
0.14874276518821716
0.14873462915420532
0.14872650802135468
0.14871838688850403
0.14871028065681458
0.14870218932628

0.14628326892852783
0.1462782770395279
0.14627328515052795
0.14626829326152802
0.14626331627368927
0.14625835418701172
0.14625337719917297
0.14624843001365662
0.14624348282814026
0.1462385356426239
0.14623360335826874
0.14622867107391357
0.1462237536907196
0.14621885120868683
0.14621394872665405
0.14620904624462128
0.1462041586637497
0.14619925618171692
0.14619438350200653
0.14618954062461853
0.14618466794490814
0.14617982506752014
0.14617496728897095
0.14617013931274414
0.14616531133651733
0.14616048336029053
0.14615567028522491
0.1461508572101593
0.1461460441350937
0.14614105224609375
0.14613568782806396
0.14613009989261627
0.14612436294555664
0.14611855149269104
0.14611263573169708
0.14610666036605835
0.14610065519809723
0.14609459042549133
0.14608854055404663
0.14608247578144073
0.14607641100883484
0.14607031643390656
0.14606425166130066
0.14605821669101715
0.14605219662189484
0.14604616165161133
0.1460401713848114
0.14603416621685028
0.14602820575237274
0.14602230489253998
0.14601

0.14456374943256378
0.144560769200325
0.14455783367156982
0.14455488324165344
0.14455191791057587
0.14454898238182068
0.1445460319519043
0.1445430964231491
0.14454017579555511
0.14453724026679993
0.14453430473804474
0.14453136920928955
0.14452844858169556
0.14452552795410156
0.14452260732650757
0.14451968669891357
0.14451678097248077
0.14451386034488678
0.14451095461845398
0.14450806379318237
0.14450515806674957
0.14450225234031677
0.14449936151504517
0.14449648559093475
0.14449357986450195
0.14449070394039154
0.14448782801628113
0.14448495209217072
0.1444820761680603
0.1444792002439499
0.14447633922100067
0.14447346329689026
0.14447060227394104
0.14446775615215302
0.1444648951292038
0.14446206390857697
0.14445920288562775
0.14445635676383972
0.1444535255432129
0.14445069432258606
0.14444784820079803
0.1444450169801712
0.14444218575954437
0.14443936944007874
0.1444365531206131
0.14443373680114746
0.14443092048168182
0.1444281041622162
0.14442530274391174
0.1444225013256073
0.1444196999

0.14353597164154053
0.14353375136852264
0.14353151619434357
0.14352929592132568
0.1435270458459854
0.14352481067180634
0.14352259039878845
0.14352035522460938
0.1435181051492691
0.14351586997509003
0.14351363480091095
0.14351141452789307
0.1435091644525528
0.14350692927837372
0.14350469410419464
0.14350245893001556
0.1435001939535141
0.14349794387817383
0.14349570870399475
0.14349345862865448
0.1434912085533142
0.14348894357681274
0.14348670840263367
0.1434844583272934
0.14348220825195312
0.14347994327545166
0.1434776782989502
0.14347541332244873
0.14347317814826965
0.143470898270607
0.14346864819526672
0.14346636831760406
0.1434641033411026
0.14346183836460114
0.14345955848693848
0.143457293510437
0.14345502853393555
0.1434527188539505
0.14345043897628784
0.14344817399978638
0.14344589412212372
0.14344361424446106
0.1434413194656372
0.14343903958797455
0.1434367597103119
0.14343446493148804
0.14343218505382538
0.14342987537384033
0.14342759549617767
0.14342530071735382
0.1434230059385

0.13722248375415802
0.13707061111927032
0.13691608607769012
0.13675902783870697
0.136599600315094
0.13643787801265717
0.136274054646492
0.1361083686351776
0.13594096899032593
0.13577201962471008
0.13560165464878082
0.13543012738227844
0.13525748252868652
0.13508403301239014
0.13491004705429077
0.13473567366600037
0.13456113636493683
0.13438652455806732
0.13421203196048737
0.13403774797916412
0.13386382162570953
0.13369038701057434
0.13351751863956451
0.1333453208208084
0.13317398726940155
0.13300369679927826
0.13283461332321167
0.13266700506210327
0.13250133395195007
0.1323385238647461
0.1321774125099182
0.13201689720153809
0.1318589597940445
0.1317056119441986
0.13155357539653778
0.13140252232551575
0.13125640153884888
0.13111306726932526
0.13097049295902252
0.1308325082063675
0.13069851696491241
0.13056571781635284
0.13043645024299622
0.1303115040063858
0.13018891215324402
0.13006886839866638
0.12995262444019318
0.12983982264995575
0.12972930073738098
0.129621222615242
0.129516646265

### Optimization of feature reconstruction

In [36]:
learning_rate = 1e-5
optimizer_mse = torch.optim.Adam(list(features_nn.parameters()), 
                                 lr=learning_rate)

In [None]:
for i in tqdm(range(3000)):
    optimizer_mse.zero_grad()
    
    showers_batch = showers_train # random.sample(showers_train, batch_size)
    showers_batch = sorted(showers_batch, key=lambda x: -x.adj.shape[1])
    
    # iterate over showers in batch
    # and calc losses
    _, output_len, edges_emb, ll_bce = process_train_graphrnn(showers_batch)
    
    ll_mse_edges = []
    
    # iterate over showers in batch
    # and calc losses
    for k, l in enumerate(output_len):
        shower = showers_batch[k]
        
        embedding = edges_emb[k][:l]

        # teacher forcing
        shower_t = torch_geometric.data.Data(x=embedding, 
                                             edge_index=shower.adj_out).to(device)
        
        # GCN to recover shower features
        features = features_nn(shower_t)
    
        # features prediction loss        
        ll_mse_edges.append(loss_mse_edges(shower, features))
        
        del shower_t, features

    ll_mse_edges = sum(ll_mse_edges) / len(ll_mse_edges)
    
    ll_mse_edges.backward()
    
    optimizer_mse.step()
    
    del edges_emb, output_len
    
    print(ll_bce.item(), 
          ll_mse_edges.item())

### Finetuning

In [None]:
learning_rate = 0.3e-5
optimizer_fine = torch.optim.Adam(list(features_nn.parameters()) +
                                  list(edge_nn.parameters()) +
                                  list(model.parameters()), lr=learning_rate)

In [None]:
scale_vector = torch.tensor([1e1, 1e1, 1e1, 1, 1]).to(device)

In [None]:
def loss_mse_edges(shower, features, scale_vector):
    return loss_mse((shower.x[0][shower.adj_out[0]] - shower.x[0][shower.adj_out[1]]) * scale_vector, 
                    (features[shower.adj_out[0]] - features[shower.adj_out[1]]) * scale_vector)

In [None]:
for i in tqdm(range(5000)):
    optimizer_fine.zero_grad()
    
    showers_batch = showers_train # random.sample(showers_train, batch_size)
    showers_batch = sorted(showers_batch, key=lambda x: -x.adj.shape[1])
    
    # iterate over showers in batch
    # and calc losses
    _, output_len, edges_emb, ll_bce = process_train_graphrnn(showers_batch)
    
    ll_mse_edges = []
    
    # iterate over showers in batch
    # and calc losses
    for k, l in enumerate(output_len):
        shower = showers_batch[k]
        
        embedding = edges_emb[k][:l]

        # teacher forcing
        shower_t = torch_geometric.data.Data(x=embedding, 
                                             edge_index=shower.adj_out).to(device)
        
        # GCN to recover shower features
        features = features_nn(shower_t)
    
        # features prediction loss        
        ll_mse_edges.append(loss_mse_edges(shower, features, scale_vector))
        
        del shower_t, features

    ll_mse_edges = sum(ll_mse_edges) / len(ll_mse_edges)
    
    (ll_bce + ll_mse_edges * 20).backward()
    
    optimizer_fine.step()
    
    print(ll_bce.item(), 
          ll_mse_edges.item())

In [None]:
from tools.opera_tools import plot_npframe
plot_npframe(shower.x.cpu().detach().numpy()[0] * np.array([1e4, 1e4, 1e4, 1, 1]))

tmp_X = features.cpu().detach().numpy()[:, :5]
tmp_X *= np.array([1e4, 1e4, 1e4, 1, 1])
plot_npframe(tmp_X)

In [None]:
# teacher forcing
shower_t = torch_geometric.data.Data(x=embedding, 
                                     edge_index=shower.adj_out).to(device)

# GCN to recover shower features
features = features_nn(shower_t)

In [None]:
features.shape

In [None]:
shower.x.shape

In [None]:
(shower.x[0][shower.adj_out[0]] - shower.x[0][shower.adj_out[1]]) - (features[shower.adj_out[0]] - features[shower.adj_out[1]]) * torch.tensor([1e4, 1e4, 1e4, 1, 1]).to(device)

In [None]:
from tools.opera_tools import plot_npframe
plot_npframe(shower.x.cpu().detach().numpy()[0] * np.array([1e4, 1e4, 1e4, 1, 1]))

tmp_X = features.cpu().detach().numpy()[:, :5]
tmp_X *= np.array([1e4, 1e4, 1e4, 1, 1])
plot_npframe(tmp_X)

In [None]:
def get_graph(adj):
    '''
    get a graph from zero-padded adj
    :param adj:
    :return:
    '''
    # remove all zeros rows and columns
    adj = adj[~np.all(adj == 0, axis=1)]
    adj = adj[:, ~np.all(adj == 0, axis=0)]
    adj = np.asmatrix(adj)
    G = nx.from_numpy_matrix(adj)
    return G

def generate_graph(model, edge_nn, max_prev_node, test_batch_energies, device):
    test_batch_size = test_batch_energies.shape[0]
    model.hidden = model.init_hidden(test_batch_energies, test_batch_size)
    model.eval()
    model.eval()

    # generate graphs
    max_num_node = 200
    
    y_pred_long = torch.ones(test_batch_size, 
                             max_num_node, 
                             max_prev_node).to(device) # discrete prediction
    
    x_step = torch.zeros(test_batch_size, 1, max_prev_node).to(device)
    for i in tqdm(range(max_num_node)):
        _, h, _ = model(x_step)
        hidden_null = torch.zeros(edge_nn.num_layers - 1, h.size(0), h.size(2)).cuda()
        edge_nn.hidden = torch.cat((h.permute(1, 0, 2), hidden_null), dim=0)  # num_layers, batch_size, hidden_size
        x_step = torch.zeros(test_batch_size, 1, max_prev_node).to(device)
        output_x_step = torch.ones(test_batch_size, 1, 1).to(device)
        for j in range(min(max_prev_node, i+1)):
            _, output_y_pred_step, _ = edge_nn(output_x_step)
            output_x_step = Bernoulli(logits=output_y_pred_step).sample()
            x_step[:, :, j:j+1] = output_x_step
            # edge_nn.hidden = hidden.data
        y_pred_long[:, i:i + 1, :] = x_step
        model.hidden = model.hidden.data
    print(y_pred_long)
    y_pred_long_data = y_pred_long.data.long()
    
    # save graphs as pickle
    G_pred_list = []
    for i in range(test_batch_size):
        adj_pred = decode_adj(y_pred_long_data[i].detach().cpu().numpy())
        G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
        G_pred_list.append(G_pred)
        
        # teacher forcing
        shower_t = torch_geometric.data.Data(x=embedding, 
                                             edge_index=shower.adj_out).to(device)
        
        # GCN to recover shower features
        features = features_nn(shower_t)

    return G_pred_list


a = generate_graph(model=model, 
                   edge_nn=edge_nn,
                   max_prev_node=max_prev_node, 
                   test_batch_energies=torch.tensor([6.6297] * 10).to(device).view(-1, 1), 
                   device=device)

In [None]:
g = nx.DiGraph(a[0])

In [None]:
adj_out_t = torch.LongTensor(np.array(list(g.edges())).T).to(device)

In [None]:
shower_t = torch_geometric.data.Data(x=embedding, 
                                     edge_index=shower.adj_out).to(device)

features = features_nn(shower_t)