# Notebook Initialization

In [1]:
!pip install datasets
!pip install groq
!pip install spacy
!python -m spacy download en_core_web_sm

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (1

In [2]:
from typing import Iterable, Any
from google.colab import drive
from tqdm import tqdm
from groq import Groq
import os
import re
import pickle
import json
import pandas as pd
import numpy as np
import datasets
import spacy

In [3]:
drive.mount('/content/drive')
nlp = spacy.load('en_core_web_sm')
groq_client = Groq(api_key='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')

DATASET_ROOT = '/content/drive/MyDrive/ADSP Project/datasets/'
SPLIT_RATIOS = (0.8, 0.1, 0.1)
GROQ_MODEL = 'llama3-8b-8192'

if not os.path.exists(DATASET_ROOT):
    raise ValueError('Invalid data root')
if any(split_ratio <= 0 for split_ratio in SPLIT_RATIOS) or sum(SPLIT_RATIOS) != 1.0:
    raise ValueError('Invalid split ratio')
groq_client.models.retrieve(GROQ_MODEL)

Mounted at /content/drive


Model(id='llama3-8b-8192', created=1693721698, object='model', owned_by='Meta', active=True, context_window=8192, public_apps=None)

# Dataset Class

In [4]:
class Dataset:

    def __init__(self, file_name:str, dataset_name:str=None) -> None:
        self._file_name:str = file_name
        self._stat_dict = {
            'passages': dict[str, int](),
            'queries': dict[str, int](),
            'augmentations': dict[str, int](),
            'relations': dict[str, int](),
            'learning': dict[str, int]()
        }
        self.dataset_name = dataset_name
        self.passage_list = list[str]()
        self.query_list = list[str]()
        self.passage_augmentation_list = list[dict[str, dict[str, int]]]()
        self.query_augmentation_list = list[dict[str, dict[str, int]]]()
        self.augmentation_dict = dict[str, set[int]]()
        self.relation_list = list[set[int]]()
        self.train_set = set[int]()
        self.validation_set = set[int]()
        self.test_set = set[int]()
        potential_dataset_path = os.path.join(DATASET_ROOT, f'{file_name}.pickle')
        if os.path.exists(potential_dataset_path):
            with open(potential_dataset_path, 'rb') as file_handle:
                public_dataset = pickle.load(file_handle)
                for attribute in public_dataset:
                    setattr(self, attribute, public_dataset[attribute])
        elif dataset_name is None:
            raise ValueError('Invalid file name')
        elif dataset_name not in {'ms-marco', 'hotpot-qa'}:
            raise ValueError('Invalid dataset name')
        self._update_stat()

    def __str__(self) -> str:
        output = f'\nnames -> file: {self._file_name}, dataset: {self.dataset_name}\n'
        for stat in self._stat_dict:
            if len(self._stat_dict[stat]) == 0:
                continue
            output += f'{stat} -> ' + ', '.join(f'{attribute}: {self._stat_dict[stat][attribute]}' for attribute in self._stat_dict[stat]) + '\n'
        return output

    def _shuffle(self) -> None:
        def __shuffle_elemenet_list(element_list:list[Any], shuffle_map:list[int]) -> None:
            temp = element_list.copy()
            for i in range(len(temp)):
                element_list[shuffle_map[i]] = temp[i]
        def __shuffle_index_set(index_set:set[Any], shuffle_map:list[int]) -> None:
            temp = index_set.copy()
            index_set.clear()
            for index in temp:
                index_set.add(shuffle_map[index])
        passage_suffle_map = np.random.permutation(len(self.passage_list)).tolist()
        query_shuffle_map = np.random.permutation(len(self.query_list)).tolist()
        __shuffle_elemenet_list(self.passage_list, passage_suffle_map)
        __shuffle_elemenet_list(self.query_list, query_shuffle_map)
        __shuffle_elemenet_list(self.passage_augmentation_list, passage_suffle_map)
        __shuffle_elemenet_list(self.query_augmentation_list, query_shuffle_map)
        for augmentation_name in self.augmentation_dict:
            __shuffle_index_set(self.augmentation_dict[augmentation_name], query_shuffle_map)
        __shuffle_elemenet_list(self.relation_list, query_shuffle_map)
        for relation_set in self.relation_list:
            __shuffle_index_set(relation_set, passage_suffle_map)
        __shuffle_index_set(self.train_set, query_shuffle_map)
        __shuffle_index_set(self.validation_set, query_shuffle_map)
        __shuffle_index_set(self.test_set, query_shuffle_map)

    def _update_stat(self) -> None:
        def __count_quantity(key:str, suffix:str, target_list:list[Any]) -> None:
            self._stat_dict[key][f'total_{suffix}'] = len(target_list)
        def __compute_stat(key:str, suffix:str, target_list:list[Iterable]) -> None:
            if len(target_list) > 0:
                self._stat_dict[key][f'minimum_{suffix}'] = min(len(iterable) for iterable in target_list)
                self._stat_dict[key][f'average_{suffix}'] = round(sum(len(iterable) for iterable in target_list) / len(target_list))
                self._stat_dict[key][f'maximum_{suffix}'] = max(len(iterable) for iterable in target_list)
        __count_quantity('passages', '', self.passage_list)
        __compute_stat('passages', 'length', self.passage_list)
        __count_quantity('queries', '', self.query_list)
        __compute_stat('queries', 'length', self.query_list)
        for augmentation_name in self.augmentation_dict:
            __count_quantity('augmentations', f'queries_augmented_with_{augmentation_name}', self.augmentation_dict[augmentation_name])
        __compute_stat('relations', 'related_passages', self.relation_list)
        __count_quantity('learning', 'queries_in_train_set', self.train_set)
        __count_quantity('learning', 'queries_in_validation_set', self.validation_set)
        __count_quantity('learning', 'queries_in_test_set', self.test_set)

    def save(self) -> None:
        self._update_stat()
        dataset_path = os.path.join(DATASET_ROOT, f'{self._file_name}.pickle')
        with open(dataset_path, 'wb') as file_handle:
            public_dataset = {attribute: getattr(self, attribute) for attribute in self.__dict__ if not attribute.startswith('_')}
            pickle.dump(public_dataset, file_handle, protocol=pickle.HIGHEST_PROTOCOL)
        stat_path = os.path.join(DATASET_ROOT, 'stat.json')
        if os.path.exists(stat_path):
            with open(stat_path, 'r') as file_handle:
                full_stat = json.load(file_handle)
        else:
            full_stat = dict[str, dict[str, dict[str, int]]]()
        full_stat[self._file_name] = {'dataset_name': self.dataset_name} | self._stat_dict
        with open(stat_path, 'w') as file_handle:
            json.dump(full_stat, file_handle, indent=4)

    def add_points(self, total_queries:int) -> None:
        def __expand_ms_marco_split(split_name:str, split_ratio:float) -> None:
            extra_split_size = int(total_queries * split_ratio)
            with tqdm(total=extra_split_size, desc=f'Downloading {split_name} set from {self.dataset_name} dataset') as pbar:
                stream = datasets.load_dataset('microsoft/ms_marco', 'v1.1', split=split_name, streaming=True)
                target_learning_set:set[int] = getattr(self, split_name + '_set')
                for point in stream.skip(len(target_learning_set)).take(extra_split_size):
                    current_passage_index = len(self.passage_list)
                    current_query_index = len(self.query_list)
                    extra_passage_list:list[str] = point['passages']['passage_text']
                    extra_query:str = point['query']
                    for passage in extra_passage_list:
                        self.passage_list.append(passage)
                        self.passage_augmentation_list.append(dict[str, dict[str, int]]())
                    self.query_list.append(extra_query)
                    self.query_augmentation_list.append(dict[str, dict[str, int]]())
                    self.relation_list.append(set(range(current_passage_index, len(self.passage_list))))
                    target_learning_set.add(current_query_index)
                    pbar.update(1)
        def __expand_hotpot_qa_split(split_name:str, split_ratio:float) -> None:
            extra_split_size = int(total_queries * split_ratio)
            with tqdm(total=extra_split_size, desc=f'Downloading {split_name} set from {self.dataset_name} dataset') as pbar:
                stream = datasets.load_dataset('hotpot_qa', 'fullwiki', split=split_name, streaming=True)
                target_learning_set:set[int] = getattr(self, split_name + '_set')
                for point in stream.skip(len(target_learning_set)).take(extra_split_size):
                    current_passage_index = len(self.passage_list)
                    current_query_index = len(self.query_list)
                    extra_passage_list = list[str]()
                    for index, title in enumerate(point['context']["title"]):
                        document = [title]
                        for sentence in point['context']['sentences'][index]:
                            document.append(sentence)
                        extra_passage_list.append("\n".join(document))
                    extra_query:str = point['question']
                    for passage in extra_passage_list:
                        self.passage_list.append(passage)
                        self.passage_augmentation_list.append(dict[str, dict[str, int]]())
                    self.query_list.append(extra_query)
                    self.query_augmentation_list.append(dict[str, dict[str, int]]())
                    self.relation_list.append(set(range(current_passage_index, len(self.passage_list))))
                    target_learning_set.add(current_query_index)
                    pbar.update(1)
        if self.dataset_name == 'ms-marco':
            for split_name, split_ratio in zip(['train', 'validation', 'test'], SPLIT_RATIOS):
                __expand_ms_marco_split(split_name, split_ratio)
        elif self.dataset_name == 'hotpot-qa':
            for split_name, split_ratio in zip(['train', 'validation', 'test'], SPLIT_RATIOS):
                __expand_hotpot_qa_split(split_name, split_ratio)
        self._shuffle()
        self._update_stat()

    def augment_with_ner(self, total_queries:int=None) -> None:
        def __initialize(augmentation_name:str) -> tuple[set[int], int]:
            if augmentation_name not in self.augmentation_dict:
                self.augmentation_dict[augmentation_name] = set[int]()
            chosen_query_index_set = set[int]()
            for query_index in np.random.permutation(len(self.query_list)).tolist():
                if query_index not in self.augmentation_dict[augmentation_name]:
                    chosen_query_index_set.add(query_index)
                if len(chosen_query_index_set) == total_queries:
                    break
            total_texts = sum(len(self.relation_list[query_index]) for query_index in chosen_query_index_set)
            return chosen_query_index_set, total_texts
        def __extract_ner(text:str) -> dict[str, dict[str, int]]:
            ner_dict = dict[str, dict[str, int]]()
            for ent in nlp(text).ents:
                key = 'spacy_entity_' + ent.label_.lower().strip()
                entity = ent.text.lower().strip()
                if key not in ner_dict:
                    ner_dict[key] = dict[str, int]()
                if entity not in ner_dict[key]:
                    ner_dict[key][entity] = 0
                ner_dict[key][entity] += 1
            return ner_dict
        chosen_query_index_set, total_texts = __initialize('spacy_ner')
        with tqdm(total=total_texts, desc='Augmenting with Spacy NER') as pbar:
            for query_index in chosen_query_index_set:
                query_ner_dict = __extract_ner(self.query_list[query_index])
                complex_passage_ner_dict = dict[int, dict[str, dict[str, int]]]()
                for passage_index in self.relation_list[query_index]:
                    complex_passage_ner_dict[passage_index] = __extract_ner(self.passage_list[passage_index])
                    pbar.update(1)
                for key in query_ner_dict:
                    self.query_augmentation_list[query_index][key] = query_ner_dict[key]
                for passage_index, passage_ner_dict in complex_passage_ner_dict.items():
                    for key in passage_ner_dict:
                        self.passage_augmentation_list[passage_index][key] = passage_ner_dict[key]
                self.augmentation_dict['spacy_ner'].add(query_index)
        self._update_stat()

    def augment_with_keyword_and_topic(self, total_queries:int=None) -> None:
        def __initialize(augmentation_name:str) -> tuple[set[int], int]:
                if augmentation_name not in self.augmentation_dict:
                    self.augmentation_dict[augmentation_name] = set[int]()
                chosen_query_index_set = set[int]()
                for query_index in np.random.permutation(len(self.query_list)).tolist():
                    if query_index not in self.augmentation_dict[augmentation_name]:
                        chosen_query_index_set.add(query_index)
                    if len(chosen_query_index_set) == total_queries:
                        break
                total_texts = sum(len(self.relation_list[query_index]) for query_index in chosen_query_index_set)
                return chosen_query_index_set, total_texts
        def __extract_keywords_and_topics(text:str) -> tuple[set[str], set[str]]:
            chat_completion = groq_client.chat.completions.create(
                messages=[
                    {'role': 'system', 'content': (
                        'You are an AI assistant tasked with identifying keywords and topics from the given text.'
                        '\nYour task rules are as follows:'
                        '\n- The output is the following dictionary:'
                        '\n  {'
                        '\n      "keyword_list": [list of most important keywords],'
                        '\n      "topic_list": [list of most important topics]'
                        '\n- Your response must be the output dictionary in JSON format without any extra information.'
                    )},
                    {'role': 'user', 'content': (
                        'Here is the text:'
                        '\n' + text + ''
                    )}
                ],
                model = GROQ_MODEL
            )
            response = chat_completion.choices[0].message.content
            keyword_set = set[str]()
            try:
                matched_groups = re.search(r'\"keyword_list\"\s*:\s*(\[[^\]]*\])', response)
                for potential_keyword in set(json.loads(matched_groups.group(1))):
                    if not isinstance(potential_keyword, str):
                        raise BaseException()
                    potential_keyword = potential_keyword.lower().strip()
                    if len(potential_keyword) == 0:
                        raise BaseException()
                    keyword_set.add(potential_keyword)
            except BaseException:
                pass
            topic_set = set[str]()
            try:
                matched_groups = re.search(r'\"topic_list\"\s*:\s*(\[[^\]]*\])', response)
                for potential_topic in set(json.loads(matched_groups.group(1))):
                    if not isinstance(potential_topic, str):
                        raise BaseException()
                    potential_topic = potential_topic.lower().strip()
                    if len(potential_topic) == 0:
                        raise BaseException()
                    topic_set.add(potential_topic)
            except BaseException:
                pass
            return keyword_set, topic_set
        chosen_query_index_set, total_texts = __initialize(f'{GROQ_MODEL}_keyword_and_topic_extraction')
        with tqdm(total=total_texts, desc=f'Augmenting with {GROQ_MODEL} Keyword and Topic Extraction') as pbar:
            for query_index in chosen_query_index_set:
                query_keyword_set, query_topic_set = __extract_keywords_and_topics(self.query_list[query_index])
                if len(query_keyword_set) == 0 or len(query_topic_set) == 0:
                    pbar.update(len(self.relation_list[query_index]))
                    continue
                passage_keyword_and_topic_dict = dict[int, tuple[set[str], set[str]]]()
                for passage_index in self.relation_list[query_index]:
                    passage_keyword_set, passage_topic_set = __extract_keywords_and_topics(self.passage_list[passage_index])
                    pbar.update(1)
                    if len(passage_keyword_set) == 0 or len(passage_topic_set) == 0:
                        break
                    passage_keyword_and_topic_dict[passage_index] = passage_keyword_set, passage_topic_set
                if len(self.relation_list[query_index]) - len(passage_keyword_and_topic_dict) != 0:
                    pbar.update(len(self.relation_list[query_index]) - len(passage_keyword_and_topic_dict))
                    continue
                self.query_augmentation_list[query_index][f'{GROQ_MODEL}_keyword'] = {keyword: 1 for keyword in query_keyword_set}
                self.query_augmentation_list[query_index][f'{GROQ_MODEL}_topic'] = {topic: 1 for topic in query_topic_set}
                for passage_index, (passage_keyword_set, passage_topic_set) in passage_keyword_and_topic_dict.items():
                    self.passage_augmentation_list[passage_index][f'{GROQ_MODEL}_keyword'] = dict.fromkeys(passage_keyword_set, 1)
                    self.passage_augmentation_list[passage_index][f'{GROQ_MODEL}_topic'] = dict.fromkeys(passage_topic_set, 1)
                self.augmentation_dict[f'{GROQ_MODEL}_keyword_and_topic_extraction'].add(query_index)
        self._update_stat()

# Operation

## MS-MARCO

### With Augmentation

In [None]:
ms_marco_dataset = Dataset('ms-marco-spacy-llama8b', 'ms-marco')
print(ms_marco_dataset)

In [None]:
ms_marco_dataset.add_points(100)
ms_marco_dataset.save()
print(ms_marco_dataset)

In [None]:
ms_marco_dataset.augment_with_ner()
ms_marco_dataset.save()
print(ms_marco_dataset)

In [None]:
ms_marco_dataset.augment_with_keyword_and_topic(10)
ms_marco_dataset.save()
print(ms_marco_dataset)

### Without Augmentation

In [None]:
ms_marco_dataset = Dataset('ms-marco-no-augmentation', 'ms-marco')
print(ms_marco_dataset)

In [None]:
ms_marco_dataset.add_points(1000)
ms_marco_dataset.save()
print(ms_marco_dataset)

## Hotpot QA

### With Augmentation

In [None]:
hotpot_qa_dataset = Dataset('hotpot-qa-spacy-llama8b', 'hotpot-qa')
print(hotpot_qa_dataset)


names -> file: hotpot-qa-spacy-llama8b, dataset: hotpot-qa
passages -> total_: 0
queries -> total_: 0
learning -> total_queries_in_train_set: 0, total_queries_in_validation_set: 0, total_queries_in_test_set: 0



In [None]:
hotpot_qa_dataset.add_points(100)
hotpot_qa_dataset.save()
print(hotpot_qa_dataset)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/9.19k [00:00<?, ?B/s]

hotpot_qa.py:   0%|          | 0.00/6.42k [00:00<?, ?B/s]

The repository for hotpot_qa contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/hotpot_qa.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading train set from hotpot-qa dataset: 100%|██████████| 80/80 [00:27<00:00,  2.94it/s]
Downloading validation set from hotpot-qa dataset: 100%|██████████| 10/10 [00:02<00:00,  3.62it/s]
Downloading test set from hotpot-qa dataset: 100%|██████████| 10/10 [00:01<00:00,  5.46it/s]



names -> file: hotpot-qa-spacy-llama8b, dataset: hotpot-qa
passages -> total_: 1000, minimum_length: 80, average_length: 598, maximum_length: 8307
queries -> total_: 100, minimum_length: 40, average_length: 111, maximum_length: 418
relations -> minimum_related_passages: 10, average_related_passages: 10, maximum_related_passages: 10
learning -> total_queries_in_train_set: 80, total_queries_in_validation_set: 10, total_queries_in_test_set: 10



In [None]:
hotpot_qa_dataset.augment_with_ner()
hotpot_qa_dataset.save()
print(hotpot_qa_dataset)

Augmenting with Spacy NER: 100%|██████████| 1000/1000 [00:33<00:00, 29.84it/s]


names -> file: hotpot-qa-spacy-llama8b, dataset: hotpot-qa
passages -> total_: 1000, minimum_length: 80, average_length: 598, maximum_length: 8307
queries -> total_: 100, minimum_length: 40, average_length: 111, maximum_length: 418
augmentations -> total_queries_augmented_with_spacy_ner: 100
relations -> minimum_related_passages: 10, average_related_passages: 10, maximum_related_passages: 10
learning -> total_queries_in_train_set: 80, total_queries_in_validation_set: 10, total_queries_in_test_set: 10






In [None]:
hotpot_qa_dataset.augment_with_keyword_and_topic(5)
hotpot_qa_dataset.save()
print(hotpot_qa_dataset)

Augmenting with llama3-8b-8192 Keyword and Topic Extraction: 100%|██████████| 40/40 [00:38<00:00,  1.04it/s]


names -> file: hotpot-qa-spacy-llama8b, dataset: hotpot-qa
passages -> total_: 1000, minimum_length: 80, average_length: 598, maximum_length: 8307
queries -> total_: 100, minimum_length: 40, average_length: 111, maximum_length: 418
augmentations -> total_queries_augmented_with_spacy_ner: 100, total_queries_augmented_with_llama3-8b-8192_keyword_and_topic_extraction: 100
relations -> minimum_related_passages: 10, average_related_passages: 10, maximum_related_passages: 10
learning -> total_queries_in_train_set: 80, total_queries_in_validation_set: 10, total_queries_in_test_set: 10






### Without Augmentation

In [5]:
hotpot_qa_dataset = Dataset('hotpot-qa-no-augmentation', 'hotpot-qa')
print(hotpot_qa_dataset)


names -> file: hotpot-qa-no-augmentation, dataset: hotpot-qa
passages -> total_: 0
queries -> total_: 0
learning -> total_queries_in_train_set: 0, total_queries_in_validation_set: 0, total_queries_in_test_set: 0



In [6]:
hotpot_qa_dataset.add_points(1000)
hotpot_qa_dataset.save()
print(hotpot_qa_dataset)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/9.19k [00:00<?, ?B/s]

hotpot_qa.py:   0%|          | 0.00/6.42k [00:00<?, ?B/s]

The repository for hotpot_qa contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/hotpot_qa.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading train set from hotpot-qa dataset: 100%|██████████| 800/800 [01:06<00:00, 11.95it/s]
Downloading validation set from hotpot-qa dataset: 100%|██████████| 100/100 [00:04<00:00, 24.56it/s]
Downloading test set from hotpot-qa dataset: 100%|██████████| 100/100 [00:02<00:00, 35.02it/s]



names -> file: hotpot-qa-no-augmentation, dataset: hotpot-qa
passages -> total_: 9913, minimum_length: 63, average_length: 567, maximum_length: 8307
queries -> total_: 1000, minimum_length: 32, average_length: 104, maximum_length: 542
relations -> minimum_related_passages: 1, average_related_passages: 10, maximum_related_passages: 10
learning -> total_queries_in_train_set: 800, total_queries_in_validation_set: 100, total_queries_in_test_set: 100



# Tests

## Suffling

In [None]:
from typing import Iterable, Any

def shuffle_elemenet_list(element_list:list[Any], shuffle_map:list[int]) -> None:
    temp = element_list.copy()
    for i in range(len(temp)):
        element_list[shuffle_map[i]] = temp[i]

def shuffle_index_set(index_set:set[Any], shuffle_map:list[int]) -> None:
    temp = index_set.copy()
    index_set.clear()
    for index in temp:
        index_set.add(shuffle_map[index])

foo = ['a', 'b', 'c', 'd', 'e']
bar = {4, 1}
print(foo)
print({foo[i] for i in bar})

shuffle_map = np.random.permutation(len(foo)).tolist()
shuffle_elemenet_list(foo, shuffle_map)
shuffle_index_set(bar, shuffle_map)
print(shuffle_map)
print(foo)
print({foo[i] for i in bar})

['a', 'b', 'c', 'd', 'e']
{'e', 'b'}
[4, 3, 0, 1, 2]
['c', 'd', 'e', 'b', 'a']
{'b', 'e'}


In [None]:
print('----------')
passage_list = ['p0', 'p1', 'p2', 'p3', 'p4']
passage_syntactic_dict = {
    'k0': [{'e_000': 1, 'e_001': 1}, {}, {'e_020': 2}, {}, {'e_040': 2, 'e_041': 5, 'e_042': 1}],
}
query_list = ['q0', 'q1', 'q2', 'q3', 'q4']
relation_list = [{2, 3}, {1}, {}, {0, 4}, {}]
train_set = {0, 3, 4}
augmentation_dict = {
    'a0': {0, 4},
    'a1': {1, 2}
}

print('passage list:', passage_list)
print('passage syntactic data:')
for key in passage_syntactic_dict:
    print('  key:', key)
    for passage_index in range(len(passage_list)):
        print('    ', passage_list[passage_index], passage_syntactic_dict[key][passage_index])
print('query_list:', query_list)
print('relations:')
for query_index in range(len(query_list)):
    print('  ', query_list[query_index], '->', {passage_list[passage_index] for passage_index in relation_list[query_index]})
print('training set:', {query_list[query_index] for query_index in train_set})
print('augmentations:')
for augmentation_name in augmentation_dict:
    print('  ', augmentation_name, '->', {query_list[query_index] for query_index in augmentation_dict[augmentation_name]})

print('----------')
passage_suffle_map = np.random.permutation(len(passage_list)).tolist()
query_shuffle_map = np.random.permutation(len(query_list)).tolist()

print('shuffle maps:')
print('  passage:', passage_suffle_map)
print('  query:', query_shuffle_map)


print('----------')
shuffle_elemenet_list(passage_list, passage_suffle_map)
for syntactic_list in passage_syntactic_dict.values():
    shuffle_elemenet_list(syntactic_list, passage_suffle_map)
shuffle_elemenet_list(query_list, query_shuffle_map)
shuffle_elemenet_list(relation_list, query_shuffle_map)
for relation_set in relation_list:
    shuffle_index_set(relation_set, passage_suffle_map)
shuffle_index_set(train_set, query_shuffle_map)
for augmentation_set in augmentation_dict.values():
    shuffle_index_set(augmentation_set, query_shuffle_map)

print('passage list:', passage_list)
print('passage syntactic data:')
for key in passage_syntactic_dict:
    print('  key:', key)
    for passage_index in range(len(passage_list)):
        print('    ', passage_list[passage_index], passage_syntactic_dict[key][passage_index])
print('query_list:', query_list)
print('relations:')
for query_index in range(len(query_list)):
    print('  ', query_list[query_index], '->', {passage_list[passage_index] for passage_index in relation_list[query_index]})
print('training set:', {query_list[query_index] for query_index in train_set})
print('augmentations:')
for augmentation_name in augmentation_dict:
    print('  ', augmentation_name, '->', {query_list[query_index] for query_index in augmentation_dict[augmentation_name]})
passage_suffle_map = np.random.permutation(len(passage_list)).tolist()
query_shuffle_map = np.random.permutation(len(query_list)).tolist()

----------
passage list: ['p0', 'p1', 'p2', 'p3', 'p4']
passage syntactic data:
  key: k0
     p0 {'e_000': 1, 'e_001': 1}
     p1 {}
     p2 {'e_020': 2}
     p3 {}
     p4 {'e_040': 2, 'e_041': 5, 'e_042': 1}
query_list: ['q0', 'q1', 'q2', 'q3', 'q4']
relations:
   q0 -> {'p2', 'p3'}
   q1 -> {'p1'}
   q2 -> set()
   q3 -> {'p4', 'p0'}
   q4 -> set()
training set: {'q4', 'q3', 'q0'}
augmentations:
   a0 -> {'q4', 'q0'}
   a1 -> {'q2', 'q1'}
----------
shuffle maps:
  passage: [3, 1, 0, 4, 2]
  query: [3, 2, 4, 1, 0]
----------
passage list: ['p2', 'p1', 'p4', 'p0', 'p3']
passage syntactic data:
  key: k0
     p2 {'e_020': 2}
     p1 {}
     p4 {'e_040': 2, 'e_041': 5, 'e_042': 1}
     p0 {'e_000': 1, 'e_001': 1}
     p3 {}
query_list: ['q4', 'q3', 'q1', 'q0', 'q2']
relations:
   q4 -> set()
   q3 -> {'p4', 'p0'}
   q1 -> {'p1'}
   q0 -> {'p2', 'p3'}
   q2 -> set()
training set: {'q0', 'q4', 'q3'}
augmentations:
   a0 -> {'q4', 'q0'}
   a1 -> {'q2', 'q1'}


## NER

In [None]:
def augment_with_ner(self, total_queries:int=None) -> None:
    def __initialize(augmentation_name:str) -> tuple[set[int], int]:
        if augmentation_name not in self.augmentation_dict:
            self.augmentation_dict[augmentation_name] = set[int]()
        chosen_query_index_set = set[int]()
        for query_index in np.random.permutation(len(self.query_list)).tolist():
            if query_index not in self.augmentation_dict[augmentation_name]:
                chosen_query_index_set.add(query_index)
            if len(chosen_query_index_set) == total_queries:
                break
        total_texts = sum(len(self.relation_list[query_index]) for query_index in chosen_query_index_set)
        return chosen_query_index_set, total_texts
    def __extract_ner(text:str) -> dict[str, dict[str, int]]:
        ner_dict = dict[str, dict[str, int]]()
        for ent in nlp(text).ents:
            key = 'spacy_entity_' + ent.label_.lower().strip()
            entity = ent.text.lower().strip()
            if key not in ner_dict:
                ner_dict[key] = dict[str, int]()
            if entity not in ner_dict[key]:
                ner_dict[key][entity] = 0
            ner_dict[key][entity] += 1
        return ner_dict
    chosen_query_index_set, total_texts = __initialize('spacy_ner')
    with tqdm(total=total_texts, desc='Augmenting with Spacy NER') as pbar:
        for query_index in chosen_query_index_set:
            query_ner_dict = __extract_ner(self.query_list[query_index])
            complex_passage_ner_dict = dict[int, dict[str, dict[str, int]]]()
            for passage_index in self.relation_list[query_index]:
                complex_passage_ner_dict[passage_index] = __extract_ner(self.passage_list[passage_index])
                pbar.update(1)
            for key in query_ner_dict:
                self.query_augmentation_list[query_index][key] = query_ner_dict[key]
            for passage_index, passage_ner_dict in complex_passage_ner_dict.items():
                for key in passage_ner_dict:
                    self.passage_augmentation_list[passage_index][key] = passage_ner_dict[key]
            self.augmentation_dict['spacy_ner'].add(query_index)
    self._update_stat()

augment_with_ner(ms_marco_dataset, 5)

In [None]:
print(ms_marco_dataset.augmentation_dict)
for query_index in ms_marco_dataset.train_set:
    print(query_index, ms_marco_dataset.query_list[query_index], ms_marco_dataset.query_augmentation_list[query_index])
    for passage_index in ms_marco_dataset.relation_list[query_index]:
        print('  ', passage_index, ms_marco_dataset.passage_list[passage_index][:50], ms_marco_dataset.passage_augmentation_list[passage_index])

{'ner': {0, 3, 5, 6, 9}}
0 was ronald reagan a democrat {'entity_person': {'ronald reagan': 1}, 'entity_norp': {'democrat': 1}}
   65 Ronald Reagan (1911-2004), a former actor and Cali {'entity_person': {'ronald reagan': 2, 'walter mondale': 1, 'geraldine ferraro': 1}, 'entity_date': {'1911-2004': 1, '1981': 1, '1989': 1, 'his 20s': 1, '1967': 1, '1975': 1, 'november 1984': 1}, 'entity_gpe': {'california': 2, 'u.s.': 2, 'illinois': 1, 'hollywood': 1}, 'entity_ordinal': {'40th': 1, 'first': 1}, 'entity_norp': {'republican': 1}}
   34 When Reagan was a 'liberal Democrat'. In 1948, a v {'entity_person': {'reagan': 1, 'ronald reagan': 1, 'harry truman': 1}, 'entity_norp': {'democrat': 2}, 'entity_date': {'1948': 1}}
   73 In his younger years, Ronald Reagan was a member o {'entity_date': {'years': 1, 'the early 1960s': 1, 'november 1984': 1}, 'entity_person': {'ronald reagan': 2, 'walter mondale': 1, 'geraldine ferraro': 1}, 'entity_org': {'the democratic party': 1}, 'entity_norp': {'democ

## Keyword and Topic

In [None]:
def augment_with_keyword_and_topic(self:Dataset, total_queries:int) -> None:
    def __initialize(augmentation_name:str) -> tuple[set[int], int]:
            if augmentation_name not in self.augmentation_dict:
                self.augmentation_dict[augmentation_name] = set[int]()
            chosen_query_index_set = set[int]()
            for query_index in np.random.permutation(len(self.query_list)).tolist():
                if query_index not in self.augmentation_dict[augmentation_name]:
                    chosen_query_index_set.add(query_index)
                if len(chosen_query_index_set) == total_queries:
                    break
            total_texts = sum(len(self.relation_list[query_index]) for query_index in chosen_query_index_set)
            return chosen_query_index_set, total_texts
    def __extract_keywords_and_topics(text:str) -> tuple[list[str], list[str]]:
        chat_completion = groq_client.chat.completions.create(
            messages=[
                {'role': 'system', 'content': (
                    'You are an AI assistant tasked with identifying keywords and topics from the given text.'
                    '\nYour task rules are as follows:'
                    '\n- The output is the following dictionary:'
                    '\n  {'
                    '\n      "keyword_list": [list of most important keywords],'
                    '\n      "topic_list": [list of most important topics]'
                    '\n- Your response must be the output dictionary in JSON format without any extra information.'
                )},
                {'role': 'user', 'content': (
                    'Here is the text:'
                    '\n' + text + ''
                )}
            ],
            model = GROQ_MODEL
        )
        response = chat_completion.choices[0].message.content
        keyword_list = list[tuple[str, int]]()
        try:
            matched_groups = re.search(r'\"keyword_list\"\s*:\s*(\[[^\]]*\])', response)
            for potential_keyword in set(json.loads(matched_groups.group(1))):
                if not isinstance(potential_keyword, str):
                    raise BaseException()
                keyword_list.append(potential_keyword.lower().strip())
        except BaseException:
            pass
        topic_list = list[tuple[str, int]]()
        try:
            matched_groups = re.search(r'\"topic_list\"\s*:\s*(\[[^\]]*\])', response)
            for potential_topic in set(json.loads(matched_groups.group(1))):
                if not isinstance(potential_topic, str):
                    raise BaseException()
                topic_list.append(potential_topic.lower().strip())
        except BaseException:
            pass
        return keyword_list, topic_list
    chosen_query_index_set, total_texts = __initialize(f'{GROQ_MODEL}_keyword_and_topic_extraction')
    with tqdm(total=total_texts, desc=f'Augmenting with {GROQ_MODEL} Keyword and Topic Extraction') as pbar:
        for query_index in chosen_query_index_set:
            query_keyword_list, query_topic_list = __extract_keywords_and_topics(self.query_list[query_index])
            passage_keyword_and_topic_dict = dict[int, tuple[list[str], list[str]]]()
            for passage_index in self.relation_list[query_index]:
                passage_keyword_and_topic_dict[passage_index] = __extract_keywords_and_topics(self.passage_list[passage_index])
                pbar.update(1)
            if len(query_keyword_list) == 0 or len(query_topic_list) == 0 or any(len(passage_keyword_list) == 0 or len(passage_topic_list) == 0 for passage_keyword_list, passage_topic_list in passage_keyword_and_topic_dict.values()):
                continue
            self.query_augmentation_list[query_index][f'{GROQ_MODEL}_keyword'] = {keyword: 1 for keyword in query_keyword_list}
            self.query_augmentation_list[query_index][f'{GROQ_MODEL}_topic'] = {topic: 1 for topic in query_topic_list}
            for passage_index, (passage_keyword_list, passage_topic_list) in passage_keyword_and_topic_dict.items():
                self.passage_augmentation_list[passage_index][f'{GROQ_MODEL}_keyword'] = {keyword: 1 for keyword in passage_keyword_list}
                self.passage_augmentation_list[passage_index][f'{GROQ_MODEL}_topic'] = {topic: 1 for topic in passage_topic_list}
            self.augmentation_dict[f'{GROQ_MODEL}_keyword_and_topic_extraction'].add(query_index)
    self._update_stat()

augment_with_keyword_and_topic(ms_marco_dataset, 2)

Augmenting with llama3-8b-8192 Keyword and Topic Extraction: 100%|██████████| 17/17 [00:04<00:00,  4.23it/s]


In [None]:
print(ms_marco_dataset.augmentation_dict)
for query_index in ms_marco_dataset.train_set:
    print(query_index, ms_marco_dataset.query_list[query_index], ms_marco_dataset.query_augmentation_list[query_index])
    for passage_index in ms_marco_dataset.relation_list[query_index]:
        print('  ', passage_index, ms_marco_dataset.passage_list[passage_index][:50], ms_marco_dataset.passage_augmentation_list[passage_index])

{'ner': {0, 3, 5, 6, 9}, 'keyword_and_topic': {4, 6}, 'llama3-8b-8192_keyword_and_topic_extraction': {8, 6}}
0 was ronald reagan a democrat {'entity_person': {'ronald reagan': 1}, 'entity_norp': {'democrat': 1}}
   65 Ronald Reagan (1911-2004), a former actor and Cali {'entity_person': {'ronald reagan': 2, 'walter mondale': 1, 'geraldine ferraro': 1}, 'entity_date': {'1911-2004': 1, '1981': 1, '1989': 1, 'his 20s': 1, '1967': 1, '1975': 1, 'november 1984': 1}, 'entity_gpe': {'california': 2, 'u.s.': 2, 'illinois': 1, 'hollywood': 1}, 'entity_ordinal': {'40th': 1, 'first': 1}, 'entity_norp': {'republican': 1}}
   34 When Reagan was a 'liberal Democrat'. In 1948, a v {'entity_person': {'reagan': 1, 'ronald reagan': 1, 'harry truman': 1}, 'entity_norp': {'democrat': 2}, 'entity_date': {'1948': 1}}
   73 In his younger years, Ronald Reagan was a member o {'entity_date': {'years': 1, 'the early 1960s': 1, 'november 1984': 1}, 'entity_person': {'ronald reagan': 2, 'walter mondale': 1, 'geral

## Overall

In [None]:
print(ms_marco_dataset.augmentation_dict)
for query_index in ms_marco_dataset.train_set:
    print(query_index, ms_marco_dataset.query_list[query_index], ms_marco_dataset.query_augmentation_list[query_index])
    for passage_index in ms_marco_dataset.relation_list[query_index]:
        print('  ', passage_index, ms_marco_dataset.passage_list[passage_index][:50].replace('\n', ' ').replace('  ', ' '), ms_marco_dataset.passage_augmentation_list[passage_index])

In [None]:
for augmentation_name in ms_marco_dataset.augmentation_dict:
    print(augmentation_name, ms_marco_dataset.augmentation_dict[augmentation_name])
for i in range(len(ms_marco_dataset.query_list)):
    print(i, ms_marco_dataset.query_augmentation_list[i])

In [None]:
keyword_empty_set = set[int]()
topic_empty_set = set[int]()
for i in range(len(ms_marco_dataset.query_list)):
    if 'llama3-8b-8192_keyword' not in ms_marco_dataset.query_augmentation_list[i]:
        keyword_empty_set.add(i)
    if 'llama3-8b-8192_topic' not in ms_marco_dataset.query_augmentation_list[i]:
        topic_empty_set.add(i)
print(keyword_empty_set)
print(topic_empty_set)

keyword_empty_set = set[int]()
topic_empty_set = set[int]()
for i in range(len(ms_marco_dataset.passage_list)):
    if 'llama3-8b-8192_keyword' not in ms_marco_dataset.passage_augmentation_list[i]:
        keyword_empty_set.add(i)
    if 'llama3-8b-8192_topic' not in ms_marco_dataset.passage_augmentation_list[i]:
        topic_empty_set.add(i)
print(keyword_empty_set)
print(topic_empty_set)