In [None]:
#@title Imports
from datasets import load_dataset
from nltk.tokenize import word_tokenize
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import chex
import copy
import dataclasses
import functools
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
import pandas as pd
import re
import time
from toylib import nn

In [None]:
nltk.download('punkt', download_dir='~/data/nltk_data')
nltk.download('punkt_tab', download_dir='~/data/nltk_data')

nltk.data.path.append('~/data/nltk_data')

In [None]:
#@title General hyperparms

@dataclasses.dataclass
class Config:
    batch_size: int = 128
    embedding_dim: int = 100
    max_tokens: int = 160

    embeddings_path: str = 'glove.6B.100d.txt'

    # training loop hyperparms
    num_epochs: int = 5
    learning_rate: float = 1e-2

config = Config()

In [None]:
#@title Dataset Loaders
def preprocess_text(text: str, stopwords = None):
    # Convert to lowercase
    text = text.lower()
    
    # Remove special characters and replace with space
    text = re.sub(r'[^a-zA-Z0-9\s]', ' ', text)
    
    # Tokenize
    tokens = word_tokenize(text)

    # Remove stopwords if provided
    if stopwords:
        tokens = [word for word in tokens if word not in stopwords]

    return tokens

def embed_tokens(tokens, glove, embedding_dim):
    vectors = []
    emb_missing = []
    for token in tokens:
        if token in glove:
            vectors.append(glove[token])
            emb_missing.append(False)
        else:
            vectors.append(np.zeros(embedding_dim))
            emb_missing.append(True)
            print(f"Token not found in glove: {token}")
    return np.stack(vectors), np.array(emb_missing)


class EmbeddedTextDataset(Dataset):
    def __init__(self, hf_dataset, glove, embedding_dim, max_tokens=50):
        self.data = hf_dataset
        self.glove = glove
        self.embedding_dim = embedding_dim
        self.max_tokens = max_tokens

    def __len__(self):
        return len(self.data)

    @property
    def stopwords(self) -> list[str]:
        stopwords = [ "a", "about", "above", "after", "again", "against", "all", "am", "an", "and", "any", "are", "as", "at", "be", "because", "been", "before", "being", "below", "between", "both", "but", "by", "could", "did", "do", "does", "doing", "down", "during", "each", "few", "for", "from", "further", "had", "has", "have", "having", "he", "he'd", "he'll", "he's", "her", "here", "here's", "hers", "herself", "him", "himself", "his", "how", "how's", "i", "i'd", "i'll", "i'm", "i've", "if", "in", "into", "is", "it", "it's", "its", "itself", "let's", "me", "more", "most", "my", "myself", "nor", "of", "on", "once", "only", "or", "other", "ought", "our", "ours", "ourselves", "out", "over", "own", "same", "she", "she'd", "she'll", "she's", "should", "so", "some", "such", "than", "that", "that's", "the", "their", "theirs", "them", "themselves", "then", "there", "there's", "these", "they", "they'd", "they'll", "they're", "they've", "this", "those", "through", "to", "too", "under", "until", "up", "very", "was", "we", "we'd", "we'll", "we're", "we've", "were", "what", "what's", "when", "when's", "where", "where's", "which", "while", "who", "who's", "whom", "why", "why's", "with", "would", "you", "you'd", "you'll", "you're", "you've", "your", "yours", "yourself", "yourselves" ]

        # specific stopwords
        specific_sw = ['br', 'movie', 'film']

        # all stopwords
        stopwords = stopwords + specific_sw
        return stopwords

    def __getitem__(self, idx):
        item = self.data[idx]
        tokens = preprocess_text(item["text"])[:self.max_tokens]
        emb, emb_missing = embed_tokens(tokens, self.glove, self.embedding_dim)

        # Pad or truncate
        num_pad = 0
        if emb.shape[0] < self.max_tokens:
            num_pad = self.max_tokens - emb.shape[0]

            # pad embeddings
            pad = np.zeros((num_pad, self.embedding_dim))
            emb = np.vstack([emb, pad])
            # pad mask
            emb_missing = np.concatenate([emb_missing, np.ones(num_pad).astype(np.bool)])
        else:
            emb = emb[:self.max_tokens]
            emb_missing = emb_missing[:self.max_tokens]

        return {
            "embedding": emb.astype(np.float32),
            "embedding_missing": emb_missing.astype(np.bool),
            "num_tokens": len(tokens),
            "raw_text": item["text"],
            "label": item["label"],
            "num_pad": num_pad,
        }

def numpy_collate(batch):
    return {
        "embedding": np.stack([item["embedding"] for item in batch]),   # (B, T, D),
        "embedding_missing": np.stack([item["embedding_missing"] for item in batch]),   # (B, T),
        "label": np.array([item["label"] for item in batch], dtype=np.float32),   # (B,)
        "num_pad": np.array([item["num_pad"] for item in batch]),   # (B,)
        "num_tokens": np.array([item["num_tokens"] for item in batch]),   # (B,)
        "raw_text": np.array([item["raw_text"] for item in batch]),   # (B,)
    }

In [None]:
#@title Embeddings loader
def load_glove_embeddings(filepath):
    embeddings = {}
    with open(filepath, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            word = parts[0]
            vector = np.array(parts[1:], dtype=np.float32)
            embeddings[word] = vector
    return embeddings

glove = load_glove_embeddings(config.embeddings_path)

In [None]:
#@title Set up dataloaders: use IMDB movie sentiment reviews from HuggingFace datasets
imdb_dataset = load_dataset("imdb")

# Create a validation set from the train set
train_val = imdb_dataset["train"].train_test_split(test_size=0.05, seed=42)

dataset_fn = functools.partial(EmbeddedTextDataset, glove=glove, embedding_dim=config.embedding_dim, max_tokens=config.max_tokens)

train_dataset = dataset_fn(train_val['train'])
val_dataset = dataset_fn(train_val['test'])
test_dataset = dataset_fn(['test'])
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=numpy_collate)
val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=numpy_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=numpy_collate)

In [None]:
print(f'Training batches: {len(train_dataset) / config.batch_size}')
print(f'Val batches: {len(val_dataset) / config.batch_size}')
print(f'Test batches: {len(test_dataset) / config.batch_size}')

In [None]:
#@title Visualize a batch of data
def visualize_samples(batch) -> pd.DataFrame:
    missing = batch['embedding_missing'].sum(axis=1) - batch['num_pad']
    return pd.DataFrame({
        'labels': batch['label'],
        'text': batch['raw_text'],
        'missing': missing,
        'num_tokens': batch['num_tokens'],
    })

batch = next(iter(train_dataloader))
visualize_samples(batch)

In [None]:
preprocess_text(batch['raw_text'][0])

In [None]:
#@title Define a model
@jax.tree_util.register_pytree_node_class
class BagOfWordsClassifier(nn.module.Module):
    hidden_sizes: list[int]
    output_dim: int = 1
    input_dim: int = 100

    def __init__(self, hidden_sizes: list[int], input_dim: int, output_dim: int, *, key: jax.random.PRNGKey) -> None:
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_sizes = hidden_sizes

        self.layers = []
        in_features = input_dim
        for hidden_size in hidden_sizes:
            key_used, key = jax.random.split(key, 2)
            self.layers.append(nn.layers.Linear(in_features, hidden_size, key=key_used))
            in_features = hidden_size

        self.output_layer = nn.layers.Linear(in_features, output_dim, key=key)

    def __call__(self, x: np.ndarray, mask: np.ndarray) -> ...:
        chex.assert_rank(x, 2)  # [num_tokens, embed_dim] - batch should be vmapped
        chex.assert_shape(mask, (x.shape[0],))

        embed_dim = x.shape[-1]
        print(x.shape)

        # average the token embeddings to form a "sentence" embedding
        x = jnp.sum(x * mask[:, np.newaxis], axis=0) / jnp.sum(mask, axis=0)
        chex.assert_shape(x, (embed_dim,))  # [embed_dim]

        for layer in self.layers:
            x = layer(x)
            x = jax.nn.relu(x)
        x = self.output_layer(x)
        return x

In [None]:
def loss_fn(model, batch):
    x, y, mask = batch["embedding"], batch["label"], ~batch["embedding_missing"]
    logits = jax.vmap(model)(x, mask)

    # Binary cross entropy from logits
    log_probs = jax.nn.log_sigmoid(logits)
    log_1minus_probs = jax.nn.log_sigmoid(-logits)  # log(1-sigmoid(x)) = log_sigmoid(-x)

    loss = -(y * log_probs + (1 - y) * log_1minus_probs)

    return jnp.mean(loss)

In [None]:
bow_model = BagOfWordsClassifier([256, 256], 100, 1, key=jax.random.PRNGKey(1))

# Optimizer
optimizer = optax.adam(learning_rate=config.learning_rate)
opt_state = optimizer.init(bow_model)

# Value and gradient
loss_and_grad_fn = jax.value_and_grad(loss_fn)

# TensorBoard writer
writer = SummaryWriter(logdir="./runs/" + time.strftime("%Y%m%d-%H%M%S"))

# Training loop
step = 0
# for epoch in range(config.num_epochs):
orig_model = copy.deepcopy(bow_model)
for epoch in range(200):
    for  batch in train_dataloader:
        
        loss_val, grads = loss_and_grad_fn(bow_model, batch)
        
        # Apply gradients
        updates, opt_state = optimizer.update(grads, opt_state)
        leaves, _ = jax.tree_util.tree_flatten(updates)
        bow_model = optax.apply_updates(bow_model, updates)

        # Log to TensorBoard
        writer.add_scalar("train/loss", float(loss_val), step)
        writer.add_scalar("train/learning_rate", config.learning_rate, step)
        writer.add_scalar("gradients/0/mean", leaves[0].mean(), step)
        writer.add_scalar("gradients/1/mean", leaves[1].mean(), step)
        writer.add_scalar("gradients/2/mean", leaves[2].mean(), step)

        num_missing = np.mean(batch['embedding_missing'].sum(axis=1) - batch['num_pad'])
        writer.add_scalar("data/padding", batch['num_pad'].mean(), step)
        writer.add_scalar("data/num_missing", num_missing, step)
        writer.add_scalar("label/mean", batch['label'].mean(), step)


        # Increment step
        step += 1
        break

    writer.flush()
writer.close()


In [None]:
!ls -lht ./runs/

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs/