<a href="https://colab.research.google.com/github/MaxGubin/video_encoders/blob/main/TextVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax

In [None]:
jax.__version__, jax.__version_info__

('0.5.2', (0, 5, 2))

In [None]:
# prompt: write in jax transformer encoder/decoder model

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state

class TransformerEncoder(nn.Module):
    num_layers: int
    d_model: int
    num_heads: int
    dff: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, train):
        # Positional Encoding
        x = x + positional_encoding(x.shape[1], self.d_model)
        # Embedding Dropout
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

        for _ in range(self.num_layers):
            x = EncoderLayer(d_model=self.d_model, num_heads=self.num_heads, dff=self.dff,
                             dropout_rate=self.dropout_rate)(x, train=train)
        return x

class EncoderLayer(nn.Module):
    d_model: int
    num_heads: int
    dff: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, train):
        attn_output = MultiHeadAttention(d_model=self.d_model, num_heads=self.num_heads, dropout_rate=self.dropout_rate)(x, x, x, train)
        x = nn.LayerNorm()(x + attn_output)  # Add & Norm
        ffn_output = point_wise_feed_forward_network(d_model=self.d_model, dff=self.dff)(x)
        x = nn.LayerNorm()(x + ffn_output) # Add & Norm
        return x

class MultiHeadAttention(nn.Module):
    d_model: int
    num_heads: int
    dropout_rate: float

    @nn.compact
    def __call__(self, v, k, q, train):
        depth = self.d_model // self.num_heads
        wq = nn.Dense(self.d_model)
        wk = nn.Dense(self.d_model)
        wv = nn.Dense(self.d_model)

        q = wq(q)
        k = wk(k)
        v = wv(v)

        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, depth, self.dropout_rate, train)

        scaled_attention = scaled_attention.reshape(scaled_attention.shape[0], -1, self.d_model)

        output = nn.Dense(self.d_model)(scaled_attention)
        return output

def scaled_dot_product_attention(q, k, v, depth, dropout_rate, train):
  matmul_qk = jnp.einsum('bqhd,bkhd->bhqk', q, k)
  dk = jnp.array(k.shape[-1], dtype=jnp.float32)
  scaled_attention_logits = matmul_qk / jnp.sqrt(dk)

  attention_weights = jax.nn.softmax(scaled_attention_logits, axis=-1)

  output = jnp.einsum('bhqk,bkhd->bqhd', attention_weights, v)

  output = nn.Dropout(rate=dropout_rate)(output, deterministic=not train)
  return output, attention_weights


def point_wise_feed_forward_network(d_model, dff):
    return nn.Sequential([
        nn.Dense(dff),
        nn.relu,
        nn.Dense(d_model)
    ])


def positional_encoding(position, d_model):
    angle_rads = get_angles(jnp.arange(position)[:, jnp.newaxis],
                          jnp.arange(d_model)[jnp.newaxis, :],
                          d_model)

    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = jnp.sin(angle_rads[:, 0::2])

    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = jnp.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[jnp.newaxis, ...]

    return jnp.array(pos_encoding)


def get_angles(pos, i, d_model):
    angle_rates = 1 / jnp.power(10000, (2 * (i//2)) / jnp.float32(d_model))
    return pos * angle_rates


In [None]:
import tensorflow_datasets as tfds