In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

#### Config File

In [5]:
RAW_PATH = "../data/raw/"
PROCESSED_PATH = "../data/processed/"

#### Utils file

In [9]:
import joblib as pkl
import os

def pickle(value = None, filename = None):
    if (value and filename) is not None:
        pkl.dump(value = value, filename=filename)
    else:
        ValueError("Pickle is not possible due to missing arguments".capitalize())
        
def clean_folder(path = None):
    if path is not None:
        if os.path.exists(path):
            for file in os.listdir(path):
                os.remove(os.path.join(path, file))
            
            print("{} - path cleaned".format(path).capitalize())
        else:
            print("{} - path doesn't exist".capitalize())
    else:
        raise ValueError("Clean folder is not possible due to missing arguments".capitalize())

#### Create the DataLoader

In [13]:
import os

class Loader:
    """
    A class for loading and preprocessing the MNIST dataset.

    This class handles the downloading of the MNIST dataset, performs image transformations, and organizes the data into batches for training or testing.

    | Parameters | Description |
    |------------|-------------|
    | batch_size | int, default=128. The number of samples to include in each batch of data. |

    | Attributes | Description |
    |------------|-------------|
    | batch_size | int. The size of the batch of data. |

    | Methods    | Description |
    |------------|-------------|
    |_do_transformation() | Applies a series of transformations to the dataset images. |
    | download_mnist()    | Downloads the MNIST dataset, applies transformations, and organizes the data into batches. |

    Examples
    --------
    >>> loader = Loader(batch_size=128)
    >>> dataloader = loader.download_mnist()
    """
    def __init__(self, batch_size = 128):
        """
        Initializes the Loader with a specified batch size.

        Parameters
        ----------
        batch_size : int, optional
            The number of samples per batch. Default is 128.
        """
        self.batch_size = batch_size

    def _do_transformation(self):
        """
        Apply transformations to the dataset images.

        Returns
        -------
        torchvision.transforms.Compose
            A composed series of transformations for image processing.
        """
        transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        return transform

    def download_mnist(self):
        """
        Download the MNIST dataset and prepare it for training.

        Checks for dataset existence, downloads if necessary, applies transformations, and prepares a DataLoader.

        Returns
        -------
        torch.utils.data.DataLoader
            A DataLoader containing the preprocessed MNIST dataset in batches.

        Raises
        ------
        Exception
            If any errors occur during the folder cleaning or dataset processing steps.
        """
        if os.path.exists(RAW_PATH):
            try:
                clean_folder(path=RAW_PATH)
            except Exception as e:
                print("Exception caught in the section # {}".format(e))
            else:
                dataloader = datasets.MNIST(root=os.path.join(RAW_PATH), train=True, download=True, transform=self._do_transformation())
                dataloader = DataLoader(dataset=dataloader, batch_size=self.batch_size, shuffle=True)

                try:
                    if os.path.exists(PROCESSED_PATH):
                        try:
                            clean_folder(path=PROCESSED_PATH)
                        except Exception as e:
                            print("Exception caught in the section # {}".format(e))
                        else:
                            pickle(value = dataloader, filename = os.path.join(PROCESSED_PATH, "dataloader.pkl"))
                    else:
                        os.makedirs(PROCESSED_PATH)
                        print("Processed path is created in the data folder & run the code again".capitalize())

                except Exception as e:
                    print("Exception caught in the section # {}".format(e))
                else:
                    return dataloader
        else:
            os.makedirs(RAW_PATH)
            print("raw folder is created in the data folder and run again this code".capitalize())


if __name__ == "__main__":
    loader = Loader(batch_size=128)
    dataloader = loader.download_mnist()

../data/raw/ - path cleaned
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/raw/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ../data/raw/MNIST/raw/train-images-idx3-ubyte.gz to ../data/raw/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/raw/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%
27.8%

Extracting ../data/raw/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/raw/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/raw/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%
100.0%


Extracting ../data/raw/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/raw/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/raw/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ../data/raw/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/raw/MNIST/raw

../data/processed/ - path cleaned
