Implement a function load mnist data() that extracts images from the
dataset folders and organizes them into separate lists for images and labels
corresponding to the train, validation, and test splits. Ensure that the
images are loaded from their respective folders without any overlap or
mixing.

In [8]:
import os
from PIL import Image
import numpy as np

def load_mnist_data(base_path='../../data/external/double_mnist'):
    images = {
        'train': [],
        'val': [],
        'test': []
    }
    labels = {
        'train': [],
        'val': [],
        'test': []
    }

    splits = ['train', 'val', 'test']
    for split in splits:
        split_path = os.path.join(base_path, split)
        for label in os.listdir(split_path):
            label_path = os.path.join(split_path, label)
            if os.path.isdir(label_path):
                for img_name in os.listdir(label_path):
                    img_path = os.path.join(label_path, img_name)
                    img = Image.open(img_path).convert('L')
                    img_array = np.array(img)
                    images[split].append(img_array)
                    labels[split].append(int(label))
    return images, labels

images, labels = load_mnist_data()

Create a class called MultiMNISTDataset that will be used to create
dataloaders for training and evaluation purposes.

In [10]:
import torch
from torch.utils.data import Dataset, DataLoader

class MultiMNISTDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]

        if label == "0":
            label_count = 0
        else:
            label_count = len(f"{label}")
            
        img = torch.tensor(img, dtype=torch.float32)
        img = img / 255.0

        return img.unsqueeze(0), label_count

train_dataset = MultiMNISTDataset(images['train'], labels['train'])
val_dataset = MultiMNISTDataset(images['val'], labels['val'])
test_dataset = MultiMNISTDataset(images['test'], labels['test'])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)