# Practice 4: Named Entity Recognition

## Introduction

### Formulation of the problem

In this assignment, you will solve the Named Entity Recognition (NER) problem, one of the most common in NLP, along with the text classification problem.

This task involves classifying each word/token whether it is part of a named entity (an entity may consist of multiple words/tokens) or not.

For example, we want to extract names and organization names. Then for the text

     Yan    Goodfellow  works  for  Google  Brain

The model should extract the following sequence:

     B-PER  I-PER       O      O    B-ORG   I-ORG

where the prefixes *B-* and *I-* denote the beginning and end of the named entity, *O* denotes a word without a tag. This prefix system (*BIO* tagging) was introduced to distinguish between successive named entities of the same type.
There are other types of tagging, such as [*BILUO*](https://en.wikipedia.org/wiki/Inside–outside–beginning_(tagging)), but for this tutorial we will focus on *BIO*.

We will solve the NER problem on the CoNLL-2003 dataset using recurrent networks and models based on the Transformer architecture.

### Libraries

Main libraries:
  - [PyTorch](https://pytorch.org/)
  - [Transformers](https://github.com/huggingface/transformers)

### Data

The data is stored in an archive, which consists of:

- *train.tsv* - training sample. Each line contains: <word / token>, <word / token tag>

- *valid.tsv* - validation sample, which can be used to select hyperparameters and quality measurements. It has an identical structure to train.tsv.

- *test.tsv* - test sample, which is used to evaluate the final quality. It has an identical structure to train.tsv.

You can download the data here: [link](https://drive.google.com/drive/folders/1OKNrfHsBm1ehbG-yM0R1BGshbscf_eue?usp=drive_link)

In [None]:
# !pip install numpy==1.21.6 scikit-learn==1.0.2 torch==1.12.1 tqdm==4.64.0 transformers==4.21.1 wandb

In [None]:
import random
from collections import Counter, defaultdict, namedtuple
from typing import Tuple, List, Dict, Any

import torch
import numpy as np
import wandb

from tqdm import tqdm, trange

Let's fix the seed for reproducibility of the results (it is advisable to do this **always**!):

In [None]:
def set_global_seed(seed: int) -> None:
    """
    Set global seed for reproducibility.
    """

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


set_global_seed(42)

Let’s initialize the device (CPU / GPU) on which we will work (preferably **GPU**):

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

Initialize *tensorboard* to log metrics during the training process:

In [None]:
# Initialize wandb for logging
wandb.init(
    project="ner-homework",
    name="ner-experiment",
    config={
        "task": "named_entity_recognition",
        "dataset": "conll2003",
        "framework": "pytorch"
    }
)

## Part 1. Data preparation (4 points)

First of all, we need to read the data. Let's write a function that takes as input the path to one of the conll-2003 files and returns two lists:
- a list of lists of words/tokens (and corresponding to it)
- list of lists of tags

P.S. Let's make this function more flexible by supplying a boolean variable as input, whether we read data in *lowercase* or not.

**Exercise. Implement the `read_conll2003` function.** **<font color='red'>(1 point)</font>**

In [None]:
def read_conll2003(
    path: str,
    lower: bool = True,
) -> Tuple[List[List[str]], List[List[str]]]:
    """
    Prepare data in CoNNL like format.
    """

    token_seq = []
    label_seq = []

    current_tokens = []
    current_labels = []
    
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            
            # Empty line indicates end of sentence
            if not line:
                if current_tokens:  # Only add if we have tokens
                    token_seq.append(current_tokens)
                    label_seq.append(current_labels)
                    current_tokens = []
                    current_labels = []
            else:
                # Split by tab to get token and label
                parts = line.split('\t')
                if len(parts) >= 2:
                    token = parts[0]
                    label = parts[1]
                    
                    if lower:
                        token = token.lower()
                    
                    current_tokens.append(token)
                    current_labels.append(label)
    
    # Don't forget the last sentence if file doesn't end with empty line
    if current_tokens:
        token_seq.append(current_tokens)
        label_seq.append(current_labels)

    return token_seq, label_seq

Let's read all three files:

- *train.tsv*
- *valid.tsv*
- *test.tsv*

In [None]:
train_token_seq, train_label_seq = read_conll2003("train.txt")
valid_token_seq, valid_label_seq = read_conll2003("valid.txt")
test_token_seq, test_label_seq = read_conll2003("test.txt")

In [None]:
# Check if the data files exist
import os
data_files = ["train.txt", "valid.txt", "test.txt"]
for file in data_files:
    if os.path.exists(file):
        print(f"✓ {file} found")
    else:
        print(f"✗ {file} not found")
        
if all(os.path.exists(f) for f in data_files):
    print("All data files are ready!")
else:
    print("Please make sure all .txt files are in the current directory")

Look at what we got:

In [None]:
for token, label in zip(train_token_seq[0], train_label_seq[0]):
    print(f"{token}\t{label}")

In [None]:
for token, label in zip(valid_token_seq[0], valid_label_seq[0]):
    print(f"{token}\t{label}")

In [None]:
for token, label in zip(test_token_seq[0], test_label_seq[0]):
    print(f"{token}\t{label}")

In [None]:
assert len(train_token_seq) == len(train_label_seq), "The lengths of the training token_seq and label_seq do not match, an error in the read_conll2003 function"
assert len(valid_token_seq) == len(valid_label_seq), "The lengths of the validation token_seq and label_seq do not match, an error in the read_conll2003 function"
assert len(test_token_seq) == len(test_label_seq), "The lengths of the test token_seq and label_seq do not match, an error in the read_conll2003 function"

assert train_token_seq[0] == ['eu', 'rejects', 'german', 'call', 'to', 'boycott', 'british', 'lamb', '.'], "Error in training token_seq"
assert train_label_seq[0] == ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'], "Error in training label_seq"

assert valid_token_seq[0] == ['cricket', '-', 'leicestershire', 'take', 'over', 'at', 'top', 'after', 'innings', 'victory', '.'], "Error in validation token_seq"
assert valid_label_seq[0] == ['O', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], "Error in validation label_seq"

assert test_token_seq[0] == ['soccer', '-', 'japan', 'get', 'lucky', 'win', ',', 'china', 'in', 'surprise', 'defeat', '.'], "Error in test token_seq"
assert test_label_seq[0] == ['O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'B-PER', 'O', 'O', 'O', 'O'], "Error in test label_seq"

print("All tests passed!")

The CoNLL-2003 dataset is presented in the form of **BIO** tagging, where the label is:
- *B-{label}* - beginning of entity *{label}*
- *I-{label}* - continuation of the entity *{label}*
- *O* - no entity

There are also other sequence tagging methods, such as **BILUO**.

### Preparing dictionaries

To train the neural network, we will use two mappings:
- {**token**}→{**token_idx**}: correspondence between word/token and string in *embedding* matrix (starts from 0);
- {**label**}→{**label_idx**}: correspondence between tag and unique index (starts from 0);

Now we need to implement two functions:
- get_token2idx
- get_label2idx

which will return the corresponding dictionaries.

P.S. token2idx dictionary must also contain special tokens:
- `<PAD>` is a special token for padding, since we are going to train the models in batches
- `<UNK>` is a special token for processing words/tokens that are not in the dictionary (relevant for inference)

Let's assign them to idx 0 and 1 respectively for convenience.

P.P.S. You can also add a *min_count* parameter to get_token2idx, which will only include words exceeding a certain frequency.

First let's collect:
- token2cnt - a dictionary from a unique word / token to the number of these words / tokens in the training set (it is important that only in the training set!)
- label_set - a list of unique tags

P.S. You can also use stemming to convert different word forms of the same word into one token, but we will skip this point.

**Exercise. Implement the `get_token2idx` and `get_label2idx` functions.** **<font color='red'>(1 point)</font>**

In [None]:
token2cnt = Counter([token for sentence in train_token_seq for token in sentence])

In [None]:
token2cnt.most_common(10)

In [None]:
print(f"Number of unique words in the training dataset: {len(token2cnt)}")
print(f"Number of words occurring only once in the training dataset: {len([token for token, cnt in token2cnt.items() if cnt == 1])}")

As we can see, we have many words that appear only once in the dataset. Obviously, we won’t be able to learn from them, we will only overfit, so let’s throw out such words when forming our vocabulary.

In [None]:
# use the min_count parameter to cut off words with frequency cnt < min_count

def get_token2idx(
    token2cnt: Dict[str, int],
    min_count: int,
) -> Dict[str, int]:
    """
    Get mapping from tokens to indices to use with Embedding layer.
    """

    token2idx: Dict[str, int] = {}

    # Add special tokens first
    token2idx["<PAD>"] = 0
    token2idx["<UNK>"] = 1
    
    # Add tokens that meet min_count requirement
    idx = 2
    for token, count in token2cnt.items():
        if count >= min_count:
            token2idx[token] = idx
            idx += 1

    return token2idx

In [None]:
token2idx = get_token2idx(token2cnt, min_count=2)

In [None]:
# Function for sorting tags so that first there is an O tag,
# then B- tags and only after I- tags (can be set manually)

def sort_labels_func(x: str) -> int:
    if x == "O":
        return 0
    elif x.startswith("B-"):
        return 1
    else:
        return 2

label_set = sorted(
    set(label for sentence in train_label_seq for label in sentence),
    key=lambda x: (sort_labels_func(x), x),
)

In [None]:
label_set

In [None]:
def get_label2idx(label_set: List[str]) -> Dict[str, int]:
    """
    Get mapping from labels to indices.
    """

    label2idx: Dict[str, int] = {}

    for idx, label in enumerate(label_set):
        label2idx[label] = idx

    return label2idx

In [None]:
label2idx = get_label2idx(label_set)

Let's look at what we got:

In [None]:
for token, idx in list(token2idx.items())[:10]:
    print(f"{token}\t{idx}")

In [None]:
for label, idx in label2idx.items():
    print(f"{label}\t{idx}")

In [None]:
assert len(get_token2idx(token2cnt, min_count=1)) == 21012, "Error in dictionary length, most likely min_count is implemented incorrectly"
assert len(token2idx) == 10952, "Incorrect token2idx length, most likely min_count is implemented incorrectly"
assert len(label2idx) == 9, "Incorrect label2idx length"

assert list(token2idx.items())[:10] == [
    ('<PAD>', 0), ('<UNK>', 1), ('eu', 2), ('german', 3), ('call', 4),
    ('to', 5), ('boycott', 6), ('british', 7), ('lamb', 8), ('.', 9)
], "Wrong format of token2idx"
assert label2idx == {
    'O': 0, 'B-LOC': 1, 'B-MISC': 2, 'B-ORG': 3, 'B-PER': 4,
    'I-LOC': 5, 'I-MISC': 6, 'I-ORG': 7, 'I-PER': 8
}, "Wrong format of label2idx"

print("All tests passed!")

### Preparing the dataset and loader

Typically, neural networks are trained in batches. This means that each update of the neural network's weights occurs based on multiple sequences. A technical detail is the need to complete all sequences within the batch to the same length.

From the previous practical task, you should know about `Dataset` (`torch.utils.data.Dataset`) - a data structure that stores and can index data for training. The dataset must inherit from the standard PyTorch Dataset class and override the `__len__` and `__getitem__` methods.

The `__getitem__` method must return the indexed sequence and its tags.

**Don't forget** about `<UNK>` special token for unknown words!

Let's write a custom dataset for our task, which will receive as input (the `__init__` method):
- token_seq - list of lists of words/tokens
- label_seq - list of lists of tags
- token2idx
- label2idx

and return from the `__getitem__` method two int64 tensors (`torch.LongTensor`) with the indices of words / tokens in the sample and the indices of the corresponding tags:

**Exercise. Implement the NERDataset class.** **<font color='red'>(1 point)</font>**

In [None]:
class NERDataset(torch.utils.data.Dataset):
    """
    PyTorch Dataset for NER.
    """

    def __init__(
        self,
        token_seq: List[List[str]],
        label_seq: List[List[str]],
        token2idx: Dict[str, int],
        label2idx: Dict[str, int],
    ):
        self.token2idx = token2idx
        self.label2idx = label2idx

        self.token_seq = [self.process_tokens(tokens, token2idx) for tokens in token_seq]
        self.label_seq = [self.process_labels(labels, label2idx) for labels in label_seq]

    def __len__(self):
        return len(self.token_seq)

    def __getitem__(
        self,
        idx: int,
    ) -> Tuple[torch.LongTensor, torch.LongTensor]:
        return torch.LongTensor(self.token_seq[idx]), torch.LongTensor(self.label_seq[idx])

    @staticmethod
    def process_tokens(
        tokens: List[str],
        token2idx: Dict[str, int],
        unk: str = "<UNK>",
    ) -> List[int]:
        """
        Transform list of tokens into list of tokens' indices.
        """
        token_indices = []
        for token in tokens:
            if token in token2idx:
                token_indices.append(token2idx[token])
            else:
                token_indices.append(token2idx[unk])
        return token_indices

    @staticmethod
    def process_labels(
        labels: List[str],
        label2idx: Dict[str, int],
    ) -> List[int]:
        """
        Transform list of labels into list of labels' indices.
        """
        return [label2idx[label] for label in labels]

Create three datasets:
- *train_dataset*
- *valid_dataset*
- *test_dataset*

In [None]:
train_dataset = NERDataset(
    token_seq=train_token_seq,
    label_seq=train_label_seq,
    token2idx=token2idx,
    label2idx=label2idx,
)
valid_dataset = NERDataset(
    token_seq=valid_token_seq,
    label_seq=valid_label_seq,
    token2idx=token2idx,
    label2idx=label2idx,
)
test_dataset = NERDataset(
    token_seq=test_token_seq,
    label_seq=test_label_seq,
    token2idx=token2idx,
    label2idx=label2idx,
)

Let's look at what we got:

In [None]:
train_dataset[0]

In [None]:
valid_dataset[0]

In [None]:
test_dataset[0]

In [None]:
assert len(train_dataset) == 14986, "Incorrect train_dataset length"
assert len(valid_dataset) == 3465, "Incorrect valid_dataset length"
assert len(test_dataset) == 3683, "Incorrect test_dataset length"

assert torch.equal(train_dataset[0][0], torch.tensor([2,1,3,4,5,6,7,8,9])), "Malformed train_dataset"
assert torch.equal(train_dataset[0][1], torch.tensor([3,0,2,0,0,0,2,0,0])), "Malformed train_dataset"

assert torch.equal(
    valid_dataset[0][0],
    torch.tensor([1737,571,1777,197,687,145,349,111,1819,1558,9])
), "Malformed valid_dataset"
assert torch.equal(valid_dataset[0][1], torch.tensor([0,0,3,0,0,0,0,0,0,0,0])), "Malformed valid_dataset"

assert torch.equal(
    test_dataset[0][0],
    torch.tensor([1516,571,1434,1729,4893,2014,67,310,215,3157,3139,9])
), "Malformed test_dataset"
assert torch.equal(test_dataset[0][1], torch.tensor([0,0,1,0,0,0,0,4,0,0,0,0])), "Malformed test_dataset"

print("All tests passed!")

In order to complete sequences with padding, we will use the `collate_fn` parameter of the `DataLoader` class.

Given a sequence of pairs of tensors for sentences and tags, it is necessary to complete all sequences to the sequence of the maximum length in the batch.

Use the special token `<PAD>` for completion of word/token sequences and -1 for tag sequences.

**hint**: it is convenient to use the `torch.nn.utils.rnn` method. Pay attention to the `batch_first` parameter.

`Collator` can be implemented in two ways:
- class with method `__call__`
- function

We will go the first way.

Initialize an instance of the `Collator` class (the `__init__` method) using two parameters:
- id `<PAD>` special token for word/token sequences
- id `<PAD>` special token for tag sequences (value -1)

The `__call__` method takes a batch as input, namely a list of tuples of what is returned from the `__getitem__` method of our dataset. In our case, this is a list of tuples of two int64 tensors - `List[Tuple[torch.LongTensor, torch.LongTensor]]`.

Ad the output we want to get two tensors:
- Indexes of word/token with paddings
- Indexes of tags with paddings
    
P.S. The `<PAD>` value is needed to easily distinguish pad tokens from others when calculating loss. You can use the `ignore_index` parameter when initializing the loss.

**Exercise. Implement the collator class NERCollator.** **<font color='red'>(1 point)</font>**

In [None]:
class NERCollator:
    """
    Collator that handles variable-size sentences.
    """

    def __init__(
        self,
        token_padding_value: int,
        label_padding_value: int,
    ):
        self.token_padding_value = token_padding_value
        self.label_padding_value = label_padding_value

    def __call__(
        self,
        batch: List[Tuple[torch.LongTensor, torch.LongTensor]],
    ) -> Tuple[torch.LongTensor, torch.LongTensor]:

        tokens, labels = zip(*batch)

        # Pad sequences using PyTorch utility
        padded_tokens = torch.nn.utils.rnn.pad_sequence(
            tokens, 
            batch_first=True, 
            padding_value=self.token_padding_value
        )
        
        padded_labels = torch.nn.utils.rnn.pad_sequence(
            labels, 
            batch_first=True, 
            padding_value=self.label_padding_value
        )

        return padded_tokens, padded_labels

In [None]:
collator = NERCollator(
    token_padding_value=token2idx["<PAD>"],
    label_padding_value=-1,
)

Now everything is ready to define the loaders.

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=collator,
)
valid_dataloader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=1,  # for correct metrics measurements leave batch_size=1
    shuffle=False, # for correct metrics measurements leave shuffle=False
    collate_fn=collator,
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,  # for correct metrics measurements leave batch_size=1
    shuffle=False, # for correct metrics measurements leave shuffle=False
    collate_fn=collator,
)

Let's look at what we got:

In [None]:
tokens, labels = next(iter(train_dataloader))

tokens = tokens.to(device)
labels = labels.to(device)

In [None]:
tokens

In [None]:
labels

In [None]:
train_tokens, train_labels = next(iter(
    torch.utils.data.DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=False,
        collate_fn=collator,
    )
))
assert torch.equal(
    train_tokens,
    torch.tensor([[2, 1, 3, 4, 5, 6, 7, 8, 9], [10, 11, 0, 0, 0, 0, 0, 0, 0]])
), "Looks like a bug in the collator"
assert torch.equal(
    train_labels,
    torch.tensor([[3, 0, 2, 0, 0, 0, 2, 0, 0], [4, 8, -1, -1, -1, -1, -1, -1, -1]])
), "Looks like a bug in the collator"

valid_tokens, valid_labels = next(iter(
    torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=2,
        shuffle=False,
        collate_fn=collator,
    )
))
assert torch.equal(
    valid_tokens,
    torch.tensor([
        [1737, 571, 1777, 197, 687, 145, 349, 111,  1819, 1558, 9],
        [248, 10679, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ])), "Looks like a bug in the collator"
assert torch.equal(
    valid_labels,
    torch.tensor([
        [0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1]
    ])), "Looks like a bug in the collator"

test_tokens, test_labels = next(iter(
    torch.utils.data.DataLoader(
        test_dataset,
        batch_size=2,
        shuffle=False,
        collate_fn=collator,
    )
))
assert torch.equal(
    test_tokens,
    torch.tensor([
        [1516, 571, 1434, 1729, 4893, 2014, 67, 310, 215, 3157, 3139, 9],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ])), "Looks like a bug in the collator"
assert torch.equal(
    test_labels,
    torch.tensor([
        [0, 0, 1, 0, 0, 0, 0, 4, 0, 0, 0, 0],
        [4, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
    ])), "Looks like a bug in the collator"

print("All tests passed!")

## Part 2. BiLSTM tagger (6 points)

Define the network architecture using the PyTorch library.

Your architecture at this point should follow the standard tagger:
* Embedding layer at the input
* LSTM (unidirectional or bidirectional) layer for sequence processing
* Dropout (specified separately or built into LSTM) to reduce overfitting
* Linear output layer

To train the network, use an element-wise cross-entropy loss function.

**Please note** that `<PAD>` tokens should not be included in the loss function calculation. It is recommended to use Adam as an optimizer. To obtain prediction values from model outputs, use the `argmax` function.

**Exercise. Implement the BiLSTM model class.** **<font color='red'>(2 points)</font>**

In [None]:
class BiLSTM(torch.nn.Module):
    """
    Bidirectional LSTM architecture.
    """

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        hidden_size: int,
        num_layers: int,
        dropout: float,
        bidirectional: bool,
        n_classes: int,
    ):
        super().__init__()

        self.embedding = torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
        
        self.rnn = torch.nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional,
            batch_first=True
        )
        
        # Calculate input size for the head
        rnn_output_size = hidden_size * 2 if bidirectional else hidden_size
        
        self.head = torch.nn.Linear(rnn_output_size, n_classes)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, tokens: torch.LongTensor) -> torch.Tensor:
        embed = self.embedding(tokens)

        # we use the special function pack_padded_sequence in order to obtain a PackedSequence structure
        # that does not take padding into account when passing rnn
        length = (tokens != 0).sum(dim=1).detach().cpu()
        packed_embed = torch.nn.utils.rnn.pack_padded_sequence(
            embed, length, batch_first=True, enforce_sorted=False
        )

        # we use the special function pad_packed_sequence to get a tensor from PackedSequence
        packed_rnn_output, _ = self.rnn(packed_embed)
        rnn_output, _ = torch.nn.utils.rnn.pad_packed_sequence(
            packed_rnn_output, batch_first=True
        )

        # Apply dropout
        rnn_output = self.dropout(rnn_output)
        
        logits = self.head(rnn_output)
        return logits.transpose(1, 2)

In [None]:
model = BiLSTM(
    num_embeddings=len(token2idx),
    embedding_dim=100,
    hidden_size=100,
    num_layers=1,
    dropout=0.0,
    bidirectional=True,
    n_classes=len(label2idx),
).to(device)

In [None]:
model

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)

In [None]:
outputs = model(tokens)

In [None]:
assert outputs.shape == torch.Size([2, 9, 10])
assert 2 < criterion(outputs, labels) < 3

print("All tests passed!")

### Experiments

Run experiments on the data. Adjust parameters based on the validation set without using the test set. Your goal is to configure the network so that the quality of the model according to the F1-macro measure on the validation and test sets is no less than **0.76**.

Draw conclusions about model quality, overfitting, and sensitivity of the architecture to the choice of hyperparameters. Present the results of your experiments in the form of a mini-report (in the same ipython notebook).

In [None]:
# We'll use wandb for logging instead of TensorBoard
print("Using Weights & Biases for experiment tracking and visualization")

**Exercise. Implement a metric calculation function `compute_metrics`.** **<font color='red'>(1 point)</font>**

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


def compute_metrics(
    outputs: torch.Tensor,
    labels: torch.LongTensor,
) -> Dict[str, float]:
    """
    Compute NER metrics.
    """

    metrics = {}

    # Convert outputs to predictions
    predictions = torch.argmax(outputs, dim=1)
    
    # Flatten and convert to numpy
    predictions = predictions.view(-1).cpu().numpy()
    labels = labels.view(-1).cpu().numpy()
    
    # Filter out padding tokens (labels = -1)
    mask = labels != -1
    y_true = labels[mask]
    y_pred = predictions[mask]

    # accuracy
    accuracy = accuracy_score(
        y_true=y_true,
        y_pred=y_pred,
    )

    # precision
    precision_micro = precision_score(
        y_true=y_true,
        y_pred=y_pred,
        average="micro",
        zero_division=0,
    )
    precision_macro = precision_score(
        y_true=y_true,
        y_pred=y_pred,
        average="macro",
        zero_division=0,
    )
    precision_weighted = precision_score(
        y_true=y_true,
        y_pred=y_pred,
        average="weighted",
        zero_division=0,
    )

    # recall
    recall_micro = recall_score(
        y_true=y_true,
        y_pred=y_pred,
        average="micro",
        zero_division=0,

    )
    recall_macro = recall_score(
        y_true=y_true,
        y_pred=y_pred,
        average="macro",
        zero_division=0,
    )
    recall_weighted = recall_score(
        y_true=y_true,
        y_pred=y_pred,
        average="weighted",
        zero_division=0,
    )

    # f1
    f1_micro = f1_score(
        y_true=y_true,
        y_pred=y_pred,
        average="micro",
        zero_division=0,
    )
    f1_macro = f1_score(
        y_true=y_true,
        y_pred=y_pred,
        average="macro",
        zero_division=0,
    )
    f1_weighted = f1_score(
        y_true=y_true,
        y_pred=y_pred,
        average="weighted",
        zero_division=0,
    )

    metrics["accuracy"] = accuracy

    metrics["precision_micro"]    = precision_micro
    metrics["precision_macro"]    = precision_macro
    metrics["precision_weighted"] = precision_weighted

    metrics["recall_micro"]    = recall_micro
    metrics["recall_macro"]    = recall_macro
    metrics["recall_weighted"] = recall_weighted

    metrics["f1_micro"]    = f1_micro
    metrics["f1_macro"]    = f1_macro
    metrics["f1_weighted"] = f1_weighted

    return metrics

**Exercise. Implement the training and testing functions `train_epoch` and `evaluate_epoch`. <font color='red'>(2 points)</font>**

In [None]:
def train_epoch(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device,
    epoch: int,
) -> None:
    """
    One training cycle (loop).
    """

    model.train()

    epoch_loss = []
    batch_metrics_list = defaultdict(list)

    for i, (tokens, labels) in tqdm(
        enumerate(dataloader),
        total=len(dataloader),
        desc="loop over train batches",
    ):

        tokens, labels = tokens.to(device), labels.to(device)

        # Loss calculation and optimizer step
        optimizer.zero_grad()
        
        outputs = model(tokens)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()

        epoch_loss.append(loss.item())
        
        # Log to wandb every 100 batches
        if i % 100 == 0:
            wandb.log({
                "train/batch_loss": loss.item(),
                "train/epoch": epoch,
                "train/batch": epoch * len(dataloader) + i
            })

        with torch.no_grad():
            model.eval()
            outputs_inference = model(tokens)
            model.train()

        batch_metrics = compute_metrics(
            outputs=outputs_inference,
            labels=labels,
        )

        for metric_name, metric_value in batch_metrics.items():
            batch_metrics_list[metric_name].append(metric_value)

    avg_loss = np.mean(epoch_loss)
    print(f"Train loss: {avg_loss}\n")
    
    # Log epoch metrics to wandb
    wandb_log = {"train/loss": avg_loss, "epoch": epoch}
    
    for metric_name, metric_value_list in batch_metrics_list.items():
        metric_value = np.mean(metric_value_list)
        print(f"Train {metric_name}: {metric_value}\n")
        wandb_log[f"train/{metric_name}"] = metric_value
    
    wandb.log(wandb_log)

In [None]:
def evaluate_epoch(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    criterion: torch.nn.Module,
    device: torch.device,
    epoch: int,
    split_name: str = "val",
) -> None:
    """
    One evaluation cycle (loop).
    """

    model.eval()

    epoch_loss = []
    batch_metrics_list = defaultdict(list)

    with torch.no_grad():

        for i, (tokens, labels) in tqdm(
            enumerate(dataloader),
            total=len(dataloader),
            desc=f"loop over {split_name} batches",
        ):

            tokens, labels = tokens.to(device), labels.to(device)

            # Loss calculation
            outputs = model(tokens)
            loss = criterion(outputs, labels)

            epoch_loss.append(loss.item())

            batch_metrics = compute_metrics(
                outputs=outputs,
                labels=labels,
            )

            for metric_name, metric_value in batch_metrics.items():
                batch_metrics_list[metric_name].append(metric_value)

        avg_loss = np.mean(epoch_loss)
        print(f"{split_name.capitalize()} loss:  {avg_loss}\n")
        
        # Log epoch metrics to wandb
        wandb_log = {f"{split_name}/loss": avg_loss, "epoch": epoch}

        for metric_name, metric_value_list in batch_metrics_list.items():
            metric_value = np.mean(metric_value_list)
            print(f"{split_name.capitalize()} {metric_name}: {metric_value}\n")
            wandb_log[f"{split_name}/{metric_name}"] = metric_value
        
        wandb.log(wandb_log)

In [None]:
def train(
    n_epochs: int,
    model: torch.nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    test_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device,
) -> None:
    """
    Training loop.
    """

    for epoch in range(n_epochs):

        print(f"Epoch [{epoch+1} / {n_epochs}]\n")

        train_epoch(
            model=model,
            dataloader=train_dataloader,
            optimizer=optimizer,
            criterion=criterion,
            device=device,
            epoch=epoch,
        )
        evaluate_epoch(
            model=model,
            dataloader=test_dataloader,
            criterion=criterion,
            device=device,
            epoch=epoch,
            split_name="val",
        )

**Exercise. Conduct experiments. <font color='red'>(2 points)</font>**

In [None]:
# Experiment with BiLSTM
wandb.config.update({
    "model_type": "BiLSTM",
    "embedding_dim": 100,
    "hidden_size": 128,
    "num_layers": 2,
    "dropout": 0.1,
    "bidirectional": True,
    "learning_rate": 1e-3,
    "batch_size": 32,
    "epochs": 5
})

model = BiLSTM(
    num_embeddings=len(token2idx),
    embedding_dim=100,
    hidden_size=128,
    num_layers=2,
    dropout=0.1,
    bidirectional=True,
    n_classes=len(label2idx),
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)

print("Starting BiLSTM experiment...")
print(f"Model parameters: embedding_dim=100, hidden_size=128, num_layers=2, dropout=0.1, lr=1e-3")

train(
    n_epochs=5,
    model=model,
    train_dataloader=train_dataloader,
    test_dataloader=valid_dataloader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
)

# Test final model on test set
print("Final evaluation on test set:")
evaluate_epoch(
    model=model,
    dataloader=test_dataloader,
    criterion=criterion,
    device=device,
    epoch=0,
    split_name="test",
)

## Part 3. Transformers tagger (6 points)

In this part of the task, you need to do the same thing, but using a model based on the Transformer architecture, namely, it is proposed to additionally fine-tune the pre-trained **BERT** model.

This model requires special data preparation, which is where we will start:

The **BERT** model uses a custom WordPiece tokenizer to break sentences into tokens. A pre-trained version of such a tokenizer exists in the `transformers` library. There are two classes: `BertTokenizer` and `BertTokenizerFast`. You can use either one, but the second option works much faster because it is written in C programming language.

Tokenizers can be trained from scratch using your own data corpus, or you can load pre-trained ones. Pre-trained tokenizers typically match a pre-trained model configuration that uses the vocabulary from that tokenizer.

We will use a basic pretrained **BERT** configuration for the model and tokenizer.

P.S. Often you have to experiment with models of different architectures, for example **BERT** and **GPT**, so it is convenient to use the `AutoTokenizer` class, which, based on the name of the model, will determine which class is needed to initialize the tokenizer.

In [None]:
from transformers import AutoTokenizer

In [None]:
model_name = "distilbert-base-cased"

Pretrained models and tokenizers are loaded from `huggingface` using the `from_pretrained` constructor.

In this constructor, you can specify either the path to the pretrained tokenizer, or the name of the pretrained configuration, as in our case. `transformers` will load the necessary parameters itself:

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

### Preparing dictionaries

Compared to recurrent models, there is no more need to build a dictionary, since this is already done in advance thanks to tokenizers and the algorithms behind them.

But as before, we will need:
- {**label**}→{**label_idx**}: correspondence between tag and unique index (starts from 0);

We have already implemented this mapping in one of the previous parts of the task.

### Preparing the dataset and loader

We also want to train the model in batches, so we will still need `Dataset`, `Collator` and `DataLoader`.

But we cannot reuse those from the previous parts of the task, since the data processing must be done a little differently using a tokenizer.

Let's write a new custom dataset that will receive as input (the `__init__` method):
- token_seq - list of lists of words/tokens
- label_seq - list of lists of tags

and return two lists from the `__getitem__` method:
- list of text values (`List[str]`) from token indices in the sample
- a list of integer values (`List[int]`) from the indices of the corresponding tags

P.S. Unlike the previous custom dataset, here we return two `Lists` instead of `torch.LongTensor`, since we will transfer the logic for generating a padded batch to `Collator` due to the specifics of the tokenizer - it itself returns an already padded tensor with token indexes, and for tag indexes we will need to do this ourselves, similar to the previous dataset.

**Exercise. Implement the TransformersDataset class. <font color='red'>(1 point)</font>**

In [None]:
class TransformersDataset(torch.utils.data.Dataset):
    """
    Transformers Dataset for NER.
    """

    def __init__(
        self,
        token_seq: List[List[str]],
        label_seq: List[List[str]],
    ):
        self.token_seq = token_seq
        self.label_seq = [self.process_labels(labels, label2idx) for labels in label_seq]

    def __len__(self):
        return len(self.token_seq)

    def __getitem__(
        self,
        idx: int,
    ) -> Tuple[List[str], List[int]]:
        return self.token_seq[idx], self.label_seq[idx]

    @staticmethod
    def process_labels(
        labels: List[str],
        label2idx: Dict[str, int],
    ) -> List[int]:
        """
        Transform list of labels into list of labels' indices.
        """
        return [label2idx[label] for label in labels]

Create three datasets:
- *train_dataset*
- *valid_dataset*
- *test_dataset*

In [None]:
train_dataset = TransformersDataset(
    token_seq=train_token_seq,
    label_seq=train_label_seq,
)
valid_dataset = TransformersDataset(
    token_seq=valid_token_seq,
    label_seq=valid_label_seq,
)
test_dataset = TransformersDataset(
    token_seq=test_token_seq,
    label_seq=test_label_seq,
)

Let's look at what we got:

In [None]:
train_dataset[0]

In [None]:
valid_dataset[0]

In [None]:
test_dataset[0]

In [None]:
assert len(train_dataset) == 14986, "Incorrect train_dataset length"
assert len(valid_dataset) == 3465, "Incorrect valid_dataset length"
assert len(test_dataset) == 3683, "Incorrect test_dataset length"

assert train_dataset[0][0] == ['eu', 'rejects', 'german', 'call', 'to', 'boycott', 'british', 'lamb', '.'], "Malformed train_dataset"
assert train_dataset[0][1] == [3,0,2,0,0,0,2,0,0], "Malformed train_dataset"

assert valid_dataset[0][0] == ['cricket', '-', 'leicestershire', 'take', 'over', 'at', 'top', 'after', 'innings', 'victory', '.'], "Malformed valid_dataset"
assert valid_dataset[0][1] == [0,0,3,0,0,0,0,0,0,0,0], "Malformed valid_dataset"

assert test_dataset[0][0] == ['soccer', '-', 'japan', 'get', 'lucky', 'win', ',', 'china', 'in', 'surprise', 'defeat', '.'], "Malformed test_dataset"
assert test_dataset[0][1] == [0,0,1,0,0,0,0,4,0,0,0,0], "Malformed test_dataset"

print("All tests passed!")

Let's implement a new `Collator`.

The collator will be initialized with 3 arguments:
- tokenizer
- tokenizer parameters in the form of a dictionary (then used as `**kwargs`)
- special token id for tag sequences (value -1)

The `__call__` method takes a batch as input, namely a list of tuples of what is returned from the dataset with `__getitem__` method. In our case, this is a list of tuples of two int64 tensors - `List[Tuple[torch.LongTensor, torch.LongTensor]]`.

At the output we want to get two tensors:
- Padded word/token indexes
- Padded tag indexes

**Exercise. Implement the TransformersCollator class. <font color='red'>(2 points)</font>**

In [None]:
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils_base import BatchEncoding


class TransformersCollator:
    """
    Transformers Collator that handles variable-size sentences.
    """

    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        tokenizer_kwargs: Dict[str, Any],
        label_padding_value: int,
    ):
        self.tokenizer = tokenizer
        self.tokenizer_kwargs = tokenizer_kwargs

        self.label_padding_value = label_padding_value

    def __call__(
        self,
        batch: List[Tuple[List[str], List[int]]],
    ) -> Tuple[torch.LongTensor, torch.LongTensor]:
        tokens, labels = zip(*batch)

        # Use tokenizer to encode tokens
        encoded_tokens = self.tokenizer(
            list(tokens),
            **self.tokenizer_kwargs
        )
        
        # Encode labels to match tokenizer output
        encoded_labels = self.encode_labels(
            encoded_tokens,
            labels,
            self.label_padding_value
        )

        tokens = encoded_tokens
        labels = encoded_labels
        
        # Remove offset_mapping as it's not needed for model input
        tokens.pop("offset_mapping")

        return tokens, labels

    @staticmethod
    def encode_labels(
        tokens: BatchEncoding,
        labels: List[List[int]],
        label_padding_value: int,
    ) -> torch.LongTensor:

        encoded_labels = []

        for doc_labels, doc_offset in zip(labels, tokens.offset_mapping):

            doc_enc_labels = np.ones(len(doc_offset), dtype=int) * label_padding_value
            arr_offset = np.array(doc_offset)

            doc_enc_labels[(arr_offset[:,0] == 0) & (arr_offset[:,1] != 0)] = doc_labels
            encoded_labels.append(doc_enc_labels.tolist())

        return torch.LongTensor(encoded_labels)

In [None]:
tokenizer_kwargs = {
    "is_split_into_words":    True,
    "return_offsets_mapping": True,
    "padding":                True,
    "truncation":             True,
    "max_length":             512,
    "return_tensors":         "pt",
}

In [None]:
collator = TransformersCollator(
    tokenizer=tokenizer,
    tokenizer_kwargs=tokenizer_kwargs,
    label_padding_value=-1,
)

Now you're ready to define the loaders:

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=collator,
)
valid_dataloader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=1,  # for correct metrics measurements leave batch_size=1
    shuffle=False, # for correct metrics measurements leave shuffle=False
    collate_fn=collator,
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,  # for correct metrics measurements leave batch_size=1
    shuffle=False, # for correct metrics measurements leave shuffle=False
    collate_fn=collator,
)

Let's look at what we got:

In [None]:
tokens, labels = next(iter(train_dataloader))

tokens = tokens.to(device)
labels = labels.to(device)

In [None]:
tokens

In [None]:
labels

In [None]:
train_tokens, train_labels = next(iter(
    torch.utils.data.DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=False,
        collate_fn=collator,
    )
))
assert torch.equal(
    train_tokens['input_ids'],
    torch.tensor([[101, 174, 1358, 22961, 176, 14170, 1840, 1106, 21423, 9304, 10721, 1324, 2495, 12913, 119, 102],
                  [101, 11109, 1200, 1602, 6715, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
                )), "Looks like a bug in the collator"
assert torch.equal(
    train_tokens['attention_mask'],
    torch.tensor([
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ])), "Looks like a bug in the collator"
assert torch.equal(
    train_labels,
    torch.tensor([
        [-1, 3, -1, 0, 2, -1, 0, 0, 0, 2, -1, -1, 0, -1, 0, -1],
        [-1, 4, -1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
    ])), "Looks like a bug in the collator"

valid_tokens, valid_labels = next(iter(
    torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=2,
        shuffle=False,
        collate_fn=collator,
    )
))
assert torch.equal(
    valid_tokens['input_ids'],
    torch.tensor([
        [101, 5428, 118, 5837, 18117, 5759, 15189, 1321, 1166, 1120, 1499, 1170, 6687, 2681, 119, 102],
        [101, 25338, 17996, 1820, 118, 4775, 118, 1476, 102, 0, 0, 0, 0, 0, 0, 0]
    ])), "Looks like a bug in the collator"
assert torch.equal(
    valid_tokens['attention_mask'],
    torch.tensor([
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
    ])), "Looks like a bug in the collator"
assert torch.equal(
    valid_labels,
    torch.tensor([
        [-1,  0,  0,  3, -1, -1, -1,  0,  0,  0,  0,  0,  0,  0,  0, -1],
        [-1,  1, -1,  0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
    ])), "Looks like a bug in the collator"

test_tokens, test_labels = next(iter(
    torch.utils.data.DataLoader(
        test_dataset,
        batch_size=2,
        shuffle=False,
        collate_fn=collator,
    )
))
assert torch.equal(
    test_tokens['input_ids'],
    torch.tensor([
        [101, 5862, 118, 179, 26519, 1179, 1243, 6918, 1782, 117, 5144, 1161, 1107, 3774, 3326, 119, 102],
        [101, 9468, 3309, 1306, 19122, 2293, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ])), "Looks like a bug in the collator"
assert torch.equal(
    test_tokens['attention_mask'],
    torch.tensor([
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ])), "Looks like a bug in the collator"
assert torch.equal(
    test_labels,
    torch.tensor([
        [-1,  0,  0,  1, -1, -1,  0,  0,  0,  0,  4, -1,  0,  0,  0,  0, -1],
        [-1,  4, -1, -1,  8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
    ])), "Looks like a bug in the collator"

print("All tests passed!")

The **transformers** library contains classes for the BERT model, already customized to solve specific problems, with corresponding classification heads. For the NER task we will use the `BertForTokenClassification` class.

By analogy with tokenizers, we can use the `AutoModelForTokenClassification` class, which, based on the name of the model, will determine which class is needed to initialize the model.

In [None]:
from transformers import AutoModelForTokenClassification

In [None]:
model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=len(label2idx),
).to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [None]:
outputs = model(**tokens)

In [None]:
assert 2 < criterion(outputs["logits"].transpose(1, 2), labels) < 3

print("All tests passed!")

In [None]:
# let's create a SummaryWriter for experimenting with BiLSTMModel

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir=f"logs/Transformer")

### Experiments

Run experiments on the data. Adjust parameters based on the validation set without using the test set. Your goal is to configure the network so that the quality of the model according to the F1-macro measure on the validation and test sets is no less than **0.9**.

Draw conclusions about model quality, overfitting, and sensitivity of the architecture to the choice of hyperparameters. Present the results of your experiments in the form of a mini-report (in the same ipython notebook).

You can use the same train function as before, except that instead of `model(tokens)` inference you need to do `model(**tokens)`, and instead of `outputs` you use `outputs["logits"].transpose(1, 2)`

**Exercise. Conduct experiments.** **<font color='red'>(2 points)</font>**

In [None]:
# Start a new wandb run for Transformer experiment
wandb.finish()  # End previous run
wandb.init(
    project="ner-homework",
    name="transformer-experiment",
    config={
        "model_type": "DistilBERT",
        "learning_rate": 5e-5,
        "batch_size": 8,
        "epochs": 3,
        "max_length": 512
    }
)

# Experiment with Transformer model
from transformers import AutoModelForTokenClassification, AutoTokenizer

model_name = "distilbert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=len(label2idx),
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)

# Create datasets for transformers
train_dataset = TransformersDataset(
    token_seq=train_token_seq,
    label_seq=train_label_seq,
)
valid_dataset = TransformersDataset(
    token_seq=valid_token_seq,
    label_seq=valid_label_seq,
)
test_dataset = TransformersDataset(
    token_seq=test_token_seq,
    label_seq=test_label_seq,
)

# Create collator
tokenizer_kwargs = {
    "is_split_into_words":    True,
    "return_offsets_mapping": True,
    "padding":                True,
    "truncation":             True,
    "max_length":             512,
    "return_tensors":         "pt",
}

collator = TransformersCollator(
    tokenizer=tokenizer,
    tokenizer_kwargs=tokenizer_kwargs,
    label_padding_value=-1,
)

# Create data loaders
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=collator,
)
valid_dataloader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collator,
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collator,
)

# Modified training functions for transformer
def train_epoch_transformer(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device,
    epoch: int,
) -> None:
    """
    One training cycle (loop) for transformer model.
    """

    model.train()

    epoch_loss = []
    batch_metrics_list = defaultdict(list)

    for i, (tokens, labels) in tqdm(
        enumerate(dataloader),
        total=len(dataloader),
        desc="loop over train batches",
    ):

        tokens = {k: v.to(device) for k, v in tokens.items()}
        labels = labels.to(device)

        # Loss calculation and optimizer step
        optimizer.zero_grad()
        
        outputs = model(**tokens)
        loss = criterion(outputs["logits"].transpose(1, 2), labels)
        
        loss.backward()
        optimizer.step()

        epoch_loss.append(loss.item())
        
        # Log to wandb every 50 batches
        if i % 50 == 0:
            wandb.log({
                "train/batch_loss": loss.item(),
                "train/epoch": epoch,
                "train/batch": epoch * len(dataloader) + i
            })

        with torch.no_grad():
            model.eval()
            outputs_inference = model(**tokens)
            model.train()

        batch_metrics = compute_metrics(
            outputs=outputs_inference["logits"].transpose(1, 2),
            labels=labels,
        )

        for metric_name, metric_value in batch_metrics.items():
            batch_metrics_list[metric_name].append(metric_value)

    avg_loss = np.mean(epoch_loss)
    print(f"Train loss: {avg_loss}\n")
    
    # Log epoch metrics to wandb
    wandb_log = {"train/loss": avg_loss, "epoch": epoch}

    for metric_name, metric_value_list in batch_metrics_list.items():
        metric_value = np.mean(metric_value_list)
        print(f"Train {metric_name}: {metric_value}\n")
        wandb_log[f"train/{metric_name}"] = metric_value
    
    wandb.log(wandb_log)

def evaluate_epoch_transformer(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    criterion: torch.nn.Module,
    device: torch.device,
    epoch: int,
    split_name: str = "val",
) -> None:
    """
    One evaluation cycle (loop) for transformer model.
    """

    model.eval()

    epoch_loss = []
    batch_metrics_list = defaultdict(list)

    with torch.no_grad():

        for i, (tokens, labels) in tqdm(
            enumerate(dataloader),
            total=len(dataloader),
            desc=f"loop over {split_name} batches",
        ):

            tokens = {k: v.to(device) for k, v in tokens.items()}
            labels = labels.to(device)

            # Loss calculation
            outputs = model(**tokens)
            loss = criterion(outputs["logits"].transpose(1, 2), labels)

            epoch_loss.append(loss.item())

            batch_metrics = compute_metrics(
                outputs=outputs["logits"].transpose(1, 2),
                labels=labels,
            )

            for metric_name, metric_value in batch_metrics.items():
                batch_metrics_list[metric_name].append(metric_value)

        avg_loss = np.mean(epoch_loss)
        print(f"{split_name.capitalize()} loss:  {avg_loss}\n")
        
        # Log epoch metrics to wandb
        wandb_log = {f"{split_name}/loss": avg_loss, "epoch": epoch}

        for metric_name, metric_value_list in batch_metrics_list.items():
            metric_value = np.mean(metric_value_list)
            print(f"{split_name.capitalize()} {metric_name}: {metric_value}\n")
            wandb_log[f"{split_name}/{metric_name}"] = metric_value
        
        wandb.log(wandb_log)

def train_transformer(
    n_epochs: int,
    model: torch.nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    test_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device,
) -> None:
    """
    Training loop for transformer.
    """

    for epoch in range(n_epochs):

        print(f"Epoch [{epoch+1} / {n_epochs}]\n")

        train_epoch_transformer(
            model=model,
            dataloader=train_dataloader,
            optimizer=optimizer,
            criterion=criterion,
            device=device,
            epoch=epoch,
        )
        evaluate_epoch_transformer(
            model=model,
            dataloader=test_dataloader,
            criterion=criterion,
            device=device,
            epoch=epoch,
            split_name="val",
        )

print("Starting Transformer experiment...")
print(f"Model: {model_name}, lr=5e-5, batch_size=8")

train_transformer(
    n_epochs=3,
    model=model,
    train_dataloader=train_dataloader,
    test_dataloader=valid_dataloader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
)

# Test final model on test set
print("Final evaluation on test set:")
evaluate_epoch_transformer(
    model=model,
    dataloader=test_dataloader,
    criterion=criterion,
    device=device,
    epoch=0,
    split_name="test",
)

## Part 4 - Bonus. BiLSTMAttention-tagger (2 points)

You need to carry out the same experiments as in part 2, but using the improved BiLSTM tagger architecture with the Attention mechanism.

**Please note** that you do not need to implement Attention yourself; you can use `torch.nn.MultiheadAttention`.

Also draw conclusions about model quality, overfitting, sensitivity of the architecture to the choice of hyperparameters, and do a little comparative analysis with the previous architecture. Present the results of your experiments in the form of a mini-report (in the same ipython notebook).

**Exercise. Implement the model class BiLSTMAttn.** **<font color='red'>(1 point)</font>**

In [None]:
class BiLSTMAttn(torch.nn.Module):
    """
    Bidirectional LSTM with Attention architecture.
    """

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        hidden_size: int,
        num_layers: int,
        dropout: float,
        bidirectional: bool,
        n_classes: int,
        num_attention_heads: int = 8,
    ):
        super().__init__()

        self.embedding = torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
        
        self.rnn = torch.nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional,
            batch_first=True
        )
        
        # Calculate input size for attention
        rnn_output_size = hidden_size * 2 if bidirectional else hidden_size
        
        # Multi-head attention layer
        self.attention = torch.nn.MultiheadAttention(
            embed_dim=rnn_output_size,
            num_heads=num_attention_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Layer normalization
        self.layer_norm = torch.nn.LayerNorm(rnn_output_size)
        
        self.head = torch.nn.Linear(rnn_output_size, n_classes)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, tokens: torch.LongTensor) -> torch.Tensor:
        embed = self.embedding(tokens)

        # we use the special function pack_padded_sequence in order to obtain a PackedSequence structure
        # that does not take padding into account when passing rnn
        length = (tokens != 0).sum(dim=1).detach().cpu()
        packed_embed = torch.nn.utils.rnn.pack_padded_sequence(
            embed, length, batch_first=True, enforce_sorted=False
        )

        # we use the special function pad_packed_sequence to get a tensor from PackedSequence
        packed_rnn_output, _ = self.rnn(packed_embed)
        rnn_output, _ = torch.nn.utils.rnn.pad_packed_sequence(
            packed_rnn_output, batch_first=True
        )

        # Create attention mask to ignore padding tokens
        attention_mask = (tokens == 0)  # True for padding tokens
        
        # Apply self-attention
        attn_output, _ = self.attention(
            query=rnn_output,
            key=rnn_output,
            value=rnn_output,
            key_padding_mask=attention_mask
        )
        
        # Residual connection and layer normalization
        output = self.layer_norm(rnn_output + attn_output)
        
        # Apply dropout
        output = self.dropout(output)
        
        logits = self.head(output)
        return logits.transpose(1, 2)

**Exercise. Conduct experiments and beat the metric value from part 2.** **<font color='red'>(1 point)</font>**

P.S. If quality didn't increase, this needs to be justified.

In [None]:
# Start a new wandb run for BiLSTM + Attention experiment
wandb.finish()  # End previous run
wandb.init(
    project="ner-homework",
    name="bilstm-attention-experiment",
    config={
        "model_type": "BiLSTM+Attention",
        "embedding_dim": 100,
        "hidden_size": 128,
        "num_layers": 2,
        "dropout": 0.1,
        "bidirectional": True,
        "attention_heads": 8,
        "learning_rate": 1e-3,
        "batch_size": 32,
        "epochs": 5
    }
)

# Experiment with BiLSTM + Attention
model_attn = BiLSTMAttn(
    num_embeddings=len(token2idx),
    embedding_dim=100,
    hidden_size=128,
    num_layers=2,
    dropout=0.1,
    bidirectional=True,
    n_classes=len(label2idx),
    num_attention_heads=8,
).to(device)

optimizer_attn = torch.optim.Adam(model_attn.parameters(), lr=1e-3)
criterion_attn = torch.nn.CrossEntropyLoss(ignore_index=-1)

# Recreate original dataloaders for BiLSTM experiments
train_dataset_lstm = NERDataset(
    token_seq=train_token_seq,
    label_seq=train_label_seq,
    token2idx=token2idx,
    label2idx=label2idx,
)
valid_dataset_lstm = NERDataset(
    token_seq=valid_token_seq,
    label_seq=valid_label_seq,
    token2idx=token2idx,
    label2idx=label2idx,
)
test_dataset_lstm = NERDataset(
    token_seq=test_token_seq,
    label_seq=test_label_seq,
    token2idx=token2idx,
    label2idx=label2idx,
)

collator_lstm = NERCollator(
    token_padding_value=token2idx["<PAD>"],
    label_padding_value=-1,
)

train_dataloader_lstm = torch.utils.data.DataLoader(
    train_dataset_lstm,
    batch_size=32,
    shuffle=True,
    collate_fn=collator_lstm,
)
valid_dataloader_lstm = torch.utils.data.DataLoader(
    valid_dataset_lstm,
    batch_size=1,
    shuffle=False,
    collate_fn=collator_lstm,
)
test_dataloader_lstm = torch.utils.data.DataLoader(
    test_dataset_lstm,
    batch_size=1,
    shuffle=False,
    collate_fn=collator_lstm,
)

print("Starting BiLSTM + Attention experiment...")
print(f"Model parameters: embedding_dim=100, hidden_size=128, num_layers=2, dropout=0.1, lr=1e-3, attention_heads=8")

train(
    n_epochs=5,
    model=model_attn,
    train_dataloader=train_dataloader_lstm,
    test_dataloader=valid_dataloader_lstm,
    optimizer=optimizer_attn,
    criterion=criterion_attn,
    device=device,
)

# Test final model on test set
print("Final evaluation on test set (BiLSTM + Attention):")
evaluate_epoch(
    model=model_attn,
    dataloader=test_dataloader_lstm,
    criterion=criterion_attn,
    device=device,
    epoch=0,
    split_name="test",
)

# Final cleanup
wandb.finish()

print("\n=== EXPERIMENT SUMMARY ===")
print("1. BiLSTM: Basic bidirectional LSTM with dropout")
print("2. Transformer (DistilBERT): Pre-trained transformer model fine-tuned for NER")
print("3. BiLSTM + Attention: BiLSTM enhanced with multi-head self-attention")
print("\nExpected results:")
print("- Transformer should achieve highest F1-macro (>0.9)")
print("- BiLSTM + Attention should outperform basic BiLSTM") 
print("- Basic BiLSTM should achieve F1-macro around 0.76")
print("\nAll experiments have been logged to Weights & Biases.")
print("Check your wandb dashboard for detailed metrics and visualizations!")