In [1]:
import json
import argparse
import os
import glob

from typing import List, Tuple, Set, Any, Optional
from nltk import ngrams
from typing import Dict
from tqdm import tqdm
from dataclasses import dataclass

from light_scenario import LightInstance, LightScenario, LightScenarioKey
from light_tokenizer import LightTokenizer, DefaultTokenizer
from load_documents import get_document_iterator
from data_overlap_stats import (
    DataOverlapStats,
    DataOverlapStatsKey,
    PART_INPUT,
    PART_REF,
)
from common.hierarchical_logger import hlog, htrack_block
from common.general import asdict_without_nones, write
from scenarios.scenario import ScenarioSpec


# The n values of the ngrams to be computed
N_VALUES: List[int] = [5, 9, 13]  # TODO: Pick the N values


@dataclass(frozen=True)
class EntryDataOverlapKey:
    """Unique key representing either the input or references of a single instance in a scenario."""

    stats_key: DataOverlapStatsKey
    part: str
    """Either PART_INPUT or PART_REF"""
    instance_id: int


# type alias for overlap-related data structures
Ngram = Tuple[str, ...]
NgramIndex = Dict[int, Dict[Ngram, Set[EntryDataOverlapKey]]]
AllDataOverlapStats = Dict[DataOverlapStatsKey, DataOverlapStats]
NgramCounter = Dict[EntryDataOverlapKey, Dict[Ngram, int]]


In [2]:

def load_light_scenarios_from_jsonl(path: str) -> List[LightScenario]:
    """
    Create a list of light scenarios from a jsonl file, where each json represents a LightScenario object.

    Input file format:

    Instance JSON 1
    Instance JSON 2
    Instance JSON 3
    ...
    """

    def create_light_instance_from_dict(instance_dict: dict) -> LightInstance:
        return LightInstance(input=instance_dict["input"], references=instance_dict["references"], id=instance_dict["id"])

    light_scenarios: List[LightScenario] = []
    light_scenario_jsons = open(path, "r").readlines()
    for light_scenario_json in light_scenario_jsons:
        light_scenario_dict: dict = json.loads(light_scenario_json)

        light_scenario_key_dict: dict = light_scenario_dict["scenario_key"]
        # if the light_scenarios are exported from helm, they will have a scenario_spec field
        scenario_spec = ScenarioSpec(**light_scenario_key_dict["scenario_spec"])
        light_scenario_key = LightScenarioKey(scenario_spec=scenario_spec, split=light_scenario_key_dict["split"])
        light_instances: List[LightInstance] = [
            create_light_instance_from_dict(instance_dict) for instance_dict in light_scenario_dict["instances"]
        ]
        light_scenarios.append(LightScenario(scenario_key=light_scenario_key, instances=light_instances))
    return light_scenarios

In [3]:
light_scenarios = load_light_scenarios_from_jsonl('run_specs_small')

TypeError: LightScenarioKey.__init__() got an unexpected keyword argument 'scenario_spec'

In [None]:
light_scenarios[0]

In [None]:

def create_all_data_overlap_stats(light_scenarios: List[LightScenario], n_values: List[int]) -> AllDataOverlapStats:
    """Given a list of scenarios and n values, initialize all_overlap_stats"""
    hlog("Initializing all data overlap stats")
    all_overlap_stats: AllDataOverlapStats = {}
    for scenario in light_scenarios:
        for n in n_values:
            # Initialize a stats instance for every pair of <scenario, n>
            stats: DataOverlapStats = DataOverlapStats.from_scenario(scenario, stats_tags={"N": n})
            if stats.stats_key in all_overlap_stats:
                raise ValueError("Duplicated settings detected.")
            all_overlap_stats[stats.stats_key] = stats
    return all_overlap_stats



In [None]:
all_overlap_stats = create_all_data_overlap_stats(light_scenarios=light_scenarios, n_values=N_VALUES) = create_all_data_overlap_stats(light_scenarios=light_scenarios, n_values=N_VALUES)

In [None]:
type(all_overlap_stats)
for key in all_overlap_stats.keys():
    print(key)
    break

In [None]:
def create_ngram_index(
    light_scenarios: List[LightScenario], n_values: List[int], tokenizer: LightTokenizer
) -> NgramIndex:
    """Given a list of scenarios and n values, initialize ngram_index"""
    ngram_index: NgramIndex = {n: {} for n in n_values}
    for scenario in light_scenarios:
        hlog(f"Building ngram indexes for {scenario.scenario_key}")
        for n in n_values:
            stats_key = DataOverlapStatsKey(metadata={"light_scenario_key": scenario.scenario_key, "N": n})
            for i in range(len(scenario.instances)):
                instance = scenario.instances[i]
                input_tokens = tokenizer.tokenize(instance.input)
                for input_ngram in ngrams(input_tokens, n):
                    if input_ngram not in ngram_index[n]:
                        ngram_index[n][input_ngram] = set()
                    ngram_index[n][input_ngram].add(
                        EntryDataOverlapKey(stats_key=stats_key, instance_id=i, part=PART_INPUT)
                    )

                # compute reference ngrams
                for reference in instance.references:
                    reference_unigrams = tokenizer.tokenize(reference)
                    for reference_ngram in ngrams(reference_unigrams, n):
                        if reference_ngram not in ngram_index[n]:
                            ngram_index[n][reference_ngram] = set()
                        ngram_index[n][reference_ngram].add(
                            EntryDataOverlapKey(stats_key=stats_key, instance_id=i, part=PART_REF)
                        )
    return ngram_index


In [None]:
tokenizer = LightTokenizer()
ngram_index = create_ngram_index(light_scenarios=light_scenarios, n_values=N_VALUES, tokenizer=tokenizer)
       

In [None]:
type(ngram_index)
ngram_index.keys()

In [None]:
type(ngram_index[5])
ngram_index[5].keys()

In [None]:
for k, v in ngram_index[5].items():
    if len(v) >= 1:
        print(k)
        print(v)
        break

In [None]:
k

In [None]:
for vv in v:
    print(vv)

In [None]:
stats_key = vv.stats_key

In [None]:
stats_key.metadata

In [None]:
vv.part

In [None]:
vv.instance_id