In [1]:
import tensorflow as tf
import numpy as np

import h5py    

mnist = tf.keras.datasets.mnist
(x_train_mnist, y_train_mnist),(x_test_mnist, y_test_mnist) = mnist.load_data()
x_train_mnist, x_test_mnist = x_train_mnist / 255.0, x_test_mnist / 255.0
x = np.concatenate((x_train_mnist,x_test_mnist))
y = np.concatenate((y_train_mnist,y_test_mnist))

  from ._conv import register_converters as _register_converters


In [2]:
def splitMNIST(x, y, anom_class, test_ratio=0.2):   
    """
    :param x: 28 by 28 input data.
    :param y: Label for each input data (isn't used)
    :param test_ratio: how much of NORMAL x is to be held out as test data
    """
    x_normal = x[y!=anom_class]
    x_anomaly = x[y==anom_class]
    
    rand_indx = np.arange(x_normal.shape[0])
    np.random.shuffle(rand_indx)
    x_normal = x_normal[rand_indx]
    test_normal_size = int((1-test_ratio)*x_normal.shape[0])
    x_train_normal, x_test_normal = x_normal[:test_normal_size,:,:], x_normal[test_normal_size:,:,:]

    return x_train_normal, x_test_normal, x_anomaly

In [4]:
# Train and test set
for i in range(10):
    print('Preparing data set where the anomaly class is ',i)
    x_train_normal, x_test_normal, x_test_abnormal = splitMNIST(x, y, anom_class=i, test_ratio=0.2)
    SAMPLE_TYPES = ['NORMAL', 'ABNORMAL']
    samples_train = {}
    samples_test = {}

    for sample_type in SAMPLE_TYPES:
      if sample_type is 'NORMAL':
        samples_train[sample_type] = x_train_normal
        samples_test[sample_type] = x_test_normal
      elif sample_type is 'ABNORMAL':
        samples_test[sample_type] = x_test_abnormal        

    h_train = h5py.File('mnist_data_train_abnormalclass-{}.hd5'.format(i))
    h_test = h5py.File('mnist_data_test_abnormalclass-{}.hd5'.format(i))
    
    for k, v in samples_train.items():
        h_train.create_dataset(k, data=np.array(v, dtype=np.int8))
        
    for k, v in samples_test.items():
        h_test.create_dataset(k, data=np.array(v, dtype=np.int8))

Preparing data set where the anomaly class is  0
Preparing data set where the anomaly class is  1
Preparing data set where the anomaly class is  2
Preparing data set where the anomaly class is  3
Preparing data set where the anomaly class is  4
Preparing data set where the anomaly class is  5
Preparing data set where the anomaly class is  6
Preparing data set where the anomaly class is  7
Preparing data set where the anomaly class is  8
Preparing data set where the anomaly class is  9


In [7]:
h5file = h5py.File('mnist_data_test_abnormalclass-{}.hd5'.format(0), 'r')

In [9]:
len(list(h5file['ABNORMAL']))

6903