In [3]:
import csv
# import glob
# import hashlib
import json
import os
import random
import re
import time
# import uuid
from typing import Any

# import backoff
import openai
# from httpx import HTTPStatusError
# from tqdm import tqdm

In [5]:
def load_json(path: str) -> Any:
    """load a json file from the specified path."""
    with open(path, 'r') as f:
        return json.load(f)
    
def save_json(path: str, data: Any) -> None:
    """save data to a json file at the specified path."""
    with open(path, 'w') as f:
        json.dump(data, f, indent=4)

def sortf(file_name):
    """
    sort function for concatenate_csv_files() files.
    extracts the part number from the file name.
    very specific to the file naming convention used in the project.
    """
    match = re.search(r'part_(\d+)_rslt\.csv', file_name)
    return int(match.group(1)) if match else float('inf')

def concatenate_csv_files(root_dir):
    """
    concatenate all csv files in the specified directory and its subdirectories.
    the concatenated files are saved as 'results.csv' in the same directory as the csv files.
    very specific to the file naming convention used in the project.
    """
    for subdir, _, files in os.walk(root_dir):
        csv_files = [f for f in files if f.endswith('.csv') and f != 'results.csv']
        
        if csv_files:
            csv_files.sort(key=sortf)
            print(csv_files)
            results_file_path = os.path.join(subdir, 'results.csv')
            
            with open(results_file_path, 'w', newline='') as results_file:
                writer = csv.writer(results_file)
                header_written = False
                
                for file_name in csv_files:
                    file_path = os.path.join(subdir, file_name)
                    with open(file_path, 'r') as csv_file:
                        reader = csv.reader(csv_file)
                        header = next(reader)
                        
                        if not header_written:
                            writer.writerow(header)
                            header_written = True
                        
                        for row in reader:
                            writer.writerow(row)
                
            with open(results_file_path, 'r') as results_file:
                reader = csv.reader(results_file)
                print(f'{results_file_path}: {sum(1 for _ in reader)} rows')

In [6]:
class DevQuestions:
    class Question:
        def __init__(self, identifier: str, question: str, answer: str):
            """
            initialize a Question instance.

            Args:
                identifier (str): the identifier for the question.
                question (str): the question text.
                answer (str): the answer text.
            """
            self.id = identifier
            self.question = question
            self.answer = answer

    def __init__(self, path: str):
        """
        initialize the DevQuestions instance by loading questions from a JSON file.

        Args:
            path (str): the file path to the JSON file containing the questions.
        """
        with open(path, 'r') as f:
            data = json.load(f)

        self.questions = {}
        self.lut = []

        for question in data:
            self.lut.append(question['_id'])
            self.questions[question['_id']] = self.Question(
                question['_id'],
                question['question'],
                question['answer'])

    def __len__(self):
        """
        return the number of questions.

        Returns:
            int: the number of questions.
        """
        return len(self.questions)

    def __getitem__(self, key: int|str):
        """
        retrieve a question by index or id.

        Args:
            key (int | str): the idx or id of the question.

        Returns:
            DevQuestions.Question: the question corresponding to the given index or id.
        """
        if isinstance(key, int):    # if key is an index
            return self.questions.get(self.lut[key])
        elif isinstance(key, str):  # if key is an id (str)
            return self.questions[key]

In [7]:
class QA:
    def __init__(self, path_to_dev: str, model: str='gpt-3.5-turbo', system_prompt: str=''):
        """
        initialize a QA instance.

        Args:
            path_to_dev (str): The file path to the development question.
            model (str, optional): The model to use. Defaults to 'gpt-3.5-turbo'.
            system_prompt (str, optional): The system prompt to use. Defaults to ''.
        """
        self.client = openai.Client()
        self.system_prompt = system_prompt
        self.dev = DevQuestions(path_to_dev)
        self.contexts = {}
        self.model = model

    def batch_by_context_no_context(self,
                                context_name: str,
                                path: str,
                                num_questions: int=0,
                                write_debug_file: bool=True) -> str:
        """
        constructs a batch of requests for the openai API and saves it to a .jsonl file.
        identifies the questions by id given in the context, this version does NOT add any context to the prompt, 
        only asks the question.

        Args:
            context_name (str): the name of the context to use.
            path (str): the directory path to save the .jsonl file.
            num_questions (int, optional): the number of questions to ask - defaults to 0 (all questions).
            write_debug_file (bool, optional): whether to write the prompts to a debug file - defaults to true.

        Returns:
            str: the path to the jsonl file containing the batched requests.
        """
        if context_name not in self.contexts:
            raise ValueError('Context not found!')

        num_questions = self._validate_num_questions(num_questions, context_name)

        print(f"Creating batch for {num_questions} questions from '{context_name}' context...")

        location = f'{path}/{context_name}-no_context/{self._generate_date_string()}/'

        try:
            os.makedirs(location, exist_ok=True)
        except Exception as e:
            print(f"error creating directory '{location}': {e}")

        batch = {}
        prompts = []
        q_ids = []

        for c, (q_id, _) in enumerate(self.contexts[context_name].items()):
            if c == num_questions:
                break

            question = self.dev[q_id]
            prompt = self._create_prompt([], question.question)  # No context added
            prompts.append(prompt)
            q_ids.append(q_id)
            batch[q_id] = self._build_request_dict(q_id, prompt)

            self._save_to_jsonl(batch, location, 'submitted_batch')

        if write_debug_file:
            self._write_prompts_to_file(q_ids, prompts, location, 'debug')

        return f'{location}submitted_batch.jsonl'

    def batch_by_context(self,
                         context_name: str,
                         path: str,
                         num_questions: int=0,
                         k: int=0,
                         write_debug_file: bool=True) -> str:
        """
        constructs a batch of requests for the OpenAI API and saves it to a .jsonl file.
        identifies the questions by id given in the context.

        Args:
            context_name (str): the name of the context.
            location (str): the directory path to save the .jsonl file.
            file_name (str): the name of the jsonl file.
            num_questions (int, optional): the number of questions to ask -defaults to 0 (all questions).
            k (int, optional): the number of contexts to include from the context dictionary - defaults to 0 (all available contexts).
            write_debug_file (bool, optional): whether to write the prompts to a debug file. defaults to True.

        Returns:
            str: the path to the jsonl file containing the batched requests.
        """
        if context_name not in self.contexts:
            raise ValueError('Context not found!')
        
        num_questions = self._validate_num_questions(num_questions, context_name)

        print(f"Creating batch for {num_questions} questions from '{context_name}'.")

        location = f'{path}/{context_name}-k_{k}/{self._generate_date_string()}/'

        try:
            os.makedirs(location, exist_ok=True)
        except Exception as e:
            print(f"Error creating directory '{location}': {e}")

        batch = {}
        k_adjustments = 0
        prompts = []
        q_ids = []

        for c, (q_id, contexts) in enumerate(self.contexts[context_name].items()):
            if c == num_questions:
                break

            k_, k_adjustments = self._adjust_k(k, contexts, k_adjustments)
            contexts = contexts[:k_]

            prompt = self._create_prompt(contexts, self.dev[q_id].question)
            prompts.append(prompt)
            q_ids.append(q_id)
            batch[q_id] = self._build_request_dict(q_id, prompt)

        self._print_k_adjustments_warning(k_adjustments, len(batch))
        self._save_to_jsonl(batch, location, 'submitted_batch')
        
        if write_debug_file: self._write_prompts_to_file(q_ids, prompts, location, 'debug')

        return f'{location}submitted_batch.jsonl'

    def batch_by_structured_mixed_context(self,
                                          context_name1: str,
                                          context_name2: str,
                                          path: str,
                                          num_questions: int=0,
                                          k1: int=0,
                                          k2: int=0,
                                          position: str="first",
                                          write_debug_file: bool=True) -> str:
        """
        constructs a batch of requests for the OpenAI API and saves it to a .jsonl file.
        identifies the questions by id given in the context.
        mixed context from two different contexts given the position and number of contexts to include from each context.

        Args:
            context_name1 (str): the name of the first context.
            context_name2 (str): the name of the second context.
            path (str): the directory path to save the .jsonl file.
            num_questions (int, optional): the number of questions to ask -defaults to 0 (all questions).
            k1 (int, optional): the number of contexts to include from the first context dictionary - defaults to 0 (all available contexts).
            k2 (int, optional): the number of contexts to include from the second context dictionary - defaults to 0 (all available contexts).
            position (str, optional): the position to place contexts from the first context in relation to the second context - defaults to "first".
            write_debug_file (bool, optional): whether to write the prompts to a debug file. defaults to True.

        Returns:
            str: the path to the jsonl file containing the batched requests.
        """
        if set(self.contexts[context_name1].keys()) != set(self.contexts[context_name2].keys()):
            raise ValueError('Context keys do not match!')

        if context_name1 not in self.contexts or context_name2 not in self.contexts:
            raise ValueError('Context not found!')

        num_questions = self._validate_num_questions(num_questions, context_name1)

        print(f"Creating batch for {num_questions} questions from '{context_name1}' and '{context_name2}' contexts...")

        location = f'{path}/{context_name1}-{context_name2}-k1_{k1}-k2_{k2}_{position}/{self._generate_date_string()}/'

        try:
            os.makedirs(location, exist_ok=True)
        except Exception as e:
            print(f"Error creating directory '{location}': {e}")

        batch = {}
        k1_adjustments = 0
        k2_adjustments = 0
        prompts = []
        q_ids = []

        for c, (q_id, contexts1) in enumerate(self.contexts[context_name1].items()):
            if c == num_questions:
                break

            question = self.dev[q_id]
            contexts2 = self.contexts[context_name2].get(q_id, [])

            k1_, k1_adjustments = self._adjust_k(k1, contexts1, k1_adjustments)
            k2_, k2_adjustments = self._adjust_k(k2, contexts2, k2_adjustments)

            combined_contexts = self._combine_contexts(contexts1[:k1_], contexts2[:k2_*2], position)
            prompt = self._create_prompt(combined_contexts, question.question)
            prompts.append(prompt)
            q_ids.append(q_id)
            batch[q_id] = self._build_request_dict(q_id, prompt)

            self._print_k_adjustments_warning(k1_adjustments, len(batch), 'k1')
            self._print_k_adjustments_warning(k2_adjustments, len(batch), 'k2')
            self._save_to_jsonl(batch, location, 'submitted_batch')

        if write_debug_file: self._write_prompts_to_file(q_ids, prompts, location, 'debug')

        return f'{location}submitted_batch.jsonl'
        
    def submit_batch(self, batch_path: str, description: str) -> str:
        """
        submit a batch job to the openai api. saves the batch job ID and description to a CSV file.

        Args:
            batch_path (str): the path to the batch jsonl file to submit
            description (str): description of the batch job, saved to the csv file.
        """
        batch_input_file = self.client.files.create(
            file=open(batch_path, "rb"),
            purpose="batch"
        )
        batch_input_file_id = batch_input_file.id

        batch_job = self.client.batches.create(
            input_file_id=batch_input_file_id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
            metadata={"file_name": batch_path}
        )
        batch_job_id = batch_job.id

        print(f"Batch job '{batch_path}' submitted with ID: {batch_job_id}")
        
        with open('batches.csv', 'a', newline='') as csvfile:
            csvwriter = csv.writer(csvfile)
            csvwriter.writerow([batch_job_id, description, batch_path])
        
        return batch_job_id
    
    def submit_split_batch(self, batch_path: str, description: str, num_splits: int) -> list:
        """
        split a .jsonl file into smaller parts and submit them as separate batch jobs.

        Args:
            batch_path (str): the path to the .jsonl file to split and submit.
            description (str): the description to use for the batch jobs, with the part number appended.
            num_splits (int): the number of splits to create
        """
        split_file_paths = self.split_jsonl_file(batch_path, num_splits)
        batch_job_ids = []

        for index, split_file_path in enumerate(split_file_paths):
            split_description = f"{description} -part {index + 1}"
            batch_job_id = self.submit_batch(split_file_path, split_description)
            batch_job_ids.append(batch_job_id)
            while not self.is_batch_done(batch_job_id) and not self.is_batch_failed(batch_job_id):
                time.sleep(30)

            with open(f'{os.path.dirname(batch_path)}/ids.txt', 'a') as f:
                f.write(f'{batch_job_id}\n')

        return batch_job_ids
    
    def split_jsonl_file(self, jsonl_file_path: str, num_splits: int) -> list:
        """
        split a .jsonl file into smaller parts. each part is saved as a separate file.
        a bit of a hacky solution, but it works.

        Args:
            jsonl_file_path (str): the path to the .jsonl file to split
            num_splits (int): the number of splits to create.
        """
        base_dir = os.path.dirname(jsonl_file_path)
        base_name = os.path.basename(jsonl_file_path).split('.')[0]
        with open(jsonl_file_path, 'r') as file:
            lines = file.readlines()

        total_lines = len(lines)
        lines_per_split = total_lines // num_splits
        split_file_paths = []

        for i in range(num_splits):
            split_file_name = f"{base_name}_part_{i + 1}.jsonl"
            split_file_path = os.path.join(base_dir, split_file_name)
            split_lines = lines[i*lines_per_split:(i+1)* lines_per_split] if i < num_splits-1 else lines[i*lines_per_split:]
            with open(split_file_path, 'w') as split_file:
                split_file.writelines(split_lines)
            split_file_paths.append(split_file_path)

        return split_file_paths
    
    def is_batch_done(self, batch_id: str) -> bool:
        """
        check if a batch job is completed

        args:
            batch_id (str): the ID of the batch job to check.

        returns:
            bool: true if the batch job is completed, otherwise False.
        """
        try:
            batch_status = self.client.batches.retrieve(batch_id)
            status = batch_status.status
            return status == "completed"
        except Exception as e:
            print(f"error checking batch job {batch_id}: {e}")
            return False
        
    def is_batch_failed(self, batch_id: str) -> bool:
        """
        check if a batch job has failed

        args:
            batch_id (str): The ID of the batch job to check.

        returns:
            bool: True if the batch job has failed, otherwise False.
        """
        try:
            batch_status = self.client.batches.retrieve(batch_id)
            status = batch_status.status
            return status in ["failed", "expired", "canceled"]
        except Exception as e:
            print(f"Error checking batch job {batch_id}: {e}")
            return False

    def load_batch_job(self, batch_job_id: str, csv_file: str='batches.csv') -> dict:
        """
        load batch job details from a csv file using the batch job ID.
        this is for retrieving 

        args:
            batch_job_id (str): the ID of the batch job to look up
            csv_file (str, optional): the path to the CSV file containing batch job details - dfaults to 'batches.csv'

        returns:
            dict: the dictionary containing the batch job details.
        """
        with open(csv_file, 'r', newline='') as csvfile:
            csvreader = csv.reader(csvfile)
            for row in csvreader:
                if row[0] == batch_job_id:
                    return {
                        "batch_job_id": row[0],
                        "description": row[1],
                        "jsonl_file_path": row[2]
                    }
        raise ValueError(f"batch job id '{batch_job_id}' not found in csvv.")

    def check_batches(self, batch_ids: list, csv_file: str='batches.csv') -> dict:
        """
        check the status of multiple batch jobs and save the results to a CSV file.

        args:
            batch_ids (list): a list of batch job ids to check.
            csv_file (str, optional): the path to the CSV file containing batch job details - defaults to 'batches.csv'

        returns:
            dict: a dictionary containing the results of the batch jobs
        """
        all_results = {}

        for batch_id in batch_ids:
            batch_details = self.load_batch_job(batch_id, csv_file)
            jsonl_file_path = batch_details["jsonl_file_path"]
            save_results_to = os.path.dirname(jsonl_file_path)
            batch_name = os.path.basename(jsonl_file_path).split('.')[0]

            try:
                batch_status = self.client.batches.retrieve(batch_id)
                status = batch_status.status
                
                if status == "completed":
                    results_file_id = batch_status.output_file_id
                    results_file = self.client.files.content(results_file_id)

                    results_file_path = os.path.join(save_results_to, f"{batch_name}_rslt.jsonl")
                    results = []
                    with open(results_file_path, 'wb') as output_file:
                        for line in results_file.iter_lines():
                            if line.strip():
                                json_line = json.loads(line)
                                results.append(json_line)
                                output_file.write((json.dumps(json_line) + '\n').encode('utf-8'))

                    print(f"batch job results for {batch_id} - {jsonl_file_path} saved to {results_file_path}")

                    try:
                        os.makedirs(os.path.join(save_results_to, 'csv'), exist_ok=True)
                    except Exception as e:
                        print(f"Error creating directory '{save_results_to}/csv': {e}")

                    results_file_path = os.path.join(save_results_to, f"csv/{batch_name}_rslt.csv")
                    self.write_csv(jsonl_file_path, results_file_path, results)
                    all_results[batch_id] = results
                elif status in ["failed", "expired", "canceled"]:
                    print(f"batch job {batch_id} / {jsonl_file_path} failed with status: {status}")
                else:
                    print(f"batch job {batch_id} / {jsonl_file_path} is not yet completed. Current status: {status}")

            except Exception as e:
                print(f"errror checking batch job {batch_id} / {jsonl_file_path}: {e}")
        
        return all_results 

    def set_system_prompt(self, prompt: str):
        """
        sets the system prompt for gpt to use

        args:
            prompt (str): the prompt to set.
        """
        if not isinstance(prompt, str):
            raise TypeError('prompt must be a string!')
        self.system_prompt = prompt

    def get_context(self, name: str) -> dict:
        """
        get a context by name

        args:
            name (str): the name of the context.

        returns:
            dict: the context dictionary.
        """
        if not isinstance(name, str) or name not in self.contexts:
            raise ValueError('Context not found!')
        return self.contexts.get(name)

    def print_context_names(self):
        """
        print the names of all contexts available
        """
        for name in self.contexts.keys():
            print(name)

    def add_context(self, context: dict, name: str):
        """
        add a context. it should be a dictionary with the following structure:
        {
            'question_idA': ['contextA1', 'contextA2', ...],
            'question_idB': ['contextB1', 'contextB2', ...],
            ...
        }

        args:
            context (dict): the context to add.
            name (str): the name of the context.
        """
        if not isinstance(context, dict):
            raise TypeError('context must be a dictionary!')
        if not isinstance(name, str):
            raise TypeError('name must be a string!')
        if name in self.contexts:
            raise ValueError('context already exists!')
        if not all(isinstance(k, str) and isinstance(v, list) for k, v in context.items()):
            raise TypeError('invalid format!')

        self.contexts[name] = context

    def write_csv(self, batch_file: str, target: str, responses: list):
        """
        Write responses to a CSV file.

        Args:
            path_to_batch (str): The file path to the batch JSONL file.
            responses (list): List of response dictionaries from the batch request.
            csv_path (str): The file path to save the CSV file.
        """
        with open(batch_file, 'r', encoding='utf-8') as batch:
            batch_prompts = {}
            for line in batch:
                request = json.loads(line)
                q_id = request['custom_id']
                prompt = request['body']['messages'][-1]['content']
                batch_prompts[q_id] = prompt

        with open(target, 'w', newline='', encoding='utf-8') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(["Question ID", "GPT Response", "Ground Truth Answer", "Prompt"])

            for response in responses:
                q_id = response['custom_id']
                gpt_response = response['response']['body']['choices'][0]['message']['content']
                question = self.dev[q_id]
                ground_truth_answer = question.answer
                prompt = batch_prompts[q_id]

                writer.writerow([q_id, gpt_response, ground_truth_answer, prompt])

    def _generate_date_string(self):
        """
        generatw a date string in the format YYYYMMDD_HHMMSS.

        Returns:
            str: The generated date string.
        """
        date_string = time.strftime("%Y%m%d_%H%M%S")
        return date_string
    
    def _validate_num_questions(self, num_questions: int, context_name: str) -> int:
        """
        Validate the number of questions to be processed.

        Args:
            num_questions (int): The number of questions to ask. If 0, all questions are processed.
            context_name (str): The name of the context.

        Returns:
            int: The validated number of questions to process.
        """
        total_questions = len(self.contexts[context_name])
        if num_questions <= 0 or num_questions > total_questions:
            num_questions = total_questions
            print('Using all available questions.')
        return num_questions

    def _combine_contexts(self, contexts1: list[str], contexts2: list[str], position: str) -> list[str]:
        """
        Combine two lists of contexts based on the specified position.

        Args:
            contexts1 (list[str]): The first list of contexts.
            contexts2 (list[str]): The second list of contexts.
            position (str): The position to place contexts from contexts1 in relation to contexts2.
                            Can be "first", "middle", or "last".

        Returns:
            list[str]: The combined list of contexts.
        """
        if position == "first":
            return contexts1 + contexts2
        elif position == "middle":
            half = len(contexts2) // 2
            return contexts2[:half] + contexts1 + contexts2[half:]
        elif position == "last":
            return contexts2 + contexts1
        else:
            raise ValueError('Invalid position value. It should be "first", "middle", or "last".')

    def _build_request_dict(self, q_id: str, prompt: str) -> dict:
        """
        Build a request dictionary for the OpenAI API.

        Args:
            q_id (str): The question ID.
            prompt (str): The prompt to be used in the API request.

        Returns:
            dict: The request dictionary.
        """
        return {
            "custom_id": q_id,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                "model": self.model, 
                "messages": [
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": prompt}],
                "max_tokens": 60
                }
            }   

    def _create_prompt(self, contexts:list[str], question:str) -> str:
        """
        create a prompt from the contexts and question.

        args:
            contexts (list[str]): the contexts
            question (str): the questionn
        """
        prompt = ''
        for context in contexts:
            prompt += context + '\n\n'
        prompt += question
        return prompt

    def _adjust_k(self, k: int, contexts: list[str], k_adjustments: int) -> tuple:
        """
        Adjust the value of k based on the length of the contexts list.

        Args:
            k (int): The desired number of contexts to include.
            contexts (list[str]): The list of contexts available.
            k_adjustments (int): The current count of k adjustments.

        Returns:
            tuple: The adjusted value of k and the updated count of k adjustments.
        """
        if k == 0 or k > len(contexts):
            k_adjustments += 1
            k = len(contexts)
        return k, k_adjustments
    
    def _print_k_adjustments_warning(self, k_adjustments: int, total: int, k_label: str = 'k') -> None:
        """
        Print a warning if there were adjustments made to the value of k.

        Args:
            k_adjustments (int): The number of adjustments made to k.
            num_batches (int): The number of batches processed.
            k_label (str, optional): The label for k. Defaults to 'k'.
        """
        if k_adjustments:
            print(f"WARNING: total adjustments of '{k_label}': {k_adjustments} of {total}")
            print(f"This warning is because the specified {k_label} is outside the valid range or exceeds available contexts.")

    def _write_prompts_to_file(self, ids: list[str], prompts: list[str], path: str, file_name: str) -> None:
        """
        write prompts to a text file.

        Args:
            prompts (list[str]): the prompts to write.
            file_path (str): the file path to write to.
        """

        with open(f'{path}/{file_name}.txt', 'w') as file:
            for i, prompt in enumerate(prompts):
                file.write(f"Question ID: {ids[i]}\n")
                file.write(f"{prompt}\n\n")
                file.write('--------------------------------------\n\n')

    def _save_to_jsonl(self, batch: dict, path: str, file_name: str) -> None:
        """
        save the batch requests to a JSONL file.

        Args:
            batch (dict): The batch of requests to save.
            path (str): The file path to save the .jsonl file.
        """
        with open(f'{path}/{file_name}.jsonl', 'w', encoding='utf-8') as f:
            for request in batch.values():
                f.write(json.dumps(request) + '\n')

In [None]:
top10 = load_json('contexts/results_10.json')

def remove_duplicates(data):
    for key, value in data.items():
        unique_entries = []
        seen_entries = set()
        for entry in value:
            cleaned_entry = ''.join(entry.split()).lower()
            if cleaned_entry not in seen_entries:
                unique_entries.append(entry)
                seen_entries.add(cleaned_entry)
        data[key] = unique_entries
    return data

keys = list(top10.keys())

top10 = remove_duplicates(top10)

for key, value in top10.items():
    for i, _ in enumerate(value):
        top10[key][i] = value[i].replace('\n\n', '\n')

save_json('contexts/top10_no_duplicates.json', top10)    

_dev = load_json('dataset/_dev.json')

In [None]:
oracle = {}

for question in _dev:  # iterate over each question in the full dataset
    question_id = question['_id'] # get the id of the question
    if question_id not in keys: continue  # skip if the question is not part of the ones from top10.json

    full_context = question['context']  # get all contexts of the question
    supporting_facts = question['supporting_facts']  # get the supporting facts of the question

    context_list = []  # initialize an empty list to store context strings
    oracle[question_id] = []  # initialize an empty list for the current question id in gold_context dict

    for idx_context, context in enumerate(full_context):  # iterate over each given context in the dev dataset
        title = context[0]  # get the title of the context
        if title not in [sf[0] for sf in supporting_facts]: continue  # skip if the title is not in the supporting facts
        # because we only want to include the context if it is a supporting fact

        string = title + '\n' # start the context string with the title and a newline
        for i, c in enumerate(context[1]): # iterate over each sentence in the context
            string += c  # append the sentence to the context string

        oracle[question_id].append(string) # add the context string to the list for the currennt question id

save_json('contexts/oracle.json', oracle)

In [None]:
hard_negatives = {}

for question in _dev: # iterate over each question in the _dev dataset
    question_id = question['_id'] # get the id of the question
    if question_id not in keys: continue # skip if the question is not part of the ones from top10.json
    
    hard_negatives[question_id] = []  # initialize an empty list for the current question id in hard_negatives
    supporting_facts = question['supporting_facts'] # get the supporting facts of the question

    titles_top10 = [t.split('\n')[0] for t in top10[question_id]] # get the titles from the top10 for the current question id

    for idx_title, title in enumerate(titles_top10): # iterate over each title in the top10 titles
        if title not in [sf[0] for sf in supporting_facts]: # if the title is not in the supporting facts
            hard_negatives[question_id].append(top10[question_id][idx_title]) # we only want to include the context if it is not a
            # supporting fact but seems useful according to the retriever

save_json('contexts/hard_negatives.json', hard_negatives)

In [None]:
randomly_drawn = {}
complete_corpus = list(load_json('dataset/wiki_musique_corpus.json').values())

print(len(complete_corpus))

num_contexts = 10

for key in keys:
    randomly_drawn[key] = []
    for _ in range(10):
        choice = random.choice(complete_corpus)
        title = choice['title']
        text = choice['text']
        randomly_drawn[key].append(title + '\n' + text)

save_json('contexts/random.json', randomly_drawn)

In [8]:
hard_negatives = load_json('contexts/hard_negatives.json')
oracle = load_json('contexts/oracle.json')
randomly_drawn = load_json('contexts/random.json')
top10 = load_json('contexts/top10_no_duplicates.json')
gibberish = load_json('contexts/gibberish.json')

In [10]:
system_prompt = 'Answer in a concise way. NO full sentences! As few words as possible! For example just a date, the name of a place, the name of a person, a number, a yes or no, etc.'

qa = QA('dataset/_dev.json', model='gpt-3.5-turbo', system_prompt=system_prompt)

qa.add_context(hard_negatives, 'hard_negatives')
qa.add_context(oracle, 'oracle')
qa.add_context(randomly_drawn, 'random')
qa.add_context(top10, 'top10')
qa.add_context(gibberish, 'gibberish')

# print system prompt
print('SYSTEM PROMPT:')
print(qa.system_prompt)

# print context names
print('\nCONTEXTS:')
qa.print_context_names()

SYSTEM PROMPT:
Answer in a concise way. NO full sentences! As few words as possible! For example just a date, the name of a place, the name of a person, a number, a yes or no, etc.

CONTEXTS:
hard_negatives
oracle
random
top10
gibberish


## 3. RANDOM

In [None]:
actually_prompt = True
num_questions = 0
description = 'random injected noise, serious run 1'

K_top10 = [1, 3, 5]
positions = ['first', 'middle', 'last']

current_ids = []

for k in K_top10:
    for position in positions:
        path = f'results/injected_noise/{k}_{position}'
        batch_path = qa.batch_by_structured_mixed_context('top10', 'random', 'results', num_questions=num_questions, k1=k, k2=2, position=position, write_debug_file=True)
        if actually_prompt:
            batch_job_ids = qa.submit_split_batch(batch_path, f'{description}, k={k}, pos={position}', 10)
            current_ids.extend(batch_job_ids)

In [None]:
result = qa.check_batches(current_ids, 'batches.csv')

In [None]:
qa.concatenate_csv_files('./results')

## 4. HARD NEGATIVES

In [None]:
qa.actually_prompt = True
actually_prompt = True

path = 'results/hard_negatives'

K = [1, 3, 5]

for k in K:
    if k > 1: break
    file_name = f'e4_qa_hard_negatives_k{k}'
    batch_loc_name = qa.batch_by_context('hard_negatives', path, file_name, num_questions=1, k=k, write_debug_file=True)
    if actually_prompt:
        responses = qa.ask_batch(batch_loc_name[0], batch_loc_name[1], f'QA on hard negatives with k={k}')
        qa.write_csv(batch_loc_name[0], batch_loc_name[1], responses)
    

## ALL RANDOM

In [23]:
actually_prompt = True
K = [1, 3, 5, 10]

current_ids = []

description = 'just random noise'

for k in K:
    path = f'results/'
    batch_loc_name = qa.batch_by_context('random', path, num_questions=5, k=k, write_debug_file=True)
    print(batch_loc_name)
    # if actually_prompt:
    #         batch_job_ids = qa.submit_split_batch(batch_loc_name[0], description, 2)
    #         current_ids.extend(batch_job_ids)

print(current_ids)

result = qa.check_batches(current_ids, 'batches.csv')
#concatenate_csv_files('./results')

Creating batch for 5 questions from 'random'.
results/submitted_batch.jsonl
Creating batch for 5 questions from 'random'.
results/submitted_batch.jsonl
Creating batch for 5 questions from 'random'.
results/submitted_batch.jsonl
Creating batch for 5 questions from 'random'.
results/submitted_batch.jsonl
[]
