## Imports

In [None]:
import wandb
token="dbc1309d957850844048db2d2add36cbefe62171"
wandb.login(key=token)

In [None]:
!pip install seqeval

In [None]:
from datasets import get_dataset_config_names
from datasets import load_dataset
from collections import defaultdict, Counter
from datasets import DatasetDict
import pandas as pd
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader

from transformers import XLMRobertaConfig, AutoConfig, AutoTokenizer, DataCollatorForTokenClassification, Trainer, TrainingArguments, EarlyStoppingCallback
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.roberta.modeling_roberta import RobertaModel
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from torch.nn.functional import cross_entropy
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from transformers import TrainerCallback

import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from datasets import concatenate_datasets


## Data Preparation and Model Building

In [None]:
def get_data():
    
    langs = ["de", "fr", "it", "en"]
    panx_ch = defaultdict(DatasetDict)
    
    for lang in langs:
        ds = load_dataset("xtreme", name=f"PAN-X.{lang}")
        for split in ds: 
            panx_ch[lang][split] = ds[split].shuffle(seed=0)
        
    return panx_ch

panx_ch = get_data()

In [None]:
tags = panx_ch["de"]["train"].features["ner_tags"].feature

In [None]:
def create_tag_names(batch):
 return {"ner_tags_str": [tags.int2str(idx) for idx in batch["ner_tags"]]}
panx_de = panx_ch["de"].map(create_tag_names)

In [None]:
class XLMRobertaForTokenClassification(RobertaPreTrainedModel):
    """
    A token classification model based on XLM-Roberta.

    This model is designed for token-level tasks such as Named Entity Recognition (NER). 
    It uses XLM-Roberta as the backbone and adds a classification head on top for predicting 
    token labels.

    Attributes:
        config_class: The configuration class for XLM-Roberta.
        num_labels (int): Number of labels for the classification task.
        roberta (RobertaModel): The XLM-Roberta model without the pooling layer.
        dropout (nn.Dropout): Dropout layer for regularization.
        classifier (nn.Linear): A linear layer for mapping hidden states to label logits.

    Args:
        config (XLMRobertaConfig): Configuration object containing the model's parameters.

    Methods:
        forward(input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
            Forward pass of the model.

            Args:
                input_ids (torch.Tensor): Tensor of input token IDs of shape `(batch_size, sequence_length)`.
                attention_mask (torch.Tensor, optional): Mask tensor of shape `(batch_size, sequence_length)` 
                                                         indicating which tokens to attend to.
                token_type_ids (torch.Tensor, optional): Tensor of shape `(batch_size, sequence_length)` 
                                                        specifying token types (not typically used in Roberta models).
                labels (torch.Tensor, optional): Tensor of shape `(batch_size, sequence_length)` containing 
                                                 the true labels for each token.

            Returns:
                TokenClassifierOutput: An output object containing:
                    - `loss` (torch.Tensor, optional): The computed loss if `labels` are provided.
                    - `logits` (torch.Tensor): Logits of shape `(batch_size, sequence_length, num_labels)`.
                    - `hidden_states` (tuple, optional): Hidden states from the backbone model.
                    - `attentions` (tuple, optional): Attention weights from the backbone model.
    """
    
    config_class = XLMRobertaConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()
        
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
        outputs = self.roberta(input_ids, attention_mask=attention_mask,token_type_ids=token_type_ids, **kwargs)
        sequence_output = self.dropout(outputs[0])
        logits = self.classifier(sequence_output)
        loss = None
        if labels is not None:
            loss_fun = nn.CrossEntropyLoss()
            loss = loss_fun(logits.view(-1, self.num_labels), labels.view(-1))
            
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)

In [None]:
index2tag = {idx: tag for idx, tag in enumerate(tags.names)}
tag2index = {tag: idx for idx, tag in enumerate(tags.names)}

In [None]:
xlmr_config = AutoConfig.from_pretrained("xlm-roberta-base", num_labels=tags.num_classes, id2label=index2tag, label2id=tag2index)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
xlmr_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

In [None]:
def tag_text(text, tags, model, tokenizer):
    """
    Tags a given text with predictions from a NER model.

    This function tokenizes the input text, runs it through a NER model, and 
    generates predictions for each token. The predictions are converted into 
    human-readable tag names, and the tokens and their corresponding tags are 
    returned as a pandas DataFrame.

    Args:
        text (str): The input text to be tagged.
        tags: A mapping or object containing tag names (e.g., `tags.names`).
        model: A pre-trained NER model that takes tokenized inputs and 
               outputs logits for each token.
        tokenizer: A tokenizer that splits the input text into tokens 
                   compatible with the model.
    
    Returns:
        pd.DataFrame: A DataFrame containing:
            - "Tokens": List of tokens from the input text.
            - "Tags": Predicted tags corresponding to each token.
    """
    tokens = tokenizer(text).tokens()
    input_ids = xlmr_tokenizer(text, return_tensors="pt").input_ids.to(device)
    outputs = model(input_ids)[0]
    predictions = torch.argmax(outputs, dim=2)
    preds = [tags.names[p] for p in predictions[0].cpu().numpy()]
    return pd.DataFrame([tokens, preds], index=["Tokens", "Tags"])

In [None]:
def tokenize_and_align_labels(examples):
    """
    Tokenizes input sentences and aligns NER labels with tokenized outputs.

    This function uses a tokenizer that supports word-level tokenization and aligns 
    the NER tags to the subword tokenization scheme. It assigns `-100` to subword 
    tokens or special tokens to ensure they are ignored during the loss computation.

    Args:
        examples (dict): A dictionary containing:
            - "tokens" (list of list of str): Sentences represented as lists of tokens.
            - "ner_tags" (list of list of int): Corresponding NER tags for the tokens.

    Returns:
        dict: A dictionary containing:
            - Tokenized inputs (e.g., "input_ids", "attention_mask").
            - "labels": Aligned labels for the tokenized inputs.
    """
    tokenized_inputs = xlmr_tokenizer(examples["tokens"], truncation=True, 
                                      is_split_into_words=True)
    labels = []
    for idx, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=idx)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None or word_idx == previous_word_idx:
                label_ids.append(-100)
            else:
                label_ids.append(label[word_idx])
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
def encode_panx_dataset(corpus):
    """
    Encodes a PAN-X dataset by tokenizing the input sentences and aligning NER labels.

    This function applies the `tokenize_and_align_labels` function to the dataset using 
    batched processing, removing unnecessary columns (e.g., 'tokens', 'ner_tags', 'langs') 
    to prepare the dataset for model training.

    Args:
        corpus (DatasetDict): A `DatasetDict` object containing splits (e.g., 'train', 
                              'validation', 'test') with the features:
                              - "tokens": List of tokens for each sentence.
                              - "ner_tags": NER labels for the tokens.
                              - "langs": Language identifiers.

    Returns:
        DatasetDict: A `DatasetDict` with tokenized inputs and aligned labels, 
                     containing features such as "input_ids", "attention_mask", 
                     and "labels".
    """
    return corpus.map(tokenize_and_align_labels, batched=True, 
                      remove_columns=['tokens', 'ner_tags', 'langs'])

## zero-shot cross-lingual transfer from German

In [None]:
panx_en_encoded = encode_panx_dataset(panx_ch["en"])
panx_fr_encoded = encode_panx_dataset(panx_ch["fr"])
panx_de_encoded = encode_panx_dataset(panx_ch["de"])
panx_it_encoded = encode_panx_dataset(panx_ch["it"])

In [None]:
def align_predictions(predictions, label_ids):
    """
    Aligns model predictions with their corresponding labels, excluding ignored indices.

    This function processes batched model predictions and label IDs, converting them 
    into human-readable tag names while skipping indices marked with `-100` (ignored labels). 
    It ensures that predictions and labels are aligned at the token level.

    Args:
        predictions (numpy.ndarray): Array of shape `(batch_size, seq_len, num_labels)` 
                                     containing the model's logits for each token.
        label_ids (numpy.ndarray): Array of shape `(batch_size, seq_len)` containing 
                                   the true label IDs for each token, with `-100` 
                                   indicating ignored tokens.

    Returns:
        tuple: A pair of lists:
            - preds_list (list of list of str): Predicted tags for each example in the batch.
            - labels_list (list of list of str): True tags for each example in the batch.
    """
    preds = np.argmax(predictions, axis=2)
    batch_size, seq_len = preds.shape
    labels_list, preds_list = [], []
    for batch_idx in range(batch_size):
        example_labels, example_preds = [], []
        for seq_idx in range(seq_len):
            # Ignore label IDs = -100
            if label_ids[batch_idx, seq_idx] != -100:
                example_labels.append(index2tag[label_ids[batch_idx][seq_idx]])
                example_preds.append(index2tag[preds[batch_idx][seq_idx]])
                
        labels_list.append(example_labels)
        preds_list.append(example_preds)
        
    return preds_list, labels_list

In [None]:
def compute_metrics(eval_pred):
    y_pred, y_true = align_predictions(eval_pred.predictions, eval_pred.label_ids)
    return {"f1": f1_score(y_true, y_pred), "precision":precision_score(y_true, y_pred),"recall":recall_score(y_true, y_pred)}

In [None]:
data_collator = DataCollatorForTokenClassification(xlmr_tokenizer)

In [None]:
def model_init():
    return (XLMRobertaForTokenClassification.from_pretrained("xlm-roberta-base", config=xlmr_config).to(device))

In [None]:
def get_f1_score(trainer, dataset):
 return trainer.predict(dataset).metrics["test_f1"]

In [None]:
panx_en_encoded = panx_en_encoded.with_format("torch")
panx_fr_encoded = panx_fr_encoded.with_format("torch")
panx_de_encoded = panx_de_encoded.with_format("torch")
panx_it_encoded = panx_it_encoded.with_format("torch")

In [None]:
en_loaders = {
    "train": DataLoader(panx_en_encoded["train"], batch_size=16, collate_fn=data_collator),
    "validation": DataLoader(panx_en_encoded["validation"], batch_size=16, collate_fn=data_collator),
    "test": DataLoader(panx_en_encoded["test"], batch_size=16, collate_fn=data_collator)
}

fr_loaders = {
    "train": DataLoader(panx_fr_encoded["train"], batch_size=16, collate_fn=data_collator),
    "validation": DataLoader(panx_fr_encoded["validation"], batch_size=16, collate_fn=data_collator),
    "test": DataLoader(panx_fr_encoded["test"], batch_size=16, collate_fn=data_collator)
}

de_loaders = {
    "train": DataLoader(panx_de_encoded["train"], batch_size=16, collate_fn=data_collator),
    "validation": DataLoader(panx_de_encoded["validation"], batch_size=16, collate_fn=data_collator),
    "test": DataLoader(panx_de_encoded["test"], batch_size=16, collate_fn=data_collator)
}

it_loaders = {
    "train": DataLoader(panx_it_encoded["train"], batch_size=16, collate_fn=data_collator),
    "validation": DataLoader(panx_it_encoded["validation"], batch_size=16, collate_fn=data_collator),
    "test": DataLoader(panx_it_encoded["test"], batch_size=16, collate_fn=data_collator)
}

In [None]:
def train_one_epoch(model, data_loaders, optimizer, device):
    model.train()
    epoch_loss = {lang: 0 for lang in data_loaders}

    # For each language, compute loss and track the worst loss
    for lang, loaders in data_loaders.items():
        for batch in tqdm(loaders["train"], desc=f"Training {lang}", leave=False):
            # Move batch to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass through the model
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            epoch_loss[lang] += loss.item()

    # Average the loss for each language over the epoch
    for lang in epoch_loss:
        epoch_loss[lang] /= len(data_loaders[lang]["train"])

    # Identify the language with the worst (highest) loss
    worst_loss_lang = max(epoch_loss, key=epoch_loss.get)  # Get language with the worst loss
    print(f"Worst Loss Language: {worst_loss_lang} with loss {epoch_loss[worst_loss_lang]}")

    # Now perform the backward pass only for the worst loss language
    optimizer.zero_grad()  # Zero the gradients before the backward pass

    # Perform the forward pass for the batch from the worst loss language
    for batch in tqdm(data_loaders[worst_loss_lang]["train"], desc=f"Backward on {worst_loss_lang}", leave=False):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass through the model
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        
        # Backward pass (only for the worst loss batch)
        loss.backward()

    # Perform optimizer step
    optimizer.step()

    return epoch_loss, worst_loss_lang


In [None]:
def validate(model, data_loaders, device):
    model.eval()
    val_loss = {lang: 0 for lang in data_loaders}

    with torch.no_grad():
        for lang, loaders in data_loaders.items():
            for batch in tqdm(loaders["validation"], desc=f"Validating {lang}", leave=False):
                # Move batch to device
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                # Forward pass through the model
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                val_loss[lang] += loss.item()

    # Calculate average validation loss
    for lang in val_loss:
        val_loss[lang] /= len(data_loaders[lang]["validation"])

    return val_loss


In [None]:
def train_loop(model, data_loaders, optimizer, device, epochs=5):
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")

        # Train for one epoch
        train_loss, worst_loss_lang = train_one_epoch(model, data_loaders, optimizer, device)
        print(f"Train Loss per Language: {train_loss}")
        print(f"Worst loss at {worst_loss_lang}.")

        # Perform validation after each epoch
        val_loss = validate(model, data_loaders, device)
        print(f"Validation Loss per Language: {val_loss}")


    return model  


In [None]:
data_loaders = {
    "en": en_loaders,
    "fr": fr_loaders,
    "de": de_loaders,
    "it": it_loaders
}
model = model_init()
optimizer = Adam(model.parameters(), lr=1e-5) 

In [None]:
trained_model = train_loop(model, data_loaders, optimizer, device, epochs=100)

In [None]:
def concatenate_splits(corpora):
    multi_corpus = DatasetDict()
    for split in corpora[0].keys():
        multi_corpus[split] = concatenate_datasets([corpus[split] for corpus in corpora]).shuffle(seed=42)
    return multi_corpus

In [None]:
corpora_encoded = concatenate_splits([panx_en_encoded, 
panx_fr_encoded, 
panx_de_encoded, 
panx_it_encoded])

In [None]:
def forward_pass_with_label(batch):
    # Convert dict of lists to list of dicts suitable for data collator
    features = [dict(zip(batch, t)) for t in zip(*batch.values())]
    # Pad inputs and labels and put all tensors on device
    batch = data_collator(features)
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels = batch["labels"].to(device)
    with torch.no_grad():
        # Pass data through model
        output = trained_model(input_ids, attention_mask)
        # logit.size: [batch_size, sequence_length, classes]
        # Predict class with largest logit value on classes axis
        predicted_label = torch.argmax(output.logits, axis=-1).cpu().numpy()
    return {"predicted_label": predicted_label}

In [None]:
test_set_corpa = corpora_encoded["test"]
test_set_corpa = test_set_corpa.map(forward_pass_with_label, batched=True, batch_size=32)
df = test_set_corpa.to_pandas()

In [None]:
index2tag[-100] = "IGN"
df["input_tokens"] = df["input_ids"].apply(
 lambda x: xlmr_tokenizer.convert_ids_to_tokens(x))
df["predicted_label"] = df["predicted_label"].apply(
 lambda x: [index2tag[i] for i in x])
df["labels"] = df["labels"].apply(
 lambda x: [index2tag[i] for i in x])
df['predicted_label'] = df.apply(
 lambda x: x['predicted_label'][:len(x['input_ids'])], axis=1)
df.head(1)
df_tokens = df.apply(pd.Series.explode)
df_tokens = df_tokens.query("labels != 'IGN'")
df_tokens.head(7)

In [None]:
def plot_confusion_matrix(y_preds, y_true, labels):
    cm = confusion_matrix(y_true, y_preds, normalize="true")
    fig, ax = plt.subplots(figsize=(6, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    disp.plot(cmap="Blues", values_format=".2f", ax=ax, colorbar=False)
    plt.title("Normalized confusion matrix")
    plt.show()
    
plot_confusion_matrix(df_tokens["labels"], df_tokens["predicted_label"],tags.names)