In [413]:
import json

from datasets import Dataset
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, DefaultDataCollator, TrainingArguments, Trainer
from torch.utils.data import DataLoader
from torch.optim import AdamW

from tqdm.auto import tqdm

import numpy as np
import evaluate
import collections


import warnings
warnings.filterwarnings("ignore")

tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-uncased')
model = AutoModelForQuestionAnswering.from_pretrained("bert-base-multilingual-uncased")

Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-b

In [393]:
S_lang2file = {
    'en' : 'tydiqa.en.train.json',
    'fi' : 'tydiqa.fi.train.json',
    'ar' : 'tydiqa.ar.train.json',
    'bn' : 'tydiqa.bn.train.json',
    'id' : 'tydiqa.in.train.json',
    'ko' : 'tydiqa.ko.train.json',
    'ru' : 'tydiqa.ru.train.json',
    'sw' : 'tydiqa.sw.train.json',
    'te' : 'tydiqa.te.train.json',
}

T_lang2file = {
    'en' : 'tydiqa.en.dev.json',
    'fi' : 'tydiqa.fi.dev.json',
    'ar' : 'tydiqa.ar.dev.json',
    'bn' : 'tydiqa.bn.dev.json',
    'id' : 'tydiqa.in.dev.json',
    'ko' : 'tydiqa.ko.dev.json',
    'ru' : 'tydiqa.ru.dev.json',
    'sw' : 'tydiqa.sw.dev.json',
    'te' : 'tydiqa.te.dev.json',
}

accuracy_dict = {} # for storing all the test accuracies in the form { (S,T,SHOT) , Acc }

In [394]:
path = "/Users/rishikesh/Desktop/Project/download/tydiqa/"

def read_data(path):  
    with open(path, 'rb') as f:
        squad = json.load(f)

    contexts = []
    questions = []
    answers = []
    id = []

    for group in squad['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)
                    id.append(qa['id'])

    return contexts, questions, answers, id

In [395]:
# All Languages: en, fi, ar, bn, id, ko, ru, sw, te = 9
# Total Language pairs = 9*9 = 81

SHOT = 0 # 0-shot or few-shot

S = 'en'
T = 'en'


In [396]:
s_path = path + 'tydiqa-goldp-v1.1-train/' + S_lang2file[S]
t_path = path + 'tydiqa-goldp-v1.1-dev/' + T_lang2file[T]

s_context, s_q, s_a, s_i = read_data(s_path)
t_context, t_q, t_a, t_i = read_data(t_path)

if SHOT>0:
    few_shot_path = path + 'tydiqa-goldp-v1.1-train/' + S_lang2file[T]
    fs_context, fs_q, fs_a, fs_i = read_data(few_shot_path)

In [397]:
s_tydi = []
for _ in range(len(s_a)):
    s_tydi.append({})
    s_tydi[_]['answers'] = s_a[_]
    s_tydi[_]['context'] = s_context[_]
    s_tydi[_]['question'] = s_q[_]
    s_tydi[_]['id'] = s_i[_]

    
if SHOT>0:
    for _ in range(SHOT):
        s_tydi.append({})
        s_tydi[len(s_tydi) - 1]['answers'] = fs_a[_]
        s_tydi[len(s_tydi) - 1]['context'] = fs_context[_]
        s_tydi[len(s_tydi) - 1]['question'] = fs_q[_]
        s_tydi[len(s_tydi) - 1]['id'] = fs_i[_]

s_data = Dataset.from_list(s_tydi)

t_tydi = []
for _ in range(len(t_a)):
    t_tydi.append({})
    t_tydi[_]['answers'] = t_a[_]
    t_tydi[_]['context'] = t_context[_]
    t_tydi[_]['question'] = t_q[_] 
    t_tydi[_]['id'] = t_i[_] 
t_data = Dataset.from_list(t_tydi)

In [398]:
# def preprocess_function(examples):
#     questions = [q.strip() for q in examples["question"]]
#     inputs = tokenizer(
#         questions,
#         examples["context"],
#         max_length=400,
#         truncation="only_second",
#         return_overflowing_tokens=True,
#         return_offsets_mapping=True,
#         padding="max_length",
#     )

#     sample_mapping = inputs.pop("overflow_to_sample_mapping")
#     offset_mapping = inputs.pop("offset_mapping")
#     inputs["start_positions"] = []
#     inputs["end_positions"] = []

#     for i, offset in enumerate(offset_mapping):
#         input_ids = inputs["input_ids"][i]
#         cls_index = input_ids.index(tokenizer.cls_token_id)
#         sequence_ids = inputs.sequence_ids(i)
#         sample_index = sample_mapping[i]
#         answers = examples['answers'][sample_index]
        
#         if answers["answer_start"] == '':
#             inputs["start_positions"].append(cls_index)
#             inputs["end_positions"].append(cls_index)
#         else:
#             # Start/end character index of the answer in the text.
#             start_char = answers["answer_start"]
#             end_char = start_char + len(answers["text"])

#             # Start token index of the current span in the text.
#             token_start_index = 0
#             while sequence_ids[token_start_index] != 1:
#                 token_start_index += 1

#             # End token index of the current span in the text.
#             token_end_index = len(input_ids) - 1
#             while sequence_ids[token_end_index] != 1:
#                 token_end_index -= 1

#             # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
#             if not (offset[token_start_index][0] <= start_char and offset[token_end_index][1] >= end_char):
#                 inputs["start_positions"].append(cls_index)
#                 inputs["end_positions"].append(cls_index)
#             else:
#                 # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
#                 # Note: we could go after the last offset if the answer is the last word (edge case).
#                 while token_start_index < len(offset) and offset[token_start_index][0] <= start_char:
#                     token_start_index += 1
#                 inputs["start_positions"].append(token_start_index - 1)
#                 while offset[token_end_index][1] >= end_char:
#                     token_end_index -= 1
#                 inputs["end_positions"].append(token_end_index + 1)

#     return inputs  

In [399]:
max_length = 384
stride = 128


def preprocess_training_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"]
        end_char = answer["answer_start"] + len(answer["text"])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [400]:
train_dataset = s_data.map(
    preprocess_training_examples,
    batched=True,
    remove_columns=s_data.column_names,
)

Map:   0%|          | 0/3696 [00:00<?, ? examples/s]

(3696, 3804)

In [401]:
# def prepare_validation_features(examples):
#     # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
#     # in one example possible giving several features when a context is long, each of those features having a
#     # context that overlaps a bit the context of the previous feature.
#     tokenized_examples = tokenizer(
#         examples["question"],
#         examples["context"],
#         truncation="only_second",
#         max_length=400,
#         return_overflowing_tokens=True,
#         return_offsets_mapping=True,
#         padding="max_length"
#     )

#     # Since one example might give us several features if it has a long context, we need a map from a feature to
#     # its corresponding example. This key gives us just that.
#     sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

#     # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
#     # corresponding example_id and we will store the offset mappings.

#     tokenized_examples["example_id"] = []

#     for i in range(len(tokenized_examples["input_ids"])):
#         # Grab the sequence corresponding to that example (to know what is the context and what is the question).
#         sequence_ids = tokenized_examples.sequence_ids(i)
#         context_index = 1

#         # One example can give several spans, this is the index of the example containing this span of text.
#         sample_index = sample_mapping[i]
#         tokenized_examples["example_id"].append(examples["id"][sample_index])

#         # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
#         # position is part of the context or not.
#         tokenized_examples["offset_mapping"][i] = [
#             (o if sequence_ids[k] == context_index else None)
#             for k, o in enumerate(tokenized_examples["offset_mapping"][i])
#         ]

#     return tokenized_examples


In [402]:
def preprocess_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

In [403]:
validation_dataset = t_data.map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=t_data.column_names,
)

Map:   0%|          | 0/440 [00:00<?, ? examples/s]

(440, 447)

In [404]:
# tokenized_s_data = s_data.map(preprocess_function, batched=True, batch_size=32, remove_columns=s_data.column_names)
# tokenized_t_data = t_data.map(prepare_validation_features, batched=True, batch_size=32, remove_columns=t_data.column_names)

In [405]:
metric = evaluate.load("squad")


def compute_metrics(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -20 - 1 : -1].tolist() #n_best
            end_indexes = np.argsort(end_logit)[-1 : -20 - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > 60 #max_answer_len
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": {"text":[ex["answers"]["text"]], "answer_start":[ex["answers"]["answer_start"]]}} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

In [406]:
data_collator = DefaultDataCollator()
training_args = TrainingArguments(
    output_dir='OP',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=1,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    #compute_metrics=compute_metrics,
)

trainer.train()

Epoch,Training Loss,Validation Loss
1,No log,No log


TrainOutput(global_step=476, training_loss=1.9688453994879203, metrics={'train_runtime': 3180.9311, 'train_samples_per_second': 1.196, 'train_steps_per_second': 0.15, 'total_flos': 745479646967808.0, 'train_loss': 1.9688453994879203, 'epoch': 1.0})

In [408]:
predictions, _, _ = trainer.predict(validation_dataset)
start_logits, end_logits = predictions
f1 = compute_metrics(start_logits, end_logits, validation_dataset, t_data)

  0%|          | 0/440 [00:00<?, ?it/s]

In [410]:
accuracy_dict[(S,T,SHOT)] = f1['f1']

65.84189610164601

In [None]:
accuracy_dict

In [269]:
# # TRAIN
# tokenized_s_data.set_format("torch")
# train_dataloader = DataLoader(tokenized_s_data, batch_size=32, shuffle=True)

# optimizer = AdamW(model.parameters(), lr=1e-5)
# model = AutoModelForQuestionAnswering.from_pretrained("bert-base-multilingual-uncased")

# # Training loop
# num_epochs = 1
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

# for epoch in range(num_epochs):
#     model.train()
#     total_loss = 0
#     total_correct = 0
#     total_samples = 0
    
#     loop = tqdm(train_dataloader, leave=True)
#     for batch in loop:
#         input_ids = batch["input_ids"].to(device)
#         attention_mask = batch["attention_mask"].to(device)
#         start_positions = batch["start_positions"].to(device)
#         end_positions = batch["end_positions"].to(device)

#         optimizer.zero_grad()

#         outputs = model(
#             input_ids=input_ids,
#             attention_mask=attention_mask,
#             start_positions=start_positions,
#             end_positions=end_positions
#         )

#         loss = outputs.loss
#         total_loss += loss.item()

#         # Get predicted start and end positions
#         pred_start_positions = torch.argmax(outputs.start_logits, dim=1)
#         pred_end_positions = torch.argmax(outputs.end_logits, dim=1)

#         # Calculate accuracy
#         correct_start = (pred_start_positions == start_positions).sum().item()
#         correct_end = (pred_end_positions == end_positions).sum().item()

#         total_correct += correct_start + correct_end
#         total_samples += start_positions.size(0) * 2  # Multiply by 2 as we have start and end positions

#         loss.backward()
#         optimizer.step()

#         loop.set_description(f'Epoch {epoch+1}')
#         loop.set_postfix(loss=loss.item())

#     average_loss = total_loss / len(train_dataloader)
#     accuracy = total_correct / total_samples

#     print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {average_loss}, Training Accuracy: {accuracy}")


Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-b

Epoch 1/1, Average Loss: 5.92772757595983, Training Accuracy: 0.013392857142857142





In [292]:
# # TEST
# tokenized_t_data.set_format("torch")
# test_dataloader = DataLoader(tokenized_t_data, batch_size=32)

# model.eval()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

# total_correct = 0
# total_samples = 0

# with torch.no_grad():
#     for batch in tqdm(test_dataloader):
#         input_ids = batch["input_ids"].to(device)
#         attention_mask = batch["attention_mask"].to(device)
#         start_positions = batch["start_positions"].to(device)
#         end_positions = batch["end_positions"].to(device)

#         outputs = model(
#             input_ids=input_ids,
#             attention_mask=attention_mask,
#             start_positions=start_positions,
#             end_positions=end_positions
#         )

#         # Get predicted start and end positions
#         pred_start_positions = torch.argmax(outputs.start_logits, dim=1)
#         pred_end_positions = torch.argmax(outputs.end_logits, dim=1)

#         # Calculate accuracy
#         correct_start = (pred_start_positions == start_positions).sum().item()
#         correct_end = (pred_end_positions == end_positions).sum().item()

#         total_correct += correct_start + correct_end
#         total_samples += start_positions.size(0) * 2  # Multiply by 2 as we have start and end positions

# accuracy = total_correct / total_samples
# print(f"Testing Accuracy: {accuracy}")

100%|███████████████████████████████████████████| 25/25 [02:26<00:00,  5.84s/it]

Testing Accuracy: 0.0159846547314578



