In [2]:
import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader

In [3]:
def BinaryMNISTcreate_imbalanced_datasets(
        root='data',
        download=True,
        batch_size=64,
        random_seed = 42,
        fractions=[0.01, 0.02, 0.05, 0.1, 0.2],
):
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    mnist_transform = transforms.ToTensor()
    Full_train_dataset = datasets.MNIST(root=root, train=True, transform=mnist_transform, download=download)
    train_data = Full_train_dataset.data
    train_laebls = Full_train_dataset.targets
    train_index_label3 = (train_laebls == 3)
    train_index_label4 = (train_laebls == 4)


    train_data_label3 = train_data[train_index_label3]
    train_data_label4 = train_data[train_index_label4]
    train_labels_label3 = train_laebls[train_index_label3]
    train_labels_label4 = train_laebls[train_index_label4]
    num_label4 = len(train_data_label4)
    num_label3 = len(train_data_label3)
    print('Number of images with label 3:', num_label3)
    print('Number of images with label 4:', num_label4)


    train_datasets_dict = {}
    for fraction in fractions:
        num_label3_need  = int(num_label4 * fraction)
        chosen_indices_label3 = np.random.choice(len(train_data_label3), num_label3_need, replace=False)
        subset_label3_data = train_data_label3[chosen_indices_label3]
        subset_label3_labels = train_labels_label3[chosen_indices_label3]

        new_train_data = torch.cat((train_data_label4, subset_label3_data), dim=0)
        new_train_labels = torch.cat((train_labels_label4, subset_label3_labels), dim=0)
        new_train_dataset = TensorDataset(new_train_data, new_train_labels)
        
        train_dataloader = DataLoader(new_train_dataset, batch_size=batch_size, shuffle=True)
        train_datasets_dict[f'train_{fraction}'] = train_dataloader

    Full_test_dataset = datasets.MNIST(root=root, train=False, transform=mnist_transform, download=download)
    test_data = Full_test_dataset.data
    test_labels = Full_test_dataset.targets
    test_index_label34 = (test_labels == 3) | (test_labels == 4)
    test_data = test_data[test_index_label34]
    test_labels = test_labels[test_index_label34]
    test_dataset = TensorDataset(test_data, test_labels)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    data_loaders_dict = {**train_datasets_dict, **{'test': test_dataloader}}

    return data_loaders_dict


In [4]:
if __name__ == '__main__':
    MNISTDataLoader = BinaryMNISTcreate_imbalanced_datasets(
        root = '/Users/max/MasterThesisData/MNIST',
        download=True,
        batch_size=64,
        random_seed = 42,
        fractions=[0.01, 0.02, 0.05, 0.1, 0.2],
    )
    MNIST_train_loader001 = MNISTDataLoader['train_0.01']
    MNIST_train_loader002 = MNISTDataLoader['train_0.02']
    MNIST_train_loader005 = MNISTDataLoader['train_0.05']
    MNIST_train_loader010 = MNISTDataLoader['train_0.1']
    MNIST_train_loader020 = MNISTDataLoader['train_0.2']
    MNIST_test_loader = MNISTDataLoader['test']
    print('Number of batches in MNIST_train_loader001:', len(MNIST_train_loader001))
    print('Number of batches in MNIST_train_loader002:', len(MNIST_train_loader002))
    print('Number of batches in MNIST_train_loader005:', len(MNIST_train_loader005))
    print('Number of batches in MNIST_train_loader010:', len(MNIST_train_loader010))
    print('Number of batches in MNIST_train_loader020:', len(MNIST_train_loader020))
    print('Number of batches in MNIST_test_loader:', len(MNIST_test_loader))
    print('Number of images in MNIST_train_loader001:', len(MNIST_train_loader001.dataset))
    print('Number of images in MNIST_train_loader002:', len(MNIST_train_loader002.dataset))
    print('Number of images in MNIST_train_loader005:', len(MNIST_train_loader005.dataset))
    print('Number of images in MNIST_train_loader010:', len(MNIST_train_loader010.dataset))
    print('Number of images in MNIST_train_loader020:', len(MNIST_train_loader020.dataset))
    print('Number of images in MNIST_test_loader:', len(MNIST_test_loader.dataset))

Number of images with label 3: 6131
Number of images with label 4: 5842
Number of batches in MNIST_train_loader001: 93
Number of batches in MNIST_train_loader002: 94
Number of batches in MNIST_train_loader005: 96
Number of batches in MNIST_train_loader010: 101
Number of batches in MNIST_train_loader020: 110
Number of batches in MNIST_test_loader: 32
Number of images in MNIST_train_loader001: 5900
Number of images in MNIST_train_loader002: 5958
Number of images in MNIST_train_loader005: 6134
Number of images in MNIST_train_loader010: 6426
Number of images in MNIST_train_loader020: 7010
Number of images in MNIST_test_loader: 1992


# Analysis Imbalanced Datasets (MNIST)

In [5]:
import torch
import numpy as np
import matplotlib.pyplot as plt

def Binary_get_class_distribution(data_loader):
    num_lable3 = 0
    num_lable4 = 0
    for _, label in data_loader:
        num_label3 += (label == 3).sum().item()
        num_label4 += (label == 4).sum().item()
    
    total_label = num_label3 + num_label4
    percentage_label3 = (num_label3 / total_label)
    percentage_label4 = (num_label4 / total_label)

    
    return percentage_label3, percentage_label4