<a href="https://colab.research.google.com/github/MaxGubin/video_encoders/blob/main/TextVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax

In [None]:
jax.__version__, jax.__version_info__

('0.5.2', (0, 5, 2))

In [None]:
# prompt: write in jax transformer encoder/decoder model

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state

class TransformerEncoder(nn.Module):
    num_layers: int
    d_model: int
    num_heads: int
    dff: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, train):
        # Positional Encoding
        x = x + positional_encoding(x.shape[1], self.d_model)
        # Embedding Dropout
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

        for _ in range(self.num_layers):
            x = EncoderLayer(d_model=self.d_model, num_heads=self.num_heads, dff=self.dff,
                             dropout_rate=self.dropout_rate)(x, train=train)
        return x

class EncoderLayer(nn.Module):
    d_model: int
    num_heads: int
    dff: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, train):
        attn_output = MultiHeadAttention(d_model=self.d_model, num_heads=self.num_heads, dropout_rate=self.dropout_rate)(x, x, x, train)
        x = nn.LayerNorm()(x + attn_output)  # Add & Norm
        ffn_output = point_wise_feed_forward_network(d_model=self.d_model, dff=self.dff)(x)
        x = nn.LayerNorm()(x + ffn_output) # Add & Norm
        return x

class MultiHeadAttention(nn.Module):
    d_model: int
    num_heads: int
    dropout_rate: float

    @nn.compact
    def __call__(self, v, k, q, train):
        depth = self.d_model // self.num_heads
        wq = nn.Dense(self.d_model)
        wk = nn.Dense(self.d_model)
        wv = nn.Dense(self.d_model)

        q = wq(q)
        k = wk(k)
        v = wv(v)

        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, depth, self.dropout_rate, train)

        scaled_attention = scaled_attention.reshape(scaled_attention.shape[0], -1, self.d_model)

        output = nn.Dense(self.d_model)(scaled_attention)
        return output

def scaled_dot_product_attention(q, k, v, depth, dropout_rate, train):
  matmul_qk = jnp.einsum('bqhd,bkhd->bhqk', q, k)
  dk = jnp.array(k.shape[-1], dtype=jnp.float32)
  scaled_attention_logits = matmul_qk / jnp.sqrt(dk)

  attention_weights = jax.nn.softmax(scaled_attention_logits, axis=-1)

  output = jnp.einsum('bhqk,bkhd->bqhd', attention_weights, v)

  output = nn.Dropout(rate=dropout_rate)(output, deterministic=not train)
  return output, attention_weights


def point_wise_feed_forward_network(d_model, dff):
    return nn.Sequential([
        nn.Dense(dff),
        nn.relu,
        nn.Dense(d_model)
    ])


def positional_encoding(position, d_model):
    angle_rads = get_angles(jnp.arange(position)[:, jnp.newaxis],
                          jnp.arange(d_model)[jnp.newaxis, :],
                          d_model)

    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = jnp.sin(angle_rads[:, 0::2])

    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = jnp.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[jnp.newaxis, ...]

    return jnp.array(pos_encoding)


def get_angles(pos, i, d_model):
    angle_rates = 1 / jnp.power(10000, (2 * (i//2)) / jnp.float32(d_model))
    return pos * angle_rates


In [7]:
!pip install -U datasets

Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, datasets
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
  Attempting uninstall: datasets
    Found existing installation: datasets 2.14.4
    Uninstalling datasets-2.14.4:
      Successfully uninstalled datasets-2.14.4
[31mERROR: pip's dependency r

In [1]:
import datasets

In [7]:
# prompt: load some texts from Wikipedia using datasets and create training and evaluation iterators


wiki = datasets.load_dataset('wikipedia', '20220301.en', download_mode="force_redownload", streaming=True)
# Split into training and evaluation sets
train_dataset = wiki['train'].shuffle(buffer_size=10_000)
#eval_dataset = wiki['test'].shuffle(buffer_size=1_000)

wikipedia.py:   0%|          | 0.00/36.7k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

In [8]:
next(iter(train_dataset))

{'id': '53087239',
 'url': 'https://en.wikipedia.org/wiki/Centipede%20%28Knife%20Party%20song%29',
 'title': 'Centipede (Knife Party song)',
 'text': '"Centipede" is a song from the Knife Party EP Rage Valley. Upon its release, it hit #47 on Billboard\'s Dance/Electronic Songs.\n\nBackground\nThe song sampled a segment from the Discovery Channel series The World\'s Most Feared Animals.\nThis sample was also used in Tarantula / Fasten Your Seatbelt.\n\nIn popular culture\nThe song was featured on the television series The Wrong Mans as well as the video game Guitar Hero Live.  The song is popular in the rhythm games "Beat Saber" and "osu!" During the 2016 United States presidential election, the song was associated with Donald Trump, especially its use in the video series "You Can\'t Stump the Trump". The terms "centipede" and "nimble navigator" were also used by Trump supporters on /r/The Donald subreddit.\n\nCharts\n\nReferences \n\n2012 songs\nSongs written by Rob Swire'}

In [None]:
# prompt: create a jax-compatible tokenizer that can tokenizer samples from the dataset

!pip install -q transformers

from transformers import AutoTokenizer

# Choose a pre-trained tokenizer. BERT's tokenizer is a common choice for many tasks.
# You could also consider tokenizers from other models like GPT-2 or RoBERTa depending on your needs.
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Example of using the tokenizer on a sample from the dataset
sample_text = next(iter(train_dataset))['text']
encoded_sample = tokenizer(sample_text, return_tensors="jax", padding="max_length", truncation=True, max_length=128) # Adjust max_length as needed

print("Sample Text:")
print(sample_text[:200] + "...") # Print first 200 characters
print("\nEncoded Sample (JAX Tensors):")
print(encoded_sample)

# You can access the tokenized input_ids and attention_mask as JAX arrays
input_ids = encoded_sample['input_ids']
attention_mask = encoded_sample['attention_mask']

print("\nInput IDs shape:", input_ids.shape)
print("Attention Mask shape:", attention_mask.shape)
```

In [5]:
import os
import shutil
cache_dir = os.path.expanduser("~/.cache/huggingface/datasets")
if os.path.exists(cache_dir):
    print(f"Removing cache directory: {cache_dir}")
    shutil.rmtree(cache_dir)



Removing cache directory: /root/.cache/huggingface/datasets
