# Data Preliminary Analysis and Pre-processing

## Overview

This notebook loads the original MNIST dataset and analyses its training and validation data distributions. It visualises sample images, computes class counts and percentages for training and validation splits, and saves sample images.

It then defines three variants of MNIST for continual learning:

* Permuted MNIST: Applies a random pixel permutation to each image.
* Rotated MNIST: Rotates each image by a random angle.
* Partitioned MNIST: Partitions dataset into distinct class pairs for each task.

For each variant, the notebook shows how to load the modified dataset, update tasks, and visualise sample batches.

The main components are:

* MNIST class to load the original MNIST dataset
* Analysis of training and validation splits
* Visualisations of sample images and class distributions
* PermutedMNIST class applying random permutations
* RotatingMNIST class applying random rotations
* PartitionMNIST class creating class-partitioned tasks
* Methods to update tasks and show sample batches for each variant

This provides a template for creating continual learning datasets by transforming MNIST and analysing their key properties. The visualisation code shows how to examine the dataset distributions.

##  Importing Required Libraries

In [None]:
import os
import copy
import glob
import torch
import random
import torchvision
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import Image
import torchvision.transforms
import matplotlib.pyplot as plt
from collections import Counter
from torchvision import datasets, transforms

## Original MNIST Dataset 

In [None]:
class MNIST:
    # Constructor of MNIST class
    def __init__(self):
        super(MNIST, self).__init__()

        # Define the root path for the MNIST data
        data_root = "mnist"

        # Load the MNIST dataset for training, perform transformations on the dataset
        self.train_dataset = torchvision.datasets.MNIST(
            data_root,
            train=True,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        )

        # Load training data in batches and shuffle it
        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset, batch_size=128, shuffle=True
        )

        # Load validation data in batches without shuffling
        self.val_loader = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(
                data_root,
                train=False,
                transform=torchvision.transforms.Compose(
                    [
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                    ]
                ),
            ),
            batch_size=128,
            shuffle=False,
        )
        
    def summary(self):
        print(f'Train dataset size: {len(self.train_dataset)}')
        print(f'Validation dataset size: {len(self.val_loader.dataset)}')
        print(f'Training batches: {len(self.train_loader)}')
        print(f'Validation batches: {len(self.val_loader)}')
        print(f'Batch size: {self.train_loader.batch_size}')
     
    # Compute training and validation class counts
    def compute_class_counts(self):
        train_class_counts = Counter()
        val_class_counts = Counter()

        # Compute training class counts
        for _, labels in self.train_loader:
            train_class_counts.update(labels.numpy())

        # Compute validation class counts
        for _, labels in self.val_loader:
            val_class_counts.update(labels.numpy())

        return dict(train_class_counts), dict(val_class_counts)

    
# Initialise the MNIST dataset
mnist = MNIST()

# Show a summary of the dataset
mnist.summary()

# Compute the class counts for training and validation
TRAIN_CLASS_COUNTS, VAL_CLASS_COUNTS = mnist.compute_class_counts()

# Create a DataFrame to represent the data
class_distribution_df = pd.DataFrame({
    'Class Label': [str(label) for label in range(10)],
    'Training Counts': [TRAIN_CLASS_COUNTS[i] for i in range(10)],
    'Validation Counts': [VAL_CLASS_COUNTS[i] for i in range(10)]
})

# Print the DataFrame without the index
print(class_distribution_df.to_string(index=False))

# Load a single batch of validation images and their corresponding labels
batch, labels = next(iter(mnist.val_loader))

# Convert the tensors into an image
torchvision.transforms.ToPILImage()(
    # Create a grid of images from the tensors for visualisation
    torchvision.utils.make_grid(
        # Select the first 64 images
        batch[:64],
        # Normalise the images to bring all pixels in the range [0, 1]
        normalize=True,
        # Set padding around each image in the grid
        padding=5,
        # Set the padding value to 0.2
        pad_value=0.2
    )
)

In [None]:
# Set the default seaborn theme
sns.set_theme()

# Set the context for plotting
sns.set_context("paper")

# Create a figure with a size of 9x5 inches
fig = plt.figure(figsize=(9, 5))

# Create the subplot for the training split class counts
ax1 = fig.add_subplot(1, 2, 1)
ax1.set_xlabel('Class Names')
ax1.set_ylabel('Class Counts')
ax1.set_title('Training Split Class Counts')
ax1.bar(list(TRAIN_CLASS_COUNTS.keys()), list(TRAIN_CLASS_COUNTS.values()), tick_label=list(TRAIN_CLASS_COUNTS.keys()))

# Create the subplot for the validation split class counts
ax2 = fig.add_subplot(1, 2, 2)
ax2.set_xlabel('Class Names')
ax2.set_ylabel('Class Counts')
ax2.set_title('Validation Split Class Counts')
ax2.bar(list(VAL_CLASS_COUNTS.keys()), list(VAL_CLASS_COUNTS.values()), tick_label=list(VAL_CLASS_COUNTS.keys()))
ax2.set_ylim(0, max(VAL_CLASS_COUNTS.values()) + 100) # Adjust y-limit to include all counts

# Adjust the subplot layout for better spacing
fig.tight_layout()

# Save the figure to a PNG file
plt.savefig('../figures/MNIST_class_count.png', dpi=300)

# Display the plot
plt.show()

In [None]:
# Calculate total number of training and validation images
total_train_images = sum(TRAIN_CLASS_COUNTS.values())
total_val_images = sum(VAL_CLASS_COUNTS.values())

# Calculate the percentage of each class in the training and validation sets
train_percentages = [(count / total_train_images) * 100 for count in TRAIN_CLASS_COUNTS.values()]
val_percentages = [(count / total_val_images) * 100 for count in VAL_CLASS_COUNTS.values()]

# Combine the training and validation class percentages into a single DataFrame
class_percentages_df = pd.DataFrame({
    'Class': list(TRAIN_CLASS_COUNTS.keys()) + list(VAL_CLASS_COUNTS.keys()),
    'Percentage': train_percentages + val_percentages,
    'Dataset': ['Training'] * len(TRAIN_CLASS_COUNTS) + ['Validation'] * len(VAL_CLASS_COUNTS)
})

# Round the percentages to the first number after the decimal point
class_percentages_df['Percentage'] = class_percentages_df['Percentage'].round(2)

# Set the default seaborn theme
sns.set_theme()

# Set the context for plotting
sns.set_context("paper")

# Create a figure with a size of 9x5 inches
fig = plt.figure(figsize=(9, 5))

# Create the subplot for the training and validation dataset class percentages
ax = fig.add_subplot(1, 1, 1)
sns.barplot(x='Class', y='Percentage', hue='Dataset', data=class_percentages_df, ax=ax)
ax.set_xlabel('Class Names')
ax.set_ylabel('Class Percentages')
ax.set_title('Class Percentages in Training and Validation Datasets')

# Adjust the subplot layout for better spacing
fig.tight_layout()

# Save the figure to a PNG file
plt.savefig('../figures/MNIST_class_percentages.png', dpi=300)

print(class_percentages_df.to_string(index=False))

# Display the plot
plt.show()


In [None]:
def save_images(mnist_dataset, path, dataset='train'):
    # Create the directory
    if not os.path.exists(path):
        os.makedirs(path)

    # Loop through the dataset and save images
    for i, (image, _) in enumerate(mnist_dataset):
        image_path = os.path.join(path, f"{dataset}_{i}.jpeg")
        # Convert tensor to PIL Image
        pil_image = torchvision.transforms.ToPILImage()(image[0])
        # Save the image
        pil_image.save(image_path)

def analyse_image_sizes(path):
    # Get a list of the image file paths within the folder that have a .jpeg extension
    image_paths = glob.glob(os.path.join(path, "*.jpeg"))
    # Analyse the distribution of image file sizes within the folder
    image_sizes = [os.path.getsize(image_path) for image_path in image_paths]
    print("Image size statistics:")
    print(pd.Series(image_sizes).describe())

# Initialise the MNIST dataset
mnist = MNIST()

# Define the path to save the images
path = 'mnist_images/train'

# Save the training images
save_images(mnist.train_dataset, path, dataset='train')

# Analyze the image sizes
analyse_image_sizes(path)

## Permuted MNIST Dataset Variant

In [None]:
class MNISTPerm:
    # Inner class permute that implements the functionality of random permutation on tensors
    class permute(object):
        # Constructor for the inner class permute
        def __init__(self):
            # Initialise the permutation to the identity permutation
            self.perm = np.arange(784)

        # Callable function to flatten the tensor, perform permutation, and reshape it
        def __call__(self, tensor):
            out = tensor.flatten()
            out = out[self.perm]
            return out.view(1, 28, 28)

        # Represent the class as its name
        def __repr__(self):
            return self.__class__.__name__

    # Constructor of MNISTPerm class
    def __init__(self, seed=0):
        super(MNISTPerm, self).__init__()

        # Define the root path for the MNIST data
        data_root = "mnist"

        # Create an instance of the inner class permute
        self.permuter = self.permute()

        # Initialise the random seed
        self.seed = seed

        # Load the MNIST dataset for training, perform transformations on the dataset
        train_dataset = torchvision.datasets.MNIST(
            data_root,
            train=True,
            download=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                    self.permuter,
                ]
            ),
        )

        # Load training data in batches and shuffle it
        self.train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=128, shuffle=True
        )

        # Load validation data in batches without shuffling
        self.val_loader = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(
                data_root,
                train=False,
                transform=torchvision.transforms.Compose(
                    [
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                        self.permuter,
                    ]
                ),
            ),
            batch_size=128,
            shuffle=False,
        )

    # Function to update the permutation for a task based on the task id and the seed
    def update_task(self, i):
        np.random.seed(i + self.seed)
        self.permuter.__setattr__("perm", np.random.permutation(784))

    # Function to reset the permutation to the original order
    def unpermute(self):
        self.permuter.__setattr__("perm", np.arange(784))
        
# Initialise the permutated MNIST dataset
mnist = MNISTPerm()

# Remove permutation from the data
mnist.unpermute()

# Load a single batch of validation images and their corresponding labels
batch, labels = next(iter(mnist.val_loader))

# Update the task
mnist.update_task(3)

# Load a batch of validation images and labels after the task update
task0, labels = next(iter(mnist.val_loader))

# Convert the concatenated tensors into an image
torchvision.transforms.ToPILImage()(
    # Create a grid of images from the tensors for visualisation
    torchvision.utils.make_grid(
        # Concatenate the original batch of images with the task 0 batch along the last dimension
        # and select the first 64 images
        torch.cat([batch, task0], dim=-1)[:64],
        # Normalise the images to bring all pixels in the range [0, 1]
        normalize=True,
        # Set padding around each image in the grid
        padding=5,
        # Set the padding value to 0.2
        pad_value=0.2
    )
)

## Rotated MNIST Dataset Variant

In [None]:
class Rotate(object):
    # Defines a callable object for rotating an image.
    def __init__(self, angle=90):
        # Initialise with a default rotation angle of 90 degrees
        self.angle = angle

    def __call__(self, img):
        # Rotate the given image by self.angle degrees.
        out = transforms.functional.rotate(img, self.angle)
        return out

    def __repr__(self):
        # Return a string representation of the Rotate class
        return self.__class__.__name__ + '(angle={})'.format(self.angle)


class RotatingMNIST:
    # A class for loading and rotating MNIST data
    def __init__(self):
        # Initialise the class, creates the data loaders for MNIST dataset
        super(RotatingMNIST, self).__init__()

        data_root = "mnist"

        self.rotater = Rotate()

        # Define the transformations applied to the training dataset
        train_dataset = datasets.MNIST(
            data_root,
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.Grayscale(3),
                    self.rotater,
                    transforms.Grayscale(1),
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        )

        # Create data loaders for the training and validation datasets
        kwargs = {}
        self.train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=128, shuffle=True, **kwargs
        )
        self.val_loader = torch.utils.data.DataLoader(
            datasets.MNIST(
                data_root,
                train=False,
                transform=transforms.Compose(
                    [
                        transforms.Grayscale(3),
                        self.rotater,
                        transforms.Grayscale(1),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,)),
                    ]
                ),
            ),
            batch_size=128,
            shuffle=True,
            **kwargs,
        )

    def update_task(self, i):
        # Update the rotation angle attribute of the rotater instance
        self.rotater.__setattr__("angle", random.randint(0, 360))
        
# Initialise a rotating MNIST dataset
mnist = RotatingMNIST()

# Update the task 
mnist.update_task(7)

# Load a single batch of validation images (i) and their corresponding labels (l)
i, l = next(iter(mnist.val_loader))

# Convert the batch of tensors into an image
torchvision.transforms.ToPILImage()(
    # Create a grid of images from the tensors for visualisation
    torchvision.utils.make_grid(
        # Since there's only one batch, we concatenate along the last dimension
        # and select the first 64 images
        torch.cat([i], dim=-1)[:64],
        # Normalise the images to bring all pixels in the range [0, 1]
        normalize=True,
        # Set padding around each image in the grid
        padding=5,
        # Set the padding value to 0.2
        pad_value=0.2
    )
)

## Partitioned MNIST Dataset Variant

In [None]:
# Function to partition a dataset based on the given label pair
def partition_dataset(dataset, label_pair):
    newdataset = copy.copy(dataset)

    # Filter the data to only include images with the specified labels
    newdataset.data = [
        im
        for im, label in zip(newdataset.data, newdataset.targets)
        if label == torch.tensor(label_pair[0]) or label == torch.tensor(label_pair[1])
    ]

    # Similarly, filter the targets to only include the specified labels
    newdataset.targets = [
        label
        for label in newdataset.targets
        if label == torch.tensor(label_pair[0]) or label == torch.tensor(label_pair[1])
    ]

    return newdataset

# Class to handle the partitioned MNIST dataset
class PartitionMNIST:
    def __init__(self):
        super(PartitionMNIST, self).__init__()
        data_root = "mnist"

        # Label combinations for the 10 tasks
        label_pairs = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (0, 2), (1, 3), (4, 6), (5, 8), (7, 9)]

        # Load the training dataset with transformations
        train_dataset = datasets.MNIST(
            data_root,
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        )

        # Load the validation dataset with transformations
        val_dataset = datasets.MNIST(
            data_root,
            train=False,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        )

        # Partition both training and validation dataset into 10 pairs,
        # based on the specified label combinations
        splits = [
            (
                partition_dataset(train_dataset, label_pair),
                partition_dataset(val_dataset, label_pair),
            )
            for label_pair in label_pairs
        ]

        # Print the length of the data in each split
        for i in range(10):
            print(len(splits[i][0].data))
            print(len(splits[i][1].data))
            print("==")

        kwargs = {}

        # Create data loaders for each of the splits with a batch size of 128
        self.loaders = [
            (
                torch.utils.data.DataLoader(
                    x[0], batch_size=128, shuffle=True, **kwargs
                ),
                torch.utils.data.DataLoader(
                    x[1], batch_size=128, shuffle=True, **kwargs
                ),
            )
            for x in splits
        ]

    # Method to update the current task. Sets the train_loader and val_loader 
    # to the loaders for the given task
    def update_task(self, i):
        self.train_loader = self.loaders[i][0]
        self.val_loader = self.loaders[i][1]

        
# Initialise a PartitionMNIST dataset
mnist = PartitionMNIST()

# Update task
mnist.update_task(9)

# Load a single batch of validation images (i) and their corresponding labels (l)
i, l = next(iter(mnist.val_loader))

# Convert the batch of tensors into an image
torchvision.transforms.ToPILImage()(
    # Create a grid of images from the tensors for visualisation
    torchvision.utils.make_grid(
        # Since there's only one batch, we concatenate along the last dimension
        # and select the first 64 images
        torch.cat([i], dim=-1)[:64],
        # Normalise the images to bring all pixels in the range [0, 1]
        normalize=True,
        # Set padding around each image in the grid
        padding=5,
        # Set the padding value to 0.2
        pad_value=0.2
    )
)

--------------------------------------------------------------------------------------------------------------------------------

#### Code adapted from:

* https://github.com/pytorch
* https://github.com/RAIVNLab/supsup
* https://www.programcreek.com/python/example/105103/torchvision.datasets.MNIST
* https://www.programcreek.com/python/example/104832/torchvision.transforms.Compose