In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

transform = transforms.ToTensor()

# Full training set (60k)
full_train = datasets.MNIST(
    root='./mnist',
    train=True,
    download=True,
    transform=transform
)

# Split: 50k train / 10k val
train_set, val_set = random_split(full_train, [50000, 10000])

test_set = datasets.MNIST(
    root='./mnist',
    train=False,
    download=True,
    transform=transform
)




Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./mnist\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:01<00:00, 7.11MB/s]


Extracting ./mnist\MNIST\raw\train-images-idx3-ubyte.gz to ./mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./mnist\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 213kB/s]


Extracting ./mnist\MNIST\raw\train-labels-idx1-ubyte.gz to ./mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./mnist\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 2.40MB/s]


Extracting ./mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to ./mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<?, ?B/s]

Extracting ./mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./mnist\MNIST\raw






In [5]:
from collections import Counter
from torchvision import datasets, transforms

dataset = datasets.MNIST(
    root="./mnist",
    train=True,
    download=False,
    transform=transforms.ToTensor()
)

labels = dataset.targets.tolist()
counts = Counter(labels)

for cls in range(10):
    print(f"Class {cls}: {counts[cls]} samples")


Class 0: 5923 samples
Class 1: 6742 samples
Class 2: 5958 samples
Class 3: 6131 samples
Class 4: 5842 samples
Class 5: 5421 samples
Class 6: 5918 samples
Class 7: 6265 samples
Class 8: 5851 samples
Class 9: 5949 samples


In [6]:
# %%
import torch
from torchvision import datasets, transforms
import numpy as np

# Load dataset
dataset = datasets.MNIST(
    root="./mnist",
    train=True,
    download=False,
    transform=transforms.ToTensor()
)

# Extract all Class 0 samples
class_0_indices = [i for i, label in enumerate(dataset.targets) if label == 0]
print(f"Found {len(class_0_indices)} Class 0 samples")

# Save Class 0 data and labels
class_0_data = torch.stack([dataset[i][0] for i in class_0_indices])
class_0_labels = torch.tensor([dataset[i][1] for i in class_0_indices])

torch.save({
    'data': class_0_data,
    'labels': class_0_labels,
    'indices': class_0_indices
}, 'class_0_full.pt')
print(f"Saved all {len(class_0_indices)} Class 0 samples to 'class_0_full.pt'")

# %%
# Randomly remove 3000 samples from Class 0
np.random.seed(42)  # for reproducibility
indices_to_remove = np.random.choice(class_0_indices, size=3000, replace=False)
indices_to_remove_set = set(indices_to_remove.tolist())

# Create new dataset with reduced Class 0
remaining_indices = [i for i in range(len(dataset)) if i not in indices_to_remove_set]

# Create subset
from torch.utils.data import Subset
reduced_dataset = Subset(dataset, remaining_indices)

print(f"Original dataset size: {len(dataset)}")
print(f"Reduced dataset size: {len(reduced_dataset)}")
print(f"Class 0 now has: {len(class_0_indices) - 3000} samples")

# Verify class distribution
reduced_labels = [dataset.targets[i].item() for i in remaining_indices]
from collections import Counter
new_counts = Counter(reduced_labels)
for cls in range(10):
    print(f"Class {cls}: {new_counts[cls]} samples")

# %%
# Use reduced_dataset for training
train_set, val_set = random_split(reduced_dataset, [47000, 10000])

Found 5923 Class 0 samples
Saved all 5923 Class 0 samples to 'class_0_full.pt'
Original dataset size: 60000
Reduced dataset size: 57000
Class 0 now has: 2923 samples
Class 0: 2923 samples
Class 1: 6742 samples
Class 2: 5958 samples
Class 3: 6131 samples
Class 4: 5842 samples
Class 5: 5421 samples
Class 6: 5918 samples
Class 7: 6265 samples
Class 8: 5851 samples
Class 9: 5949 samples


In [7]:
# %%
# Save the reduced Class 0 separately
reduced_class_0_indices = [i for i in class_0_indices if i not in indices_to_remove_set]
print(f"Reduced Class 0 has {len(reduced_class_0_indices)} samples")

reduced_class_0_data = torch.stack([dataset[i][0] for i in reduced_class_0_indices])
reduced_class_0_labels = torch.tensor([dataset[i][1] for i in reduced_class_0_indices])

torch.save({
    'data': reduced_class_0_data,
    'labels': reduced_class_0_labels,
    'indices': reduced_class_0_indices
}, 'class_0_reduced.pt')
print(f"Saved reduced Class 0 ({len(reduced_class_0_indices)} samples) to 'class_0_reduced.pt'")

Reduced Class 0 has 2923 samples
Saved reduced Class 0 (2923 samples) to 'class_0_reduced.pt'
