# Class Parity Label Analysis Tutorial


## _Problem Statement_

For machine learning tasks, a discrepancy in label frequencies between train and test datasets can result in poor model performance.

To help with this, DataEval has a tool that compares the label distributions of two datasets.


### _When to use_

The `Parity` class and similar should be used when you would like to determine if two datasets have statistically independent labels.


### _What you will need_

1. A labeled training image dataset
2. A labeled test image dataset to evaluate the label distribution of


### _Setting up_

Let's import the required libraries needed to set up a minimal working example


In [None]:
try:
    import google.colab  # noqa: F401

    # specify the version of DataEval (==X.XX.X) for versions other than the latest
    %pip install -q dataeval
except Exception:
    pass

In [None]:
from dataeval.metrics.bias import label_parity
from dataeval.utils.torch.datasets import MNIST

## Load the data

We will use the MNIST dataset from torchvision for this tutorial on class label statistics


In [None]:
train_ds = MNIST("./data", train=True, download=True, size=2000)
test_ds = MNIST("./data", train=False, download=True, size=500)

# Take a subset of 2000 training images and 500 test images
train_labels = train_ds.targets
test_labels = test_ds.targets

## Evaluate label statistical independence

Now, let's look at how to use DataEval's label statistics analyzer.
Start by initializing a `Parity` object. Compute the chi-squared value of hypothesis that test_ds has the same class distribution as train_ds by specifying the two datasets to be compared, as well as the number of unique classes (for MNIST, there are 10 unique classes). It also returns the p-value of the test.


In [None]:
results = label_parity(train_labels, test_labels)
print(f"The chi-squared value for the two label distributions is {results.score}, with p-value {results.p_value}")

In [None]:
### TEST ASSERTION CELL ###
assert results.score == 0.0
assert results.p_value == 1.0