In [1]:
# This notebeook is for displaying statistics about the ood dataset split

In [2]:
# load corpus dataset
from src.data.create_corpus_dataset import create_dataset, DatasetConfig, UtteranceType, DatasetSplitType
import os
from datasets import load_from_disk
from src.data.dataset_splits import create_splits_from_corpus_dataset
from collections import defaultdict
import pandas as pd

# load dataset
dataset_folder = "../../data/processed/"
dataset_path = os.path.join(dataset_folder, "corpus_dataset_v1")

if not os.path.exists(dataset_path):
    # Beispiel zum Erstellen eines Datensatzes. Mögliche Optionen von DatasetConfig sind im DocString beschrieben.
    create_dataset(
        DatasetConfig(
            dataset_path=dataset_path,
            project_dir="../../",
            num_previous_turns=3,
            include_role=True,
            sep_token="[SEP]",
            utterance_type=UtteranceType.User,
            eval_size=0.5,
            validation_test_ratio=0.5
        )
    )

# Beispiel zum Laden des Datensatzes + collate_function des DataLoaders um dynamisch ein Subset der negative passages zu laden.
corpus_dataset = load_from_disk(dataset_path)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
    # Create an Out-of-Distribution (Simple) split
ood_split_simple = create_splits_from_corpus_dataset(
    corpus_dataset,
    dataset_split_type=DatasetSplitType.OutOfDistributionSimple,
    save_folder=dataset_folder,
    dataset_save_name="dataset_split_out_of_distribution_simple",
    seed=424
)
ood_split_simple["train"]["queries"]


{'Z.K1'}
MEDAI
{'FAQ.3', 'Z.K5-1', 'Z.P7', 'Z.P3', 'Z.P3-2', 'FAQ.7', 'Z.K16-1', 'Z.K15-1', 'Z.K3-1-1-1', 'Z.K5'}

{'Z.P1-1-1'}
MEDAI
{'FAQ.3', 'Z.K5-1', 'Z.P7', 'Z.P3', 'Z.P3-2', 'FAQ.7', 'Z.K16-1', 'Z.K15-1', 'Z.K3-1-1-1', 'Z.K5'}

{'Z.P3'}
MEDAI
{'FAQ.3', 'Z.K5-1', 'Z.P7', 'Z.P3', 'Z.P3-2', 'FAQ.7', 'Z.K16-1', 'Z.K15-1', 'Z.K3-1-1-1', 'Z.K5'}

{'Z.P3'}
MEDAI
{'FAQ.3', 'Z.K5-1', 'Z.P7', 'Z.P3', 'Z.P3-2', 'FAQ.7', 'Z.K16-1', 'Z.K15-1', 'Z.K3-1-1-1', 'Z.K5'}

{'Z.P5'}
MEDAI
{'FAQ.3', 'Z.K5-1', 'Z.P7', 'Z.P3', 'Z.P3-2', 'FAQ.7', 'Z.K16-1', 'Z.K15-1', 'Z.K3-1-1-1', 'Z.K5'}

{'Z.P4-1'}
MEDAI
{'FAQ.3', 'Z.K5-1', 'Z.P7', 'Z.P3', 'Z.P3-2', 'FAQ.7', 'Z.K16-1', 'Z.K15-1', 'Z.K3-1-1-1', 'Z.K5'}

{'Z.K6'}
MEDAI
{'FAQ.3', 'Z.K5-1', 'Z.P7', 'Z.P3', 'Z.P3-2', 'FAQ.7', 'Z.K16-1', 'Z.K15-1', 'Z.K3-1-1-1', 'Z.K5'}

{'Z.K1'}
MEDAI
{'FAQ.3', 'Z.K5-1', 'Z.P7', 'Z.P3', 'Z.P3-2', 'FAQ.7', 'Z.K16-1', 'Z.K15-1', 'Z.K3-1-1-1', 'Z.K5'}

{'Z.P1'}
MEDAI
{'FAQ.3', 'Z.K5-1', 'Z.P7', 'Z.P3', 'Z.P3-2', 'FAQ.7', 'Z.K

ValueError: Test set should be at most 0.18 of the total dataset. but is 0.5480734605689593. Choose a different seed.

In [None]:
# Count labels per scenario and split
label_counts = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))

for split in ["train", "validation", "test"]:
    for query in ood_split_simple[split]["queries"]:
        scenario = query["discussion_scenario"]
        for label in query["labels"]:
            label_counts[split][scenario][label] += 1

# Build dataframe
records = []
for split, split_dict in label_counts.items():
    for scenario, labels in split_dict.items():
        for label, count in labels.items():
            records.append({
                "scenario": scenario,
                "label": label,
                "split": split,
                "count": count
            })

dataframe = pd.DataFrame(records)
dataframe

In [None]:
import matplotlib.pyplot as plt
# # Plot stacked bar chart per scenario
scenarios = dataframe["scenario"].unique()
for scenario in scenarios:
    df_scenario = dataframe[dataframe["scenario"] == scenario]
    pivot_df = df_scenario.pivot(index="label", columns="split", values="count").fillna(0)
    pivot_df.plot(kind="bar", stacked=True, color=["blue", "yellow", "red"], figsize=(12, 6))
    plt.title(f"Label Distribution in {scenario}")
    plt.ylabel("Count")
    plt.xlabel("Label")
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.show()

In [None]:
def check_label_distribution(df):
    """
    Checks that any (scenario, label) pair found in 'validation' or 'test'
    also exists in 'train'. If not, prints out the faulty scenario-label pairs.
    Otherwise, confirms that the distribution is valid.
    """
    # Get all (scenario, label) pairs that appear in train (count > 0)
    train_pairs = df[(df["split"] == "train") & (df["count"] > 0)][["scenario", "label"]]
    dis_allowed_pairs = set(zip(train_pairs["scenario"], train_pairs["label"]))

    # Find all pairs in test that have count > 0
    non_train = df[df["split"].isin(["test"]) & (df["count"] > 0)].copy()

    # Mark which of these are allowed
    non_train["is_allowed"] = non_train.apply(
        lambda row: (row["scenario"], row["label"]) not in dis_allowed_pairs, axis=1
    )

    # Collect the ones that are not allowed
    faulty = non_train[~non_train["is_allowed"]]

    if len(faulty) > 0:
        print("Found scenario/label pairs in test that do appear in train:")
        print(faulty[["scenario", "label", "split", "count"]])
    else:
        print("All scenario/label pairs in test are valid (they appear in train).")

check_label_distribution(dataframe)