In [48]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [49]:
# Import common tensorflow layers and activations
# Import tensor kerastensor
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Layer
from tensorflow.keras.layers import Lambda, Multiply, Add 
from tensorflow.keras.activations import relu, sigmoid, softmax
from tensorflow.keras import Model
from tensorflow.keras import Sequential
from tensorflow.keras import Input

In [87]:
class AttentiveTransformer(Layer):
    def __init__(self, units, priors = None, gamma=1.3, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.units = units
        self.priors = priors if priors is not None else []
        self.gamma = tf.constant(gamma, dtype=tf.float32)
        # self.forward = None
        self.dense = Dense(self.units, name=f"{self.name}_dense")
        self.BN = BatchNormalization(name=f"{self.name}_BN")

    
    # def build(self, input_shape):
    #     # super().build(input_shape)
    #     _in = Input(shape=input_shape[1:], name=f"{self._name}_input")
    #     x = Dense(self.units, name=f"{self._name}_dense")(_in)
    #     x = BatchNormalization(name=f"{self._name}_BN")(x)
    #     print("building", self.name)
    #     for prior in self.priors:
    #         print("prior_mul", prior._keras_history.layer.name)
    #         x = Multiply()([(self.gamma - prior),x])
    #     x = softmax(x, axis=-1)
    #     self.forward = Model(_in, x)
    
    def call(self, data):
        candidate_mask = self.BN(self.dense(data))
        print(self.priors)
        for prior in self.priors:
            print(prior)
            candidate_mask = candidate_mask * (self.gamma - prior)
        return candidate_mask
        

In [84]:
# Test AttentiveTransformer on normal data with empty list of priors
at = AttentiveTransformer(units=10, priors=[])
x = tf.random.normal((1, 7))
y = at(x)
x, y

ListWrapper([])


(<tf.Tensor: shape=(1, 7), dtype=float32, numpy=
 array([[-3.7246137 , -1.2955515 ,  1.0983032 , -1.7555252 ,  0.35104498,
         -2.64866   ,  0.9813243 ]], dtype=float32)>,
 <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[-2.088805  ,  2.7135048 , -0.96651167, -3.3743446 , -2.7218435 ,
         -2.6411772 ,  1.7550788 ,  0.5382888 , -0.23096025, -2.237022  ]],
       dtype=float32)>)

In [88]:
# Test AttentiveTransformer on normal data with instance of itself as prior
at2 = AttentiveTransformer(units=10, priors=[y, 1])
z = at(x)
z

ListWrapper([])


<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[-2.088805  ,  2.7135048 , -0.96651167, -3.3743446 , -2.7218435 ,
        -2.6411772 ,  1.7550788 ,  0.5382888 , -0.23096025, -2.237022  ]],
      dtype=float32)>

In [86]:
# Test graph construction of AttentiveTransformer with no prior
test_in = Input(shape=(7,), name="test_in", dtype="float32")
at1 = AttentiveTransformer(units=7, priors=[], name="test_attn_1")
y = at1(test_in)

ListWrapper([])


In [89]:

at2 = AttentiveTransformer(units=7, name="test_attn_2")
x = at2(test_in)*(1.3-y)
Model(test_in, x*y)

ListWrapper([])


<keras.engine.functional.Functional at 0x7f0e8d7ea9d0>