## Analyze the Labelled Dataset

In [None]:
from models_under_pressure.config import AIS_DATASETS
from models_under_pressure.interfaces.dataset import Label, LabelledDataset

dataset = LabelledDataset.load_from(AIS_DATASETS["mmlu_sandbagging"]["file_path"])

# Get the labels (high-stakes or low-stakes)
labels = dataset.labels
# Get the sandbagging information
is_sandbagging = dataset.other_fields["is_sandbagging"]

# Count high-stakes and low-stakes examples
high_stakes_count = labels.count(Label.HIGH_STAKES)
low_stakes_count = labels.count(Label.LOW_STAKES)
ambiguous_count = labels.count(Label.AMBIGUOUS)

print(f"Total examples: {len(labels)}")
print(
    f"High-stakes examples: {high_stakes_count} ({high_stakes_count / len(labels) * 100:.2f}%)"
)
print(
    f"Low-stakes examples: {low_stakes_count} ({low_stakes_count / len(labels) * 100:.2f}%)"
)
print(
    f"Ambiguous examples: {ambiguous_count} ({ambiguous_count / len(labels) * 100:.2f}%)"
)
print("\n")

# Analyze sandbagging vs non-sandbagging
sandbagging_examples = [i for i, sb in enumerate(is_sandbagging) if sb]
non_sandbagging_examples = [i for i, sb in enumerate(is_sandbagging) if not sb]

print(
    f"Sandbagging examples: {len(sandbagging_examples)} ({len(sandbagging_examples) / len(labels) * 100:.2f}%)"
)
print(
    f"Non-sandbagging examples: {len(non_sandbagging_examples)} ({len(non_sandbagging_examples) / len(labels) * 100:.2f}%)"
)
print("\n")

# Count high-stakes and low-stakes for sandbagging examples
sandbagging_high_stakes = sum(
    1 for i in sandbagging_examples if labels[i] == Label.HIGH_STAKES
)
sandbagging_low_stakes = sum(
    1 for i in sandbagging_examples if labels[i] == Label.LOW_STAKES
)
sandbagging_ambiguous = sum(
    1 for i in sandbagging_examples if labels[i] == Label.AMBIGUOUS
)

print("For sandbagging examples:")
print(
    f"High-stakes: {sandbagging_high_stakes} ({sandbagging_high_stakes / len(sandbagging_examples) * 100:.2f}%)"
)
print(
    f"Low-stakes: {sandbagging_low_stakes} ({sandbagging_low_stakes / len(sandbagging_examples) * 100:.2f}%)"
)
print(
    f"Ambiguous: {sandbagging_ambiguous} ({sandbagging_ambiguous / len(sandbagging_examples) * 100:.2f}%)"
)
print("\n")

# Count high-stakes and low-stakes for non-sandbagging examples
non_sandbagging_high_stakes = sum(
    1 for i in non_sandbagging_examples if labels[i] == Label.HIGH_STAKES
)
non_sandbagging_low_stakes = sum(
    1 for i in non_sandbagging_examples if labels[i] == Label.LOW_STAKES
)
non_sandbagging_ambiguous = sum(
    1 for i in non_sandbagging_examples if labels[i] == Label.AMBIGUOUS
)

print("For non-sandbagging examples:")
print(
    f"High-stakes: {non_sandbagging_high_stakes} ({non_sandbagging_high_stakes / len(non_sandbagging_examples) * 100:.2f}%)"
)
print(
    f"Low-stakes: {non_sandbagging_low_stakes} ({non_sandbagging_low_stakes / len(non_sandbagging_examples) * 100:.2f}%)"
)
print(
    f"Ambiguous: {non_sandbagging_ambiguous} ({non_sandbagging_ambiguous / len(non_sandbagging_examples) * 100:.2f}%)"
)