In [1]:
import tensorflow as tf
import numpy as np
import h5py as h5
import random as rd

ModuleNotFoundError: No module named 'tensorflow'

In [2]:
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

### Routines for data generators

To construct adjacency matrix, we consider hits as an ordered set, according to the time of the activation. Then, we say that two verticies are connected if they are separeted by not more than k cells (parameter k_nearest).

Note that adj matrix is prepared in generators.

In [4]:
import numpy as np
import h5py as h5
h5f = "/home/leonov/Baikal/Cut_8_nu/APRIL/Ordered/Data/mc_baikal_norm_cut-8_ordered_equal_big.h5"

max_len = 32 #110
file = h5f 
regime = 'train' 
batch_size = 32
return_reminder = True 
k_nearest = 2
with h5.File(h5f,'r') as hf:
    print(hf[regime+'/data'].shape[1]  , hf[regime+'/data'].shape[0] )

97 1694740


In [4]:
# generator without shuffling
# yields (data, labels, adjacency)
class generator_no_shuffle:
    
    def __init__(self, file, regime, batch_size, return_reminder, k_nearest):
        self.file = file
        self.regime = regime
        self.batch_size = batch_size
        self.return_reminder = return_reminder
        with h5.File(self.file,'r') as hf:
            self.num = hf[self.regime+'/data'].shape[0]
            self.data_length = 32  # hf[self.regime+'/data'].shape[1]   # поставлю
        self.batch_num = self.num // self.batch_size
        if return_reminder:
            self.gen_num = self.num
        else:
            self.gen_num = self.batch_num*self.batch_size
        te = [ np.expand_dims(np.eye(self.data_length),axis=0)]
        for i in range(1,k_nearest):
            te.append(np.expand_dims(np.eye(self.data_length, k=i),axis=0))
            te.append(np.expand_dims(np.eye(self.data_length, k=-i),axis=0))
        self.full_adj = np.sum( np.concatenate(te, axis=0), axis=0 )

    def __call__(self):
        start = 0
        stop = self.batch_size
        with h5.File(self.file, 'r') as hf:
            for i in range(self.batch_num):
                mask = hf[self.regime+'/mask'][start:stop]
                mask_channel = np.expand_dims( mask, axis=-1 )
                mask_channel_2 = np.expand_dims( mask, axis=1 )
                mask_to_adj = mask_channel*mask_channel_2
                #labels = hf[self.regime+'/labels_signal_noise'][start:stop]
                yield ( np.concatenate((hf[self.regime+'/data'][start:stop],mask_channel), axis=-1), 
                   self.full_adj*mask_to_adj ) #labels,
                start += self.batch_size
                stop += self.batch_size
            if self.return_reminder:
                mask = hf[self.regime+'/mask'][start:stop]
                mask_channel = np.expand_dims( mask, axis=-1 )
                mask_channel_2 = np.expand_dims( mask, axis=1 )
                mask_to_adj = mask_channel*mask_channel_2
                #labels = hf[self.regime+'/labels_signal_noise'][start:stop]
                yield ( np.concatenate((hf[self.regime+'/data'][start:stop],mask_channel), axis=-1), 
                   self.full_adj*mask_to_adj ) #labels,
                
# generator with shuffling
class generator_with_shuffle:
    
    def __init__(self, file, regime, batch_size, buffer_size, return_reminder, k_nearest):
        self.file = file
        self.regime = regime
        self.batch_size = batch_size
        self.return_reminder = return_reminder
        self.buffer_size = buffer_size
        with h5.File(self.file,'r') as hf:
            self.num = hf[self.regime+'/data/'].shape[0]
            self.data_length = 32 #hf[self.regime+'/data'].shape[1]
        self.batch_num = (self.num-self.buffer_size) // self.batch_size
        self.last_batches_num = self.buffer_size // self.batch_size
        if return_reminder:
            self.gen_num = self.num
        else:
            self.gen_num = (self.batch_num+self.last_batches_num)*self.batch_size
        te = [ np.expand_dims(np.eye(self.data_length),axis=0)]
        for i in range(1,k_nearest):
            te.append(np.expand_dims(np.eye(self.data_length, k=i),axis=0))
            te.append(np.expand_dims(np.eye(self.data_length, k=-i),axis=0))
        self.full_adj = np.sum( np.concatenate(te, axis=0), axis=0 )

    def __call__(self):
        start = self.buffer_size
        stop = self.buffer_size + self.batch_size
        with h5.File(self.file, 'r') as hf:
            mask = hf[self.regime+'/mask'][:self.buffer_size]
            mask_channel = np.expand_dims( mask, axis=-1 )
            mask_channel_2 = np.expand_dims( mask, axis=1 )
            buffer_data = np.concatenate((hf[self.regime+'/data'][:self.buffer_size],mask_channel), axis=-1)
            #buffer_labels = hf[self.regime+'/labels_signal_noise'][:self.buffer_size]
            buffer_adj = self.full_adj*mask_channel*mask_channel_2
            for i in range(self.batch_num):
                idxs = rd.sample( range(self.buffer_size), k=self.batch_size )
                yield ( buffer_data[idxs],  buffer_adj[idxs] )
                #yield ( buffer_data[idxs], buffer_labels[idxs], buffer_adj[idxs] )
                mask = hf[self.regime+'/mask'][start:stop]
                mask_channel = np.expand_dims( mask, axis=-1 )
                mask_channel_2 = np.expand_dims( mask, axis=1 )
                buffer_data[idxs] = np.concatenate((hf[self.regime+'/data'][start:stop],mask_channel),axis=-1)
                #labels = hf[self.regime+'/labels_signal_noise'][start:stop]
                #buffer_labels[idxs] = labels
                adj = self.full_adj*mask_channel*mask_channel_2
                buffer_adj[idxs] = adj
                start += self.batch_size
                stop += self.batch_size
            # fill the buffer with left data, if any
            mask = hf[self.regime+'/mask'][start:stop]
            mask_channel = np.expand_dims( mask, axis=-1 )
            mask_channel_2 = np.expand_dims( mask, axis=1 )
            buffer_data = np.concatenate( (buffer_data,np.concatenate((hf[self.regime+'/data'][start:stop],mask_channel),axis=-1)), axis=0 )
            #labels = hf[self.regime+'/labels_signal_noise'][start:stop]
            #buffer_labels = np.concatenate( (buffer_labels,labels), axis=0 )
            adj = self.full_adj*mask_channel*mask_channel_2
            buffer_adj = np.concatenate( (buffer_adj,adj), axis=0 )
            #sh_idxs = rd.sample( range(buffer_labels.shape[0]), k=buffer_labels.shape[0] )
            start = 0
            stop = self.batch_size
            for i in range(self.last_batches_num):
                idxs = sh_idxs[start:stop]
                #yield ( buffer_data[idxs], buffer_labels[idxs], buffer_adj[idxs] )
                yield ( buffer_data[idxs], buffer_adj[idxs] )
                start += self.batch_size
                stop += self.batch_size
            if self.return_reminder:
                idxs = sh_idxs[start:stop]
                yield ( buffer_data[idxs],  buffer_adj[idxs] )
               # yield ( buffer_data[idxs], buffer_labels[idxs], buffer_adj[idxs] )

In [5]:
### Datasets
def make_datasets(h5f,make_generator_shuffle,return_batch_reminder,train_batch_size,train_buffer_size,test_batch_size,k_nearest):
    # generator for training data
    if make_generator_shuffle:
        tr_generator = generator_with_shuffle(h5f,'train',train_batch_size,train_buffer_size,return_batch_reminder,k_nearest)
    else:
        tr_generator = generator_no_shuffle(h5f,'train',train_batch_size,return_batch_reminder,k_nearest)
    if return_batch_reminder:
        # size of the last batch is unknown
        tr_batch_size = None
    else:
        tr_batch_size = train_batch_size

    train_dataset = tf.data.Dataset.from_generator( tr_generator, 
                        output_signature=( tf.TensorSpec(shape=(tr_batch_size,max_len,6)), tf.TensorSpec(shape=(tr_batch_size,max_len,2)),
                                         tf.TensorSpec(shape=(tr_batch_size,max_len,max_len))) )

    if make_generator_shuffle:
        train_dataset = train_dataset.repeat(-1).prefetch(tf.data.AUTOTUNE)
    else:
        train_dataset = train_dataset.repeat(-1).shuffle(num_batch_shuffle)

    # generator for validation data
    te_generator = generator_no_shuffle(h5f,'test',test_batch_size,False,k_nearest)
    te_batch_size = tr_batch_size
    test_dataset = tf.data.Dataset.from_generator( te_generator, 
                        output_signature=( tf.TensorSpec(shape=(tr_batch_size,max_len,6)), tf.TensorSpec(shape=(tr_batch_size,max_len,2)),
                                          tf.TensorSpec(shape=(tr_batch_size,max_len,max_len))) )
    
    test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

    return train_dataset, test_dataset

In [6]:
train_dataset, test_dataset = make_datasets(h5f,True,False,batch_size,500*batch_size,batch_size,k_nearest)

2022-05-02 17:06:20.620746: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-02 17:06:21.121019: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10410 MB memory:  -> device: 0, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:17:00.0, compute capability: 6.1


### Graph Convolutional Model

This is simplest graph neural network. The protocol goes as follows:

1) Transform initial encoding of the nodes (NodesEncoder). This is usually one-layer NN.

2) Update nodes encodings by agrregating information from connected cells (StateUpdater). Simple NN or just linear algebra.

In the model, we normilize adjacency matrix to have a better data flow.

The dataset yields element of the form (data, albels, adj)

In [7]:
# pre-transforms nodes
# adds res connection for better data flow
class NodesEncoder(tf.keras.layers.Layer):
    
    def __init__(self, units, activation):
        super(NodesEncoder, self).__init__()
        self.num_layers = len(units)
        self.dense_layers = [ tf.keras.layers.Dense(un) for un in units ]
        self.activation = activation
        self.units = units
        
    def build(self, input_shape):
        if self.num_layers==0:
            self.out_encs_length = input_shape[-1]
        else:
            self.out_encs_length = self.units[-1]+input_shape[-1]

    def call(self, x):
        init_x = x
        for lr in self.dense_layers:
            x = lr(x)
            x = self.activation(x)
            x = tf.concat((x,init_x),axis=-1)
        return x

In [8]:
# states updater
class StateUpdater(tf.keras.layers.Layer):
    
    def __init__(self, units, activation):
        super(StateUpdater, self).__init__()
        self.dense_layers = [ tf.keras.layers.Dense(un) for un in units ]
        self.activation = activation

    def call(self, vert, adj):
        # aggragate information from neighbours
        vert = tf.matmul(adj, vert)
        # update state
        for lr in self.dense_layers:
            vert = lr(vert)
            vert = self.activation(vert)
        return vert

In [9]:
class GraphConvLayer(tf.keras.layers.Layer):
    
    def __init__(self, nodes_encoder, state_updater):
        super(GraphConvLayer, self).__init__()
        self.nodes_encoder = nodes_encoder
        self.state_updater = state_updater
        
    def call(self, vert, adj):
        # transform features
        transf_encs = self.nodes_encoder(vert)
        # update states
        new_encs = self.state_updater(transf_encs, adj)
        return new_encs

In [10]:
class GraphConvModel(tf.keras.Model):
    
    # last graph layer must have 2 chennels to apply softmax 
    def __init__(self, graph_conv_layers):
        super(GraphConvModel, self).__init__()
        self.gr_layers = [gr for gr in graph_conv_layers]
        
    def compile(self, optimizer, loss_fn, metrics):
        super(GraphConvModel, self).compile()
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.loss_tracker = tf.keras.metrics.Mean(name='loss')
        self.metrics_ = metrics
        self.all_metrics = metrics+[self.loss_tracker]
        
    def norm_adj(self, adj):
        degree = tf.reduce_sum(adj, axis=-1)
        # small plus to avoid 0/0
        norm_degree = tf.linalg.diag(1./tf.sqrt(degree+1e-8))
        n_adj = tf.matmul(norm_degree, tf.matmul(adj, norm_degree) )
        return n_adj
    
    @tf.function
    def call(self, datas):
        (x, labels, adj) = datas
        adj = self.norm_adj(adj)
        mask = x[:,:,-1:]
        for gr_lr in self.gr_layers:
            x = gr_lr(x, adj)
        # yield correct predictions for auxillary hits
        preds = tf.where( tf.cast(mask,bool), x, tf.constant([0.,1.]) )
        return preds
        
    @property
    def metrics(self):
        return self.all_metrics
    
    @tf.function
    def train_step(self, datas):
        (x, labels, adj) = datas
        with tf.GradientTape() as tape:
            preds = self.call(datas)
            loss = self.loss_fn(labels,preds)
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients( zip(grads, self.trainable_weights) )
        self.loss_tracker.update_state(loss)
        prd_cls = tf.math.argmax( preds, axis=-1 )
        true_cls = tf.math.argmax( labels, axis=-1 )
        for m in self.metrics_:
            m.update_state(prd_cls, true_cls)
        ms = { m.name : m.result() for m in self.all_metrics }
        return ms
    
    @tf.function
    def test_step(self, datas):
        (x, labels, adj) = datas
        preds = self.call(datas)
        loss = self.loss_fn(labels,preds)
        self.loss_tracker.update_state(loss)
        prd_cls = tf.math.argmax( preds, axis=-1 )
        true_cls = tf.math.argmax( labels, axis=-1 )
        for m in self.metrics_:
            m.update_state(prd_cls, true_cls)
        ms = { m.name : m.result() for m in self.all_metrics }
        return ms

In [11]:
selu = tf.keras.activations.selu
softmax = tf.keras.activations.softmax

nodes_encs_length_s = [[16],[16]]
new_state_length_s = [[32],[2]]
state_upd_acts = [selu,softmax]

gr_layers = []
for (nodes_l, news_l, new_act) in zip(nodes_encs_length_s,new_state_length_s, state_upd_acts):
    nodes_encoder = NodesEncoder(nodes_l, selu)
    state_updater = StateUpdater(news_l, new_act)
    gr_layers.append( GraphConvLayer(nodes_encoder,state_updater) )
    
gnn = GraphConvModel(gr_layers)

lr = 0.002
loss_fn = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=True)
metrics = [tf.keras.metrics.Accuracy()]
gnn.compile(optimizer, loss_fn, metrics)

In [12]:
gnn.fit(train_dataset, steps_per_epoch=2500, validation_steps=500, epochs=2, validation_data=test_dataset, verbose=1)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7f3f626eab00>

### Graph Attention Model

Some of the neighbouring nodes are irrelevant, we need to account for this fact. 

We do this by introducing NN for establishing relevance of two nodes (AttentionEstablisher). It weights the contributions of nodes as seen by the state updater.

In [13]:
# calculates attentions for pair of channels
class AttentionEstablisher(tf.keras.layers.Layer):
    
    def __init__(self, hid_units, activation):
        super(AttentionEstablisher, self).__init__()
        self.dense_layers = [ tf.keras.layers.Dense(un) for un in hid_units ]
        self.activation = activation
        # use softmax, not sigmoid - better data flow
        self.soft_layer = tf.keras.layers.Dense(1, activation=tf.keras.activations.softmax)

    def call(self, x):
        for lr in self.dense_layers:
            x = lr(x)
            x = self.activation(x)
        res = self.soft_layer(x)
        return res

In [14]:
class GraphAttLayer(tf.keras.layers.Layer):
    
    def __init__(self, nodes_encoder, state_updater, attention_establisher):
        super(GraphAttLayer, self).__init__()
        self.nodes_encoder = nodes_encoder
        self.attention_establisher = attention_establisher
        self.state_updater = state_updater

    def make_pairs(self, vert, adj):
        adj_exp = tf.expand_dims(adj,axis=-1)
        vert_exp_1 = tf.expand_dims(vert,axis=1) # (bs,1,om,c)
        vert_exp_2 = tf.expand_dims(vert,axis=2) # (bs,om,1,c)
        rel_vert = adj_exp*vert_exp_2 # (bs, om, om_targ, c)
        base_vert = adj_exp*vert_exp_1 # (bs, om_base, om, c)
        pairs = tf.concat((rel_vert,base_vert),axis=-1) # (bs, om_base, om_target, 2c)
        return pairs
        
    def call(self, vert, adj):
        # transform features
        transf_encs = self.nodes_encoder(vert)
        # calculate attentions
        mask_adj = tf.cast(adj, bool)
        pairs = self.make_pairs(vert, adj)
        ## making loop over OMs might be better in terms of memory and calculations efficiency
        atts = tf.where(mask_adj, tf.squeeze(self.attention_establisher(pairs)), 0.)
        # aggregate information, use attention weights instead of adj matrix
        new_encs = self.state_updater(transf_encs, atts)
        return new_encs

In [15]:
# note we do not need to normilize adj matrix
class GraphAttModel(tf.keras.Model):
    
    # last graph layer must have 2 chennels to apply softmax 
    def __init__(self, graph_layers):
        super(GraphAttModel, self).__init__()
        self.gr_layers = [gr for gr in graph_layers]
        
    def compile(self, optimizer, loss_fn, metrics):
        super(GraphAttModel, self).compile()
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.loss_tracker = tf.keras.metrics.Mean(name='loss')
        self.metrics_ = metrics
        self.all_metrics = metrics+[self.loss_tracker]
    
    @tf.function
    def call(self, datas):
        (x, labels, adj) = datas
        mask = x[:,:,-1:]
        for gr_lr in self.gr_layers:
            x = gr_lr(x, adj)
        # yield correct predictions for auxillary hits
        preds = tf.where( tf.cast(mask,bool), x, tf.constant([0.,1.]) )
        return preds
        
    @property
    def metrics(self):
        return self.all_metrics
    
    @tf.function
    def train_step(self, datas):
        (x, labels, adj) = datas
        with tf.GradientTape() as tape:
            preds = self.call(datas)
            loss = self.loss_fn(labels,preds)
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients( zip(grads, self.trainable_weights) )
        self.loss_tracker.update_state(loss)
        prd_cls = tf.math.argmax( preds, axis=-1 )
        true_cls = tf.math.argmax( labels, axis=-1 )
        for m in self.metrics_:
            m.update_state(prd_cls, true_cls)
        ms = { m.name : m.result() for m in self.all_metrics }
        return ms
    
    @tf.function
    def test_step(self, datas):
        (x, labels, adj) = datas
        preds = self.call(datas)
        loss = self.loss_fn(labels,preds)
        self.loss_tracker.update_state(loss)
        prd_cls = tf.math.argmax( preds, axis=-1 )
        true_cls = tf.math.argmax( labels, axis=-1 )
        for m in self.metrics_:
            m.update_state(prd_cls, true_cls)
        ms = { m.name : m.result() for m in self.all_metrics }
        return ms

In [16]:
selu = tf.keras.activations.selu
softmax = tf.keras.activations.softmax

nodes_encs_length_s = [[16],[16]]
new_state_length_s = [[32],[2]]
att_units_s = [[16],[16]]
state_upd_acts = [selu,softmax]
att_upd_acts = [selu,selu]

gr_layers = []
for (nodes_l,news_l,new_act,att_l,att_act) in zip(nodes_encs_length_s,new_state_length_s,state_upd_acts,att_units_s,att_upd_acts):
    nodes_encoder = NodesEncoder(nodes_l, selu)
    state_updater = StateUpdater(news_l, new_act)
    attention_establisher = AttentionEstablisher(att_l, att_act)
    gr_layers.append( GraphAttLayer(nodes_encoder,state_updater,attention_establisher) )
    
gnn = GraphAttModel(gr_layers)

lr = 0.002
loss_fn = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=True)
metrics = [tf.keras.metrics.Accuracy()]
gnn.compile(optimizer, loss_fn, metrics)

In [17]:
gnn.fit(train_dataset, steps_per_epoch=2500, validation_steps=500, epochs=2, validation_data=test_dataset, verbose=1)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7f3f00229870>

### Message Passing Graph

Verticies keep information on themselves. Hence their encoding are not optimal for updating encodings of other verticies. It makes sence then to prepare messages from one node to enother used for updating states (MessageCreator). 

In this cases attention is not needed as messages should incorporate relevance.

State updater should be modified to use messages (does not need adj matrix).

In [18]:
# states updater MP
class StateUpdaterMP(tf.keras.layers.Layer):
    
    def __init__(self, units, activation):
        super(StateUpdaterMP, self).__init__()
        self.dense_layers = [ tf.keras.layers.Dense(un) for un in units ]
        self.activation = activation

    def call(self, vert):
        # update state
        for lr in self.dense_layers:
            vert = lr(vert)
            vert = self.activation(vert)
        return vert

In [19]:
# creates messages for pairs of nodes
class MessageCreator(tf.keras.layers.Layer):
    
    def __init__(self, hid_units, activation):
        super(MessageCreator, self).__init__()
        self.dense_layers = [ tf.keras.layers.Dense(un) for un in hid_units ]
        self.activation = activation

    def call(self, x):
        for lr in self.dense_layers:
            x = lr(x)
            x = self.activation(x)
        return x

In [20]:
class GraphMessPassLayer(tf.keras.layers.Layer):
    
    def __init__(self, nodes_encoder, state_updater, message_creator):
        super(GraphMessPassLayer, self).__init__()
        self.nodes_encoder = nodes_encoder
        self.message_creator = message_creator
        self.state_updater = state_updater

    def make_pairs(self, vert, adj):
        adj_exp = tf.expand_dims(adj,axis=-1)
        vert_exp_1 = tf.expand_dims(vert,axis=1) # (bs,1,om,c)
        vert_exp_2 = tf.expand_dims(vert,axis=2) # (bs,om,1,c)
        rel_vert = adj_exp*vert_exp_2 # (bs, om, om_targ, c)
        base_vert = adj_exp*vert_exp_1 # (bs, om_base, om, c)
        pairs = tf.concat((rel_vert,base_vert),axis=-1) # (bs, om_base, om_target, 2c)
        return pairs
        
    def call(self, vert, adj):
        # transform features
        transf_encs = self.nodes_encoder(vert)
        mask_adj = tf.expand_dims( tf.cast(adj, bool), axis=-1 )
        pairs = self.make_pairs(vert, adj)
        ## making loop over OMs might be better in terms of memory and calculations efficiency
        messages = tf.where(mask_adj, self.message_creator(pairs), [0.]) # (bs, om, om, c)
        # take mean over target OMs to form message
        n_neigh = tf.expand_dims( tf.reduce_sum( adj,axis=1 ), axis=-1 )
        messages = tf.math.reduce_sum(messages, axis=2, keepdims=False)/(n_neigh+1e-8)
        # concat message and current encodings
        upd_from = tf.concat((transf_encs,messages), axis=-1)
        # update state
        new_encs = self.state_updater(upd_from)
        return new_encs

In [21]:
# note we do not need to normilize adj matrix
class GraphMPModel(tf.keras.Model):
    
    # last graph layer must have 2 chennels to apply softmax 
    def __init__(self, graph_layers):
        super(GraphMPModel, self).__init__()
        self.gr_layers = [gr for gr in graph_layers]
        
    def compile(self, optimizer, loss_fn, metrics):
        super(GraphMPModel, self).compile()
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.loss_tracker = tf.keras.metrics.Mean(name='loss')
        self.metrics_ = metrics
        self.all_metrics = metrics+[self.loss_tracker]
    
    @tf.function
    def call(self, datas):
        (x, labels, adj) = datas
        mask = x[:,:,-1:]
        for gr_lr in self.gr_layers:
            x = gr_lr(x, adj)
        # yield correct predictions for auxillary hits
        preds = tf.where( tf.cast(mask,bool), x, tf.constant([0.,1.]) )
        return preds
        
    @property
    def metrics(self):
        return self.all_metrics
    
    @tf.function
    def train_step(self, datas):
        (x, labels, adj) = datas
        with tf.GradientTape() as tape:
            preds = self.call(datas)
            loss = self.loss_fn(labels,preds)
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients( zip(grads, self.trainable_weights) )
        self.loss_tracker.update_state(loss)
        prd_cls = tf.math.argmax( preds, axis=-1 )
        true_cls = tf.math.argmax( labels, axis=-1 )
        for m in self.metrics_:
            m.update_state(prd_cls, true_cls)
        ms = { m.name : m.result() for m in self.all_metrics }
        return ms
    
    @tf.function
    def test_step(self, datas):
        (x, labels, adj) = datas
        preds = self.call(datas)
        loss = self.loss_fn(labels,preds)
        self.loss_tracker.update_state(loss)
        prd_cls = tf.math.argmax( preds, axis=-1 )
        true_cls = tf.math.argmax( labels, axis=-1 )
        for m in self.metrics_:
            m.update_state(prd_cls, true_cls)
        ms = { m.name : m.result() for m in self.all_metrics }
        return ms

In [22]:
selu = tf.keras.activations.selu
softmax = tf.keras.activations.softmax

nodes_encs_length_s = [[16],[16]]
new_state_length_s = [[32],[2]]
mess_units_s = [[16],[16]]
state_upd_acts = [selu,softmax]
mess_acts = [selu,selu]

gr_layers = []
for (nodes_l,news_l,new_act,mess_l,mess_act) in zip(nodes_encs_length_s,new_state_length_s,state_upd_acts,mess_units_s,mess_acts):
    nodes_encoder = NodesEncoder(nodes_l, selu)
    state_updater = StateUpdaterMP(news_l, new_act)
    message_creator = MessageCreator(mess_l, mess_act)
    gr_layers.append( GraphMessPassLayer(nodes_encoder,state_updater,message_creator) )
    
gnn = GraphAttModel(gr_layers)

lr = 0.002
loss_fn = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=True)
metrics = [tf.keras.metrics.Accuracy()]
gnn.compile(optimizer, loss_fn, metrics)

In [23]:
gnn.fit(train_dataset, steps_per_epoch=2500, validation_steps=500, epochs=2, validation_data=test_dataset, verbose=1)

Epoch 1/2


2022-05-02 17:17:31.239158: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:828] layout failed: INVALID_ARGUMENT: Size of values 1 does not match size of permutation 4 @ fanin shape inStatefulPartitionedCall/StatefulPartitionedCall/graph_mess_pass_layer/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer




2022-05-02 17:20:03.422661: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:828] layout failed: INVALID_ARGUMENT: Size of values 1 does not match size of permutation 4 @ fanin shape inStatefulPartitionedCall/StatefulPartitionedCall/graph_mess_pass_layer/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/2


<keras.callbacks.History at 0x7f3ef86bc310>

### Possible improvements.

1) Usually, GNN should be agnostic to vetricies labeling (hence taking mean). In our case, however, we have a very good choice of natural ordering - according to activation times. We can make use of this ordering by making OM-wise convolutions for nodes transformations.

2) Introduce edge features.This might be usefull for track segmentation.

3) Introduce master vertex. It stores information about the graph as a whole. At each iteration, it can be used to update vertx encoding and vice versa.

4) Make adj matrix dynamic (like edge feature). That is, all vetricies are connected and elements in adj matrix define the weight of the connection.

In [None]:
# generator without shuffling
# yields (data, labels, adjacency)
class generator_no_shuffle:
    
    def __init__(self, file, regime, batch_size, return_reminder, k_nearest):
        self.file = file
        self.regime = regime
        self.batch_size = batch_size
        self.return_reminder = return_reminder
        with h5.File(self.file,'r') as hf:
            self.num = hf[self.regime+'/data'].shape[0]
            self.data_length = 32  # hf[self.regime+'/data'].shape[1]   # поставлю
        self.batch_num = self.num // self.batch_size
        if return_reminder:
            self.gen_num = self.num
        else:
            self.gen_num = self.batch_num*self.batch_size
        te = [ np.expand_dims(np.eye(self.data_length),axis=0)]
        for i in range(1,k_nearest):
            te.append(np.expand_dims(np.eye(self.data_length, k=i),axis=0))
            te.append(np.expand_dims(np.eye(self.data_length, k=-i),axis=0))
        self.full_adj = np.sum( np.concatenate(te, axis=0), axis=0 )

    def __call__(self):
        start = 0
        stop = self.batch_size
        with h5.File(self.file, 'r') as hf:
            for i in range(self.batch_num):
                #mask = hf[self.regime+'/mask'][start:stop]
                mask = hf[self.regime+'/mask'][start:stop,:self.data_length]
                mask_channel = np.expand_dims( mask, axis=-1 )
                mask_channel_2 = np.expand_dims( mask, axis=1 )
                mask_to_adj = mask_channel*mask_channel_2
                #labels = hf[self.regime+'/labels_signal_noise'][start:stop]
                polar = hf[self.regime+'/ev_chars'][start:stop]
                cos, sin = np.expand_dims(np.cos(polar),axis=1) ,  np.expand_dims(np.sin(polar),axis=1)
                target = np.concatenate((cos, sin), axis=-1)
                #yield ( np.concatenate((hf[self.regime+'/data'][start:stop],mask_channel), axis=-1), 
                yield ( np.concatenate((hf[self.regime+'/data'][start:stop,:self.data_length], mask_channel), axis=-1), 
                   target, self.full_adj*mask_to_adj ) #labels,
                start += self.batch_size
                stop += self.batch_size
            if self.return_reminder:
                #mask = hf[self.regime+'/mask'][start:stop]
                mask = hf[self.regime+'/mask'][start:stop,:self.data_length]
                mask_channel = np.expand_dims( mask, axis=-1 )
                mask_channel_2 = np.expand_dims( mask, axis=1 )
                mask_to_adj = mask_channel*mask_channel_2
                #labels = hf[self.regime+'/labels_signal_noise'][start:stop]
                polar = hf[self.regime+'/ev_chars'][start:stop]
                cos, sin = np.expand_dims(np.cos(polar),axis=1) ,  np.expand_dims(np.sin(polar),axis=1)
                target = np.concatenate((cos, sin), axis=-1)
                #yield ( np.concatenate((hf[self.regime+'/data'][start:stop],mask_channel), axis=-1),
                yield ( np.concatenate((hf[self.regime+'/data'][start:stop,:self.data_length], mask_channel)
                   axis=-1), target, self.full_adj*mask_to_adj ) #labels,
                
# generator with shuffling
class generator_with_shuffle:
    
    def __init__(self, file, regime, batch_size, buffer_size, return_reminder, k_nearest):
        self.file = file
        self.regime = regime
        self.batch_size = batch_size
        self.return_reminder = return_reminder
        self.buffer_size = buffer_size
        with h5.File(self.file,'r') as hf:
            self.num = hf[self.regime+'/data/'].shape[0]
            self.data_length = 32 #hf[self.regime+'/data'].shape[1]
        self.batch_num = (self.num-self.buffer_size) // self.batch_size
        self.last_batches_num = self.buffer_size // self.batch_size
        if return_reminder:
            self.gen_num = self.num
        else:
            self.gen_num = (self.batch_num+self.last_batches_num)*self.batch_size
        te = [ np.expand_dims(np.eye(self.data_length),axis=0)]
        for i in range(1,k_nearest):
            te.append(np.expand_dims(np.eye(self.data_length, k=i),axis=0))
            te.append(np.expand_dims(np.eye(self.data_length, k=-i),axis=0))
        self.full_adj = np.sum( np.concatenate(te, axis=0), axis=0 )

    def __call__(self):
        start = self.buffer_size
        stop = self.buffer_size + self.batch_size
        with h5.File(self.file, 'r') as hf:
            mask = hf[self.regime+'/mask'][:self.buffer_size,:self.data_length]
            mask_channel = np.expand_dims( mask, axis=-1 )
            mask_channel_2 = np.expand_dims( mask, axis=1 )
            buffer_data = np.concatenate((hf[self.regime+'/data'][:self.buffer_size,:self.data_length],mask_channel),
                                         axis=-1)
            polar = hf[self.regime+'/ev_chars'][:self.buffer_size]
            cos, sin = np.expand_dims(np.cos(polar),axis=1) ,  np.expand_dims(np.sin(polar),axis=1)
            buffer_target = np.concatenate((cos, sin), axis=-1)            

            buffer_adj = self.full_adj*mask_channel*mask_channel_2
            for i in range(self.batch_num):
                idxs = rd.sample( range(self.buffer_size), k=self.batch_size )
                yield ( buffer_data[idxs], buffer_target[idxs]   buffer_adj[idxs] )
                #yield ( buffer_data[idxs], buffer_labels[idxs], buffer_adj[idxs] )
                #mask = hf[self.regime+'/mask'][start:stop]
                mask = hf[self.regime+'/mask'][start:stop,:self.data_length]
                mask_channel = np.expand_dims( mask, axis=-1 )
                mask_channel_2 = np.expand_dims( mask, axis=1 )
                buffer_data[idxs] = np.concatenate((hf[self.regime+'/data'][start:stop,:self.data_length],mask_channel),
                                                   axis=-1)
                polar = hf[self.regime+'/ev_chars'][start:stop]
                cos, sin = np.expand_dims(np.cos(polar),axis=1) ,  np.expand_dims(np.sin(polar),axis=1)
                target = np.concatenate((cos, sin), axis=-1)  
                buffer_target[idx] = target
                adj = self.full_adj*mask_channel*mask_channel_2
                buffer_adj[idxs] = adj
                start += self.batch_size
                stop += self.batch_size
            # fill the buffer with left data, if any
            #mask = hf[self.regime+'/mask'][start:stop]
            mask = hf[self.regime+'/mask'][start:stop,:self.data_length]
            mask_channel = np.expand_dims( mask, axis=-1 )
            mask_channel_2 = np.expand_dims( mask, axis=1 )     
            buffer_data = np.concatenate( (buffer_data,np.concatenate((hf[self.regime+'/data'][start:stop,:self.data_length],mask_channel),
                                                                      axis=-1)), axis=0 )
            polar = hf[self.regime+'/ev_chars'][start:stop]
            cos, sin = np.expand_dims(np.cos(polar),axis=1) ,  np.expand_dims(np.sin(polar),axis=1)
            target = np.concatenate((cos, sin), axis=-1)  
            buffer_target = np.concatenate( (buffer_target,target), axis=0 )
            adj = self.full_adj*mask_channel*mask_channel_2
            buffer_adj = np.concatenate( (buffer_adj,adj), axis=0 )
            sh_idxs = rd.sample( range(buffer_target.shape[0]), k=buffer_target.shape[0] )
            start = 0
            stop = self.batch_size
            for i in range(self.last_batches_num):
                idxs = sh_idxs[start:stop]
                yield ( buffer_data[idxs], buffer_target[idxs], buffer_adj[idxs] )
                start += self.batch_size
                stop += self.batch_size
            if self.return_reminder:
                idxs = sh_idxs[start:stop]
                yield ( buffer_data[idxs], buffer_target[idxs], buffer_adj[idxs] )
                
### Datasets
def make_datasets(h5f,make_generator_shuffle,return_batch_reminder,train_batch_size,train_buffer_size,test_batch_size,k_nearest):
    # generator for training data
    if make_generator_shuffle:
        tr_generator = generator_with_shuffle(h5f,'train',train_batch_size,train_buffer_size,return_batch_reminder,k_nearest)
    else:
        tr_generator = generator_no_shuffle(h5f,'train',train_batch_size,return_batch_reminder,k_nearest)
    if return_batch_reminder:
        # size of the last batch is unknown
        tr_batch_size = None
    else:
        tr_batch_size = train_batch_size
    #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! здесь изменить
    train_dataset = tf.data.Dataset.from_generator( tr_generator, 
                        output_signature=( tf.TensorSpec(shape=(tr_batch_size,max_len,6)), tf.TensorSpec(shape=(tr_batch_size,max_len,2)),
                                         tf.TensorSpec(shape=(tr_batch_size,max_len,max_len))) )

    if make_generator_shuffle:
        train_dataset = train_dataset.repeat(-1).prefetch(tf.data.AUTOTUNE)
    else:
        train_dataset = train_dataset.repeat(-1).shuffle(num_batch_shuffle)

    # generator for validation data
    te_generator = generator_no_shuffle(h5f,'test',test_batch_size,False,k_nearest)
    te_batch_size = tr_batch_size
    test_dataset = tf.data.Dataset.from_generator( te_generator, 
                        output_signature=( tf.TensorSpec(shape=(tr_batch_size,max_len,6)), tf.TensorSpec(shape=(tr_batch_size,max_len,2)),
                                          tf.TensorSpec(shape=(tr_batch_size,max_len,max_len))) )
    
    test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

    return train_dataset, test_dataset