In [1]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import Sequence
from tensorflow.keras.utils import to_categorical
import numpy as np

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)


In [26]:
class MNISTDataGenerator(Sequence):
    def __init__(self, batch_size=16, n_classes=10, noise_shape=(100,), shuffle=False):
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.noise_shape = noise_shape
        self.shuffle = shuffle
        
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        
        # use small datset for the sake of development
        x_train = x_train[:1000, :, :]
        y_train = y_train[:1000]

        self.x_train = x_train
        self.y_train = to_categorical(y_train, num_classes=self.n_classes)
        
        # toggle whether datagenerator should return noise only
        self.noise_only = True
        
        self.on_epoch_end()

        

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(self.x_train.shape[0] // self.batch_size)


    def __getitem__(self, index):
        """
        Generate one batch of data
        :return: {'real_input': real_data, 'noise_input': noise},
                 {'real_output': real_output, 'fake_output': fake_output}
        """
        if self.noise_only:
            # blank data
            real_data = np.zeros((self.batch_size,)+self.x_train.shape[1:])
            real_output = np.zeros(self.batch_size)
            real_output = to_categorical(real_output, num_classes=self.n_classes)
        else:
            real_data = self.x_train[index:(index+self.batch_size), :, :]
            real_output = self.y_train[index:(index+self.batch_size), :]
        
        noise = np.random.uniform(-1.0, 1.0, size=(self.batch_size, 100))
        fake_output = np.random.randint(self.n_classes, size=100)
        fake_output = to_categorical(fake_output, num_classes=self.n_classes)
        
        return {'real_input': real_data, 'noise_input': noise}, {'real_output': real_output, 'fake_output': fake_output}
    
    def on_epoch_end(self):
        pass

In [27]:
mnist_gen = MNISTDataGenerator(n_classes=10)
x, y = mnist_gen[0]
for key in x:
    print(key, x[key].shape)
for key in y:
    print(key, y[key].shape)

real_input (16, 28, 28)
noise_input (16, 100)
real_output (16, 10)
fake_output (100, 10)


In [25]:
mnist_gen.noise_only = False
x, y = mnist_gen[0]
for key in x:
    print(key, x[key].shape)
for key in y:
    print(key, y[key].shape)

np.all(x['real_input'] == np.zeros((28, 28)))

real_input (16, 28, 28)
noise_input (16, 100)
real_output (16, 10)
fake_output (100, 10)


False