In [1]:
from pathlib import Path

BASE_PATH = Path('../')
PATH_TO_DATA = BASE_PATH/'data'
PATH_TO_MODELS = BASE_PATH/'checkpoints'

PATH_TO_DATA.mkdir(exist_ok=True, parents=True)
PATH_TO_MODELS.mkdir(exist_ok=True, parents=True)

#### <b>Load Libraries</b>

In [2]:
import os
import pickle
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split


from torchvision.datasets import CIFAR10

#### <b>Download data</b>

In [3]:
dataset = CIFAR10(root=PATH_TO_DATA, download=True)

Files already downloaded and verified


#### <b>Do splits</b>

In [4]:
inds = list(range(len(dataset)))

train_shadow_inds, unseen_inds = train_test_split(inds, test_size=0.10, random_state=42)
train_inds, _ = train_test_split(train_shadow_inds, test_size=0.50, random_state=42)
retain_inds, forget_inds = train_test_split(train_inds, test_size=0.05, random_state=42)

In [6]:
[len(x) for x in (retain_inds, forget_inds, unseen_inds)]

[21375, 1125, 5000]

In [7]:
# кол-во наблюдений в выборках (в процентах)

[f'{100*len(x)/len(inds):.2f}%' for x in (retain_inds, forget_inds, unseen_inds)]

['42.75%', '2.25%', '10.00%']

In [8]:
shadow_datasets = PATH_TO_DATA/'shadow_inds'
shadow_datasets.mkdir(exist_ok=True, parents=True)

num_shadows = 128
counter = pd.Series(index=train_shadow_inds, data=num_shadows//2)
data_split = pd.DataFrame(data=False, index=train_shadow_inds, columns=range(num_shadows))
shadow_inds = []

for shadow_idx in range(num_shadows):
    
    shadow_inds.append(counter\
        .sample(frac=1, random_state=shadow_idx)\
        .sort_values(ascending=False, kind='stable')\
        .iloc[:len(counter)//2]\
        .index.tolist()
    )
    counter[shadow_inds[-1]] -= 1
    data_split.loc[shadow_inds[-1], shadow_idx] = True


np.save(PATH_TO_DATA/'shadow_inds.npy', shadow_inds)
np.save(PATH_TO_DATA/'train_inds.npy', train_inds)
np.save(PATH_TO_DATA/'unseen_inds.npy', unseen_inds)
np.save(PATH_TO_DATA/'retain_inds.npy', retain_inds)
np.save(PATH_TO_DATA/'forget_inds.npy', forget_inds)

with open(PATH_TO_DATA/'data_split.pickle', 'wb') as file:
    pickle.dump(
        {filename: datasets.tolist() for filename, datasets in zip(data_split.index, data_split.values)},
        file
    )