In [12]:
import os
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist

In [13]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [14]:
# Define the desired imbalance ratio
imbalance_ratio = 0.1

# Get the number of samples in the minority class
num_minority_class = int(np.ceil(np.sum(x_train == 0) * imbalance_ratio))

# Get the indices of the samples in the minority class
minority_class_indices = np.where(x_train == 0)[0][:num_minority_class]

# Get the indices of the samples in the majority class
majority_class_indices = np.where(x_train != 0)[0]

# Combine the minority and majority class indices
balanced_indices = np.concatenate([minority_class_indices, majority_class_indices])

# Shuffle the indices
np.random.shuffle(balanced_indices)

# Create the imbalanced dataset
x_train_imbalanced = x_train[balanced_indices]
y_train_imbalanced = y_train[balanced_indices]

In [15]:
np.save('x_train_imbalanced.npy', x_train_imbalanced)
np.save('y_train_imbalanced.npy', y_train_imbalanced)

In [16]:
x_train_imbalanced = np.load('x_train_imbalanced.npy')
y_train_imbalanced = np.load('y_train_imbalanced.npy')

##### The script above only generates an imbalanced training dataset

In [17]:
x_train_imbalanced

array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 