# Notebook Initialization

In [None]:
!pip install datasets
!pip install groq
!pip install spacy
!python -m spacy download en_core_web_sm

In [None]:
from typing import Iterable, Any
from google.colab import drive
from tqdm import tqdm
from groq import Groq
import os
import re
import pickle
import json
import pandas as pd
import numpy as np
import datasets
import spacy

In [None]:
drive.mount('/content/drive')
nlp = spacy.load('en_core_web_sm')
groq_client = Groq(api_key='YOUR API KEY FROM GROQ')

DATASET_ROOT = '/content/drive/MyDrive/ADSP Project/datasets/'
SPLIT_RATIOS = (0.8, 0.1, 0.1)
GROQ_MODEL = 'llama3-8b-8192'

if not os.path.exists(DATASET_ROOT):
    raise ValueError('Invalid data root')
if any(split_ratio <= 0 for split_ratio in SPLIT_RATIOS) or sum(SPLIT_RATIOS) != 1.0:
    raise ValueError('Invalid split ratio')
groq_client.models.retrieve(GROQ_MODEL)

# Dataset Class

In [None]:
class Dataset:

    def __init__(self, file_name:str, dataset_name:str=None) -> None:
        self._file_name:str = file_name
        self._stat_dict = {
            'passages': dict[str, int](),
            'queries': dict[str, int](),
            'augmentations': dict[str, int](),
            'relations': dict[str, int](),
            'learning': dict[str, int]()
        }
        self.dataset_name = dataset_name
        self.passage_list = list[str]()
        self.query_list = list[str]()
        self.passage_augmentation_list = list[dict[str, dict[str, int]]]()
        self.query_augmentation_list = list[dict[str, dict[str, int]]]()
        self.augmentation_dict = dict[str, set[int]]()
        self.relation_list = list[set[int]]()
        self.train_set = set[int]()
        self.validation_set = set[int]()
        self.test_set = set[int]()
        potential_dataset_path = os.path.join(DATASET_ROOT, f'{file_name}.pickle')
        if os.path.exists(potential_dataset_path):
            with open(potential_dataset_path, 'rb') as file_handle:
                public_dataset = pickle.load(file_handle)
                for attribute in public_dataset:
                    setattr(self, attribute, public_dataset[attribute])
        elif dataset_name is None:
            raise ValueError('Invalid file name')
        elif dataset_name not in {'ms-marco', 'hotpot-qa'}:
            raise ValueError('Invalid dataset name')
        self._update_stat()

    def __str__(self) -> str:
        output = f'\nnames -> file: {self._file_name}, dataset: {self.dataset_name}\n'
        for stat in self._stat_dict:
            if len(self._stat_dict[stat]) == 0:
                continue
            output += f'{stat} -> ' + ', '.join(f'{attribute}: {self._stat_dict[stat][attribute]}' for attribute in self._stat_dict[stat]) + '\n'
        return output

    def _shuffle(self) -> None:
        def __shuffle_elemenet_list(element_list:list[Any], shuffle_map:list[int]) -> None:
            temp = element_list.copy()
            for i in range(len(temp)):
                element_list[shuffle_map[i]] = temp[i]
        def __shuffle_index_set(index_set:set[Any], shuffle_map:list[int]) -> None:
            temp = index_set.copy()
            index_set.clear()
            for index in temp:
                index_set.add(shuffle_map[index])
        passage_suffle_map = np.random.permutation(len(self.passage_list)).tolist()
        query_shuffle_map = np.random.permutation(len(self.query_list)).tolist()
        __shuffle_elemenet_list(self.passage_list, passage_suffle_map)
        __shuffle_elemenet_list(self.query_list, query_shuffle_map)
        __shuffle_elemenet_list(self.passage_augmentation_list, passage_suffle_map)
        __shuffle_elemenet_list(self.query_augmentation_list, query_shuffle_map)
        for augmentation_name in self.augmentation_dict:
            __shuffle_index_set(self.augmentation_dict[augmentation_name], query_shuffle_map)
        __shuffle_elemenet_list(self.relation_list, query_shuffle_map)
        for relation_set in self.relation_list:
            __shuffle_index_set(relation_set, passage_suffle_map)
        __shuffle_index_set(self.train_set, query_shuffle_map)
        __shuffle_index_set(self.validation_set, query_shuffle_map)
        __shuffle_index_set(self.test_set, query_shuffle_map)

    def _update_stat(self) -> None:
        def __count_quantity(key:str, suffix:str, target_list:list[Any]) -> None:
            self._stat_dict[key][f'total_{suffix}'] = len(target_list)
        def __compute_stat(key:str, suffix:str, target_list:list[Iterable]) -> None:
            if len(target_list) > 0:
                self._stat_dict[key][f'minimum_{suffix}'] = min(len(iterable) for iterable in target_list)
                self._stat_dict[key][f'average_{suffix}'] = round(sum(len(iterable) for iterable in target_list) / len(target_list))
                self._stat_dict[key][f'maximum_{suffix}'] = max(len(iterable) for iterable in target_list)
        __count_quantity('passages', '', self.passage_list)
        __compute_stat('passages', 'length', self.passage_list)
        __count_quantity('queries', '', self.query_list)
        __compute_stat('queries', 'length', self.query_list)
        for augmentation_name in self.augmentation_dict:
            __count_quantity('augmentations', f'queries_augmented_with_{augmentation_name}', self.augmentation_dict[augmentation_name])
        __compute_stat('relations', 'related_passages', self.relation_list)
        __count_quantity('learning', 'queries_in_train_set', self.train_set)
        __count_quantity('learning', 'queries_in_validation_set', self.validation_set)
        __count_quantity('learning', 'queries_in_test_set', self.test_set)

    def save(self) -> None:
        self._update_stat()
        dataset_path = os.path.join(DATASET_ROOT, f'{self._file_name}.pickle')
        with open(dataset_path, 'wb') as file_handle:
            public_dataset = {attribute: getattr(self, attribute) for attribute in self.__dict__ if not attribute.startswith('_')}
            pickle.dump(public_dataset, file_handle, protocol=pickle.HIGHEST_PROTOCOL)
        stat_path = os.path.join(DATASET_ROOT, 'stat.json')
        if os.path.exists(stat_path):
            with open(stat_path, 'r') as file_handle:
                full_stat = json.load(file_handle)
        else:
            full_stat = dict[str, dict[str, dict[str, int]]]()
        full_stat[self._file_name] = {'dataset_name': self.dataset_name} | self._stat_dict
        with open(stat_path, 'w') as file_handle:
            json.dump(full_stat, file_handle, indent=4)

    def add_points(self, total_queries:int) -> None:
        def __expand_ms_marco_split(split_name:str, split_ratio:float) -> None:
            extra_split_size = int(total_queries * split_ratio)
            with tqdm(total=extra_split_size, desc=f'Downloading {split_name} set from {self.dataset_name} dataset') as pbar:
                stream = datasets.load_dataset('microsoft/ms_marco', 'v1.1', split=split_name, streaming=True)
                target_learning_set:set[int] = getattr(self, split_name + '_set')
                for point in stream.skip(len(target_learning_set)).take(extra_split_size):
                    current_passage_index = len(self.passage_list)
                    current_query_index = len(self.query_list)
                    extra_passage_list:list[str] = point['passages']['passage_text']
                    extra_query:str = point['query']
                    for passage in extra_passage_list:
                        self.passage_list.append(passage)
                        self.passage_augmentation_list.append(dict[str, dict[str, int]]())
                    self.query_list.append(extra_query)
                    self.query_augmentation_list.append(dict[str, dict[str, int]]())
                    self.relation_list.append(set(range(current_passage_index, len(self.passage_list))))
                    target_learning_set.add(current_query_index)
                    pbar.update(1)
        def __expand_hotpot_qa_split(split_name:str, split_ratio:float) -> None:
            extra_split_size = int(total_queries * split_ratio)
            with tqdm(total=extra_split_size, desc=f'Downloading {split_name} set from {self.dataset_name} dataset') as pbar:
                stream = datasets.load_dataset('hotpot_qa', 'fullwiki', split=split_name, streaming=True)
                target_learning_set:set[int] = getattr(self, split_name + '_set')
                for point in stream.skip(len(target_learning_set)).take(extra_split_size):
                    current_passage_index = len(self.passage_list)
                    current_query_index = len(self.query_list)
                    extra_passage_list = list[str]()
                    for index, title in enumerate(point['context']["title"]):
                        document = [title]
                        for sentence in point['context']['sentences'][index]:
                            document.append(sentence)
                        extra_passage_list.append("\n".join(document))
                    extra_query:str = point['question']
                    for passage in extra_passage_list:
                        self.passage_list.append(passage)
                        self.passage_augmentation_list.append(dict[str, dict[str, int]]())
                    self.query_list.append(extra_query)
                    self.query_augmentation_list.append(dict[str, dict[str, int]]())
                    self.relation_list.append(set(range(current_passage_index, len(self.passage_list))))
                    target_learning_set.add(current_query_index)
                    pbar.update(1)
        if self.dataset_name == 'ms-marco':
            for split_name, split_ratio in zip(['train', 'validation', 'test'], SPLIT_RATIOS):
                __expand_ms_marco_split(split_name, split_ratio)
        elif self.dataset_name == 'hotpot-qa':
            for split_name, split_ratio in zip(['train', 'validation', 'test'], SPLIT_RATIOS):
                __expand_hotpot_qa_split(split_name, split_ratio)
        self._shuffle()
        self._update_stat()

    def augment_with_ner(self, total_queries:int=None) -> None:
        def __initialize(augmentation_name:str) -> tuple[set[int], int]:
            if augmentation_name not in self.augmentation_dict:
                self.augmentation_dict[augmentation_name] = set[int]()
            chosen_query_index_set = set[int]()
            for query_index in np.random.permutation(len(self.query_list)).tolist():
                if query_index not in self.augmentation_dict[augmentation_name]:
                    chosen_query_index_set.add(query_index)
                if len(chosen_query_index_set) == total_queries:
                    break
            total_texts = sum(len(self.relation_list[query_index]) for query_index in chosen_query_index_set)
            return chosen_query_index_set, total_texts
        def __extract_ner(text:str) -> dict[str, dict[str, int]]:
            ner_dict = dict[str, dict[str, int]]()
            for ent in nlp(text).ents:
                key = 'spacy_entity_' + ent.label_.lower().strip()
                entity = ent.text.lower().strip()
                if key not in ner_dict:
                    ner_dict[key] = dict[str, int]()
                if entity not in ner_dict[key]:
                    ner_dict[key][entity] = 0
                ner_dict[key][entity] += 1
            return ner_dict
        chosen_query_index_set, total_texts = __initialize('spacy_ner')
        with tqdm(total=total_texts, desc='Augmenting with Spacy NER') as pbar:
            for query_index in chosen_query_index_set:
                query_ner_dict = __extract_ner(self.query_list[query_index])
                complex_passage_ner_dict = dict[int, dict[str, dict[str, int]]]()
                for passage_index in self.relation_list[query_index]:
                    complex_passage_ner_dict[passage_index] = __extract_ner(self.passage_list[passage_index])
                    pbar.update(1)
                for key in query_ner_dict:
                    self.query_augmentation_list[query_index][key] = query_ner_dict[key]
                for passage_index, passage_ner_dict in complex_passage_ner_dict.items():
                    for key in passage_ner_dict:
                        self.passage_augmentation_list[passage_index][key] = passage_ner_dict[key]
                self.augmentation_dict['spacy_ner'].add(query_index)
        self._update_stat()

    def augment_with_keyword_and_topic(self, total_queries:int=None) -> None:
        def __initialize(augmentation_name:str) -> tuple[set[int], int]:
                if augmentation_name not in self.augmentation_dict:
                    self.augmentation_dict[augmentation_name] = set[int]()
                chosen_query_index_set = set[int]()
                for query_index in np.random.permutation(len(self.query_list)).tolist():
                    if query_index not in self.augmentation_dict[augmentation_name]:
                        chosen_query_index_set.add(query_index)
                    if len(chosen_query_index_set) == total_queries:
                        break
                total_texts = sum(len(self.relation_list[query_index]) for query_index in chosen_query_index_set)
                return chosen_query_index_set, total_texts
        def __extract_keywords_and_topics(text:str) -> tuple[set[str], set[str]]:
            chat_completion = groq_client.chat.completions.create(
                messages=[
                    {'role': 'system', 'content': (
                        'You are an AI assistant tasked with identifying keywords and topics from the given text.'
                        '\nYour task rules are as follows:'
                        '\n- The output is the following dictionary:'
                        '\n  {'
                        '\n      "keyword_list": [list of most important keywords],'
                        '\n      "topic_list": [list of most important topics]'
                        '\n- Your response must be the output dictionary in JSON format without any extra information.'
                    )},
                    {'role': 'user', 'content': (
                        'Here is the text:'
                        '\n' + text + ''
                    )}
                ],
                model = GROQ_MODEL
            )
            response = chat_completion.choices[0].message.content
            keyword_set = set[str]()
            try:
                matched_groups = re.search(r'\"keyword_list\"\s*:\s*(\[[^\]]*\])', response)
                for potential_keyword in set(json.loads(matched_groups.group(1))):
                    if not isinstance(potential_keyword, str):
                        raise BaseException()
                    potential_keyword = potential_keyword.lower().strip()
                    if len(potential_keyword) == 0:
                        raise BaseException()
                    keyword_set.add(potential_keyword)
            except BaseException:
                pass
            topic_set = set[str]()
            try:
                matched_groups = re.search(r'\"topic_list\"\s*:\s*(\[[^\]]*\])', response)
                for potential_topic in set(json.loads(matched_groups.group(1))):
                    if not isinstance(potential_topic, str):
                        raise BaseException()
                    potential_topic = potential_topic.lower().strip()
                    if len(potential_topic) == 0:
                        raise BaseException()
                    topic_set.add(potential_topic)
            except BaseException:
                pass
            return keyword_set, topic_set
        chosen_query_index_set, total_texts = __initialize(f'{GROQ_MODEL}_keyword_and_topic_extraction')
        with tqdm(total=total_texts, desc=f'Augmenting with {GROQ_MODEL} Keyword and Topic Extraction') as pbar:
            for query_index in chosen_query_index_set:
                query_keyword_set, query_topic_set = __extract_keywords_and_topics(self.query_list[query_index])
                if len(query_keyword_set) == 0 or len(query_topic_set) == 0:
                    pbar.update(len(self.relation_list[query_index]))
                    continue
                passage_keyword_and_topic_dict = dict[int, tuple[set[str], set[str]]]()
                for passage_index in self.relation_list[query_index]:
                    passage_keyword_set, passage_topic_set = __extract_keywords_and_topics(self.passage_list[passage_index])
                    pbar.update(1)
                    if len(passage_keyword_set) == 0 or len(passage_topic_set) == 0:
                        break
                    passage_keyword_and_topic_dict[passage_index] = passage_keyword_set, passage_topic_set
                if len(self.relation_list[query_index]) - len(passage_keyword_and_topic_dict) != 0:
                    pbar.update(len(self.relation_list[query_index]) - len(passage_keyword_and_topic_dict))
                    continue
                self.query_augmentation_list[query_index][f'{GROQ_MODEL}_keyword'] = {keyword: 1 for keyword in query_keyword_set}
                self.query_augmentation_list[query_index][f'{GROQ_MODEL}_topic'] = {topic: 1 for topic in query_topic_set}
                for passage_index, (passage_keyword_set, passage_topic_set) in passage_keyword_and_topic_dict.items():
                    self.passage_augmentation_list[passage_index][f'{GROQ_MODEL}_keyword'] = dict.fromkeys(passage_keyword_set, 1)
                    self.passage_augmentation_list[passage_index][f'{GROQ_MODEL}_topic'] = dict.fromkeys(passage_topic_set, 1)
                self.augmentation_dict[f'{GROQ_MODEL}_keyword_and_topic_extraction'].add(query_index)
        self._update_stat()

# Operation

## MS-MARCO

### With Augmentation

In [None]:
ms_marco_dataset = Dataset('ms-marco-spacy-llama8b', 'ms-marco')
print(ms_marco_dataset)

In [None]:
ms_marco_dataset.add_points(100)
ms_marco_dataset.save()
print(ms_marco_dataset)

In [None]:
ms_marco_dataset.augment_with_ner()
ms_marco_dataset.save()
print(ms_marco_dataset)

In [None]:
ms_marco_dataset.augment_with_keyword_and_topic(10)
ms_marco_dataset.save()
print(ms_marco_dataset)

### Without Augmentation

In [None]:
ms_marco_dataset = Dataset('ms-marco-no-augmentation', 'ms-marco')
print(ms_marco_dataset)

In [None]:
ms_marco_dataset.add_points(1000)
ms_marco_dataset.save()
print(ms_marco_dataset)

## Hotpot QA

### With Augmentation

In [None]:
hotpot_qa_dataset = Dataset('hotpot-qa-spacy-llama8b', 'hotpot-qa')
print(hotpot_qa_dataset)

In [None]:
hotpot_qa_dataset.add_points(100)
hotpot_qa_dataset.save()
print(hotpot_qa_dataset)

In [None]:
hotpot_qa_dataset.augment_with_ner()
hotpot_qa_dataset.save()
print(hotpot_qa_dataset)

In [None]:
hotpot_qa_dataset.augment_with_keyword_and_topic(5)
hotpot_qa_dataset.save()
print(hotpot_qa_dataset)

### Without Augmentation

In [None]:
hotpot_qa_dataset = Dataset('hotpot-qa-no-augmentation', 'hotpot-qa')
print(hotpot_qa_dataset)

In [None]:
hotpot_qa_dataset.add_points(1000)
hotpot_qa_dataset.save()
print(hotpot_qa_dataset)