In [9]:
from datasets import load_dataset
from datasets import load_metric
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import AutoModelForQuestionAnswering, XLMRobertaTokenizer, XLMRobertaForSequenceClassification
from functools import partial
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from datasets import load_metric
import torch
from lab5_helper import collate_fn, get_train_features, get_validation_features, post_process_predictions, predict, val_collate_fn
from lab5_train import *
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

lr = 2e-5
n_epochs = 3
weight_decay = 0.01
warmup_steps = 200
bert_map = {
    'bengali': 'xlm-roberta-base', 
    'english': 'xlm-roberta-base', 
    'indonesian': 'xlm-roberta-base', 
    'arabic': 'xlm-roberta-base'
}
device = 'cuda'

compute_squad = load_metric("squad_v2")
dataset = load_dataset("copenlu/answerable_tydiqa")

for split in dataset.keys():
    dataset[split] = dataset[split].add_column('id', list(range(len(dataset[split]))))

for language, bert in list(bert_map.items()):
    print(f'-------- Base Language: {language} ------------')
    # language_dataset = dataset.filter(lambda example: example['language'] == language)
    # tk = AutoTokenizer.from_pretrained(bert, max_len=300)

    # tokenized_train_dataset = language_dataset['train'].map(partial(get_train_features, tk), batched=True, remove_columns=language_dataset['train'].column_names)
    # tokenized_validation_dataset = language_dataset['validation'].map(partial(get_train_features, tk), batched=True, remove_columns=language_dataset['validation'].column_names)

    # train_dl = DataLoader(tokenized_train_dataset, collate_fn=collate_fn, shuffle=True, batch_size=4)

    # # Load xlm-roberta-base model
    # model = XLMRobertaForSequenceClassification.from_pretrained(bert, num_labels=2).to(device)
    # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    # optimizer_grouped_parameters = [
    #     {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
    #     'weight_decay': weight_decay},
    #     {'params': [p for n, p in model.named_parameters() if any(
    #         nd in n for nd in no_decay)], 'weight_decay': 0.0}
    # ]

    # optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    # scheduler = get_linear_schedule_with_warmup(
    #     optimizer,
    #     warmup_steps,
    #     n_epochs * len(train_dl)
    # )

    # losses = train(
    #     model,
    #     train_dl,
    #     optimizer,
    #     scheduler,
    #     n_epochs,
    #     device
    # )
    # torch.save(model, f'{language}_xlm-roberta-base_classification.pt')
    model = torch.load(f'{language}_xlm-roberta-base_classification.pt')
    for language2, bert in list(bert_map.items()):
        language_dataset = dataset.filter(lambda example: example['language'] == language2)
        tk = AutoTokenizer.from_pretrained(bert, max_len=300)
        
        tokenized_validation_dataset = language_dataset['validation'].map(partial(get_train_features, tk), batched=True, remove_columns=language_dataset['validation'].column_names)  

        val_dl = DataLoader(tokenized_validation_dataset, collate_fn=collate_fn, batch_size=32)
        
        logits = predict(model, val_dl, device)
        # Assuming binary classification; use torch.argmax for multi-class.
        predictions = logits.argmax(dim=1).cpu().numpy()
        ground_truth = []
        for batch in val_dl:
            ground_truth.extend(batch['labels'].numpy())
            
        accuracy = accuracy_score(ground_truth, predictions)
        precision = precision_score(ground_truth, predictions)
        recall = recall_score(ground_truth, predictions)
        f1 = f1_score(ground_truth, predictions)

        print(f"Results for {language2}:")
        print(f"Accuracy: {accuracy}")
        print(f"Precision: {precision}")
        print(f"Recall: {recall}")
        print(f"F1 Score: {f1}")
        print("-------------------------")


-------- Base Language: bengali ------------


Evaluation: 100%|██████████| 9/9 [00:01<00:00,  5.36it/s]


Results for bengali:
Accuracy: 0.5300751879699248
Precision: 0.5300751879699248
Recall: 1.0
F1 Score: 0.6928746928746928
-------------------------


Evaluation: 100%|██████████| 34/34 [00:10<00:00,  3.25it/s]


Results for english:
Accuracy: 0.4911792014856082
Precision: 0.4911792014856082
Recall: 1.0
F1 Score: 0.6587795765877957
-------------------------


Evaluation: 100%|██████████| 40/40 [00:12<00:00,  3.29it/s]


Results for indonesian:
Accuracy: 0.49801113762927607
Precision: 0.49801113762927607
Recall: 1.0
F1 Score: 0.6648964418481148
-------------------------


Evaluation: 100%|██████████| 66/66 [00:21<00:00,  3.00it/s]


Results for arabic:
Accuracy: 0.497377205531712
Precision: 0.497377205531712
Recall: 1.0
F1 Score: 0.6643312101910828
-------------------------
-------- Base Language: english ------------


Evaluation: 100%|██████████| 9/9 [00:02<00:00,  3.51it/s]


Results for bengali:
Accuracy: 0.7443609022556391
Precision: 0.762589928057554
Recall: 0.75177304964539
F1 Score: 0.757142857142857
-------------------------


Evaluation: 100%|██████████| 34/34 [00:11<00:00,  2.95it/s]


Results for english:
Accuracy: 0.8365831012070566
Precision: 0.8058925476603119
Recall: 0.8790170132325141
F1 Score: 0.8408679927667269
-------------------------


Evaluation: 100%|██████████| 40/40 [00:12<00:00,  3.18it/s]


Results for indonesian:
Accuracy: 0.8591885441527446
Precision: 0.8437978560490046
Recall: 0.8801916932907349
F1 Score: 0.8616106333072713
-------------------------


Evaluation: 100%|██████████| 66/66 [00:24<00:00,  2.69it/s]


Results for arabic:
Accuracy: 0.8721983786361469
Precision: 0.8666035950804163
Recall: 0.87823585810163
F1 Score: 0.8723809523809524
-------------------------
-------- Base Language: indonesian ------------


Evaluation: 100%|██████████| 9/9 [00:02<00:00,  3.69it/s]


Results for bengali:
Accuracy: 0.48872180451127817
Precision: 0.6086956521739131
Recall: 0.09929078014184398
F1 Score: 0.17073170731707316
-------------------------


Evaluation: 100%|██████████| 34/34 [00:12<00:00,  2.68it/s]


Results for english:
Accuracy: 0.5041782729805014
Precision: 0.45614035087719296
Recall: 0.04914933837429111
F1 Score: 0.0887372013651877
-------------------------


Evaluation: 100%|██████████| 40/40 [00:14<00:00,  2.83it/s]


Results for indonesian:
Accuracy: 0.771678599840891
Precision: 0.7548872180451128
Recall: 0.8019169329073482
F1 Score: 0.7776917118512781
-------------------------


Evaluation: 100%|██████████| 66/66 [00:26<00:00,  2.51it/s]


Results for arabic:
Accuracy: 0.5517405817835003
Precision: 0.7288888888888889
Recall: 0.15723873441994246
F1 Score: 0.2586750788643533
-------------------------
-------- Base Language: arabic ------------


Evaluation: 100%|██████████| 9/9 [00:02<00:00,  3.44it/s]


Results for bengali:
Accuracy: 0.5300751879699248
Precision: 0.5300751879699248
Recall: 1.0
F1 Score: 0.6928746928746928
-------------------------


Evaluation: 100%|██████████| 34/34 [00:13<00:00,  2.58it/s]


Results for english:
Accuracy: 0.4911792014856082
Precision: 0.4911792014856082
Recall: 1.0
F1 Score: 0.6587795765877957
-------------------------


Evaluation: 100%|██████████| 40/40 [00:14<00:00,  2.76it/s]


Results for indonesian:
Accuracy: 0.49801113762927607
Precision: 0.49801113762927607
Recall: 1.0
F1 Score: 0.6648964418481148
-------------------------


Evaluation: 100%|██████████| 66/66 [00:26<00:00,  2.47it/s]


Results for arabic:
Accuracy: 0.497377205531712
Precision: 0.497377205531712
Recall: 1.0
F1 Score: 0.6643312101910828
-------------------------


In [None]:
metrics = {l: {} for l in bert_map.keys()}
for language, bert in list(bert_map.items()):
    print(f'Language: {language}')
    model = torch.load(f'{language}_xlm-roberta-base_span_detection.pt')
    for language2, bert in list(bert_map.items()):
        language_dataset = dataset.filter(lambda example: example['language'] == language2)
        tk = AutoTokenizer.from_pretrained(bert, max_len=300)

        tokenized_train_dataset = language_dataset['train'].map(partial(get_train_features, tk), batched=True, remove_columns=language_dataset['train'].column_names)
        tokenized_validation_dataset = language_dataset['validation'].map(partial(get_validation_features, tk), batched=True, remove_columns=language_dataset['validation'].column_names)

        train_dl = DataLoader(tokenized_train_dataset, collate_fn=collate_fn, shuffle=True, batch_size=4)
        val_dl = DataLoader(tokenized_validation_dataset, collate_fn=val_collate_fn, batch_size=32)
        
        logits = predict(model, val_dl, device)
        predictions = post_process_predictions(language_dataset['validation'], tokenized_validation_dataset, logits)
        formatted_predictions = [{'id': k, 'prediction_text': v, 'no_answer_probability': 0.} for k, v in predictions.items()]
        gold = [{
            'id': example['id'],
            'answers': {
                'text': example['annotations']['answer_text'],
                'answer_start': example['annotations']['answer_start']}
            }
            for example in language_dataset['validation']]


        metrics[language][language2] = compute_squad.compute(references=gold, predictions=formatted_predictions)

Language: bengali


Map:   0%|          | 0/4779 [00:00<?, ? examples/s]

Map:   0%|          | 0/224 [00:00<?, ? examples/s]

Evaluation:   0%|          | 0/9 [00:00<?, ?it/s]


KeyError: 'start_logits'

In [None]:
# Pretty print metrics
print("Exact match scores, xlm-r span detection:")
print(' & '.join(['tuned lan'] + list(bert_map.keys())) + '\\\\')
for language in bert_map.keys():
    print(' & '.join([language] + [f'{metrics[language][language2]["exact"]:.2f}' for language2 in bert_map.keys()]) + '\\\\')
print(' & '.join(['Average'] + [f'{sum([metrics[language][language2]["exact"] for language2 in bert_map.keys()]) / len(bert_map.keys()):.2f}' for language in bert_map.keys()]) + '\\\\')



Exact match scores, xlm-r span detection:
tuned lan & bengali & english & indonesian & arabic\\
bengali & 28.12 & 21.31 & 20.49 & 13.72\\
english & 15.62 & 25.96 & 27.54 & 20.08\\
indonesian & 17.86 & 26.16 & 34.17 & 25.13\\
arabic & 18.30 & 26.16 & 29.89 & 30.34\\
Average & 20.91 & 22.30 & 25.83 & 26.17\\
