In [1]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
import pickle
import numpy as np

In [2]:
def flip_labels(dataset, flip_percentage):
    np.random.seed(32)
    num_samples = len(dataset)
    num_classes = len(dataset.classes)
    s = 0
    arr = np.array([])

    for i in range(num_samples):
        img, label = dataset[i]
        
        if np.random.rand() < flip_percentage and s < num_samples * flip_percentage:
            new_label = np.random.choice(num_classes)
            s +=1
            arr = np.append(arr, dataset.targets[i])
            while new_label == label:
                new_label = np.random.choice(num_classes)
            
            dataset.targets[i] = new_label
    
    return dataset, s, arr

In [3]:
def get_dataset(data, shard: list):
        indices = []

        for i in range(len(data.targets)):
            if data.targets[i] in shard:
                indices.append(i)

        return indices

In [4]:
def get_dataloader(data: datasets, batch_size: int, shuffle: bool, split_flag: bool = False):
    if split_flag:
        training_data, validation_data = random_split(data, [int((1 - 0.2) * len(data)), int(0.2 * len(data))])
        return DataLoader(dataset=training_data, batch_size=batch_size, num_workers=2, pin_memory=True, shuffle=shuffle), DataLoader(dataset=validation_data, batch_size=batch_size, num_workers=2, pin_memory=True, shuffle=shuffle)
    
    return DataLoader(dataset=data, batch_size=batch_size, num_workers=2, pin_memory=True, shuffle=shuffle)

In [5]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ), )])
dataset = datasets.FashionMNIST(download=True, root='data/', train=True, transform=transform)
data_to_flip = datasets.FashionMNIST(download=True, root='data/', train=True, transform=transform)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting data/FashionMNIST\raw\train-images-idx3-ubyte.gz to data/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting data/FashionMNIST\raw\train-labels-idx1-ubyte.gz to data/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting data/FashionMNIST\raw\t10k-images-idx3-ubyte.gz to data/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting data/FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to data/FashionMNIST\raw



In [6]:
def flip_distribution(flip_factor, data):
    flipped_data, counter, flips = flip_labels(data, flip_factor)
    print(f"Label Distribution After Flipping {flip_factor * 100}%:")
    flipped_stats = np.unique(flipped_data.targets, return_counts=True)
    print(f'\tTotal Flips:  {counter}')
    print('\tClass: ', flipped_stats[0])
    print('\tData: ', flipped_stats[1])
    print('\tNumber of samples flipped per class: ', np.unique(flips, return_counts=True)[1])
    return flipped_data, flipped_stats, counter, flips

In [7]:
def write_metadata(path, flipped_stats, flipped_data, flips, flipped_sub, counter):
    with open(path, 'w') as file:
        file.write(f"""Dataset: FashionMNIST
    Total Flips:  {counter}
    Classes: {flipped_stats[0]}
    Data: {flipped_stats[1]}
    Number of samples flipped per class: {np.unique(flips, return_counts=True)[1]}
    Subset Distribution: {np.unique(flipped_data.targets[flipped_sub.indices], return_counts=True)[1]}
        """)    

In [9]:
flip_percentages = [0.05, 0.1, 0.15]
shards = {'0': [0, 1, 5, 6, 9], '1': [2, 3, 4, 7, 8], '2': [0, 1, 5, 7, 8], '3': [2, 3, 4, 6, 9]}
general_path = f'mislabeled_subsets/'

for percent in flip_percentages:
    flipped_data, flipped_stats, counter, flips = flip_distribution(flip_factor=percent, data=data_to_flip)
    with open(general_path + f'{percent}/metadata.txt', 'w') as file:
        file.write(f"""Dataset: FashionMNIST
    Total Flips:  {counter}
    Classes: {flipped_stats[0]}
    Data: {flipped_stats[1]}
    Number of samples flipped per class: {np.unique(flips, return_counts=True)[1]}""")
        
    for shard_id, shard in shards.items():
        sub_indices = get_dataset(dataset, shard)
        flipped_sub = Subset(dataset=flipped_data, indices=sub_indices)
        with open(general_path + f'{percent}/subset_{shard_id}.pkl', "wb") as file:
            pickle.dump(flipped_sub, file)
        with open(general_path + f'{percent}/metadata.txt', 'a') as file:
            file.write(f"\nSubset {shard_id} Distribution: {np.unique(flipped_data.targets[flipped_sub.indices], return_counts=True)[1]}")


Label Distribution After Flipping 5.0%:
	Total Flips:  2986
	Class:  [0 1 2 3 4 5 6 7 8 9]
	Data:  [5973 6022 6021 5982 6036 6001 5965 5983 5984 6033]
	Number of samples flipped per class:  [301 299 297 313 272 318 303 303 318 262]
Label Distribution After Flipping 10.0%:
	Total Flips:  5986
	Class:  [0 1 2 3 4 5 6 7 8 9]
	Data:  [5943 6058 6045 6002 6029 6052 5940 5961 6004 5966]
	Number of samples flipped per class:  [598 562 597 606 609 596 590 604 580 644]
Label Distribution After Flipping 15.0%:
	Total Flips:  8984
	Class:  [0 1 2 3 4 5 6 7 8 9]
	Data:  [5911 6067 6029 6021 6042 6098 5874 6000 6034 5924]
	Number of samples flipped per class:  [931 894 906 877 886 875 899 874 912 930]
