In [19]:
import numpy as np
import tensorflow as tf

In [20]:
class PreNorm(tf.Module):
    def __init__(self, dim, fn):
        super(SimpleModel, self).__init__()
        self.norm = tf.keras.layers.LayerNormalization(axis=1, epsilon=1e-5)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [21]:
class FeedForward(tf.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super(SimpleModel, self).__init__()
        self.net = tf.keras.Sequential([
            tf.keras.layers.Dense(hidden_dim),
            tf.keras.layers.Activation(tf.nn.gelu),
            tf.keras.layers.GaussianNoise(dropout),
            tf.keras.layers.Dense(dim),
            tf.keras.layers.GaussianNoise(dropout)
        ])

    def forward(self, x):
        return self.net(x)

In [22]:
class Attention(tf.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super(SimpleModel, self).__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = tf.nn.softmax(dim=-1)
        self.to_qkv = tf.keras.layers.Dense(dim, inner_dim * 3, use_bias=False)

        if project_out:
            self.to_out = tf.keras.Sequential([
                tf.keras.layers.Dense(inner_dim, dim),
                tf.keras.layers.GaussianNoise(dropout)
            ])
        else:
            self.to_out = tf.identity

    def forward(self, x):
        b, n, _ = x.shape
        h = self.heads

        qkv = tf.split(self.to_qkv(x), num_or_size_splits=3, axis=-1)
        q, k, v = map(lambda t: tf.reshape(t, (b, n, h, -1)), qkv)

        dots = tf.einsum('bhid, bhjd->bhij', q, k) * self.scale
        attn = self.attend(dots)

        out = tf.einsum('bhij, bhjd->bhid', attn, v)
        out = tf.reshape(out, (b, n, -1))
        
        return self.to_out(out)

In [23]:
class Transformer(tf.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super(SimpleModel, self).__init__()
        self.layers = []
        for _ in range(depth):
            self.layers.append([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ])

    def forward(self, x):
        for attn, ff in self.layers:
            x = tf.add(attn(x), x)
            x = tf.add(ff(x), x)
        return x

In [24]:
class PositionalEncoding(tf.Module):
    def __init__(self, d_model, max_len=500):
        super(PositionalEncoding, self).__init__()

        pe = tf.zeros((max_len, d_model), dtype=tf.float32)
        position = tf.range(0, max_len, dtype=tf.float32)[:, tf.newaxis]
        div_term = tf.exp(tf.range(0, d_model, 2, dtype=tf.float32) * (-np.log(10000.0) / d_model))
        pe = tf.cast(pe, dtype=tf.float32)
        pe = tf.cast(pe, dtype=tf.float32)
        pe = pe + tf.sin(position * div_term)
        pe = pe + tf.cos(position * div_term)
        pe = tf.expand_dims(pe, axis=0)
        self.pe = tf.Variable(pe)

    def forward(self, x):
        x = x + self.pe[:, :x.shape[1], :]
        return x

In [25]:
class ViT(tf.Module):
    def __init__(self, *, input_dim=320, output_dim=512, dim=1024, depth=6, heads=16, mlp_dim=2048, pool='cls', dim_head=64, dropout=0.1, emb_dropout=0.1):
        super(SimpleModel, self).__init__()
        self.project = tf.keras.layers.Dense(input_dim,dim)
        self.pos_encoder = PositionalEncoding(dim)
        self.cls_token = tf.Variable(tf.random.normal((1, 1, dim), dtype=tf.float32))
        self.dropout = tf.keras.layers.Dropout(emb_dropout) 
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool
        self.to_latent = tf.identity
        self.mlp_head = tf.keras.Sequential([
            tf.keras.layers.LayerNormalization(epsilon=1e-5),
            tf.keras.layers.Dense(output_dim)
        ])
        self.tanh = tf.keras.activations.tanh
    def forward(self, x):
        x = self.project(x)
        b, n, _ = x.shape
        x = self.pos_encoder(x)
        x = self.dropout(x)
        x = self.transformer(x)
        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.tanh(self.mlp_head(x))