# MNIST Images Converted to netCDF for Faster Batch Loading

In [13]:
# !pip3 install Pillow
# !pip3 install tqdm
# !pip3 install matplotlib
!pip3 install torchvision
!pip3 install torch

Collecting torchvision
  Using cached torchvision-0.15.2-cp310-cp310-macosx_10_9_x86_64.whl (1.5 MB)
Collecting torch==2.0.1
  Using cached torch-2.0.1-cp310-none-macosx_10_9_x86_64.whl (143.4 MB)
Collecting networkx
  Using cached networkx-3.1-py3-none-any.whl (2.1 MB)
Collecting typing-extensions
  Downloading typing_extensions-4.6.3-py3-none-any.whl (31 kB)
Collecting filelock
  Downloading filelock-3.12.2-py3-none-any.whl (10 kB)
Collecting sympy
  Using cached sympy-1.12-py3-none-any.whl (5.7 MB)
Collecting mpmath>=0.19
  Using cached mpmath-1.3.0-py3-none-any.whl (536 kB)
Installing collected packages: mpmath, typing-extensions, sympy, networkx, filelock, torch, torchvision
Successfully installed filelock-3.12.2 mpmath-1.3.0 networkx-3.1 sympy-1.12 torch-2.0.1 torchvision-0.15.2 typing-extensions-4.6.3

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49

In [3]:
from PIL import Image
import os
import numpy as np
import numpy as np
from tqdm import tqdm
import pncpy
from mpi4py import MPI
from array import array
import struct

# paths
TRAIN_IMAGES = '../../../exercise/whale-and-dolphin/train_images'
TEST_IMAGES = '../../../exercise/whale-and-dolphin/test_images'


class MnistDataloader(object):
    def __init__(self, training_images_filepath,training_labels_filepath,
                 test_images_filepath, test_labels_filepath):
        self.training_images_filepath = training_images_filepath
        self.training_labels_filepath = training_labels_filepath
        self.test_images_filepath = test_images_filepath
        self.test_labels_filepath = test_labels_filepath
    
    def read_images_labels(self, images_filepath, labels_filepath):        
        labels = []
        with open(labels_filepath, 'rb') as file:
            magic, size = struct.unpack(">II", file.read(8))
            if magic != 2049:
                raise ValueError('Magic number mismatch, expected 2049, got {}'.format(magic))
            labels = array("B", file.read())        
        
        with open(images_filepath, 'rb') as file:
            magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
            if magic != 2051:
                raise ValueError('Magic number mismatch, expected 2051, got {}'.format(magic))
            image_data = array("B", file.read())   
        images = []
        for i in range(size):
            images.append([0] * rows * cols)
        for i in range(size):
            img = np.array(image_data[i * rows * cols:(i + 1) * rows * cols])
            img = img.reshape(28, 28)
            images[i][:] = img            
        
        return images, labels
            
    def load_data(self):
        x_train, y_train = self.read_images_labels(self.training_images_filepath, self.training_labels_filepath)
        x_test, y_test = self.read_images_labels(self.test_images_filepath, self.test_labels_filepath)
        return (x_train, y_train),(x_test, y_test)  



%matplotlib inline
import random
import matplotlib.pyplot as plt

#
# Set file paths based on added MNIST Datasets
#
input_path = '.'
training_images_filepath = os.path.join(input_path, 'train-images-idx3-ubyte/train-images-idx3-ubyte')
training_labels_filepath = os.path.join(input_path, 'train-labels-idx1-ubyte/train-labels-idx1-ubyte')
test_images_filepath = os.path.join(input_path, 't10k-images-idx3-ubyte/t10k-images-idx3-ubyte')
test_labels_filepath = os.path.join(input_path, 't10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte')


#
# Load MINST dataset
#
mnist_dataloader = MnistDataloader(training_images_filepath, training_labels_filepath, test_images_filepath, test_labels_filepath)
(x_train, y_train), (x_test, y_test) = mnist_dataloader.load_data()



# def list_files(gtdir):
#     file_list = []
#     for root, dirs, files in os.walk(gtdir):
#         for file in files:
#             file_list.append(os.path.join(root,file))
#     return file_list

def to_nc(samples, labels, comm, out_file_path='mnist_train_images.nc'):


    if os.path.exists(out_file_path):
        os.remove(out_file_path)
    labels = list(labels)
    with pncpy.File(out_file_path, comm= comm, mode = "w", format = "64BIT_DATA") as fnc:
        
        dim_y = fnc.def_dim("Y", 28)
        dim_x = fnc.def_dim("X", 28)
        dim_num = fnc.def_dim("idx", len(samples))

        # define nc variable for all imgs
        v = fnc.def_var("images", pncpy.NC_UBYTE, (dim_num, dim_y, dim_x))
        # put labels into attributes
        v_label = fnc.def_var("labels", pncpy.NC_UBYTE, (dim_num, ))
        
        # put values into each nc variable
        fnc.enddef()
        v_label[:] = np.array(labels, dtype = np.uint8)
        for idx, img in enumerate(samples):
            v[idx, :, :] = img


comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

print('=> ========= Converting Train Images ========= <=')
to_nc(x_train, y_train, comm)
print('=> ========= Converting Test Images ========= <=')
to_nc(x_test, y_test, comm, "mnist_test_images.nc")

