In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#@title Imports
from tensorboardX import SummaryWriter
import chex
import copy
import dataclasses
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd
import time

from toylib import nn
from toylib.data import imdb

In [None]:
#@title General hyperparms

@dataclasses.dataclass
class Config:
    batch_size: int = 128
    embedding_dim: int = 100
    max_tokens: int = 160

    embeddings_path: str = 'glove.6B.100d.txt'

    # training loop hyperparms
    num_epochs: int = 5
    learning_rate: float = 1e-2

    # <CLS> token to be used for classification
    # We use 'unk' because we don't have a <CLS> token in the vocab and we're not using the 'unk' token otherwise
    cls_token: str = 'unk'

config = Config()

In [None]:
import nltk
in_ = 10
out_ = (8, 12)
x = np.zeros((4, in_))

nltk.download('punkt', download_dir='~/data/nltk_data')
nltk.download('punkt_tab', download_dir='~/data/nltk_data')

nltk.data.path.append('~/data/nltk_data')

In [None]:
# Load embeddings and dataset
glove = imdb.load_glove_embeddings(config.embeddings_path)

In [None]:
#@title Set up dataloaders: use IMDB movie sentiment reviews from HuggingFace datasets
(
    train_dataset,
    val_dataset,
    test_dataset,
    train_dataloader,
    val_dataloader,
    test_dataloader,
) = imdb.load_dataset(
    glove=glove,
    batch_size=config.batch_size,
    embedding_dim=config.embedding_dim,
    max_tokens=config.max_tokens,
    cls_token=config.cls_token,
)
print(f'Training batches: {len(train_dataset) / config.batch_size}')
print(f'Val batches: {len(val_dataset) / config.batch_size}')
print(f'Test batches: {len(test_dataset) / config.batch_size}')

In [None]:
#@title Visualize a batch of data
def visualize_samples(batch) -> pd.DataFrame:
    missing = batch['embedding_missing'].sum(axis=1) - batch['num_pad']
    return pd.DataFrame({
        'labels': batch['label'],
        'text': batch['raw_text'],
        'missing': missing,
        'num_tokens': batch['num_tokens'],
    })

batch = next(iter(train_dataloader))
visualize_samples(batch)


For the exact formulation of attention, we  ()
<div style="text-align: center;">
    <img src="../images/04.scaled-dot-product.png" alt="Scaled Dot Product Attention">
</div>

$$\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

We implement a model inspired by the [BERT](https://arxiv.org/abs/1810.04805)-style self-attention based model architecture. Instead of using a learned tokenizer, as is common in all LLMs now, we continue to rely on the pre-trained and fixed Glove word embeddings. Further, we use a much smaller model and do not perform any self-supervised pre-training on a large text corpus.

All these changes greatly reduce the scope of our initial approach, while still resulting in a relatively performant model.

Similar to BERT, we allocate add a special token at the beginning of each example. Here, we use the `unk` token because it exists in the Glove vocabulary and we are not utilizing it in our present model. This differs from BERT, which uses the `<CLS>` token, because the lack of pre-training prevents the model from understanding the significance of this token.

Our overall strategy is as follows:
1. Add the `unk` token at the beginning of each example
1. Use a few multi-head self attention layers to fuse/mix the input token embeddings
1. Use the final layer embedding corresponding to the `unk` token as the final sentence representation
1. Apply the output projection on the `unk` token embedding





For each self attention layer, we use the structure from the [Attention Is All You Need](https://arxiv.org/pdf/1706.03762) paper:

<div style="text-align: center;">
    <img src="../images/04.mha.png" alt="Multi Headed Attention Block">
</div>

This paper used `h = 8` parallel heads in each attention block / layer. For each of these we use
dk = dv = dmodel/h = 64. Due to the reduced dimension of each head, the total computational cost
is similar to that of single-head attention with full dimensionality.



In [None]:
from typing import Optional

In [None]:
#@title Define a model
@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass
class SelfAttentionClassifier(nn.module.Module):
    key: jax.random.PRNGKey

    input_dim: int = 100
    output_dim: int = 1

    # Attention layers
    num_layers: int = 2
    qkv_dim: int = 256
    num_heads: int = 4

    def __post_init__(self) -> None:
        # Generate keys
        keys = jax.random.split(self.key, self.num_layers+2)

        # Input projection
        self.input_projection = nn.layers.Linear(in_features=self.input_dim, out_features=self.qkv_dim, key=keys[0])

        # Self-attention layers
        self.layers = []
        for ix in range(self.num_layers):
            self.layers.append(nn.attention.MultiHeadAttention(num_heads=self.num_heads, qkv_dim=self.qkv_dim, key=keys[ix+1]))

        # Output projection
        self.output_layer = nn.layers.Linear(in_features=self.qkv_dim, out_features=self.output_dim, key=keys[-1])

    def __call__(self, x: jnp.ndarray, mask: Optional[jnp.ndarray] = None) -> ...:
        chex.assert_rank(x, 2)  # [num_tokens, embed_dim] - batch should be vmapped
        if mask is not None:
            chex.assert_shape(mask, (x.shape[0],))

        print('Input', x.shape)

        # Input projection to project the embeddings to the model dimension
        x = self.input_projection(x)
        print('input projection', x.shape)

        for layer in self.layers:
            x, _ = layer(Q=x, K=x, V=x, mask=mask)
        print('encoder output', x.shape)

        # Apply the output projection on only the first token which corresponds to the <CLS> token
        x = self.output_layer(x[0])
        print('output projection', x.shape)
        return x

In [None]:
def loss_fn(model, batch):
    x, y, mask = batch["embedding"], batch["label"], ~batch["embedding_missing"]
    logits = jax.vmap(model)(x, mask)

    # Binary cross entropy from logits
    log_probs = jax.nn.log_sigmoid(logits)
    log_1minus_probs = jax.nn.log_sigmoid(-logits)  # log(1-sigmoid(x)) = log_sigmoid(-x)

    loss = -(y * log_probs + (1 - y) * log_1minus_probs)

    return jnp.mean(loss)

In [None]:
%load_ext tensorboard
%tensorboard --logdir 4-attention/

In [None]:
# Model
model = SelfAttentionClassifier(
    input_dim=config.embedding_dim,
    output_dim=1,
    num_layers=2,
    num_heads=4,
    qkv_dim=256,
    key=jax.random.PRNGKey(10)
)

# Optimizer
optimizer = optax.adam(learning_rate=config.learning_rate)
opt_state = optimizer.init(model)

# Value and gradient
loss_and_grad_fn = jax.value_and_grad(loss_fn)

# TensorBoard writer
writer = SummaryWriter(logdir="./4-attention/" + time.strftime("%Y%m%d-%H%M%S"))

# Training loop
step = 0
# for epoch in range(config.num_epochs):
orig_model = copy.deepcopy(model)
for epoch in range(200):
    for  batch in train_dataloader:
        
        loss_val, grads = loss_and_grad_fn(model, batch)
        
        # Apply gradients
        updates, opt_state = optimizer.update(grads, opt_state)
        leaves, _ = jax.tree_util.tree_flatten(updates)
        model = optax.apply_updates(model, updates)

        # Log to TensorBoard
        writer.add_scalar("train/loss", float(loss_val), step)
        writer.add_scalar("train/learning_rate", config.learning_rate, step)
        writer.add_scalar("gradients/0/mean", leaves[0].mean(), step)
        writer.add_scalar("gradients/1/mean", leaves[1].mean(), step)
        writer.add_scalar("gradients/2/mean", leaves[2].mean(), step)

        num_missing = np.mean(batch['embedding_missing'].sum(axis=1) - batch['num_pad'])
        writer.add_scalar("data/padding", batch['num_pad'].mean(), step)
        writer.add_scalar("data/num_missing", num_missing, step)
        writer.add_scalar("label/mean", batch['label'].mean(), step)


        # Increment step
        step += 1
        break

    writer.flush()
writer.close()


In [None]:
!ls -lht ./runs/