In [None]:
from datasets import load_dataset

dataset_name = "batterydata/pos_tagging"
training_dataset = load_dataset(dataset_name, split="train")
test_dataset = load_dataset(dataset_name, split="test")


In [None]:
training_dataset


In [None]:
test_dataset


So for data preprocessing:
1. Make words to idx and labels to idx dictionaries
2. Make a validation split from the training set
3. Encode all the data with indices found from 1

In [None]:
# a dict containing word -> idx mapping
def create_word_indices(dataset):
    unique_words = set()
    word_to_idx = dict()
    # add an out of vocab token
    oov_token = "<OOV>"
    pad_token = "<PAD>"
    word_to_idx[oov_token] = 0
    word_to_idx[pad_token] = 1
    
    # find unique words
    for data in dataset:
        words = data["words"]
        for w in words:
            unique_words.add(w)
            
    # add index to them
    for idx, uw in enumerate(list(unique_words)):
        word_to_idx[uw] = idx + 2 # since oov is at 0 and pad at 1
        
    
    return word_to_idx


# ===============
word_to_idx = create_word_indices(training_dataset)
len(word_to_idx)


In [None]:
def create_label_to_idx(dataset):
    unique_labels = set()
    label_to_idx = dict()
    # add an out of vocab token
    oov_token = "<OOV>"
    pad_token = "<PAD>"
    label_to_idx[oov_token] = 0
    label_to_idx[pad_token] = 1
    
    # find the labels
    for data in dataset:
        labels = data["labels"]
        for l in labels:
            unique_labels.add(l)
            
    # index
    for idx, label in enumerate(list(unique_labels)):
        label_to_idx[label] = idx + 2
        
    return label_to_idx
    
label_to_idx = create_label_to_idx(training_dataset)
print(len(label_to_idx))
    

In [None]:
print(label_to_idx)


In [None]:
# for a single instance
def encode_data_instance(data, word_to_idx, label_to_idx):
    words = [
        word_to_idx.get(word, word_to_idx["<OOV>"]) for word in data["words"]
    ]
    
    labels = [
        label_to_idx[label] for label in data["labels"]
    ]
    
    return {
        "words": words,
        "labels": labels
    }
    

print(encode_data_instance(training_dataset[0], word_to_idx, label_to_idx))


In [None]:
trainset = map(lambda data: encode_data_instance(data, word_to_idx, label_to_idx), training_dataset)
trainset = list(trainset)

print(trainset[0])


In [None]:
testset = map(lambda data: encode_data_instance(
    data, word_to_idx, label_to_idx), test_dataset)
testset = list(testset)

print(testset[0])


In [None]:
assert len(training_dataset) == len(trainset)


In [None]:
# now to create the validation set
import numpy as np

def create_train_validation_splits(trainset, validation_ratio):
    validation_set_size = int(len(trainset) * validation_ratio)
    validation_indices = np.random.choice(len(trainset), replace=False, size=validation_set_size).tolist()
    
    # now to separate trainset indices
    trainset_indices = [i for i in range(len(trainset)) if i not in validation_indices]
    
    return trainset_indices, validation_indices


trainset_indices, validation_indices = create_train_validation_splits(trainset, 0.3)

print(len(trainset_indices))
print(len(validation_indices))


assert len(trainset_indices) + len(validation_indices) == len(trainset)


In [None]:
max_seq_len = np.max([len(d["words"]) for d in trainset])
max_seq_len


In [None]:
from torch.utils.data import Dataset
import jax_dataloader as jdl

class TagDataset(Dataset):
    def __init__(self, indices, dataset) -> None:
        self.indices = indices
        self.dataset = dataset
        
        
    def __len__(self):
        if self.indices is None:
            # this is for the test case
            return len(self.dataset)
        else:
            return len(self.indices)
        
    def __getitem__(self, index) -> np.ndarray:
        if self.indices is None:
            idx = index
        else:
            idx = self.indices[index]
            
        data = self.dataset[idx]
        
        # padding to 300
        # pad token idx is 1
        words = np.ones((300, ), dtype=np.int32)
        words[:len(data["words"])] = data["words"] 
    
        
        labels = np.ones((300, ), dtype=np.int32)
        labels[:len(data["labels"])] = data["labels"]
        
        # labels = np.array(data["labels"])
        
        return words, labels

train_loader = jdl.DataLoader(TagDataset(trainset_indices, trainset), "pytorch", batch_size=128, shuffle=True)
val_loader = jdl.DataLoader(TagDataset(validation_indices, trainset), "pytorch", batch_size=128, shuffle=False)
train_loader = jdl.DataLoader(TagDataset(None, testset), "pytorch", batch_size=128, shuffle=False)


# =========== test a dataloader ==========
for batch in train_loader:
    print(batch)
    break


In [None]:
from jax import random, jit, vmap, grad, value_and_grad
import jax.numpy as jnp
import flax.linen as nn


In [None]:
master_key = random.PRNGKey(seed=2023)
master_key, model_init_key = random.split(master_key)
master_key, dropout_key = random.split(master_key)


In [None]:
# lstm in flax: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.LSTMCell.html

from typing import Any


class LSTMTagger(nn.Module):
    vocab_size: int
    embedding_dimensions: int
    projection_dims: int # aka hidden dims for projection after lstm
    n_labels: int
    training = True
        
    @nn.compact
    def __call__(self, words) -> Any:        
        # ========== Embedding ==========
        x = nn.Embed(
            num_embeddings=self.vocab_size, features=self.embedding_dimensions, name="embedding")(words)
        x = nn.Dropout(0.2, deterministic=not self.training)(x)
        
        
        # ========= LSTM ============
        lstm = nn.OptimizedLSTMCell(features=self.projection_dims, name="lstm")
        carry = lstm.initialize_carry(random.PRNGKey(2024), x.shape)    
        carry, x = lstm(carry=carry, inputs=x)
        
        # ========== Dense ==========
        x = nn.Dense(features=self.n_labels, name="dense")(x)
        x = nn.leaky_relu(x)
        
        return x


model = LSTMTagger(len(word_to_idx), 300, 300, 300)

# why ?
# https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html
model_rngs = {"params": model_init_key, "dropout": dropout_key}

init_params = model.init(model_rngs, np.array(trainset[0]["words"]))
logits = model.apply(init_params, jnp.array(trainset[0]["words"]), rngs={"dropout": random.PRNGKey(99)})
print(logits.shape)


In [None]:
import optax


@jit
def calculate_loss(params, words, labels):
    logits = model.apply(params, words, rngs={"dropout": random.PRNGKey(90)})
    loss = optax.softmax_cross_entropy(logits, labels)
    return loss.mean(axis=-1)


# you're vmapping the whole grad function!
# have a separate batch loss function!
@jit
def batched_loss(params, words_batched, labels_batched):
    batch_loss = vmap(calculate_loss, in_axes=(None, 0, 0))(params, words_batched, labels_batched)
    return batch_loss.mean(axis=-1)


In [None]:
from tqdm.auto import trange
from flax.training import train_state
from functools import partial

optimiser = optax.sgd(learning_rate=0.01)
init_state = train_state.TrainState.create(
    apply_fn=model.apply, # the forward function
    params=init_params,
    tx=optimiser
)
criterion = value_and_grad(batched_loss)



@partial(jit, static_argnums=0)
def train_step(criterion, state, words_batched, labels_batched):
    loss_value, grads = criterion(state.params, words_batched, labels_batched)    
    updated_state = state.apply_gradients(grads=grads)
    return loss_value, updated_state


def train_model(state, train_loader, epochs=100, log_every_n_step=200):    
    step_counter = 0
    for _ in trange(epochs):
        for batch in train_loader:
            words, labels = batch
            loss, state = train_step(criterion, state, words, labels)
            
            
            step_counter += 1
            if step_counter % log_every_n_step == 0:
                print(loss)

    return state 


In [None]:
state = train_model(init_state, train_loader)
