## Data Splitting
To stimulate different data owners with the different dataset, we split the dataset by public dataset and private datasets with different sets of labels

In [2]:
# import necessary packages for this notebook 
import numpy as np
import pandas as pd

In [3]:
# define dataset path
dataset_path = '../Datasets/CIFAR10/'
img_path = '../Datasets/CIFAR10/images/'

In [4]:
# get the data of all images
data_df = pd.read_csv(dataset_path + 'data.csv')
data_df

FileNotFoundError: [Errno 2] No such file or directory: '../Datasets/CIFAR10/data.csv'

In [None]:
# get label map
id2label = data_df[['label', 'label_name']].drop_duplicates().sort_values('label').set_index('label').to_dict()['label_name']
id2label[-1] = 'other'
label2id = {v:k for k,v in id2label.items()}
label2id

{'plane': 0,
 'car': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9,
 'other': -1}

In [5]:
def get_dataset(data_df: pd.DataFrame, drop=True, **kwargs):
    """Get a custom dataset by defining label and the associated count value or fraction

    Args:
        data_df (pd.DataFrame): _description_
        drop (bool, optional): whether the data selected will be removed from data_df. Defaults to True.
        kwargs: label'

    Returns:
        _type_: _description_
    """

    label_count_df = pd.DataFrame({
        'label': list(kwargs.keys()),
        'count': list(kwargs.values())
    })

    # sanity check
    for label in label_count_df.label:
        if label == 'other':
            continue
        # assert the label exists
        assert np.logical_or(
            label in data_df.label.astype('str').unique().tolist(),
            label in data_df.label_name.unique().tolist()
        )
        # assert no two labels are the same
        # TODO

    # change the fraction of data to number if the count is a fraction
    for i, (label, count) in label_count_df.iterrows():
        if label == 'other':
            continue
        if count <= 1:
            if isinstance(label, str):
                label = label2id[label]
            label_count_total = len(data_df.query(f'label == {label}'))
            label_count_df.iloc[i, 1] = round(count * label_count_total)

    label_count_df['count'] = label_count_df['count'].astype('int32')
    label_count_df['label'] = label_count_df['label'].apply(
        lambda x: label2id[x])
    # convert a data frame to dictionary
    label_count_dict = dict(
        zip(label_count_df['label'], label_count_df['count']))

    labels = [label for label in label_count_dict if label != -1]

    return_data_idx = [i for label in labels
                       for i in data_df.query(f'label == {label}')
                       .sample(label_count_dict[label]).index
                       ]
    # add indices for other labels
    return_data_idx += [] if -1 not in label_count_dict else \
        data_df.query(f'label not in {labels}') \
        .sample(label_count_dict[-1]).index.tolist()

    images = data_df.image[return_data_idx]
    labels = [int(l)
              for label, count in zip(label_count_df['label'], label_count_df['count'])
              for l in np.ones(count) * label]

    return (
        data_df.iloc[[i for i in data_df.index if i not in return_data_idx]],
        pd.DataFrame({
            'image': images,
            'label': labels
        }).reset_index(drop=True)
    )


# example
data_df, dataset = get_dataset(
    data_df,
    plane=0.2,
    car=0.2,
    bird=0.2,
    cat=1500,
    other=2000
)

dataset


NameError: name 'data_df' is not defined