In [1]:
# load corpus dataset
from src.data.create_corpus_dataset import create_dataset, DatasetConfig, UtteranceType, DatasetSplitType
import os
from datasets import load_from_disk
from src.data.dataset_splits import create_splits_from_corpus_dataset
from collections import defaultdict
import pandas as pd

# load dataset
dataset_folder = "../../data/processed/"
dataset_path = os.path.join(dataset_folder, "corpus_dataset_v1")

if not os.path.exists(dataset_path):
    # Beispiel zum Erstellen eines Datensatzes. Mögliche Optionen von DatasetConfig sind im DocString beschrieben.
    create_dataset(
        DatasetConfig(
            dataset_path=dataset_path,
            project_dir="../../",
            num_previous_turns=3,
            include_role=True,
            sep_token="[SEP]",
            utterance_type=UtteranceType.User,
            eval_size=0.5,
            validation_test_ratio=0.5
        )
    )

# Beispiel zum Laden des Datensatzes + collate_function des DataLoaders um dynamisch ein Subset der negative passages zu laden.
corpus_dataset = load_from_disk(dataset_path)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load split dataset
in_distribution_split = create_splits_from_corpus_dataset(corpus_dataset=corpus_dataset,
                                                          dataset_split_type=DatasetSplitType.InDistribution,
                                                          save_folder=dataset_folder,
                                                          dataset_save_name="dataset_split_in_distribution", )
in_distribution_split
print("a")

Saving the dataset (1/1 shards): 100%|██████████| 1943/1943 [00:00<00:00, 169785.47 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 10392/10392 [00:00<00:00, 1455138.12 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1943/1943 [00:00<00:00, 196905.69 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1943/1943 [00:00<00:00, 261521.49 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 416/416 [00:00<00:00, 61609.07 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 10392/10392 [00:00<00:00, 953830.82 examples/s] 
Saving the dataset (1/1 shards): 100%|██████████| 416/416 [00:00<00:00, 80786.67 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 416/416 [00:00<00:00, 80887.79 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 418/418 [00:00<00:00, 71329.96 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 10392/10392 [00:00<00:00, 983687.82 examples/s] 
Saving the dataset (1/1 shards): 100%|█████

a





In [3]:

# Count labels per scenario and split
label_counts = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))

for split in ["train", "validation", "test"]:
    for query in in_distribution_split[split]["queries"]:
        scenario = query["discussion_scenario"]
        for label in query["labels"]:
            label_counts[split][scenario][label] += 1

# Build dataframe
records = []
for split, split_dict in label_counts.items():
    for scenario, labels in split_dict.items():
        for label, count in labels.items():
            records.append({
                "scenario": scenario,
                "label": label,
                "split": split,
                "count": count
            })

dataframe = pd.DataFrame(records)

# # Plot stacked bar chart per scenario
# scenarios = df["scenario"].unique()
# for scenario in scenarios:
#     df_scenario = df[df["scenario"] == scenario]
#     pivot_df = df_scenario.pivot(index="label", columns="split", values="count").fillna(0)
#     pivot_df.plot(kind="bar", stacked=True, color=["blue", "yellow", "red"], figsize=(12, 6))
#     plt.title(f"Label Distribution in {scenario}")
#     plt.ylabel("Count")
#     plt.xlabel("Label")
#     plt.xticks(rotation=90)
#     plt.tight_layout()
#     plt.show()


In [4]:
def check_label_distribution(df):
    """
    Checks that any (scenario, label) pair found in 'validation' or 'test'
    also exists in 'train'. If not, prints out the faulty scenario-label pairs.
    Otherwise, confirms that the distribution is valid.
    """
    # Get all (scenario, label) pairs that appear in train (count > 0)
    train_pairs = df[(df["split"] == "train") & (df["count"] > 0)][["scenario", "label"]]
    allowed_pairs = set(zip(train_pairs["scenario"], train_pairs["label"]))

    # Find all pairs in validation/test that have count > 0
    non_train = df[df["split"].isin(["validation", "test"]) & (df["count"] > 0)].copy()

    # Mark which of these are allowed
    non_train["is_allowed"] = non_train.apply(
        lambda row: (row["scenario"], row["label"]) in allowed_pairs, axis=1
    )

    # Collect the ones that are not allowed
    faulty = non_train[~non_train["is_allowed"]]

    if len(faulty) > 0:
        print("Found scenario/label pairs in test/validation that do not appear in train:")
        print(faulty[["scenario", "label", "split", "count"]])
    else:
        print("All scenario/label pairs in test and validation are valid (they appear in train).")

check_label_distribution(dataframe)




All scenario/label pairs in test and validation are valid (they appear in train).
