In [None]:
import random
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset

plt.style.use("ggplot")

In [None]:
annotated_data_path = Path.cwd().parent / "data" / "raw" / "foqa.csv"
assert annotated_data_path.exists()

In [None]:
raw_dataset_dict = load_dataset("alexandrainst/foqa", "all-samples")
raw_df = pd.concat(
    [raw_dataset_dict[split].to_pandas() for split in raw_dataset_dict.keys()]
)
raw_df

In [None]:
context_lengths = raw_df.original_context.str.len().sort_values()
context_lengths.hist(bins="auto")
plt.show()

In [None]:
def extract_shorter_context(answer_dict: dict, context: str) -> str:
    """Extracts a shorter context containing the answer.

    This ensures that the number of characters in the context is
    at most 5000 characters.

    Args:
        answer_dict:
            The answer dictionary, with keys 'text' and 'answer_start'.
        context:
            The context.

    Returns:
        The shortened context.

    Raises:
        RuntimeError:
            If it wasn't possible to shorten the context.
    """
    # If the context is short enough already then we don't have to shorten it
    if len(context) < 5000:
        return context

    answer = answer_dict["text"][0]

    paragraphs = [p for p in context.split("\n\n") if p]
    paragraph_answer_idx = next(
        idx for idx, paragraph in enumerate(paragraphs) if answer in paragraph
    )
    single_paragraph_is_short_enough = len(paragraphs[paragraph_answer_idx]) < 5000
    assert answer in paragraphs[paragraph_answer_idx]

    # If the paragraph containing the answer is short enough, then we identify all the
    # possible longest contexts with the answer, by including paragraphs above and below that
    # paragraph. We then select the context at random from these.
    if single_paragraph_is_short_enough:
        all_valid_contexts = set()
        for start_paragraph_idx in range(paragraph_answer_idx, -1, -1):
            best_candidate_context = paragraphs[paragraph_answer_idx]
            for end_paragraph_idx in range(paragraph_answer_idx, len(paragraphs)):
                candidate_context = "\n\n".join(
                    paragraphs[start_paragraph_idx : end_paragraph_idx + 1]
                )
                if candidate_context and len(candidate_context) < 5000:
                    best_candidate_context = candidate_context
                else:
                    break
            all_valid_contexts.add(best_candidate_context)
        assert len(all_valid_contexts) > 0
        context = random.choice(list(all_valid_contexts))
        return context

    # Otherwise, we start splitting up the paragraph containing the answer into lines
    lines = paragraphs[paragraph_answer_idx].split("\n")
    line_answer_idx = next(idx for idx, line in enumerate(lines) if answer in line)
    assert answer in lines[line_answer_idx]

    # Again, we do the same as we did for the paragraphs, just for lines instead
    all_valid_contexts = set()
    for start_line_idx in range(line_answer_idx, -1, -1):
        best_candidate_context = ""
        for end_line_idx in range(line_answer_idx, len(lines)):
            candidate_context = "\n".join(lines[start_line_idx : end_line_idx + 1])
            if candidate_context and len(candidate_context) < 5000:
                best_candidate_context = candidate_context
            else:
                break
        all_valid_contexts.add(best_candidate_context)
    assert len(all_valid_contexts) > 0
    return random.choice(list(all_valid_contexts))

In [None]:
def update_answer_start(context: str, answer_dict: dict) -> dict:
    """Updates the start of the answers.

    Args:
        context:
            The context where the answer appears.
        answer_dict:
            The original answer dictionary, with keys 'text' and 'answer_start'.

    Returns:
        A dictionary with keys 'text' and 'answer_start', each being
        lists with a single element.
    """
    answer = answer_dict["text"][0]
    answer_start = context.index(answer)
    assert context[answer_start : answer_start + len(answer)] == answer
    return dict(text=[answer], answer_start=[answer_start])

In [None]:
if "original_context" not in raw_df.columns:
    raw_df.rename(columns=dict(context="original_context"), inplace=True)
raw_df["context"] = raw_df.apply(
    lambda x: extract_shorter_context(
        answer_dict=x.answers, context=x.original_context
    ),
    axis=1,
)
raw_df["answers"] = raw_df.apply(
    lambda x: update_answer_start(context=x.context, answer_dict=x.answers), axis=1
)
shorter_context_lengths = raw_df.context.str.len().sort_values()
shorter_context_lengths.hist(bins="auto")
plt.show()

In [None]:
annotated_df = pd.read_csv(annotated_data_path)
annotated_df.columns = [
    "id",
    "title",
    "context",
    "question",
    "deleteme",
    "validation",
    "answers",
]
annotated_df.drop(columns=["deleteme", "answers"], inplace=True)
annotated_df["answers"] = raw_df.answers
annotated_df["context"] = raw_df.context
annotated_df = annotated_df[raw_df.drop(columns="original_context").columns]
annotated_df

In [None]:
correct_texts = ["correct", "corrected"]
correct_df = annotated_df.query("validation in @correct_texts")
correct_df

In [None]:
test_df = correct_df.query('validation == "corrected"')
num_test_samples_missing = 1024 - len(test_df)
test_df = pd.concat(
    [
        test_df,
        correct_df.loc[
            [idx for idx in correct_df.index if idx not in test_df.index]
        ].sample(n=num_test_samples_missing),
    ]
)
val_df = correct_df.loc[
    [idx for idx in correct_df.index if idx not in test_df.index]
].sample(n=128)
train_df = correct_df.loc[
    [
        idx
        for idx in correct_df.index
        if idx not in test_df.index and idx not in val_df.index
    ]
]
len(train_df), len(val_df), len(test_df)

In [None]:
train = Dataset.from_pandas(train_df, preserve_index=False)
val = Dataset.from_pandas(val_df, preserve_index=False)
test = Dataset.from_pandas(test_df, preserve_index=False)
DatasetDict(dict(train=train, val=val, test=test)).push_to_hub(
    "alexandrainst/foqa", "default"
)

In [None]:
incorrect_texts = ["incorrect", "incorrect-answer"]
incorrect_df = annotated_df.query("validation in @incorrect_texts").dropna()

wrong_corrected_df = raw_df.drop(columns="original_context").loc[
    annotated_df.query('validation == "corrected"').index.tolist()
]
wrong_corrected_df.validation = annotated_df.loc[
    wrong_corrected_df.index.tolist()
].validation

incorrect_df = pd.concat(objs=[incorrect_df, wrong_corrected_df])
incorrect_df.head()

In [None]:
Dataset.from_pandas(incorrect_df, preserve_index=False).push_to_hub(
    "alexandrainst/foqa", "incorrect-samples"
)
Dataset.from_pandas(raw_df, preserve_index=False).push_to_hub(
    "alexandrainst/foqa", "all-samples"
)