In [2]:
from datasets import load_dataset

dataset_name = "batterydata/pos_tagging"
training_dataset = load_dataset(dataset_name, split="train")
test_dataset = load_dataset(dataset_name, split="test")


In [4]:
training_dataset


Dataset({
    features: ['words', 'labels'],
    num_rows: 13054
})

In [5]:
test_dataset


Dataset({
    features: ['words', 'labels'],
    num_rows: 1451
})

So for data preprocessing:
1. Make words to idx and labels to idx dictionaries
2. Make a validation split from the training set
3. Encode all the data with indices found from 1

In [9]:
# a dict containing word -> idx mapping
def create_word_indices(dataset):
    unique_words = set()
    word_to_idx = dict()
    # add an out of vocab token
    oov_token = "<OOV>"
    word_to_idx[oov_token] = 0
    
    # find unique words
    for data in dataset:
        words = data["words"]
        for w in words:
            unique_words.add(w)
            
    # add index to them
    for idx, uw in enumerate(list(unique_words)):
        word_to_idx[uw] = idx + 1 # since oov is at 0
        
    
    return word_to_idx


# ===============
word_to_idx = create_word_indices(training_dataset)
len(word_to_idx)


24848

In [10]:
def create_label_to_idx(dataset):
    unique_labels = set()
    label_to_idx = dict()
    # add an out of vocab token
    oov_token = "<OOV>"
    label_to_idx[oov_token] = 0
    
    # find the labels
    for data in dataset:
        labels = data["labels"]
        for l in labels:
            unique_labels.add(l)
            
    # index
    for idx, label in enumerate(list(unique_labels)):
        label_to_idx[label] = idx + 1
        
    return label_to_idx
    
label_to_idx = create_label_to_idx(training_dataset)
print(len(label_to_idx))
    

49


In [21]:
print(label_to_idx)


{'<OOV>': 0, 'VBZ': 1, "''": 2, 'TO': 3, 'FW': 4, '-NONE-': 5, 'POS': 6, 'RBS': 7, 'UH': 8, 'LS': 9, 'RBR': 10, 'WP': 11, 'RP': 12, 'RB': 13, 'VB': 14, 'PRP$': 15, 'WRB': 16, 'CC': 17, 'WP$': 18, 'DT': 19, 'NN': 20, ')': 21, ',': 22, '``': 23, '$': 24, 'VBD': 25, 'VBG': 26, 'PRP': 27, 'SYM': 28, 'VBP': 29, '-RRB-': 30, 'IN': 31, '.': 32, 'VBN': 33, 'MD': 34, 'WDT': 35, '(': 36, 'JJS': 37, 'NNS': 38, '#': 39, 'JJR': 40, 'JJ': 41, ':': 42, 'CD': 43, 'NNP': 44, '-LRB-': 45, 'PDT': 46, 'NNPS': 47, 'EX': 48}


In [20]:
# for a single instance
def encode_data_instance(data, word_to_idx, label_to_idx):
    words = [
        word_to_idx[word] for word in data["words"]
    ]
    
    labels = [
        label_to_idx[label] for label in data["labels"]
    ]
    
    return {
        "words": words,
        "labels": labels
    }
    

print(encode_data_instance(training_dataset[0], word_to_idx, label_to_idx))


{'words': [972, 22062, 9313, 24281, 23983, 8078, 14643, 3389, 5765, 2188, 15739, 14561, 359, 3839, 18638, 24336, 19654, 222, 16838, 24336, 8019, 5675, 222, 3756, 3389, 17664, 6815, 14065, 2387, 13151, 23047, 10021, 22512, 19, 8031, 22436, 5927], 'labels': [20, 31, 19, 20, 1, 13, 33, 3, 14, 19, 41, 20, 31, 20, 38, 31, 44, 22, 41, 31, 20, 20, 22, 14, 3, 14, 19, 41, 20, 31, 44, 17, 44, 6, 41, 38, 32]}


In [31]:
trainset = map(lambda data: encode_data_instance(data, word_to_idx, label_to_idx), training_dataset)
trainset = list(trainset)

print(trainset[0])


{'words': [972, 22062, 9313, 24281, 23983, 8078, 14643, 3389, 5765, 2188, 15739, 14561, 359, 3839, 18638, 24336, 19654, 222, 16838, 24336, 8019, 5675, 222, 3756, 3389, 17664, 6815, 14065, 2387, 13151, 23047, 10021, 22512, 19, 8031, 22436, 5927], 'labels': [20, 31, 19, 20, 1, 13, 33, 3, 14, 19, 41, 20, 31, 20, 38, 31, 44, 22, 41, 31, 20, 20, 22, 14, 3, 14, 19, 41, 20, 31, 44, 17, 44, 6, 41, 38, 32]}


In [30]:
testset = map(lambda data: encode_data_instance(
    data, word_to_idx, label_to_idx), test_dataset)
testset = list(testset)

print(testset[0])


{'words': [972, 22062, 9313, 24281, 23983, 8078, 14643, 3389, 5765, 2188, 15739, 14561, 359, 3839, 18638, 24336, 19654, 222, 16838, 24336, 8019, 5675, 222, 3756, 3389, 17664, 6815, 14065, 2387, 13151, 23047, 10021, 22512, 19, 8031, 22436, 5927], 'labels': [20, 31, 19, 20, 1, 13, 33, 3, 14, 19, 41, 20, 31, 20, 38, 31, 44, 22, 41, 31, 20, 20, 22, 14, 3, 14, 19, 41, 20, 31, 44, 17, 44, 6, 41, 38, 32]}


In [32]:
assert len(training_dataset) == len(trainset)


In [41]:
# now to create the validation set
import numpy as np

def create_train_validation_splits(trainset, validation_ratio):
    validation_set_size = int(len(trainset) * validation_ratio)
    validation_indices = np.random.choice(len(trainset), replace=False, size=validation_set_size).tolist()
    
    # now to separate trainset indices
    trainset_indices = [i for i in range(len(trainset)) if i not in validation_indices]
    
    return trainset_indices, validation_indices


trainset_indices, validation_indices = create_train_validation_splits(trainset, 0.3)

print(len(trainset_indices))
print(len(validation_indices))


assert len(trainset_indices) + len(validation_indices) == len(trainset)


9138
3916


In [51]:
from jax import random, jit, vmap, grad
import jax.numpy as jnp
import flax.linen as nn


master_key = random.PRNGKey(seed=2023)
master_key, model_init_key = random.split(master_key)


In [56]:
# lstm in flax: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.LSTMCell.html

from typing import Any


class LSTMTagger(nn.Module):
    vocab_size: int
    embedding_dimensions: int
    projection_dims: int # aka hidden dims for projection after lstm
    n_labels: int
    
    def setup(self) -> None:
        self.embedding = nn.Embed(
            num_embeddings=self.vocab_size, features=self.embedding_dimensions)
        
        self.lstm = nn.LSTMCell(self.projection_dims)
        self.dense = nn.Dense(features=self.n_labels)
        
    
    def __call__(self, words) -> Any:
        x = (words)
        x = nn.LSTMCell(features=self.projection_dims)(x)
        x = nn.Dense(features=self.n_labels)(x)
        x = nn.leaky_relu(x)
        x = nn.log_softmax(x)
        
        return x


model = LSTMTagger(len(word_to_idx), 150, 100, len(label_to_idx))

init_params = model.init(model_init_key, np.array(trainset[0]["words"]))


TypeError: LSTMCell.__init__() got an unexpected keyword argument 'features'

In [55]:
init_params


FrozenDict({
    params: {
        Embed_0: {
            embedding: Array([[-0.05916596,  0.01026116,  0.08525214, ..., -0.03253829,
                     0.01560668, -0.0402378 ],
                   [ 0.08890485, -0.09597716,  0.01476798, ...,  0.26407877,
                     0.10020454,  0.06108525],
                   [ 0.01983153, -0.04647755, -0.03345514, ...,  0.05045682,
                     0.06162264,  0.03780266],
                   ...,
                   [ 0.00149957, -0.13568471, -0.03070986, ..., -0.1231816 ,
                    -0.07690515,  0.0227146 ],
                   [-0.01137438, -0.09372773,  0.03765107, ...,  0.04488415,
                     0.09964965, -0.01214164],
                   [ 0.0132378 , -0.09923965, -0.0267396 , ...,  0.06017342,
                    -0.07800906,  0.08927516]], dtype=float32),
        },
        Dense_0: {
            kernel: Array([[-0.03942307, -0.12093627, -0.07201098, ..., -0.1012962 ,
                     0.02250638,  0.0582009