# Generating

In [2]:
from cifar10_web import cifar10
import numpy as np

X_train, y_train, X_test, y_test = cifar10(path=None)
X_train = X_train.reshape(50000,3,32,32).transpose(0,2,3,1)
#X_train /= X_train.max()
X_test = X_train.reshape(50000,3,32,32).transpose(0,2,3,1)
#X_test /= X_train.max()
y_train = np.array([np.argmax(a, axis=0) for a in y_train])
y_test = np.array([np.argmax(a, axis=0) for a in y_test])

# index_selection

In [3]:
drop_ratio = [0.4, 0.6, 0.75, 0.9]
X_train_selected = [X_train[y_train == label] for label in range(10)]

In [4]:
def sample_index(labelled_data, drop_ratio):
    return np.random.choice(labelled_data.shape[0], int(np.ceil(labelled_data.shape[0]*(1 - drop_ratio))), replace=False)

In [5]:
indexes = [sample_index(class_, dr) for dr in drop_ratio for class_ in X_train_selected]

In [6]:
np.savez("selected_index_40.npz", indexes)

# Read

In [7]:
from cifar10_web import cifar10
import numpy as np

def get_indexes(index_list, label: int = 5, drop_ratio: float = 0.4):

    drop_ratio_list = dict(zip([0.4, 0.6, 0.75, 0.9],range(4)))
    
    return index_list[label+10*drop_ratio_list[drop_ratio]]

def get_cifar_10(return_one_hot_y: bool = False):
    
    X_train, y_train, X_test, y_test = cifar10(path=None)
    
    X_train = X_train.reshape(-1,3,32,32).transpose(0,2,3,1)
    X_test = X_test.reshape(-1,3,32,32).transpose(0,2,3,1)
        
    if return_one_hot_y == False:
        y_train = np.array([np.argmax(a, axis=0) for a in y_train])
        y_test = np.array([np.argmax(a, axis=0) for a in y_test])
        
    return X_train, y_train, X_test, y_test

def get_imbalanced_dataset(X_train, y_train, label: int = 5, drop_ratio: float = 0.4):
    
    assert(drop_ratio in [0.4, 0.6, 0.75, 0.9])
    
    npzfile = np.load("selected_index_40.npz", allow_pickle = True)
    indexes = npzfile["arr_0"]
    
    if y_train.ndim == 2:
        y_train_ = np.array([np.argmax(a, axis=0) for a in y_train])
    else:
        y_train_ = y_train
        
    label_index = np.where(y_train_ == label)[0]
    sample_index = get_indexes(indexes, label = label, drop_ratio = drop_ratio)
    deleted_index = np.delete(label_index, sample_index)
    
    X_imbalanced = np.delete(X_train, deleted_index, 0)
    y_imbalanced = np.delete(y_train, deleted_index, 0)
    
    X_deleted = X_train[deleted_index]
    y_deleted = y_train[deleted_index]
    return X_imbalanced, y_imbalanced, X_deleted, y_deleted

In [8]:
X_train, y_train, X_test, y_test = get_cifar_10()
print("shape: ", X_train.shape, y_train.shape, X_test.shape, y_test.shape)

X_train_imbalanced ,y_train_imbalanced, X_deleted, y_deleted= get_imbalanced_dataset(X_train, y_train, label = 3, drop_ratio= 0.9)
print("shape after imbalanced: ", X_train_imbalanced.shape,y_train_imbalanced.shape,X_deleted.shape, y_deleted.shape)
print("y = 3: ", np.sum(y_train == 3))
print("y = 7: ",np.sum(y_train == 7))
print("y = 3 after imbalanced: ", np.sum(y_train_imbalanced == 3))
print("y = 7 after imbalanced: ",np.sum(y_train_imbalanced == 7))

shape:  (50000, 32, 32, 3) (50000,) (10000, 32, 32, 3) (10000,)
shape after imbalanced:  (45500, 32, 32, 3) (45500,) (4500, 32, 32, 3) (4500,)
y = 3:  5000
y = 7:  5000
y = 3 after imbalanced:  500
y = 7 after imbalanced:  5000


In [9]:
X_train, y_train, X_test, y_test = get_cifar_10()
print("shape: ", X_train.shape, y_train.shape, X_test.shape, y_test.shape)

X_train_imbalanced ,y_train_imbalanced, X_deleted, y_deleted= get_imbalanced_dataset(X_train, y_train, label = 9, drop_ratio= 0.75)
print("shape after imbalanced: ", X_train_imbalanced.shape,y_train_imbalanced.shape,X_deleted.shape, y_deleted.shape)
print("y = 3: ", np.sum(y_train == 3))
print("y = 9: ",np.sum(y_train == 9))
print("y = 3 after imbalanced: ", np.sum(y_train_imbalanced == 3))
print("y = 9 after imbalanced: ",np.sum(y_train_imbalanced == 9))

shape:  (50000, 32, 32, 3) (50000,) (10000, 32, 32, 3) (10000,)
shape after imbalanced:  (46250, 32, 32, 3) (46250,) (3750, 32, 32, 3) (3750,)
y = 3:  5000
y = 9:  5000
y = 3 after imbalanced:  5000
y = 9 after imbalanced:  1250
