In [None]:
import typing

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

data_path = "../../datasets"
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # normalize by mean and standard devia,
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

train_dataset = datasets.MNIST(
    data_path,
    train=True,
    download=True,
    transform=transform,
)

test_loader = DataLoader(
    datasets.MNIST(data_path, train=False, download=False, transform=transform),
    shuffle=False,
    drop_last=False,
    batch_size=10000,
    generator=torch.Generator(),
)

In [3]:
from collections import Counter


def count_distinct_values(targets: typing.Iterable[typing.Any]):
    targets_count = Counter(targets)

    for target, count in targets_count.items():
        print(f"Target {target}: {count} records")

    print(f"Total: {targets_count.total()}")

In [4]:
count_distinct_values(train_dataset.targets.numpy())

Target 5: 5421 records
Target 0: 5923 records
Target 4: 5842 records
Target 1: 6742 records
Target 9: 5949 records
Target 2: 5958 records
Target 3: 6131 records
Target 6: 5918 records
Target 7: 6265 records
Target 8: 5851 records
Total: 60000


In [5]:
from data_splitting import (
    index_by_approximate_binary_target_partitions,
    partition_dataset,
    IndexVector,
)

sample_index = index_by_approximate_binary_target_partitions(train_dataset, 10, 42)
subsets = partition_dataset(train_dataset, sample_index)

for subset_idx, subset in enumerate(subsets):
    print(f"==== subset #{subset_idx} ====")
    subset_targets = list(int(it) for _, it in DataLoader(subset))
    count_distinct_values(subset_targets)

==== subset #0 ====
Target 2: 3000 records
Target 7: 3000 records
Total: 6000
==== subset #1 ====
Target 3: 3000 records
Target 6: 2935 records
Target 7: 65 records
Total: 6000
==== subset #2 ====
Target 1: 3665 records
Target 2: 2335 records
Total: 6000
==== subset #3 ====
Target 2: 623 records
Target 3: 2377 records
Target 5: 17 records
Target 6: 2983 records
Total: 6000
==== subset #4 ====
Target 3: 754 records
Target 4: 2246 records
Target 0: 2923 records
Target 1: 77 records
Total: 6000
==== subset #5 ====
Target 8: 51 records
Target 9: 2949 records
Target 0: 3000 records
Total: 6000
==== subset #6 ====
Target 4: 596 records
Target 5: 2404 records
Target 7: 200 records
Target 8: 2800 records
Total: 6000
==== subset #7 ====
Target 5: 3000 records
Target 4: 3000 records
Total: 6000
==== subset #8 ====
Target 7: 3000 records
Target 8: 3000 records
Total: 6000
==== subset #9 ====
Target 9: 3000 records
Target 1: 3000 records
Total: 6000
