In [3]:
from typing import List

from compute_data_overlap_metrics import (
    create_all_data_overlap_stats,
    get_all_data_overlap_stats,
    create_ngram_index,
    EntryDataOverlapKey,
    Ngram,
    NgramIndex,
    AllDataOverlapStats
)
from data_overlap_spec import DataOverlapStats, DataOverlapStatsKey, OverlapProtocolSpec
from light_scenario import LightScenario, LightInstance, LightScenarioKey
from light_tokenizer import LightTokenizer, DefaultTokenizer
from data_overlap_stats import (
    OldDataOverlapStatsKey,
    OldDataOverlapStats,
    PART_INPUT,
    PART_REF,
)
from scenarios.scenario import ScenarioSpec

N_VALUES = [5, 13]

ALL_DATA_OVERLAP_STATS = [
    DataOverlapStats(
        data_overlap_stats_key=DataOverlapStatsKey(
            light_scenario_key=LightScenarioKey(
                scenario_spec=ScenarioSpec(
                    class_name="helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario", args={}
                ),
                split="test",
            ),
            overlap_protocol_spec=OverlapProtocolSpec(N=13),
        ),
        instance_ids_with_overlapping_input=["id1"],
        instance_ids_with_overlapping_reference=["id1"],
    ),
    DataOverlapStats(
        data_overlap_stats_key=DataOverlapStatsKey(
            light_scenario_key=LightScenarioKey(
                scenario_spec=ScenarioSpec(
                    class_name="helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario2", args={}
                ),
                split="test",
            ),
            overlap_protocol_spec=OverlapProtocolSpec(N=5),
        ),
        instance_ids_with_overlapping_input=["id1"],
        instance_ids_with_overlapping_reference=["id1"],
    ),
]


TEST_DOCUMENT: str = (
    "The Center for Research on Foundation Models (CRFM) is "
    "an interdisciplinary initiative born out of the Stanford "
    "Institute for Human-Centered Artificial Intelligence (HAI) "
    "that aims to make fundamental advances in the study, development, "
    "and deployment of foundation models."
)

TEST_TOKENS_SPLIT_BY_SPACE: List[str] = [
    "The",
    "Center",
    "for",
    "Research",
    "on",
    "Foundation",
    "Models",
    "(CRFM)",
    "is",
    "an",
    "interdisciplinary",
    "initiative",
    "born",
    "out",
    "of",
    "the",
    "Stanford",
    "Institute",
    "for",
    "Human-Centered",
    "Artificial",
    "Intelligence",
    "(HAI)",
    "that",
    "aims",
    "to",
    "make",
    "fundamental",
    "advances",
    "in",
    "the",
    "study,",
    "development,",
    "and",
    "deployment",
    "of",
    "foundation",
    "models.",
]

TEST_TOKENS_BY_DEFAULT_TOKENIZER: List[str] = [
    "the",
    "center",
    "for",
    "research",
    "on",
    "foundation",
    "models",
    "crfm",
    "is",
    "an",
    "interdisciplinary",
    "initiative",
    "born",
    "out",
    "of",
    "the",
    "stanford",
    "institute",
    "for",
    "human",
    "centered",
    "artificial",
    "intelligence",
    "hai",
    "that",
    "aims",
    "to",
    "make",
    "fundamental",
    "advances",
    "in",
    "the",
    "study",
    "development",
    "and",
    "deployment",
    "of",
    "foundation",
    "models",
    "",
]

TEST_SCENARIO_1 = LightScenario(
    scenario_key=LightScenarioKey(
        scenario_spec=ScenarioSpec(
            class_name="helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario", args={}
        ),
        split="test",
    ),
    instances=[
        LightInstance(input="Center for Research on Foundation", references=["bar", "baz"], id="id1"),
        LightInstance(input="bar bar", references=["foo", "baz"], id="id2"),
    ],
)
TEST_SCENARIO_2 = LightScenario(
    scenario_key=LightScenarioKey(
        scenario_spec=ScenarioSpec(
            class_name="helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario2", args={}
        ),
        split="test",
    ),
    instances=[LightInstance(input=TEST_DOCUMENT, references=[TEST_DOCUMENT, TEST_DOCUMENT], id="id1")],
)



In [4]:
scenarios = [TEST_SCENARIO_1, TEST_SCENARIO_2]
all_overlap_stats: AllDataOverlapStats
all_overlap_stats = create_all_data_overlap_stats(light_scenarios=scenarios, n_values=N_VALUES)

 

Initializing all data overlap stats


In [5]:
n_values = N_VALUES
from data_overlap_stats import (
    OldDataOverlapStatsKey,
    OldDataOverlapStats,
    PART_INPUT,
    PART_REF,
)
all_overlap_stats: AllDataOverlapStats = {}
for scenario in scenarios:
    for n in n_values:
        # Initialize a stats instance for every pair of <scenario, n>
        stats: OldDataOverlapStats = OldDataOverlapStats.from_scenario(scenario, stats_tags={"N": n})
        if stats.stats_key in all_overlap_stats:
            print(stats.stats_key)
            print('hi')
        all_overlap_stats[stats.stats_key] = stats
 

In [6]:
print(all_overlap_stats)

{OldDataOverlapStatsKey(metadata={'light_scenario_key': LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario', args={}), split='test'), 'N': 5}): <data_overlap_stats.OldDataOverlapStats object at 0x13f4e11b0>, OldDataOverlapStatsKey(metadata={'light_scenario_key': LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario', args={}), split='test'), 'N': 13}): <data_overlap_stats.OldDataOverlapStats object at 0x13f4e0df0>, OldDataOverlapStatsKey(metadata={'light_scenario_key': LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario2', args={}), split='test'), 'N': 5}): <data_overlap_stats.OldDataOverlapStats object at 0x13f4e1390>, OldDataOverlapStatsKey(metadata={'light_scenario_key': LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenari

In [13]:
from typing import List, DefaultDict, Set

from collections import defaultdict
from compute_data_overlap_metrics import (
    create_all_data_overlap_stats,
    compute_document_data_overlap,
    create_ngram_index,
    EntryDataOverlapKey,
    Ngram,
    NgramIndex,
    AllDataOverlapStats,
)

tokenizer = LightTokenizer()
stats_keys = set()
scenarios = [TEST_SCENARIO_1, TEST_SCENARIO_2]
ngram_index: NgramIndex
ngram_index = create_ngram_index(
    light_scenarios=scenarios, n_values=N_VALUES, tokenizer=tokenizer, stats_keys=stats_keys
)

stats_1_key, stats_2_key, stats_3_key = (
    DataOverlapStatsKey(
        light_scenario_key=TEST_SCENARIO_1.scenario_key, overlap_protocol_spec=OverlapProtocolSpec(N=5)
    ),
    DataOverlapStatsKey(
        light_scenario_key=TEST_SCENARIO_2.scenario_key, overlap_protocol_spec=OverlapProtocolSpec(N=5)
    ),
    DataOverlapStatsKey(
        light_scenario_key=TEST_SCENARIO_2.scenario_key, overlap_protocol_spec=OverlapProtocolSpec(N=13)
    ),
)

test_5_gram: Ngram = ("Center", "for", "Research", "on", "Foundation")

stats_key_to_input_ids: DefaultDict[DataOverlapStatsKey, Set] = defaultdict(set)
stats_key_to_reference_ids: DefaultDict[DataOverlapStatsKey, Set] = defaultdict(set)

compute_document_data_overlap(
    document=TEST_DOCUMENT,
    ngram_index=ngram_index,
    tokenizer=tokenizer,
    stats_key_to_input_ids=stats_key_to_input_ids,
    stats_key_to_reference_ids=stats_key_to_reference_ids,
)

Building ngram indexes for LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario', args={}), split='test')
Building ngram indexes for LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario2', args={}), split='test')


In [14]:
stats_key_to_input_ids

defaultdict(set,
            {DataOverlapStatsKey(light_scenario_key=LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario2', args={}), split='test'), overlap_protocol_spec=OverlapProtocolSpec(N=5)): {'id1'},
             DataOverlapStatsKey(light_scenario_key=LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario', args={}), split='test'), overlap_protocol_spec=OverlapProtocolSpec(N=5)): {'id1'},
             DataOverlapStatsKey(light_scenario_key=LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario2', args={}), split='test'), overlap_protocol_spec=OverlapProtocolSpec(N=13)): {'id1'}})

In [15]:
stats_key_to_reference_ids

defaultdict(set,
            {DataOverlapStatsKey(light_scenario_key=LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario2', args={}), split='test'), overlap_protocol_spec=OverlapProtocolSpec(N=5)): {'id1'},
             DataOverlapStatsKey(light_scenario_key=LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.natural_qa_scenario.NaturalQAScenario2', args={}), split='test'), overlap_protocol_spec=OverlapProtocolSpec(N=13)): {'id1'}})