In [1]:
from datasets import load_dataset

dataset_name = "batterydata/pos_tagging"
training_dataset = load_dataset(dataset_name, split="train")
test_dataset = load_dataset(dataset_name, split="test")


In [2]:
training_dataset


Dataset({
    features: ['words', 'labels'],
    num_rows: 13054
})

In [3]:
test_dataset


Dataset({
    features: ['words', 'labels'],
    num_rows: 1451
})

So for data preprocessing:
1. Make words to idx and labels to idx dictionaries
2. Make a validation split from the training set
3. Encode all the data with indices found from 1

In [4]:
# a dict containing word -> idx mapping
def create_word_indices(dataset):
    unique_words = set()
    word_to_idx = dict()
    # add an out of vocab token
    oov_token = "<OOV>"
    word_to_idx[oov_token] = 0
    
    # find unique words
    for data in dataset:
        words = data["words"]
        for w in words:
            unique_words.add(w)
            
    # add index to them
    for idx, uw in enumerate(list(unique_words)):
        word_to_idx[uw] = idx + 1 # since oov is at 0
        
    
    return word_to_idx


# ===============
word_to_idx = create_word_indices(training_dataset)
len(word_to_idx)


24848

In [5]:
def create_label_to_idx(dataset):
    unique_labels = set()
    label_to_idx = dict()
    # add an out of vocab token
    oov_token = "<OOV>"
    label_to_idx[oov_token] = 0
    
    # find the labels
    for data in dataset:
        labels = data["labels"]
        for l in labels:
            unique_labels.add(l)
            
    # index
    for idx, label in enumerate(list(unique_labels)):
        label_to_idx[label] = idx + 1
        
    return label_to_idx
    
label_to_idx = create_label_to_idx(training_dataset)
print(len(label_to_idx))
    

49


In [6]:
print(label_to_idx)


{'<OOV>': 0, '-RRB-': 1, 'CD': 2, '(': 3, 'NN': 4, 'VBG': 5, 'VBZ': 6, 'RBS': 7, 'PRP': 8, 'EX': 9, 'LS': 10, 'JJ': 11, 'PRP$': 12, 'WRB': 13, 'RP': 14, 'DT': 15, 'WP': 16, '-NONE-': 17, 'RB': 18, '-LRB-': 19, 'NNP': 20, 'JJR': 21, ':': 22, ')': 23, 'MD': 24, 'WDT': 25, '#': 26, ',': 27, 'TO': 28, 'UH': 29, 'FW': 30, 'VBP': 31, 'NNPS': 32, 'PDT': 33, 'NNS': 34, 'VBN': 35, 'CC': 36, '``': 37, 'RBR': 38, 'VBD': 39, 'IN': 40, "''": 41, '$': 42, 'SYM': 43, 'JJS': 44, 'WP$': 45, 'POS': 46, 'VB': 47, '.': 48}


In [10]:
# for a single instance
def encode_data_instance(data, word_to_idx, label_to_idx):
    words = [
        word_to_idx.get(word, word_to_idx["<OOV>"]) for word in data["words"]
    ]
    
    labels = [
        label_to_idx[label] for label in data["labels"]
    ]
    
    return {
        "words": words,
        "labels": labels
    }
    

print(encode_data_instance(training_dataset[0], word_to_idx, label_to_idx))


{'words': [15446, 427, 22449, 8706, 2829, 2527, 16321, 19167, 5695, 20770, 24005, 10805, 14760, 11868, 19064, 2835, 23737, 12453, 19839, 2835, 18153, 12428, 12453, 5759, 19167, 19285, 11537, 20158, 2837, 18421, 15437, 2231, 13408, 13704, 3811, 12177, 8760], 'labels': [4, 40, 15, 4, 6, 18, 35, 28, 47, 15, 11, 4, 40, 4, 34, 40, 20, 27, 11, 40, 4, 4, 27, 47, 28, 47, 15, 11, 4, 40, 20, 36, 20, 46, 11, 34, 48]}


In [11]:
trainset = map(lambda data: encode_data_instance(data, word_to_idx, label_to_idx), training_dataset)
trainset = list(trainset)

print(trainset[0])


{'words': [15446, 427, 22449, 8706, 2829, 2527, 16321, 19167, 5695, 20770, 24005, 10805, 14760, 11868, 19064, 2835, 23737, 12453, 19839, 2835, 18153, 12428, 12453, 5759, 19167, 19285, 11537, 20158, 2837, 18421, 15437, 2231, 13408, 13704, 3811, 12177, 8760], 'labels': [4, 40, 15, 4, 6, 18, 35, 28, 47, 15, 11, 4, 40, 4, 34, 40, 20, 27, 11, 40, 4, 4, 27, 47, 28, 47, 15, 11, 4, 40, 20, 36, 20, 46, 11, 34, 48]}


In [12]:
testset = map(lambda data: encode_data_instance(
    data, word_to_idx, label_to_idx), test_dataset)
testset = list(testset)

print(testset[0])


{'words': [7390, 0, 21970, 22449, 2671, 5809, 12453, 21448, 16107, 9733, 10171, 12453, 6045, 22449, 4621, 20680, 3481, 6866, 5509, 15700, 427, 19232, 6218, 11362, 12453, 2231, 7783, 22063, 18696, 19, 19924, 11757, 10383, 8760], 'labels': [36, 20, 39, 15, 20, 4, 27, 8, 6, 17, 17, 27, 40, 15, 20, 20, 11, 4, 6, 4, 40, 11, 40, 4, 27, 36, 18, 6, 18, 47, 15, 4, 34, 48]}


In [13]:
assert len(training_dataset) == len(trainset)


In [14]:
# now to create the validation set
import numpy as np

def create_train_validation_splits(trainset, validation_ratio):
    validation_set_size = int(len(trainset) * validation_ratio)
    validation_indices = np.random.choice(len(trainset), replace=False, size=validation_set_size).tolist()
    
    # now to separate trainset indices
    trainset_indices = [i for i in range(len(trainset)) if i not in validation_indices]
    
    return trainset_indices, validation_indices


trainset_indices, validation_indices = create_train_validation_splits(trainset, 0.3)

print(len(trainset_indices))
print(len(validation_indices))


assert len(trainset_indices) + len(validation_indices) == len(trainset)


9138
3916


In [24]:
from jax import random, jit, vmap, grad
import jax.numpy as jnp
import flax.linen as nn


In [25]:
master_key = random.PRNGKey(seed=2023)
master_key, model_init_key = random.split(master_key)


In [26]:
# lstm in flax: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.LSTMCell.html

from typing import Any


class LSTMTagger(nn.Module):
    vocab_size: int
    embedding_dimensions: int
    projection_dims: int # aka hidden dims for projection after lstm
    n_labels: int
        
    @nn.compact
    def __call__(self, words) -> Any:
        # ========== Embedding ==========
        x = nn.Embed(
            num_embeddings=self.vocab_size, features=self.embedding_dimensions)(words)
        
        # ========= LSTM ============
        lstm = nn.OptimizedLSTMCell(features=self.projection_dims)
        carry = lstm.initialize_carry(random.PRNGKey(2023), x.shape)    
        carry, x = lstm(carry=carry, inputs=x)
        
        # ========== Dense ==========
        x = nn.Dense(features=self.n_labels)(x)
        x = nn.leaky_relu(x)
        
        x = nn.log_softmax(x)
        
        return x


model = LSTMTagger(len(word_to_idx), 150, 150, len(label_to_idx))

init_params = model.init(model_init_key, np.array(trainset[0]["words"]))
