In [13]:
from datasets import load_dataset
from datasets import load_metric
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import AutoModelForQuestionAnswering
from functools import partial
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from datasets import load_metric
import torch
from helper import collate_fn, get_train_features, get_validation_features, post_process_predictions, predict, val_collate_fn
from lab6_train import *

lr = 2e-5
n_epochs = 3
weight_decay = 0.01
warmup_steps = 200
bert_map = {
    'bengali': 'xlm-roberta-base', 
    'english': 'xlm-roberta-base', 
    'indonesian': 'xlm-roberta-base', 
    'arabic': 'xlm-roberta-base'
}
device = 'cuda'

compute_squad = load_metric("squad_v2")
dataset = load_dataset("copenlu/answerable_tydiqa")

for split in dataset.keys():
    dataset[split] = dataset[split].add_column('id', list(range(len(dataset[split]))))

for language, bert in list(bert_map.items()):
    print(f'Language: {language}')
    language_dataset = dataset.filter(lambda example: example['language'] == language)
    tk = AutoTokenizer.from_pretrained(bert, max_len=300)

    tokenized_train_dataset = language_dataset['train'].map(partial(get_train_features, tk), batched=True, remove_columns=language_dataset['train'].column_names)
    tokenized_validation_dataset = language_dataset['validation'].map(partial(get_validation_features, tk), batched=True, remove_columns=language_dataset['validation'].column_names)

    train_dl = DataLoader(tokenized_train_dataset, collate_fn=collate_fn, shuffle=True, batch_size=4)

    model = AutoModelForQuestionAnswering.from_pretrained(bert).to(device)
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay': weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        warmup_steps,
        n_epochs * len(train_dl)
    )

    losses = train(
        model,
        train_dl,
        optimizer,
        scheduler,
        n_epochs,
        device
    )
    torch.save(model, f'{language}_xlm-roberta-base_span_detection.pt')
    model = torch.load(f'{language}_xlm-roberta-base_span_detection.pt')
    for language, bert in list(bert_map.items()):
        language_dataset = dataset.filter(lambda example: example['language'] == language)
        tk = AutoTokenizer.from_pretrained(bert, max_len=300)

        tokenized_train_dataset = language_dataset['train'].map(partial(get_train_features, tk), batched=True, remove_columns=language_dataset['train'].column_names)
        tokenized_validation_dataset = language_dataset['validation'].map(partial(get_validation_features, tk), batched=True, remove_columns=language_dataset['validation'].column_names)

        train_dl = DataLoader(tokenized_train_dataset, collate_fn=collate_fn, shuffle=True, batch_size=4)
        val_dl = DataLoader(tokenized_validation_dataset, collate_fn=val_collate_fn, batch_size=32)
        
        logits = predict(model, val_dl, device)
        predictions = post_process_predictions(language_dataset['validation'], tokenized_validation_dataset, logits)
        formatted_predictions = [{'id': k, 'prediction_text': v, 'no_answer_probability': 0.} for k, v in predictions.items()]
        gold = [{
            'id': example['id'],
            'answers': {
                'text': example['annotations']['answer_text'],
                'answer_start': example['annotations']['answer_start']}
            }
            for example in language_dataset['validation']]


        print(compute_squad.compute(references=gold, predictions=formatted_predictions))


Language: bengali


Downloading (…)lve/main/config.json:   0%|          | 0.00/615 [00:00<?, ?B/s]

Downloading (…)tencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

Map:   0%|          | 0/4779 [00:00<?, ? examples/s]

Map:   0%|          | 0/224 [00:00<?, ? examples/s]

Downloading model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

Some weights of XLMRobertaForQuestionAnswering were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 1486/1486 [03:21<00:00,  7.38it/s]
100%|██████████| 1486/1486 [04:15<00:00,  5.83it/s]
100%|██████████| 1486/1486 [04:17<00:00,  5.77it/s]
Evaluation: 100%|██████████| 9/9 [00:02<00:00,  3.27it/s]
100%|██████████| 224/224 [00:00<00:00, 1121.04it/s]


{'exact': 28.125, 'f1': 34.021687637759065, 'total': 224, 'HasAns_exact': 28.125, 'HasAns_f1': 34.021687637759065, 'HasAns_total': 224, 'best_exact': 28.125, 'best_exact_thresh': 0.0, 'best_f1': 34.021687637759065, 'best_f1_thresh': 0.0}
Language: english


Map:   0%|          | 0/7389 [00:00<?, ? examples/s]

Map:   0%|          | 0/990 [00:00<?, ? examples/s]

Some weights of XLMRobertaForQuestionAnswering were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 2013/2013 [05:12<00:00,  6.45it/s]
100%|██████████| 2013/2013 [05:11<00:00,  6.47it/s]
100%|██████████| 2013/2013 [05:11<00:00,  6.47it/s]
Evaluation: 100%|██████████| 34/34 [00:11<00:00,  2.91it/s]
100%|██████████| 990/990 [00:00<00:00, 1240.64it/s]


{'exact': 25.95959595959596, 'f1': 32.8463579671379, 'total': 990, 'HasAns_exact': 25.95959595959596, 'HasAns_f1': 32.8463579671379, 'HasAns_total': 990, 'best_exact': 25.95959595959596, 'best_exact_thresh': 0.0, 'best_f1': 32.8463579671379, 'best_f1_thresh': 0.0}
Language: indonesian


Map:   0%|          | 0/11394 [00:00<?, ? examples/s]

Map:   0%|          | 0/1191 [00:00<?, ? examples/s]

Some weights of XLMRobertaForQuestionAnswering were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 3003/3003 [07:01<00:00,  7.12it/s]
100%|██████████| 3003/3003 [06:59<00:00,  7.15it/s]
100%|██████████| 3003/3003 [07:06<00:00,  7.05it/s]
Evaluation: 100%|██████████| 40/40 [00:12<00:00,  3.10it/s]
100%|██████████| 1191/1191 [00:00<00:00, 1317.53it/s]


{'exact': 34.17296389588581, 'f1': 39.967987795222456, 'total': 1191, 'HasAns_exact': 34.17296389588581, 'HasAns_f1': 39.967987795222456, 'HasAns_total': 1191, 'best_exact': 34.17296389588581, 'best_exact_thresh': 0.0, 'best_f1': 39.967987795222456, 'best_f1_thresh': 0.0}
Language: arabic


Map:   0%|          | 0/29598 [00:00<?, ? examples/s]

Map:   0%|          | 0/1902 [00:00<?, ? examples/s]

Some weights of XLMRobertaForQuestionAnswering were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 8443/8443 [22:24<00:00,  6.28it/s]
100%|██████████| 8443/8443 [22:14<00:00,  6.33it/s]
100%|██████████| 8443/8443 [22:17<00:00,  6.31it/s]
Evaluation: 100%|██████████| 66/66 [00:23<00:00,  2.84it/s]
100%|██████████| 1902/1902 [00:01<00:00, 1219.61it/s]


{'exact': 30.336487907465827, 'f1': 38.44675862942356, 'total': 1902, 'HasAns_exact': 30.336487907465827, 'HasAns_f1': 38.44675862942356, 'HasAns_total': 1902, 'best_exact': 30.336487907465827, 'best_exact_thresh': 0.0, 'best_f1': 38.44675862942356, 'best_f1_thresh': 0.0}


In [17]:
metrics = {l: {} for l in bert_map.keys()}
for language, bert in list(bert_map.items()):
    print(f'Language: {language}')
    model = torch.load(f'{language}_xlm-roberta-base_span_detection.pt')
    for language2, bert in list(bert_map.items()):
        language_dataset = dataset.filter(lambda example: example['language'] == language2)
        tk = AutoTokenizer.from_pretrained(bert, max_len=300)

        tokenized_train_dataset = language_dataset['train'].map(partial(get_train_features, tk), batched=True, remove_columns=language_dataset['train'].column_names)
        tokenized_validation_dataset = language_dataset['validation'].map(partial(get_validation_features, tk), batched=True, remove_columns=language_dataset['validation'].column_names)

        train_dl = DataLoader(tokenized_train_dataset, collate_fn=collate_fn, shuffle=True, batch_size=4)
        val_dl = DataLoader(tokenized_validation_dataset, collate_fn=val_collate_fn, batch_size=32)
        
        logits = predict(model, val_dl, device)
        predictions = post_process_predictions(language_dataset['validation'], tokenized_validation_dataset, logits)
        formatted_predictions = [{'id': k, 'prediction_text': v, 'no_answer_probability': 0.} for k, v in predictions.items()]
        gold = [{
            'id': example['id'],
            'answers': {
                'text': example['annotations']['answer_text'],
                'answer_start': example['annotations']['answer_start']}
            }
            for example in language_dataset['validation']]


        metrics[language][language2] = compute_squad.compute(references=gold, predictions=formatted_predictions)

Language: bengali


Evaluation: 100%|██████████| 9/9 [00:01<00:00,  5.39it/s]
100%|██████████| 224/224 [00:00<00:00, 1118.94it/s]
Evaluation: 100%|██████████| 34/34 [00:06<00:00,  5.10it/s]
100%|██████████| 990/990 [00:00<00:00, 1250.43it/s]
Evaluation: 100%|██████████| 40/40 [00:07<00:00,  5.16it/s]
100%|██████████| 1191/1191 [00:00<00:00, 1339.36it/s]
Evaluation: 100%|██████████| 66/66 [00:15<00:00,  4.23it/s]
100%|██████████| 1902/1902 [00:01<00:00, 1184.38it/s]


Language: english


Evaluation: 100%|██████████| 9/9 [00:01<00:00,  5.32it/s]
100%|██████████| 224/224 [00:00<00:00, 1132.63it/s]
Evaluation: 100%|██████████| 34/34 [00:08<00:00,  4.20it/s]
100%|██████████| 990/990 [00:00<00:00, 1249.72it/s]
Evaluation: 100%|██████████| 40/40 [00:09<00:00,  4.29it/s]
100%|██████████| 1191/1191 [00:00<00:00, 1326.67it/s]
Evaluation: 100%|██████████| 66/66 [00:17<00:00,  3.67it/s]
100%|██████████| 1902/1902 [00:01<00:00, 1232.29it/s]


Language: indonesian


Evaluation: 100%|██████████| 9/9 [00:01<00:00,  5.18it/s]
100%|██████████| 224/224 [00:00<00:00, 1122.07it/s]
Evaluation: 100%|██████████| 34/34 [00:09<00:00,  3.75it/s]
100%|██████████| 990/990 [00:00<00:00, 1231.18it/s]
Evaluation: 100%|██████████| 40/40 [00:10<00:00,  3.95it/s]
100%|██████████| 1191/1191 [00:00<00:00, 1332.25it/s]
Evaluation: 100%|██████████| 66/66 [00:19<00:00,  3.41it/s]
100%|██████████| 1902/1902 [00:01<00:00, 1224.58it/s]


Language: arabic


Evaluation: 100%|██████████| 9/9 [00:01<00:00,  4.75it/s]
100%|██████████| 224/224 [00:00<00:00, 1112.89it/s]
Evaluation: 100%|██████████| 34/34 [00:09<00:00,  3.44it/s]
100%|██████████| 990/990 [00:00<00:00, 1234.54it/s]
Evaluation: 100%|██████████| 40/40 [00:10<00:00,  3.66it/s]
100%|██████████| 1191/1191 [00:00<00:00, 1310.22it/s]
Evaluation: 100%|██████████| 66/66 [00:20<00:00,  3.23it/s]
100%|██████████| 1902/1902 [00:01<00:00, 1226.49it/s]


In [27]:
# Pretty print metrics
print("Exact match scores, xlm-r span detection:")
print(' & '.join(['tuned lan'] + list(bert_map.keys())) + '\\\\')
for language in bert_map.keys():
    print(' & '.join([language] + [f'{metrics[language][language2]["exact"]:.2f}' for language2 in bert_map.keys()]) + '\\\\')
print(' & '.join(['Average'] + [f'{sum([metrics[language][language2]["exact"] for language2 in bert_map.keys()]) / len(bert_map.keys()):.2f}' for language in bert_map.keys()]) + '\\\\')



Exact match scores, xlm-r span detection:
tuned lan & bengali & english & indonesian & arabic\\
bengali & 28.12 & 21.31 & 20.49 & 13.72\\
english & 15.62 & 25.96 & 27.54 & 20.08\\
indonesian & 17.86 & 26.16 & 34.17 & 25.13\\
arabic & 18.30 & 26.16 & 29.89 & 30.34\\
Average & 20.91 & 22.30 & 25.83 & 26.17\\
