# Train the XLM-r

In [None]:
from datasets import load_dataset
from datasets import load_metric
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import AutoModelForQuestionAnswering
from functools import partial
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from datasets import load_metric
import torch
from helper import collate_fn, get_train_features, get_validation_features, post_process_predictions, predict, val_collate_fn
from lab6_train import *

lr = 2e-5
n_epochs = 3
weight_decay = 0.01
warmup_steps = 200
bert_map = {
    'bengali': 'xlm-roberta-base',
    'english': 'xlm-roberta-base',
    'indonesian': 'xlm-roberta-base',
    'arabic': 'xlm-roberta-base'
}
device = 'cuda'

compute_squad = load_metric("squad_v2")
dataset = load_dataset("copenlu/answerable_tydiqa")

for split in dataset.keys():
    dataset[split] = dataset[split].add_column('id', list(range(len(dataset[split]))))

for language, bert in list(bert_map.items()):
    print(f'Language: {language}')
    language_dataset = dataset.filter(lambda example: example['language'] == language)
    tk = AutoTokenizer.from_pretrained(bert, max_len=300)

    tokenized_train_dataset = language_dataset['train'].map(partial(get_train_features, tk), batched=True, remove_columns=language_dataset['train'].column_names)

    train_dl = DataLoader(tokenized_train_dataset, collate_fn=collate_fn, shuffle=True, batch_size=8)

    model = AutoModelForQuestionAnswering.from_pretrained(bert).to(device)
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay': weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        warmup_steps,
        n_epochs * len(train_dl)
    )

    losses = train(
        model,
        train_dl,
        optimizer,
        scheduler,
        n_epochs,
        device
    )
    torch.save(model, f'{language}_xml_roberta_base_span_detection_2.pt')

# Evaluate Zero-shot multilingual

In [None]:

tk = AutoTokenizer.from_pretrained('xlm-roberta-base', max_len=300)
bert_map = {
    'bengali': 'xlm-roberta-base',
    'english': 'xlm-roberta-base',
    'indonesian': 'xlm-roberta-base',
    'arabic': 'xlm-roberta-base'
}
metrics = {l: {} for l in bert_map.keys()}
datasets = {
    l: dataset.filter(lambda example: example['language'] == l) for l in bert_map.keys()
}
tokenized_validation_datasets = {
    l: language_dataset['validation'].map(partial(get_validation_features, tk), batched=True, remove_columns=language_dataset['validation'].column_names) for l, language_dataset in datasets.items()   
}
for language, bert in list(bert_map.items()):
    print(f'Language: {language}')
    model = torch.load(f'{language}_xml_roberta_base_span_detection_2.pt')
    for language2, bert in list(bert_map.items()):
        language_dataset = datasets[language2]

        tokenized_validation_dataset = tokenized_validation_datasets[language2] 

        val_dl = DataLoader(tokenized_validation_dataset, collate_fn=val_collate_fn, batch_size=32)
        
        logits = predict(model, val_dl, device)
        predictions = post_process_predictions(language_dataset['validation'], tokenized_validation_dataset, logits)
        formatted_predictions = [{'id': k, 'prediction_text': v, 'no_answer_probability': 0.} for k, v in predictions.items()]
        gold = [{
            'id': example['id'],
            'answers': {
                'text': example['annotations']['answer_text'],
                'answer_start': example['annotations']['answer_start']}
            }
            for example in language_dataset['validation']]


        metrics[language][language2] = compute_squad.compute(references=gold, predictions=formatted_predictions)

# Pretty Print it

In [None]:
# Pretty print EM
print("Exact match scores, xlm-r span detection:")
print(' & '.join(['tuned lan'] + list(bert_map.keys())) + '\\\\')
for language in bert_map.keys():
    print(' & '.join([language] + [f'{metrics[language][language2]["exact"]:.2f}' for language2 in bert_map.keys()]) + '\\\\')
print(' & '.join(['Average'] + [f'{sum([metrics[language][language2]["exact"] for language2 in bert_map.keys()]) / len(bert_map.keys()):.2f}' for language in bert_map.keys()]) + '\\\\')

print("\n\n")
# Pretty print F1
print("F1 scores, xlm-r span detection:")
print(' & '.join(['tuned lan'] + list(bert_map.keys())) + '\\\\')
for language in bert_map.keys():
    print(' & '.join([language] + [f'{metrics[language][language2]["f1"]:.2f}' for language2 in bert_map.keys()]) + '\\\\')
print(' & '.join(['Average'] + [f'{sum([metrics[language][language2]["f1"] for language2 in bert_map.keys()]) / len(bert_map.keys()):.2f}' for language in bert_map.keys()]) + '\\\\')
