In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from plotly import graph_objects as go, io as pio
from sklearn import metrics
from importlib import reload


import polars as pl
import util

reload(util)
pio.templates.default = "plotly_dark"


## load data


In [None]:
examples = util.load_examples(tag_map=util.MAP_TAGS).sort("length")
lang_counts = {
    d["lang"]: d["len"]
    for d in examples.group_by("lang").len().sort("len", descending=True).to_dicts()
}
print(lang_counts)
# display(examples)

In [None]:
# train-val split
n_train = int(0.6 * len(examples))
examples = examples.sample(fraction=1, shuffle=True)
train_df, val_df = examples.head(n_train), examples.tail(-n_train)

print(f"split: {len(train_df)} training, {len(val_df)} val")


## most common tokens


In [None]:
token_counts = (
    examples.select(pl.col("tokens").explode().alias("token"))
    .group_by("token")
    .agg(pl.len().alias("count"))
    .sort("count", descending=True)
)
tag_counts = (
    examples.select(pl.col("tags").explode().alias("tag"))
    .group_by("tag")
    .agg(pl.len().alias("count"))
    .sort("count", descending=True)
)
print("common tokens: ", token_counts.head(60).rows())
print("common tags  : ", tag_counts.rows())


print(f"\nwe have {len(token_counts)} unique tokens, and {len(tag_counts)} unique tags")


## make a vocab!

- add padding to both tokens and tags
- also, convert tokens and tags to integers


In [None]:
# vocab for tokens
vocab: list = token_counts["token"].to_list()[:100]
vocab.insert(0, "<pad>")
vocab.insert(1, "<unk>")
token2idx = {t: i for i, t in enumerate(vocab)}

# tags
tag_vocab: list = tag_counts["tag"].to_list()
tag_vocab.insert(0, "<pad>")
tag2idx = {t: i for i, t in enumerate(tag_vocab)}

print("vocab (tokens):", vocab)
print("vocab (tags)  :", tag_vocab)
# Convert tokens and labels to indices
# these are lists of lists!
train_token_idx = [[token2idx.get(t, 1) for t in seq] for seq in train_df["tokens"]]
train_tag_idx = [[tag2idx[t] for t in seq] for seq in train_df["tags"]]
# print("\nlists of lists:")
# print(train_token_idx)
# print(train_tag_idx)
print(f"\ntraining examples of length: {[len(e) for e in train_token_idx]}")

# validation data
val_token_idx = [[token2idx.get(t, 1) for t in seq] for seq in val_df["tokens"]]
val_tag_idx = [[tag2idx[t] for t in seq] for seq in val_df["tags"]]
print(f"validation examples of length: {[len(e) for e in val_token_idx]}")


### Prepare data for model


In [None]:
def seqs2padded_tensor(sequences: list[list[int | float]], pad_value=0):
    return nn.utils.rnn.pad_sequence(
        (torch.tensor(s) for s in sequences),
        batch_first=True,
        padding_value=pad_value,
    )


train_token_tensors = seqs2padded_tensor(train_token_idx)
train_tag_tensors = seqs2padded_tensor(train_tag_idx)
val_token_tensors = seqs2padded_tensor(val_token_idx)
val_tag_tensors = seqs2padded_tensor(val_tag_idx)

print(f"token tensor (train): {train_token_tensors.shape}")
print(f"tag tensor   (train): {train_tag_tensors.shape}")
print(f"token tensor (val): {val_token_tensors.shape}")
print(f"tag tensor   (val): {val_tag_tensors.shape}")

In [None]:
class SequenceDataset(Dataset):
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return self.tokens[idx], self.labels[idx]


# Create dataset and dataloader
train_dataset = SequenceDataset(train_token_tensors, train_tag_tensors)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)


## model


In [None]:
class LSTMTagger(nn.Module):
    def __init__(self, vocab_size, tagset_size, embedding_dim, hidden_dim):
        super(LSTMTagger, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

    def forward(self, sentences):
        embeds = self.embedding(
            sentences
        )  # Shape: (batch_size, seq_len, embedding_dim)
        lstm_out, _ = self.lstm(embeds)  # Shape: (batch_size, seq_len, hidden_dim)
        tag_scores = self.hidden2tag(
            lstm_out
        )  # Shape: (batch_size, seq_len, tagset_size)
        return tag_scores


In [None]:
# Parameters
embedding_dim = 32
hidden_dim = 64
vocab_size = len(vocab)
tagset_size = len(tag_vocab)

model = LSTMTagger(vocab_size, tagset_size, embedding_dim, hidden_dim)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

# Training loop
epochs = 40
losses = []
for epoch in range(epochs):
    for examples, labels in train_loader:
        model.zero_grad()

        # Forward pass
        tag_scores = model(examples)

        # Reshape for loss calculation
        loss = loss_function(tag_scores.view(-1, tagset_size), labels.view(-1))

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

    # print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
go.Figure(go.Scatter(y=losses))

## evaluate


In [None]:
model.eval()
with torch.no_grad():
    tag_scores = model(val_token_tensors)
    predictions = torch.argmax(tag_scores, dim=-1)  # Shape: (batch_size, seq_len)

pred_tags = []
true_tags = []

for pred, true_t in zip(predictions, val_tag_idx):
    # print(pred.shape)
    # print((len(true_tags)))
    true_tags.extend([tag_vocab[t] for t in true_t])
    pred_tags.extend([tag_vocab[t] for t in pred[: len(true_t)]])
print(pred_tags)
print(true_tags)

metrics.ConfusionMatrixDisplay.from_predictions(
    true_tags, pred_tags, xticks_rotation="vertical"
)