In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import warnings
with warnings.catch_warnings():
    warnings.filterwarnings("ignore",category=FutureWarning)

In [3]:
# MultiHeadAttention
# https://www.tensorflow.org/tutorials/text/transformer, appears in "Attention is all you need" NIPS 2018 paper
import numpy as np
import tensorflow as tf


def scaled_dot_product_attention(q, k, v, mask):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
    The mask has different shapes depending on its type(padding or look ahead) 
    but it must be broadcastable for addition.

    Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

    Returns:
    output, attention_weights
    """

    matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
  
    # scale matmul_qk
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # add the mask to the scaled tensor.
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)  

    # softmax is normalized on the last axis (seq_len_k) so that the scores
    # add up to 1.
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

    return output, attention_weights


def print_out(q, k, v):
    temp_out, temp_attn = scaled_dot_product_attention(
      q, k, v, None)
    print ('Attention weights are:')
    print (temp_attn)
    print ('Output is:')
    print (temp_out)
    
np.set_printoptions(suppress=True)

temp_k = tf.constant([[10,0,0],
                      [0,10,0],
                      [0,0,10],
                      [0,0,10]], dtype=tf.float32)  # (4, 3)

temp_v = tf.constant([[   1,0],
                      [  10,0],
                      [ 100,5],
                      [1000,6]], dtype=tf.float32)  # (4, 2)

# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)

Attention weights are:
tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[10.  0.]], shape=(1, 2), dtype=float32)


In [4]:
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model

    
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0
    
        self.depth = d_model // self.num_heads
    
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)
        
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, q, k, v, mask=None):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask)
    
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = tf.reshape(scaled_attention, 
                                      (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output
    

# temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
# y = tf.random.uniform((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
# out = temp_mha(v=y, k=y, q=y)
# print(out.shape)


class RFF(tf.keras.layers.Layer):
    """
    Row-wise FeedForward layers.
    """
    def __init__(self, d):
        super(RFF, self).__init__()
        
        self.linear_1 = Dense(d, activation='relu')
        self.linear_2 = Dense(d, activation='relu')
        self.linear_3 = Dense(d, activation='relu')
            
    def call(self, x):
        """
        Arguments:
            x: a float tensor with shape [b, n, d].
        Returns:
            a float tensor with shape [b, n, d].
        """
        return self.linear_3(self.linear_2(self.linear_1(x)))   


# mlp = RFF(3)
# y = mlp(tf.ones(shape=(2, 4, 3)))  # The first call to the `mlp` will create the weights
# print('weights:', len(mlp.weights))
# print('trainable weights:', len(mlp.trainable_weights))

In [357]:
# Referencing https://arxiv.org/pdf/1810.00825.pdf 
# and the original PyTorch implementation https://github.com/TropComplique/set-transformer/blob/master/blocks.py
from tensorflow import repeat
# from tensorflow.keras.backend import repeat_elements
from tensorflow.keras.layers import LayerNormalization


class MultiHeadAttentionBlock(tf.keras.layers.Layer):
    def __init__(self, d, h, rff):
        super(MultiHeadAttentionBlock, self).__init__()
        self.multihead = MultiHeadAttention(d, h)
        self.layer_norm1 = LayerNormalization(epsilon=1e-6, dtype='float32')
        self.layer_norm2 = LayerNormalization(epsilon=1e-6, dtype='float32')
        self.rff = rff
    
    def call(self, x, y):
        """
        Arguments:
            x: a float tensor with shape [b, n, d].
            y: a float tensor with shape [b, m, d].
        Returns:
            a float tensor with shape [b, n, d].
        """
    
        h = self.layer_norm1(x + self.multihead(x, y, y))
        return self.layer_norm2(h + self.rff(h))

# x_data = tf.random.normal(shape=(10, 2, 9))
# y_data = tf.random.normal(shape=(10, 3, 9))
# rff = RFF(d=9)
# mab = MultiHeadAttentionBlock(9, 3, rff=rff)
# mab(x_data, y_data).shape    

    
class SetAttentionBlock(tf.keras.layers.Layer):
    def __init__(self, d, h, rff):
        super(SetAttentionBlock, self).__init__()
        self.mab = MultiHeadAttentionBlock(d, h, rff)
    
    def call(self, x):
        """
        Arguments:
            x: a float tensor with shape [b, n, d].
        Returns:
            a float tensor with shape [b, n, d].
        """
        return self.mab(x, x)

    
class InducedSetAttentionBlock(tf.keras.layers.Layer):
    def __init__(self, d, m, h, rff1, rff2):
        """
        Arguments:
            d: an integer, input dimension.
            m: an integer, number of inducing points.
            h: an integer, number of heads.
            rff1, rff2: modules, row-wise feedforward layers.
                It takes a float tensor with shape [b, n, d] and
                returns a float tensor with the same shape.
        """
        super(InducedSetAttentionBlock, self).__init__()
        self.mab1 = MultiHeadAttentionBlock(d, h, rff1)
        self.mab2 = MultiHeadAttentionBlock(d, h, rff2)
        self.inducing_points = tf.random.normal(shape=(1, m, d))

    def call(self, x):
        """
        Arguments:
            x: a float tensor with shape [b, n, d].
        Returns:
            a float tensor with shape [b, n, d].
        """
#         b = x.shape[0] 
        b = 1000
        p = self.inducing_points
        p = repeat_elements(p, (b), axis=0)  # shape [b, m, d]  
        
        h = self.mab1(p, x)  # shape [b, m, d]
        return self.mab2(x, h)     
    

class PoolingMultiHeadAttention(tf.keras.layers.Layer):

    def __init__(self, d, k, h, rff, rff_s):
        """
        Arguments:
            d: an integer, input dimension.
            k: an integer, number of seed vectors.
            h: an integer, number of heads.
            rff: a module, row-wise feedforward layers.
                It takes a float tensor with shape [b, n, d] and
                returns a float tensor with the same shape.
        """
        super(PoolingMultiHeadAttention, self).__init__()
        self.mab = MultiHeadAttentionBlock(d, h, rff)
        self.seed_vectors = tf.random.normal(shape=(1, k, d))
        self.rff_s = rff_s

    @tf.function
    def call(self, z):
        """
        Arguments:
            z: a float tensor with shape [b, n, d].
        Returns:
            a float tensor with shape [b, k, d]
        """
        b = tf.shape(z)[0]
        s = self.seed_vectors
        s = repeat(s, (b), axis=0)  # shape [b, k, d]
        return self.mab(s, self.rff_s(z))
    

# z = tf.random.normal(shape=(10, 2, 9))
# rff, rff_s = RFF(d=9), RFF(d=9) 
# pma = PoolingMultiHeadAttention(d=9, k=10, h=3, rff=rff, rff_s=rff_s)
# pma(z).shape

In [415]:
from tensorflow.keras.layers import Dense
    

class STEncoderBasic(tf.keras.layers.Layer):
    def __init__(self, d=12, m=6, h=6):
        super(STEncoderBasic, self).__init__()
        
        # Embedding part
        self.linear_1 = Dense(d, activation='relu')
        
        # Encoding part
        self.isab_1 = SetAttentionBlock(d, h, RFF(d))
        self.isab_2 = SetAttentionBlock(d, h, RFF(d))
            
    def call(self, x):
        return self.isab_2(self.isab_1(self.linear_1(x)))

    
class STDecoderBasic(tf.keras.layers.Layer):
    def __init__(self, out_dim, d=12, m=6, h=2, k=8):
        super(STDecoderBasic, self).__init__()
        
        self.PMA = PoolingMultiHeadAttention(d, k, h, RFF(d), RFF(d))
        self.SAB = SetAttentionBlock(d, h, RFF(d))
        self.output_mapper = Dense(out_dim)   
        self.k, self.d = k, d

    def call(self, x):
        decoded_vec = self.SAB(self.PMA(x))
        decoded_vec = tf.reshape(decoded_vec, [-1, self.k * self.d])
        return tf.reshape(self.output_mapper(decoded_vec), (tf.shape(decoded_vec)[0],))


In [416]:
def gen_max_dataset(batch_size=1000, set_size=9):
    """
    The number of objects per set is constant in this toy example
    """
    x = np.random.uniform(1, 100, (batch_size, set_size))
    y = np.max(x, axis=1)
    x, y = np.expand_dims(x, axis=2), np.expand_dims(y, axis=1)
    return tf.cast(x, 'float32'), tf.cast(y, 'float32')

x, y = gen_max_dataset()
x.shape

TensorShape([1000, 9, 1])

In [417]:
# Generate training/test set
n_train_batches, n_test_batches = 1, 1
train_data = [gen_max_dataset() for i in range(n_train_batches)]
test_data = [gen_max_dataset() for i in range(n_test_batches)]

X_train, y_train = [i[0] for i in train_data], [i[1] for i in train_data]
X_test, y_test = [i[0] for i in test_data], [i[1] for i in test_data]

X_train = X_train[0].numpy()
y_train = y_train[0].numpy().reshape(1000)

In [418]:
# Dimensionality check on encoder-decoder couple

encoder = STEncoderBasic(d=3, m=2, h=1)
encoded = encoder(X_train)
print(encoded.shape)

decoder = STDecoderBasic(out_dim=1, d=1, m=2, h=1, k=1)
decoded = decoder(encoded)
print(decoded.shape)

(1000, 9, 3)
(1000,)


In [430]:
# Actual model for max-set prediction

class SetTransformer(tf.keras.Model):
    def __init__(self, ):
        super(SetTransformer, self).__init__()
        self.basic_encoder = STEncoderBasic(d=4, m=3, h=2)
        self.basic_decoder = STDecoderBasic(out_dim=1, d=4, m=2, h=2, k=2)
    
    def call(self, x):
        enc_output = self.basic_encoder(x)  # (batch_size, set_len, d_model)
        return self.basic_decoder(enc_output)
    
set_t = SetTransformer()
set_t(X_train).shape

TensorShape([1000])

In [431]:
set_transformer = SetTransformer()
set_transformer.compile(loss='mae', optimizer='adam')
set_transformer.fit(X_train, y_train, epochs=300, batch_size=64)

Train on 1000 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300

Epoch 92/300
Epoch 93/300
Epoch 94/300
Epoch 95/300
Epoch 96/300
Epoch 97/300
Epoch 98/300
Epoch 99/300
Epoch 100/300
Epoch 101/300
Epoch 102/300
Epoch 103/300
Epoch 104/300
Epoch 105/300
Epoch 106/300
Epoch 107/300
Epoch 108/300
Epoch 109/300
Epoch 110/300
Epoch 111/300
Epoch 112/300
Epoch 113/300
Epoch 114/300
Epoch 115/300
Epoch 116/300
Epoch 117/300
Epoch 118/300
Epoch 119/300
Epoch 120/300
Epoch 121/300
Epoch 122/300
Epoch 123/300
Epoch 124/300
Epoch 125/300
Epoch 126/300
Epoch 127/300
Epoch 128/300
Epoch 129/300
Epoch 130/300
Epoch 131/300
Epoch 132/300
Epoch 133/300
Epoch 134/300
Epoch 135/300
Epoch 136/300
Epoch 137/300
Epoch 138/300
Epoch 139/300
Epoch 140/300
Epoch 141/300
Epoch 142/300
Epoch 143/300
Epoch 144/300
Epoch 145/300
Epoch 146/300
Epoch 147/300
Epoch 148/300
Epoch 149/300
Epoch 150/300
Epoch 151/300
Epoch 152/300
Epoch 153/300
Epoch 154/300
Epoch 155/300
Epoch 156/300
Epoch 157/300
Epoch 158/300
Epoch 159/300
Epoch 160/300
Epoch 161/300
Epoch 162/300
Epoch 163/300


Epoch 183/300
Epoch 184/300
Epoch 185/300
Epoch 186/300
Epoch 187/300
Epoch 188/300
Epoch 189/300
Epoch 190/300
Epoch 191/300
Epoch 192/300
Epoch 193/300
Epoch 194/300
Epoch 195/300
Epoch 196/300
Epoch 197/300
Epoch 198/300
Epoch 199/300
Epoch 200/300
Epoch 201/300
Epoch 202/300
Epoch 203/300
Epoch 204/300
Epoch 205/300
Epoch 206/300
Epoch 207/300
Epoch 208/300
Epoch 209/300
Epoch 210/300
Epoch 211/300
Epoch 212/300
Epoch 213/300
Epoch 214/300
Epoch 215/300
Epoch 216/300
Epoch 217/300
Epoch 218/300
Epoch 219/300
Epoch 220/300
Epoch 221/300
Epoch 222/300
Epoch 223/300
Epoch 224/300
Epoch 225/300
Epoch 226/300
Epoch 227/300
Epoch 228/300
Epoch 229/300
Epoch 230/300
Epoch 231/300
Epoch 232/300
Epoch 233/300
Epoch 234/300
Epoch 235/300
Epoch 236/300
Epoch 237/300
Epoch 238/300
Epoch 239/300
Epoch 240/300
Epoch 241/300
Epoch 242/300
Epoch 243/300
Epoch 244/300
Epoch 245/300
Epoch 246/300
Epoch 247/300
Epoch 248/300
Epoch 249/300
Epoch 250/300
Epoch 251/300
Epoch 252/300
Epoch 253/300
Epoch 

Epoch 274/300
Epoch 275/300
Epoch 276/300
Epoch 277/300
Epoch 278/300
Epoch 279/300
Epoch 280/300
Epoch 281/300
Epoch 282/300
Epoch 283/300
Epoch 284/300
Epoch 285/300
Epoch 286/300
Epoch 287/300
Epoch 288/300
Epoch 289/300
Epoch 290/300
Epoch 291/300
Epoch 292/300
Epoch 293/300
Epoch 294/300
Epoch 295/300
Epoch 296/300
Epoch 297/300
Epoch 298/300
Epoch 299/300
Epoch 300/300


<tensorflow.python.keras.callbacks.History at 0x7f35fd0777f0>

In [425]:
prediction = set_transformer.predict(X_test)
np.abs(prediction - y_test).mean()

6.777