In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import warnings
with warnings.catch_warnings():
    warnings.filterwarnings("ignore",category=FutureWarning)

In [2]:
# MultiHeadAttention
# https://www.tensorflow.org/tutorials/text/transformer, appears in "Attention is all you need" NIPS 2018 paper
import numpy as np
import tensorflow as tf


def scaled_dot_product_attention(q, k, v, mask):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
    The mask has different shapes depending on its type(padding or look ahead) 
    but it must be broadcastable for addition.

    Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

    Returns:
    output, attention_weights
    """

    matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
  
    # scale matmul_qk
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # add the mask to the scaled tensor.
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)  

    # softmax is normalized on the last axis (seq_len_k) so that the scores
    # add up to 1.
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

    return output, attention_weights


def print_out(q, k, v):
    temp_out, temp_attn = scaled_dot_product_attention(
      q, k, v, None)
    print ('Attention weights are:')
    print (temp_attn)
    print ('Output is:')
    print (temp_out)
    
np.set_printoptions(suppress=True)

temp_k = tf.constant([[10,0,0],
                      [0,10,0],
                      [0,0,10],
                      [0,0,10]], dtype=tf.float32)  # (4, 3)

temp_v = tf.constant([[   1,0],
                      [  10,0],
                      [ 100,5],
                      [1000,6]], dtype=tf.float32)  # (4, 2)

# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)

Attention weights are:
tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[10.  0.]], shape=(1, 2), dtype=float32)


In [3]:
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model

    
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0
    
        self.depth = d_model // self.num_heads
    
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)
        
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, q, k, v, mask=None):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask)
    
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = tf.reshape(scaled_attention, 
                                      (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output
    

# temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
# y = tf.random.uniform((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
# out = temp_mha(v=y, k=y, q=y)
# print(out.shape)


class RFF(tf.keras.layers.Layer):
    """
    Row-wise FeedForward layers.
    """
    def __init__(self, d):
        super(RFF, self).__init__()
        
        self.linear_1 = Dense(d, activation='relu')
        self.linear_2 = Dense(d, activation='relu')
        self.linear_3 = Dense(d, activation='relu')
            
    def call(self, x):
        """
        Arguments:
            x: a float tensor with shape [b, n, d].
        Returns:
            a float tensor with shape [b, n, d].
        """
        return self.linear_3(self.linear_2(self.linear_1(x)))   


# mlp = RFF(3)
# y = mlp(tf.ones(shape=(2, 4, 3)))  # The first call to the `mlp` will create the weights
# print('weights:', len(mlp.weights))
# print('trainable weights:', len(mlp.trainable_weights))

In [33]:
# Referencing https://arxiv.org/pdf/1810.00825.pdf 
# and the original PyTorch implementation https://github.com/TropComplique/set-transformer/blob/master/blocks.py
# from tensorflow import repeat
from tensorflow.keras.backend import repeat_elements
from tensorflow.keras.layers import LayerNormalization


class MultiHeadAttentionBlock(tf.keras.layers.Layer):
    def __init__(self, d, h, rff):
        super(MultiHeadAttentionBlock, self).__init__()
        self.multihead = MultiHeadAttention(d, h)
        self.layer_norm1 = LayerNormalization(epsilon=1e-6, dtype='float32')
        self.layer_norm2 = LayerNormalization(epsilon=1e-6, dtype='float32')
        self.rff = rff
    
    def call(self, x, y):
        """
        Arguments:
            x: a float tensor with shape [b, n, d].
            y: a float tensor with shape [b, m, d].
        Returns:
            a float tensor with shape [b, n, d].
        """
    
        h = self.layer_norm1(x + self.multihead(x, y, y))
        return self.layer_norm2(h + self.rff(h))

# x_data = tf.random.normal(shape=(10, 2, 9))
# y_data = tf.random.normal(shape=(10, 3, 9))
# rff = RFF(d=9)
# mab = MultiHeadAttentionBlock(9, 3, rff=rff)
# mab(x_data, y_data).shape    

    
class SetAttentionBlock(tf.keras.layers.Layer):
    def __init__(self, d, h, rff):
        super(SetAttentionBlock, self).__init__()
        self.mab = MultiHeadAttentionBlock(d, h, rff)
    
    def call(self, x):
        """
        Arguments:
            x: a float tensor with shape [b, n, d].
        Returns:
            a float tensor with shape [b, n, d].
        """
        return self.mab(x, x)

    
class InducedSetAttentionBlock(tf.keras.layers.Layer):
    def __init__(self, d, m, h, rff1, rff2):
        """
        Arguments:
            d: an integer, input dimension.
            m: an integer, number of inducing points.
            h: an integer, number of heads.
            rff1, rff2: modules, row-wise feedforward layers.
                It takes a float tensor with shape [b, n, d] and
                returns a float tensor with the same shape.
        """
        super(InducedSetAttentionBlock, self).__init__()
        self.mab1 = MultiHeadAttentionBlock(d, h, rff1)
        self.mab2 = MultiHeadAttentionBlock(d, h, rff2)
#         self.inducing_points = tf.random.normal(shape=(1, m, d))
        self.inducing_points = tf.Variable(initial_value=tf.random.normal(shape=(1, m, d)), 
                                           trainable=True)

    def call(self, x):
        """
        Arguments:
            x: a float tensor with shape [b, n, d].
        Returns:
            a float tensor with shape [b, n, d].
        """
        b = x.shape[0]
        p = self.inducing_points
        p = repeat(p, (b), axis=0)  # shape [b, m, d]  
        h = self.mab1(p, x)  # shape [b, m, d]
        return self.mab2(x, h)     
    

class PoolingMultiHeadAttention(tf.keras.layers.Layer):

    def __init__(self, d, k, h, rff, rff_s):
        """
        Arguments:
            d: an integer, input dimension.
            k: an integer, number of seed vectors.
            h: an integer, number of heads.
            rff: a module, row-wise feedforward layers.
                It takes a float tensor with shape [b, n, d] and
                returns a float tensor with the same shape.
        """
        super(PoolingMultiHeadAttention, self).__init__()
        self.mab = MultiHeadAttentionBlock(d, h, rff)
#         self.seed_vectors = tf.random.normal(shape=(1, k, d))
        self.seed_vectors = self.add_weight(initializer='uniform',
                                            shape=(1, k, d),
                                            trainable=True)
        self.rff_s = rff_s

    def call(self, z):
        """
        Arguments:
            z: a float tensor with shape [b, n, d].
        Returns:
            a float tensor with shape [b, k, d]
        """
        b = z.shape[0]
        s = self.seed_vectors
        try:
            s = repeat_elements(s, (b), axis=0)  # shape [b, k, d]
        except TypeError:
            print(s.shape, z.shape)
            print(s)
            s = tf.ones(shape=z.shape)
        return self.mab(s, self.rff_s(z))
    

# z = tf.random.normal(shape=(10, 2, 9))
# rff, rff_s = RFF(d=9), RFF(d=9) 
# pma = PoolingMultiHeadAttention(d=9, k=10, h=3, rff=rff, rff_s=rff_s)
# pma(z).shape

In [34]:
from tensorflow.keras.layers import Dense


# class STEncoderBasic(tf.keras.layers.Layer):
#     def __init__(self, d=12, m=6, h=6):
#         super(STEncoderBasic, self).__init__()
        
#         # Embedding part
#         self.linear_1 = Dense(d, activation='relu')
        
#         # Encoding part
#         self.isab_1 = InducedSetAttentionBlock(d, m, h, RFF(d), RFF(d))
#         self.isab_2 = InducedSetAttentionBlock(d, m, h, RFF(d), RFF(d)) 
            
#     def call(self, x):
#         return self.isab_2(self.isab_1(self.linear_1(x)))

    
# class STDecoderBasic(tf.keras.layers.Layer):
#     def __init__(self, out_dim, d=12, m=6, h=2, k=8):
#         super(STDecoderBasic, self).__init__()
        
#         self.PMA = PoolingMultiHeadAttention(d, k, h, RFF(d), RFF(d))
#         self.SAB = SetAttentionBlock(d, h, RFF(d))
#         self.output_mapper = Dense(out_dim)      

#     def call(self, x):
#         decoded_vec = self.SAB(self.PMA(x))
#         b, k, d = decoded_vec.shape
#         decoded_vec = tf.reshape(decoded_vec, [b, k * d])
#         return tf.reshape(self.output_mapper(decoded_vec), b)


class STEncoderBasic(tf.keras.layers.Layer):
    def __init__(self, d=12, m=6, h=6):
        super(STEncoderBasic, self).__init__()
        
        # Embedding part
        self.linear_1 = Dense(d, activation='relu')
        
        # Encoding part
        self.isab_1 = SetAttentionBlock(d, h, RFF(d))
        self.isab_2 = SetAttentionBlock(d, h, RFF(d)) 
            
    def call(self, x):
        return self.isab_2(self.isab_1(self.linear_1(x)))

    
class STDecoderBasic(tf.keras.layers.Layer):
    def __init__(self, out_dim, d=12, m=6, h=2, k=8):
        super(STDecoderBasic, self).__init__()
        
        self.PMA = PoolingMultiHeadAttention(d, k, h, RFF(d), RFF(d))
        self.SAB = SetAttentionBlock(d, h, RFF(d))
        self.output_mapper = Dense(out_dim)      

    def call(self, x):
        decoded_vec = self.SAB(self.PMA(x))
        b, k, d = decoded_vec.shape
        decoded_vec = tf.reshape(decoded_vec, [b, k * d])
        return tf.reshape(self.output_mapper(decoded_vec), b)


In [35]:
def gen_max_dataset(batch_size=1000, set_size=9):
    """
    The number of objects per set is constant in this toy example
    """
    x = np.random.uniform(1, 100, (batch_size, set_size))
    y = np.max(x, axis=1)
    x, y = np.expand_dims(x, axis=2), np.expand_dims(y, axis=1)
    return tf.cast(x, 'float32'), tf.cast(y, 'float32')

x, y = gen_max_dataset()
x.shape

TensorShape([1000, 9, 1])

In [36]:
# Generate training/test set
n_train_batches, n_test_batches = 1, 1
train_data = [gen_max_dataset() for i in range(n_train_batches)]
test_data = [gen_max_dataset() for i in range(n_test_batches)]

X_train, y_train = [i[0] for i in train_data], [i[1] for i in train_data]
X_test, y_test = [i[0] for i in test_data], [i[1] for i in test_data]

# print(len(X_test), len(y_test))
# X_test[0].shape, y_test[0].shape


X_train = X_train[0].numpy()
y_train = y_train[0].numpy().reshape(1000)

In [37]:
# Dimensionality check on encoder-decoder couple

encoder = STEncoderBasic(d=3, m=2, h=1)
encoded = encoder(X_train)
print(encoded.shape)

decoder = STDecoderBasic(out_dim=1, d=3, m=2, h=1, k=1)
decoded = decoder(encoded)
print(decoded.shape)

(1000, 9, 3)
(1000,)


In [38]:
# Actual model for max-set prediction

class SetTransformer(tf.keras.Model):
    def __init__(self, ):
        super(SetTransformer, self).__init__()
        self.basic_encoder = STEncoderBasic(d=3, m=2, h=1)
        self.basic_decoder = STDecoderBasic(out_dim=1, d=3, m=2, h=1, k=1)
    
    def call(self, x):
        enc_output = self.basic_encoder(x)  # (batch_size, set_len, d_model)
        return self.basic_decoder(enc_output)

In [39]:
set_transformer = SetTransformer()
set_transformer.compile(loss='mean_squared_error', optimizer='sgd')

In [40]:
set_transformer.fit(X_train)

(1, 1, 3) (None, 9, 3)
<tf.Variable 'Variable:0' shape=(1, 1, 3) dtype=float32>


ValueError: in user code:

    <ipython-input-14-d053fa0e5893>:11 call  *
        return self.basic_decoder(enc_output)
    <ipython-input-10-fa37133e00cc>:58 call  *
        decoded_vec = self.SAB(self.PMA(x))
    <ipython-input-33-cf0c8ec2b6c1>:116 call  *
        s = tf.ones(shape=z.shape)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py:2956 ones  **
        shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:1341 convert_to_tensor
        ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/constant_op.py:334 _tensor_shape_tensor_conversion_function
        "Cannot convert a partially known TensorShape to a Tensor: %s" % s)

    ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 9, 3)
