In [None]:
"""
This script pocesses the MNIST dataset.
Original MNIST dataset:
- Author: Yann LeCun, Corinna Cortes, Christopher J.C. Burges
- License: Creative Commons Attribution-Share Alike 3.0 (CC BY-SA 3.0)
- Source: https://yann.lecun.com/exdb/mnist/

Processed dataset is distributed under th same license (CC BY-SA 3.0).
"""

In [None]:
import h5py
import torchvision
import torchvision.transforms as transforms
import os
import numpy as np
# 保存先ディレクトリを設定
output_dir = '../resources/MNIST'
try:
    os.makedirs(output_dir, exist_ok=True)
except Exception as e:
    print(f"ディレクトリの作成に失敗しました: {e}")
    exit(1)

# MNISTデータのダウンロードと準備
raw_train_data = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
raw_test_data = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# 前処理関数
def preprocess_image(image, target_size=(28, 28)):
    transform = transforms.Compose([
        transforms.Resize(target_size) if target_size != (28, 28) else transforms.Lambda(lambda x: x),
        transforms.Normalize((0.5,), (0.5,))
    ])
    return transform(image)

# HDF5形式に保存する関数
def save_to_hdf5(dataset, h5_filename,x_name,y_name, target_size=(28, 28)):
    file_path = os.path.join(output_dir, h5_filename)
    try:
        with h5py.File(file_path, 'w') as h5_file:
            images = h5_file.create_dataset(x_name, (len(dataset),1, *target_size), dtype='float32',compression='gzip')
            labels = h5_file.create_dataset(y_name, (len(dataset),), dtype='int64',compression='gzip')
            batch_size =100
            image_batch=[]
            label_batch =[]
            
            for i, (image, label) in enumerate(dataset):
                processed_image = preprocess_image(image, target_size)

                # 前処理後の画像の検証
                if processed_image is None or not processed_image.numpy().shape:
                    print(f"処理された画像が不正です。インデックス: {i}")
                    continue
                image_batch.append(processed_image.numpy().reshape(1, *target_size))
                label_batch.append(label)

                if (i+1)% batch_size ==0 or (i+1)==len(dataset):
                    images[i+1-batch_size:i+1]=np.array(image_batch)
                    labels[i+1-batch_size:i+1]=np.array(label_batch)
                    image_batch.clear()
                    label_batch.clear()
                    
        # ファイルが存在するか確認
        if os.path.exists(file_path):
            print(f"データが {file_path} に保存されました。")  # 成功メッセージ
        else:
            print(f"ファイル保存に失敗しました: {file_path}")
    except Exception as e:
        print(f"{file_path} の保存中にエラーが発生しました: {e}")

# データの保存
save_to_hdf5(raw_train_data, 'aug_train.h5','train_image','train_label')
save_to_hdf5(raw_test_data, 'aug_test.h5','test_image','test_label')

