
## All necessary imports

In [None]:
# !cd tools/ && python setup_opera_distance_metric.py build_ext --inplace

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm_notebook as tqdm
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import Sequential
from torch.distributions import Bernoulli


from tools.opera_distance_metric import generate_k_nearest_graph, \
                                        opera_distance_metric_py, \
                                        generate_radius_graph

from graph_rnn import bfs_seq, encode_adj, decode_adj

from torch.nn.utils.rnn import pack_padded_sequence, pack_sequence, pad_sequence, pad_packed_sequence
import random

sns.set(font_scale=2)

In [None]:
device = torch.device('cuda:0')

In [None]:
df = pd.read_pickle('./data/showers.pkl')

In [None]:
def bfs_handmade(G, start):
    visited, queue = set(), [start]
    while queue:
        vertex = queue.pop(0)
        if vertex not in visited:
            visited.add(vertex)
            edges = sorted(G.out_edges(vertex, data=True), key=lambda x: x[2]['weight'])
            queue.extend(set([x[1] for x in edges]) - visited)
    return np.array(list(visited))[np.argsort(np.array(list(G.nodes())))]


def encode_adj(adj, max_prev_node=10, is_full = False):
    '''
    :param adj: n*n, rows means time step, while columns are input dimension
    :param max_degree: we want to keep row number, but truncate column numbers
    :return:
    '''
    if is_full:
        max_prev_node = adj.shape[0] - 1
    
    # successors only
    adj = adj
    
    # pick up lower tri
    adj = np.tril(adj, k=-1)
    n = adj.shape[0]
    adj = adj[1:n, 0:n-1]

    # use max_prev_node to truncate
    # note: now adj is a (n-1) * (n-1) matrix
    adj_output = np.zeros((adj.shape[0], max_prev_node))
    for i in range(adj.shape[0]):
        input_start = max(0, i - max_prev_node + 1)
        input_end = i + 1
        output_start = max_prev_node + input_start - input_end
        output_end = max_prev_node
        adj_output[i, output_start:output_end] = adj[i, input_start:input_end]
        adj_output[i,:] = adj_output[i,:][::-1] # reverse order

    return adj_output

## Model parameters

In [None]:
max_prev_node = 10
graph_state_size = 64
embedding_size = 256
edge_rnn_embedding_size = 64

In [None]:
batch_size = 50

In [None]:
from collections import namedtuple

In [None]:
graphrnn_shower = namedtuple('graphrnn_shower', field_names=['x', 
                                                             'adj', 
                                                             'adj_out', 
                                                             'adj_squared', 
                                                             'ele_p',
                                                             'distances'])

In [None]:
def preprocess_shower_for_graphrnn(shower, device, k=4, symmetric=False):
    X = np.vstack([
        np.arange(len(shower.SX)),
        shower.SX,
        shower.SY,  
        shower.SZ, 
        shower.TX,
        shower.TY,
        shower.ele_P]
    ).T
    print(len(X))
    edges_from, edges_to, distances = generate_k_nearest_graph(X, k=k, symmetric=symmetric)
    G = nx.Graph()
    edges = []
    for i in range(len(distances)):
        edges.append((edges_from[i], edges_to[i], {'weight': distances[i]}))
        
    G.add_edges_from(edges)
    G = nx.DiGraph(G)

    adj = np.asarray(nx.to_numpy_matrix(G))

    start_idx = 0
    x_idx = np.array(bfs_handmade(G, start_idx))
    adj = adj[np.ix_(x_idx, x_idx)]

    # actual data
    adj_output = encode_adj(adj, max_prev_node=max_prev_node)
    X = X[x_idx, 1:]
    X = X / np.array([1e3, 1e3, 1e3, 1, 1, 1])
    distances = np.log(1. + np.array(distances))
    
    # for now forget about distances
    # TODO: what to do with distances?
    adj_output[adj_output!=0] = 1.
    
    adj_output_t = torch.tensor(np.append(np.ones((1, max_prev_node)), 
                                          adj_output, axis=0), 
                                dtype=torch.float32).to(device).view(1, -1, max_prev_node)
    
    X_t = torch.tensor(X[:, :-1], dtype=torch.float32).to(device).view(1, -1, 5)

    adj_out_t = torch.LongTensor(np.array(list(nx.from_numpy_matrix(decode_adj(adj_output), 
                                                                    create_using=nx.DiGraph).edges())).T).to(device)
    
    adj_squared_t = torch.tensor(adj, dtype=torch.float32).to(device)
    
    return graphrnn_shower(adj=adj_output_t, 
                           x=X_t, 
                           adj_out=adj_out_t,
                           adj_squared=adj_squared_t,
                           distances=distances,
                           ele_p=torch.tensor(X[-1, -1], dtype=torch.float32).to(device))

In [None]:
%%time
showers_train = []
for i, shower in list(df.iterrows())[:3]:
    showers_train.append(preprocess_shower_for_graphrnn(shower, device=device, k=3))

In [None]:
len(df)

#### GraphRNN 

Generates embeddings for nodes.

In [None]:
class GraphRNN(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, 
                 num_layers, has_input=True, has_output=False, output_size=None):
        super(GraphRNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.has_input = has_input
        self.has_output = has_output

        if has_input:
            self.input = nn.Linear(input_size, embedding_size)
            self.rnn = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, 
                              num_layers=num_layers, batch_first=True)
        else:
            self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, 
                              num_layers=num_layers, batch_first=True)
        if has_output:
            self.output = nn.Sequential(
                nn.Linear(hidden_size, embedding_size),
                nn.ReLU(),
                nn.Linear(embedding_size, output_size)
            )

        self.relu = nn.ReLU()
        # initialize
        self.hidden_emb = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, self.hidden_size)
        )
        self.hidden = None  # need initialize before forward run

    def init_hidden(self, input, batch_size):
        hidden_emb = torch.cat([self.hidden_emb(input).view(1, batch_size, self.hidden_size), 
                                torch.zeros(self.num_layers - 1, batch_size, self.hidden_size).cuda()])
        return hidden_emb

    def forward(self, input_raw, pack=False, input_len=None):
        output_raw_emb, output_raw, output_len = None, None, None
        
        if self.has_input:
            input = self.input(input_raw)
            input = self.relu(input)
        else:
            input = input_raw
        if pack:
            
            pass # input = pack_sequence(input)
        
        output_raw_emb, self.hidden = self.rnn(input, self.hidden)
        if pack:
            output_raw_emb, output_len = pad_packed_sequence(output_raw_emb, batch_first=True)
        
        if self.has_output:
            output_raw = self.output(output_raw_emb)
            
        if pack:
            output_raw_packed = pack_padded_sequence(output_raw, lengths=output_len, batch_first=True)
            return output_raw_emb, output_raw, output_len
        
        # return hidden state at each time step
        return output_raw_emb, output_raw, output_len

In [None]:
model = GraphRNN(input_size=max_prev_node, 
                 embedding_size=max_prev_node, 
                 output_size=edge_rnn_embedding_size, 
                 has_output=True, 
                 hidden_size=embedding_size, 
                 num_layers=4, 
                 has_input=False).to(device)

### Edge network

In [None]:
edge_nn = GraphRNN(input_size=1, 
                   embedding_size=edge_rnn_embedding_size,
                   hidden_size=edge_rnn_embedding_size, 
                   num_layers=4, has_input=True, has_output=True, 
                   output_size=1).to(device)

### FeaturesGCN

In [None]:
import torch_geometric.transforms as T
import torch_cluster
import torch_geometric

from torch_geometric.nn import NNConv, GCNConv, GraphConv
from torch_geometric.nn import PointConv, EdgeConv, SplineConv


class FeaturesGCN(torch.nn.Module):
    def __init__(self, dim_in, embedding_size=128, num_layers=4, dim_out=6):
        super().__init__()
        
        self.wconv_in = EdgeConv(Sequential(nn.Linear(dim_in * 2, embedding_size)), 'max')
        
        self.layers = nn.ModuleList(modules=[EdgeConv(Sequential(nn.Linear(embedding_size * 2, embedding_size)), 'max')
                                   for i in range(num_layers)])

        self.wconv_out = EdgeConv(Sequential(nn.Linear(embedding_size * 2, dim_out)), 'max')

        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.wconv_in(x=x, edge_index=edge_index)
        
        for l in self.layers:
            x = l(x=x, edge_index=edge_index)
        
        x = self.wconv_out(x=x, edge_index=edge_index)
        
        return x

In [None]:
features_nn = FeaturesGCN(dim_in=edge_rnn_embedding_size * max_prev_node, 
                          embedding_size=128, num_layers=4,
                          dim_out=5).to(device=device)

#### Losses

In [None]:
sigmoid = nn.Sigmoid().to(device)
loss_bce = nn.BCELoss().to(device)
loss_mse = torch.nn.MSELoss().to(device)

def loss_mse_edges(shower, features):
    return loss_mse((shower.x[0][shower.adj_out[0]] - shower.x[0][shower.adj_out[1]]), 
                    (features[shower.adj_out[0]] - features[shower.adj_out[1]]))

### process_train_graphrnn

In [None]:
from torch.nn.utils.rnn import PackedSequence

def process_train_graphrnn(showers_batch):
    batch_size = len(showers_batch)
    
    model.hidden = model.init_hidden(input=torch.stack([x.ele_p for x in showers_batch]).view(-1, 1), 
                                     batch_size=batch_size)
    
    packed_adj_batch = pack_sequence([x.adj[0] for x in showers_batch])
    _, embedding_batch, output_len = model(packed_adj_batch, pack=True)

    packed_embedding_batch = pack_padded_sequence(embedding_batch, output_len, batch_first=True).data
    
    hidden_null = torch.zeros(4 - 1, packed_embedding_batch.shape[0], packed_embedding_batch.shape[1]).to(device)
    edge_nn.hidden = torch.cat((packed_embedding_batch.view(1, 
                                                            packed_embedding_batch.size(0), 
                                                            packed_embedding_batch.size(1)), hidden_null), dim=0)
    packed_adj_batch_data = packed_adj_batch.data
    packed_adj_batch_data = packed_adj_batch_data.view(packed_adj_batch_data.shape[0], 
                                                       packed_adj_batch_data.shape[1], 1)
    
    packed_adj_batch = torch.cat((torch.ones(packed_adj_batch_data.shape[0], 1, 1).to(device), 
                                  packed_adj_batch_data[:, 0:-1, 0:1]), dim=1)
    
    edges_emb, edges, _ = edge_nn(packed_adj_batch)
    return embedding_batch, output_len, pad_packed_sequence(PackedSequence(edges_emb.contiguous().view(edges_emb.size(0), -1), output_len))[0], loss_bce(torch.sigmoid(edges), packed_adj_batch_data)

In [None]:
showers_batch = showers_train
showers_batch = sorted(showers_batch, key=lambda x: -x.adj.shape[1])

In [None]:
embedding_batch, output_len, edges, ll_bce = process_train_graphrnn(showers_batch)

### Optimization of edge predictions

In [None]:
from itertools import chain

learning_rate = 1e-5
optimizer_bce = torch.optim.Adam(list(model.parameters()) + 
                                 list(edge_nn.parameters()), 
                                 lr=learning_rate)

In [None]:
model.train()
edge_nn.train()

for i in tqdm(range(5000)):
    optimizer_bce.zero_grad()
    
    showers_batch = showers_train # random.sample(showers_train, batch_size)
    showers_batch = sorted(showers_batch, key=lambda x: -x.adj.shape[1])
    
    # iterate over showers in batch
    # and calc losses
    embedding_batch, output_len, edges_emb, ll_bce = process_train_graphrnn(showers_batch)

    ll_bce.backward()
    
    optimizer_bce.step()
    
    print(ll_bce.item())
    
    del embedding_batch, output_len

### Optimization of feature reconstruction

In [None]:
learning_rate = 1e-5
optimizer_mse = torch.optim.Adam(list(features_nn.parameters()), 
                                 lr=learning_rate)

In [None]:
for i in tqdm(range(3000)):
    optimizer_mse.zero_grad()
    
    showers_batch = showers_train # random.sample(showers_train, batch_size)
    showers_batch = sorted(showers_batch, key=lambda x: -x.adj.shape[1])
    
    # iterate over showers in batch
    # and calc losses
    _, output_len, edges_emb, ll_bce = process_train_graphrnn(showers_batch)
    
    ll_mse_edges = []
    
    # iterate over showers in batch
    # and calc losses
    for k, l in enumerate(output_len):
        shower = showers_batch[k]
        
        embedding = edges_emb[k][:l]

        # teacher forcing
        shower_t = torch_geometric.data.Data(x=embedding, 
                                             edge_index=shower.adj_out).to(device)
        
        # GCN to recover shower features
        features = features_nn(shower_t)
    
        # features prediction loss        
        ll_mse_edges.append(loss_mse_edges(shower, features))
        
        del shower_t, features

    ll_mse_edges = sum(ll_mse_edges) / len(ll_mse_edges)
    
    ll_mse_edges.backward()
    
    optimizer_mse.step()
    
    del edges_emb, output_len
    
    print(ll_bce.item(), 
          ll_mse_edges.item())

### Finetuning

In [None]:
learning_rate = 0.3e-5
optimizer_fine = torch.optim.Adam(list(features_nn.parameters()) +
                                  list(edge_nn.parameters()) +
                                  list(model.parameters()), lr=learning_rate)

In [None]:
scale_vector = torch.tensor([1e1, 1e1, 1e1, 1, 1]).to(device)

In [None]:
def loss_mse_edges(shower, features, scale_vector):
    return loss_mse((shower.x[0][shower.adj_out[0]] - shower.x[0][shower.adj_out[1]]) * scale_vector, 
                    (features[shower.adj_out[0]] - features[shower.adj_out[1]]) * scale_vector)

In [None]:
for i in tqdm(range(5000)):
    optimizer_fine.zero_grad()
    
    showers_batch = showers_train # random.sample(showers_train, batch_size)
    showers_batch = sorted(showers_batch, key=lambda x: -x.adj.shape[1])
    
    # iterate over showers in batch
    # and calc losses
    _, output_len, edges_emb, ll_bce = process_train_graphrnn(showers_batch)
    
    ll_mse_edges = []
    
    # iterate over showers in batch
    # and calc losses
    for k, l in enumerate(output_len):
        shower = showers_batch[k]
        
        embedding = edges_emb[k][:l]

        # teacher forcing
        shower_t = torch_geometric.data.Data(x=embedding, 
                                             edge_index=shower.adj_out).to(device)
        
        # GCN to recover shower features
        features = features_nn(shower_t)
    
        # features prediction loss        
        ll_mse_edges.append(loss_mse_edges(shower, features, scale_vector))
        
        del shower_t, features

    ll_mse_edges = sum(ll_mse_edges) / len(ll_mse_edges)
    
    (ll_bce + ll_mse_edges * 20).backward()
    
    optimizer_fine.step()
    
    print(ll_bce.item(), 
          ll_mse_edges.item())

In [None]:
from tools.opera_tools import plot_npframe
plot_npframe(shower.x.cpu().detach().numpy()[0] * np.array([1e4, 1e4, 1e4, 1, 1]))

tmp_X = features.cpu().detach().numpy()[:, :5]
tmp_X *= np.array([1e4, 1e4, 1e4, 1, 1])
plot_npframe(tmp_X)

In [None]:
# teacher forcing
shower_t = torch_geometric.data.Data(x=embedding, 
                                     edge_index=shower.adj_out).to(device)

# GCN to recover shower features
features = features_nn(shower_t)

In [None]:
features.shape

In [None]:
shower.x.shape

In [None]:
(shower.x[0][shower.adj_out[0]] - shower.x[0][shower.adj_out[1]]) - (features[shower.adj_out[0]] - features[shower.adj_out[1]]) * torch.tensor([1e4, 1e4, 1e4, 1, 1]).to(device)

In [None]:
from tools.opera_tools import plot_npframe
plot_npframe(shower.x.cpu().detach().numpy()[0] * np.array([1e4, 1e4, 1e4, 1, 1]))

tmp_X = features.cpu().detach().numpy()[:, :5]
tmp_X *= np.array([1e4, 1e4, 1e4, 1, 1])
plot_npframe(tmp_X)

In [None]:
def get_graph(adj):
    '''
    get a graph from zero-padded adj
    :param adj:
    :return:
    '''
    # remove all zeros rows and columns
    adj = adj[~np.all(adj == 0, axis=1)]
    adj = adj[:, ~np.all(adj == 0, axis=0)]
    adj = np.asmatrix(adj)
    G = nx.from_numpy_matrix(adj)
    return G

def generate_graph(model, edge_nn, max_prev_node, test_batch_energies, device):
    test_batch_size = test_batch_energies.shape[0]
    model.hidden = model.init_hidden(test_batch_energies, test_batch_size)
    model.eval()
    model.eval()

    # generate graphs
    max_num_node = 200
    
    y_pred_long = torch.ones(test_batch_size, 
                             max_num_node, 
                             max_prev_node).to(device) # discrete prediction
    
    x_step = torch.zeros(test_batch_size, 1, max_prev_node).to(device)
    for i in tqdm(range(max_num_node)):
        _, h, _ = model(x_step)
        hidden_null = torch.zeros(edge_nn.num_layers - 1, h.size(0), h.size(2)).cuda()
        edge_nn.hidden = torch.cat((h.permute(1, 0, 2), hidden_null), dim=0)  # num_layers, batch_size, hidden_size
        x_step = torch.zeros(test_batch_size, 1, max_prev_node).to(device)
        output_x_step = torch.ones(test_batch_size, 1, 1).to(device)
        for j in range(min(max_prev_node, i+1)):
            _, output_y_pred_step, _ = edge_nn(output_x_step)
            output_x_step = Bernoulli(logits=output_y_pred_step).sample()
            x_step[:, :, j:j+1] = output_x_step
            # edge_nn.hidden = hidden.data
        y_pred_long[:, i:i + 1, :] = x_step
        model.hidden = model.hidden.data
    print(y_pred_long)
    y_pred_long_data = y_pred_long.data.long()
    
    # save graphs as pickle
    G_pred_list = []
    for i in range(test_batch_size):
        adj_pred = decode_adj(y_pred_long_data[i].detach().cpu().numpy())
        G_pred = get_graph(adj_pred) # get a graph from zero-padded adj
        G_pred_list.append(G_pred)
        
        # teacher forcing
        shower_t = torch_geometric.data.Data(x=embedding, 
                                             edge_index=shower.adj_out).to(device)
        
        # GCN to recover shower features
        features = features_nn(shower_t)

    return G_pred_list


a = generate_graph(model=model, 
                   edge_nn=edge_nn,
                   max_prev_node=max_prev_node, 
                   test_batch_energies=torch.tensor([6.6297] * 10).to(device).view(-1, 1), 
                   device=device)

In [None]:
g = nx.DiGraph(a[0])

In [None]:
adj_out_t = torch.LongTensor(np.array(list(g.edges())).T).to(device)

In [None]:
shower_t = torch_geometric.data.Data(x=embedding, 
                                     edge_index=shower.adj_out).to(device)

features = features_nn(shower_t)