In [3]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=len(trainset), shuffle=True)

# Download and load the test data
testset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)
testloader = DataLoader(testset, batch_size=len(testset), shuffle=False)

# Function to convert DataLoader object to tensor
def loader_to_tensor(dataloader):
    dataiter = iter(dataloader)
    images, labels = next(dataiter)
    return images, labels

# Get the entire dataset
train_images, train_labels = loader_to_tensor(trainloader)
test_images, test_labels = loader_to_tensor(testloader)

# Check the shape of the tensors
print("Train images shape:", train_images.shape)  # Should print torch.Size([60000, 1, 28, 28])
print("Train labels shape:", train_labels.shape)  # Should print torch.Size([60000])
print("Test images shape:", test_images.shape)    # Should print torch.Size([10000, 1, 28, 28])
print("Test labels shape:", test_labels.shape)    # Should print torch.Size([10000])


Train images shape: torch.Size([60000, 1, 28, 28])
Train labels shape: torch.Size([60000])
Test images shape: torch.Size([10000, 1, 28, 28])
Test labels shape: torch.Size([10000])


In [6]:

# Sort images by labels
def sort_images_by_labels(images, labels, n_classes=10):
    sorted_images = [[] for _ in range(n_classes)]
    
    for i in range(len(labels)):
        label = labels[i]
        image = images[i]
        sorted_images[label].append(image)
    
    # Convert lists to tensors
    for i in range(n_classes):
        sorted_images[i] = torch.stack(sorted_images[i])
    
    # Combine sorted image tensors
    sorted_tensor = torch.stack(sorted_images)
    
    return sorted_tensor

train_sorted = sort_images_by_labels(train_images, train_labels)
test_sorted = sort_images_by_labels(test_images, test_labels)

print("Train sorted shape:", train_sorted.shape)  # Should print torch.Size([10, n_of_images_per_class, 1, 28, 28])
print("Test sorted shape:", test_sorted.shape)    # Should print torch.Size([10, n_of_images_per_class, 1, 28, 28])

RuntimeError: stack expects each tensor to be equal size, but got [5923, 1, 28, 28] at entry 0 and [6742, 1, 28, 28] at entry 1