In [1]:
import numpy as np

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
from tensorflow.keras.layers import Input, Dense, Conv2D, Flatten, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Reshape

import tensorflow_datasets as tfds

In [57]:
# Thor DataLoad
# Specify the path to the manually downloaded dataset
data_dir = r'C:\Users\thorp\OneDrive\Dokumenter\Uni\Kandidat\Anvendt maskinlæring\Exam\Data\patch_camelyon'

# Load PatchCamelyon dataset using TFDS
def convert_sample(sample):
    image, label = sample['image'], sample['label']  
    image = tf.image.convert_image_dtype(image, tf.float32)
    label = tf.one_hot(label, 2, dtype=tf.float32)
    return image, image # Here i make sure that labels are not passed through since the AE does not use labels

In [58]:
ds1,ds2,ds3 = tfds.load('patch_camelyon',split=['train[:20%]','test[:5%]','validation[:5%]'],
                        data_dir = data_dir,
                        download=False,
                        shuffle_files=True)

In [59]:
train_dataset       = ds1.map(convert_sample).batch(32)
validation_dataset  = ds3.map(convert_sample).batch(32)
test_dataset        = ds2.map(convert_sample).batch(32)

In [60]:
latent_dim = 2
# Encoder
encoder_inputs = tf.keras.layers.Input(shape=(96, 96, 3))
x = tf.keras.layers.Conv2D(32, 3, 2, activation='relu')(encoder_inputs)
x = tf.keras.layers.Conv2D(64, 3, 2, activation='relu')(x)
x = tf.keras.layers.Flatten()(x)
encoded = tf.keras.layers.Dense(2 * latent_dim)(x)  # 2 for mean and standard deviation

encoder = tf.keras.models.Model(inputs=encoder_inputs, outputs=encoded)
encoder.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 96, 96, 3)]       0         
                                                                 
 conv2d_16 (Conv2D)          (None, 47, 47, 32)        896       
                                                                 
 conv2d_17 (Conv2D)          (None, 23, 23, 64)        18496     
                                                                 
 flatten_8 (Flatten)         (None, 33856)             0         
                                                                 
 dense_15 (Dense)            (None, 4)                 135428    
                                                                 
Total params: 154,820
Trainable params: 154,820
Non-trainable params: 0
_________________________________________________________________


In [61]:
decoder = tf.keras.models.Sequential([
    tf.keras.layers.Dense(units=6*6*64, activation='relu', input_shape=(latent_dim,)),
    tf.keras.layers.Reshape(target_shape=(6, 6, 64)),  # To get in "image format"
    tf.keras.layers.Conv2DTranspose(128, 3, 2, padding='same', activation='relu'),
    tf.keras.layers.Conv2DTranspose(64, 3, 2, padding='same', activation='relu'),
    tf.keras.layers.Conv2DTranspose(32, 3, 2, padding='same', activation='relu'),
    tf.keras.layers.Conv2DTranspose(3, 3, 2, padding='same', activation='sigmoid'),  # Adjust stride to get to 96x96x3
])

decoder.summary()

Model: "sequential_14"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_16 (Dense)            (None, 2304)              6912      
                                                                 
 reshape_8 (Reshape)         (None, 6, 6, 64)          0         
                                                                 
 conv2d_transpose_30 (Conv2D  (None, 12, 12, 128)      73856     
 Transpose)                                                      
                                                                 
 conv2d_transpose_31 (Conv2D  (None, 24, 24, 64)       73792     
 Transpose)                                                      
                                                                 
 conv2d_transpose_32 (Conv2D  (None, 48, 48, 32)       18464     
 Transpose)                                                      
                                                     

In [62]:
class VAE(tf.keras.Model):
    def __init__(self, latent_dim, encoder, decoder):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = encoder
        self.decoder = decoder
    
    def encode(self, x):
        params = self.encoder(x)
        return tf.split(params, num_or_size_splits=2, axis=1) # mean, logvar
        
    def decode(self, z):
        return self.decoder(z)
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * 0.5) + mean #sigma= sqrt(exp(logvar))
    
    @tf.function
    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100, self.latent_dim))
        return tf.sigmoid(self.decode(eps))

In [63]:
def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    vals = -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi)
    return tf.reduce_sum(vals, axis=raxis)

def compute_loss(model, x):
    # Output from encoder
    mean, logvar = model.encode(x)    
    # The reparameterization trick
    z = model.reparameterize(mean, logvar)    
    # We assume that p(x|z) is multivariate Bernoulli, ie. the final dense layer 
    # has a sigmoid activation function, see page. 11
    # in Kingma, D. P., & Welling, M. (2013).
    x_logit = model.decode(z)
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, 
                                                        labels=x)
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])    
    # Assume normaility of p(z)
    logpz = log_normal_pdf(z, 0., 0.)    
    # Assume normality of q(z|x)
    logqz_x = log_normal_pdf(z, mean, logvar)
    # -tf.reduce_mean(decoder + sampler - encoder)
    return -tf.reduce_mean(logpx_z + logpz - logqz_x)

In [64]:
@tf.function
def train_step(model, x, optimizer):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

In [65]:
model = VAE(latent_dim, encoder, decoder)
optimizer = tf.keras.optimizers.Adam(1e-4)
test_sample = next(iter(test_dataset.take(1)))[:16]

def generate_and_save_images(model, epoch, test_sample):
    mean, logvar = model.encode(test_sample)
    z = model.reparameterize(mean, logvar)
    predictions = model.sample(z)
    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0], cmap='gray')
        plt.axis('off')

    plt.savefig('./vae-img/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

In [67]:
for epoch in range(1):
    for train_x, _ in train_dataset:
        train_step(model, train_x, optimizer)

    loss = tf.keras.metrics.Mean()
    for test_x in test_dataset:
        loss(compute_loss(model, test_x))
    variational_lower_bound = -loss.result()

    #print(f'Epoch: {epoch}, Test set variational lower bound:
   #        {variational_lower_bound}')
    generate_and_save_images(model, epoch, test_sample)

ValueError: Layer "model_1" expects 1 input(s), but it received 2 input tensors. Inputs received: [<tf.Tensor: shape=(32, 96, 96, 3), dtype=float32, numpy=
array([[[[0.57254905, 0.3019608 , 0.6666667 ],
         [0.69411767, 0.41960788, 0.76470596],
         [0.81568635, 0.5529412 , 0.86274517],
         ...,
         [0.1137255 , 0.07843138, 0.39607847],
         [0.52156866, 0.33333334, 0.73333335],
         [0.427451  , 0.10196079, 0.5647059 ]],

        [[0.6       , 0.34509805, 0.65882355],
         [0.6784314 , 0.43529415, 0.7294118 ],
         [0.90196085, 0.6862745 , 0.95294124],
         ...,
         [0.15686275, 0.1254902 , 0.45098042],
         [0.27058825, 0.10980393, 0.50980395],
         [0.43137258, 0.16470589, 0.60784316]],

        [[0.7294118 , 0.49411768, 0.74509805],
         [0.7960785 , 0.6       , 0.8313726 ],
         [0.8313726 , 0.6784314 , 0.90196085],
         ...,
         [0.07843138, 0.03921569, 0.3803922 ],
         [0.3921569 , 0.27450982, 0.65882355],
         [0.40000004, 0.19607845, 0.61960787]],

        ...,

        [[0.16470589, 0.08627451, 0.41960788],
         [0.227451  , 0.07450981, 0.43137258],
         [0.38431376, 0.16470589, 0.54509807],
         ...,
         [0.29803923, 0.16862746, 0.5254902 ],
         [0.32941177, 0.20392159, 0.58431375],
         [0.31764707, 0.20000002, 0.5921569 ]],

        [[0.20000002, 0.12156864, 0.50980395],
         [0.2784314 , 0.1137255 , 0.5137255 ],
         [0.43529415, 0.20784315, 0.6117647 ],
         ...,
         [0.3647059 , 0.24705884, 0.58431375],
         [0.28627452, 0.18431373, 0.5411765 ],
         [0.23529413, 0.13725491, 0.5058824 ]],

        [[0.30980393, 0.23529413, 0.654902  ],
         [0.29411766, 0.12156864, 0.5568628 ],
         [0.40784317, 0.16470589, 0.5921569 ],
         ...,
         [0.5686275 , 0.46274513, 0.7803922 ],
         [0.33333334, 0.24705884, 0.5803922 ],
         [0.20392159, 0.1254902 , 0.47450984]]],


       [[[0.8705883 , 0.70980394, 0.8431373 ],
         [0.6745098 , 0.50980395, 0.6666667 ],
         [0.5529412 , 0.37254903, 0.5568628 ],
         ...,
         [0.62352943, 0.4039216 , 0.5803922 ],
         [0.4666667 , 0.27450982, 0.45882356],
         [0.4784314 , 0.29803923, 0.4901961 ]],

        [[0.909804  , 0.73333335, 0.8862746 ],
         [0.48235297, 0.30588236, 0.4666667 ],
         [0.6509804 , 0.47058827, 0.6509804 ],
         ...,
         [0.6313726 , 0.4156863 , 0.57254905],
         [0.6313726 , 0.43921572, 0.6156863 ],
         [0.5294118 , 0.34901962, 0.53333336]],

        [[0.45882356, 0.26666668, 0.4431373 ],
         [0.6117647 , 0.427451  , 0.6       ],
         [0.5529412 , 0.36862746, 0.53333336],
         ...,
         [0.7490196 , 0.54509807, 0.6862745 ],
         [0.59607846, 0.40784317, 0.5647059 ],
         [0.5803922 , 0.40000004, 0.5803922 ]],

        ...,

        [[0.8588236 , 0.627451  , 0.7372549 ],
         [0.93725497, 0.7176471 , 0.8313726 ],
         [0.90196085, 0.6901961 , 0.8235295 ],
         ...,
         [0.8000001 , 0.61960787, 0.72156864],
         [0.8352942 , 0.654902  , 0.76470596],
         [0.83921576, 0.64705884, 0.7725491 ]],

        [[0.7490196 , 0.45882356, 0.5882353 ],
         [0.69411767, 0.41960788, 0.5529412 ],
         [0.7607844 , 0.50980395, 0.64705884],
         ...,
         [0.90196085, 0.7254902 , 0.8078432 ],
         [0.8352942 , 0.64705884, 0.7490196 ],
         [0.8431373 , 0.6509804 , 0.76470596]],

        [[0.93725497, 0.6156863 , 0.7568628 ],
         [0.854902  , 0.5529412 , 0.69411767],
         [0.7607844 , 0.48627454, 0.627451  ],
         ...,
         [0.82745105, 0.6431373 , 0.72156864],
         [0.82745105, 0.6392157 , 0.7254902 ],
         [0.8980393 , 0.70980394, 0.80392164]]],


       [[[0.5529412 , 0.41960788, 0.7176471 ],
         [0.48627454, 0.3529412 , 0.6509804 ],
         [0.15686275, 0.07843138, 0.35686275],
         ...,
         [0.8862746 , 0.86274517, 0.90196085],
         [0.8862746 , 0.86274517, 0.90196085],
         [0.8862746 , 0.86274517, 0.90196085]],

        [[0.5568628 , 0.43921572, 0.72156864],
         [0.36862746, 0.2509804 , 0.53333336],
         [0.14117648, 0.07450981, 0.34901962],
         ...,
         [0.882353  , 0.8588236 , 0.8980393 ],
         [0.882353  , 0.8588236 , 0.8980393 ],
         [0.882353  , 0.8588236 , 0.8980393 ]],

        [[0.54901963, 0.4431373 , 0.7137255 ],
         [0.27058825, 0.16470589, 0.43529415],
         [0.19215688, 0.1254902 , 0.3921569 ],
         ...,
         [0.87843144, 0.854902  , 0.8941177 ],
         [0.87843144, 0.854902  , 0.8941177 ],
         [0.87843144, 0.854902  , 0.8941177 ]],

        ...,

        [[0.1764706 , 0.01176471, 0.34901962],
         [0.34509805, 0.13725491, 0.49803925],
         [0.37254903, 0.11764707, 0.48627454],
         ...,
         [0.3803922 , 0.30980393, 0.6       ],
         [0.20784315, 0.1254902 , 0.427451  ],
         [0.26666668, 0.16078432, 0.4784314 ]],

        [[0.1254902 , 0.03921569, 0.36078432],
         [0.21960786, 0.08627451, 0.42352945],
         [0.3019608 , 0.10588236, 0.454902  ],
         ...,
         [0.24313727, 0.17254902, 0.47058827],
         [0.2392157 , 0.14901961, 0.46274513],
         [0.3803922 , 0.26666668, 0.59607846]],

        [[0.19215688, 0.19215688, 0.48235297],
         [0.20000002, 0.14509805, 0.44705886],
         [0.20784315, 0.07058824, 0.3921569 ],
         ...,
         [0.25882354, 0.18823531, 0.48627454],
         [0.2392157 , 0.14117648, 0.4666667 ],
         [0.2901961 , 0.16078432, 0.5019608 ]]],


       ...,


       [[[0.3647059 , 0.20000002, 0.58431375],
         [0.8000001 , 0.7176471 , 0.98823535],
         [0.5058824 , 0.47450984, 0.6745098 ],
         ...,
         [0.60784316, 0.56078434, 0.80392164],
         [0.24705884, 0.21176472, 0.47450984],
         [0.13725491, 0.10588236, 0.3921569 ]],

        [[0.19215688, 0.01176471, 0.39607847],
         [0.7254902 , 0.6392157 , 0.8941177 ],
         [0.6431373 , 0.60784316, 0.79215693],
         ...,
         [0.6313726 , 0.5647059 , 0.7686275 ],
         [0.3019608 , 0.26666668, 0.48235297],
         [0.16078432, 0.14901961, 0.37254903]],

        [[0.25490198, 0.07058824, 0.43137258],
         [0.49803925, 0.4039216 , 0.64705884],
         [0.76470596, 0.72156864, 0.89019614],
         ...,
         [0.7254902 , 0.65882355, 0.8235295 ],
         [0.23529413, 0.21568629, 0.37254903],
         [0.07843138, 0.08627451, 0.2392157 ]],

        ...,

        [[0.3803922 , 0.3372549 , 0.5568628 ],
         [0.47058827, 0.39607847, 0.6156863 ],
         [0.74509805, 0.62352943, 0.854902  ],
         ...,
         [0.76470596, 0.627451  , 0.83921576],
         [0.63529414, 0.47058827, 0.7176471 ],
         [0.7137255 , 0.5176471 , 0.80392164]],

        [[0.43529415, 0.34901962, 0.6313726 ],
         [0.38431376, 0.27450982, 0.5647059 ],
         [0.32941177, 0.18823531, 0.4784314 ],
         ...,
         [0.7843138 , 0.6313726 , 0.8862746 ],
         [0.7411765 , 0.54509807, 0.8313726 ],
         [0.7019608 , 0.47058827, 0.8000001 ]],

        [[0.53333336, 0.41960788, 0.7490196 ],
         [0.3803922 , 0.24705884, 0.58431375],
         [0.50980395, 0.36078432, 0.69411767],
         ...,
         [0.47450984, 0.30980393, 0.5921569 ],
         [0.68235296, 0.4666667 , 0.7843138 ],
         [0.8117648 , 0.5529412 , 0.9058824 ]]],


       [[[0.5411765 , 0.3254902 , 0.72156864],
         [0.80392164, 0.64705884, 0.92549026],
         [0.85098046, 0.77647066, 0.90196085],
         ...,
         [0.2784314 , 0.18823531, 0.5411765 ],
         [0.25490198, 0.10196079, 0.4666667 ],
         [0.3529412 , 0.13333334, 0.5137255 ]],

        [[0.2509804 , 0.02745098, 0.43921572],
         [0.7019608 , 0.5294118 , 0.83921576],
         [0.9568628 , 0.85098046, 1.        ],
         ...,
         [0.31764707, 0.21960786, 0.5764706 ],
         [0.3529412 , 0.19215688, 0.56078434],
         [0.3921569 , 0.16862746, 0.54901963]],

        [[0.28235295, 0.0627451 , 0.49803925],
         [0.4666667 , 0.2784314 , 0.6156863 ],
         [1.        , 0.8745099 , 1.        ],
         ...,
         [0.32156864, 0.19215688, 0.54901963],
         [0.38431376, 0.20784315, 0.57254905],
         [0.5176471 , 0.29411766, 0.6745098 ]],

        ...,

        [[0.83921576, 0.454902  , 0.7803922 ],
         [0.93725497, 0.5647059 , 0.8862746 ],
         [0.8588236 , 0.50980395, 0.8235295 ],
         ...,
         [0.54901963, 0.2509804 , 0.5019608 ],
         [0.54901963, 0.18431373, 0.45882356],
         [0.5921569 , 0.1764706 , 0.4666667 ]],

        [[0.5568628 , 0.22352943, 0.5411765 ],
         [0.58431375, 0.25882354, 0.57254905],
         [0.73333335, 0.41176474, 0.73333335],
         ...,
         [0.7058824 , 0.3529412 , 0.6431373 ],
         [0.6313726 , 0.22352943, 0.5176471 ],
         [0.7176471 , 0.25882354, 0.5647059 ]],

        [[0.65882355, 0.36862746, 0.6784314 ],
         [0.7803922 , 0.48627454, 0.8078432 ],
         [0.47450984, 0.18823531, 0.50980395],
         ...,
         [0.7254902 , 0.3254902 , 0.6509804 ],
         [0.62352943, 0.18039216, 0.49411768],
         [0.627451  , 0.13333334, 0.454902  ]]],


       [[[0.81568635, 0.6901961 , 0.8862746 ],
         [0.92549026, 0.77647066, 0.9607844 ],
         [0.8862746 , 0.70980394, 0.87843144],
         ...,
         [0.69803923, 0.32156864, 0.7058824 ],
         [0.8117648 , 0.39607847, 0.7803922 ],
         [0.69803923, 0.30980393, 0.6745098 ]],

        [[0.7607844 , 0.6       , 0.82745105],
         [0.7686275 , 0.58431375, 0.8000001 ],
         [0.86666673, 0.65882355, 0.86274517],
         ...,
         [0.6784314 , 0.28627452, 0.6745098 ],
         [0.6784314 , 0.28627452, 0.6627451 ],
         [0.6117647 , 0.25882354, 0.6117647 ]],

        [[0.7254902 , 0.5019608 , 0.7803922 ],
         [0.7490196 , 0.5137255 , 0.78823537],
         [0.8235295 , 0.5803922 , 0.8352942 ],
         ...,
         [0.6509804 , 0.24313727, 0.6313726 ],
         [0.7960785 , 0.427451  , 0.7960785 ],
         [0.6431373 , 0.32941177, 0.67058825]],

        ...,

        [[0.7725491 , 0.42352945, 0.74509805],
         [0.79215693, 0.48235297, 0.8000001 ],
         [0.6431373 , 0.3647059 , 0.6784314 ],
         ...,
         [0.5764706 , 0.23137257, 0.5686275 ],
         [0.65882355, 0.3372549 , 0.6666667 ],
         [0.62352943, 0.32156864, 0.6431373 ]],

        [[0.6627451 , 0.3372549 , 0.6431373 ],
         [0.7490196 , 0.4666667 , 0.76470596],
         [0.63529414, 0.38431376, 0.68235296],
         ...,
         [0.7372549 , 0.43529415, 0.7490196 ],
         [0.6509804 , 0.35686275, 0.6784314 ],
         [0.6313726 , 0.34509805, 0.6627451 ]],

        [[0.8235295 , 0.5254902 , 0.81568635],
         [0.7490196 , 0.48235297, 0.7686275 ],
         [0.6156863 , 0.3921569 , 0.67058825],
         ...,
         [0.9960785 , 0.7176471 , 1.        ],
         [0.9176471 , 0.6392157 , 0.95294124],
         [0.77647066, 0.49803925, 0.8117648 ]]]], dtype=float32)>, <tf.Tensor: shape=(32, 96, 96, 3), dtype=float32, numpy=
array([[[[0.57254905, 0.3019608 , 0.6666667 ],
         [0.69411767, 0.41960788, 0.76470596],
         [0.81568635, 0.5529412 , 0.86274517],
         ...,
         [0.1137255 , 0.07843138, 0.39607847],
         [0.52156866, 0.33333334, 0.73333335],
         [0.427451  , 0.10196079, 0.5647059 ]],

        [[0.6       , 0.34509805, 0.65882355],
         [0.6784314 , 0.43529415, 0.7294118 ],
         [0.90196085, 0.6862745 , 0.95294124],
         ...,
         [0.15686275, 0.1254902 , 0.45098042],
         [0.27058825, 0.10980393, 0.50980395],
         [0.43137258, 0.16470589, 0.60784316]],

        [[0.7294118 , 0.49411768, 0.74509805],
         [0.7960785 , 0.6       , 0.8313726 ],
         [0.8313726 , 0.6784314 , 0.90196085],
         ...,
         [0.07843138, 0.03921569, 0.3803922 ],
         [0.3921569 , 0.27450982, 0.65882355],
         [0.40000004, 0.19607845, 0.61960787]],

        ...,

        [[0.16470589, 0.08627451, 0.41960788],
         [0.227451  , 0.07450981, 0.43137258],
         [0.38431376, 0.16470589, 0.54509807],
         ...,
         [0.29803923, 0.16862746, 0.5254902 ],
         [0.32941177, 0.20392159, 0.58431375],
         [0.31764707, 0.20000002, 0.5921569 ]],

        [[0.20000002, 0.12156864, 0.50980395],
         [0.2784314 , 0.1137255 , 0.5137255 ],
         [0.43529415, 0.20784315, 0.6117647 ],
         ...,
         [0.3647059 , 0.24705884, 0.58431375],
         [0.28627452, 0.18431373, 0.5411765 ],
         [0.23529413, 0.13725491, 0.5058824 ]],

        [[0.30980393, 0.23529413, 0.654902  ],
         [0.29411766, 0.12156864, 0.5568628 ],
         [0.40784317, 0.16470589, 0.5921569 ],
         ...,
         [0.5686275 , 0.46274513, 0.7803922 ],
         [0.33333334, 0.24705884, 0.5803922 ],
         [0.20392159, 0.1254902 , 0.47450984]]],


       [[[0.8705883 , 0.70980394, 0.8431373 ],
         [0.6745098 , 0.50980395, 0.6666667 ],
         [0.5529412 , 0.37254903, 0.5568628 ],
         ...,
         [0.62352943, 0.4039216 , 0.5803922 ],
         [0.4666667 , 0.27450982, 0.45882356],
         [0.4784314 , 0.29803923, 0.4901961 ]],

        [[0.909804  , 0.73333335, 0.8862746 ],
         [0.48235297, 0.30588236, 0.4666667 ],
         [0.6509804 , 0.47058827, 0.6509804 ],
         ...,
         [0.6313726 , 0.4156863 , 0.57254905],
         [0.6313726 , 0.43921572, 0.6156863 ],
         [0.5294118 , 0.34901962, 0.53333336]],

        [[0.45882356, 0.26666668, 0.4431373 ],
         [0.6117647 , 0.427451  , 0.6       ],
         [0.5529412 , 0.36862746, 0.53333336],
         ...,
         [0.7490196 , 0.54509807, 0.6862745 ],
         [0.59607846, 0.40784317, 0.5647059 ],
         [0.5803922 , 0.40000004, 0.5803922 ]],

        ...,

        [[0.8588236 , 0.627451  , 0.7372549 ],
         [0.93725497, 0.7176471 , 0.8313726 ],
         [0.90196085, 0.6901961 , 0.8235295 ],
         ...,
         [0.8000001 , 0.61960787, 0.72156864],
         [0.8352942 , 0.654902  , 0.76470596],
         [0.83921576, 0.64705884, 0.7725491 ]],

        [[0.7490196 , 0.45882356, 0.5882353 ],
         [0.69411767, 0.41960788, 0.5529412 ],
         [0.7607844 , 0.50980395, 0.64705884],
         ...,
         [0.90196085, 0.7254902 , 0.8078432 ],
         [0.8352942 , 0.64705884, 0.7490196 ],
         [0.8431373 , 0.6509804 , 0.76470596]],

        [[0.93725497, 0.6156863 , 0.7568628 ],
         [0.854902  , 0.5529412 , 0.69411767],
         [0.7607844 , 0.48627454, 0.627451  ],
         ...,
         [0.82745105, 0.6431373 , 0.72156864],
         [0.82745105, 0.6392157 , 0.7254902 ],
         [0.8980393 , 0.70980394, 0.80392164]]],


       [[[0.5529412 , 0.41960788, 0.7176471 ],
         [0.48627454, 0.3529412 , 0.6509804 ],
         [0.15686275, 0.07843138, 0.35686275],
         ...,
         [0.8862746 , 0.86274517, 0.90196085],
         [0.8862746 , 0.86274517, 0.90196085],
         [0.8862746 , 0.86274517, 0.90196085]],

        [[0.5568628 , 0.43921572, 0.72156864],
         [0.36862746, 0.2509804 , 0.53333336],
         [0.14117648, 0.07450981, 0.34901962],
         ...,
         [0.882353  , 0.8588236 , 0.8980393 ],
         [0.882353  , 0.8588236 , 0.8980393 ],
         [0.882353  , 0.8588236 , 0.8980393 ]],

        [[0.54901963, 0.4431373 , 0.7137255 ],
         [0.27058825, 0.16470589, 0.43529415],
         [0.19215688, 0.1254902 , 0.3921569 ],
         ...,
         [0.87843144, 0.854902  , 0.8941177 ],
         [0.87843144, 0.854902  , 0.8941177 ],
         [0.87843144, 0.854902  , 0.8941177 ]],

        ...,

        [[0.1764706 , 0.01176471, 0.34901962],
         [0.34509805, 0.13725491, 0.49803925],
         [0.37254903, 0.11764707, 0.48627454],
         ...,
         [0.3803922 , 0.30980393, 0.6       ],
         [0.20784315, 0.1254902 , 0.427451  ],
         [0.26666668, 0.16078432, 0.4784314 ]],

        [[0.1254902 , 0.03921569, 0.36078432],
         [0.21960786, 0.08627451, 0.42352945],
         [0.3019608 , 0.10588236, 0.454902  ],
         ...,
         [0.24313727, 0.17254902, 0.47058827],
         [0.2392157 , 0.14901961, 0.46274513],
         [0.3803922 , 0.26666668, 0.59607846]],

        [[0.19215688, 0.19215688, 0.48235297],
         [0.20000002, 0.14509805, 0.44705886],
         [0.20784315, 0.07058824, 0.3921569 ],
         ...,
         [0.25882354, 0.18823531, 0.48627454],
         [0.2392157 , 0.14117648, 0.4666667 ],
         [0.2901961 , 0.16078432, 0.5019608 ]]],


       ...,


       [[[0.3647059 , 0.20000002, 0.58431375],
         [0.8000001 , 0.7176471 , 0.98823535],
         [0.5058824 , 0.47450984, 0.6745098 ],
         ...,
         [0.60784316, 0.56078434, 0.80392164],
         [0.24705884, 0.21176472, 0.47450984],
         [0.13725491, 0.10588236, 0.3921569 ]],

        [[0.19215688, 0.01176471, 0.39607847],
         [0.7254902 , 0.6392157 , 0.8941177 ],
         [0.6431373 , 0.60784316, 0.79215693],
         ...,
         [0.6313726 , 0.5647059 , 0.7686275 ],
         [0.3019608 , 0.26666668, 0.48235297],
         [0.16078432, 0.14901961, 0.37254903]],

        [[0.25490198, 0.07058824, 0.43137258],
         [0.49803925, 0.4039216 , 0.64705884],
         [0.76470596, 0.72156864, 0.89019614],
         ...,
         [0.7254902 , 0.65882355, 0.8235295 ],
         [0.23529413, 0.21568629, 0.37254903],
         [0.07843138, 0.08627451, 0.2392157 ]],

        ...,

        [[0.3803922 , 0.3372549 , 0.5568628 ],
         [0.47058827, 0.39607847, 0.6156863 ],
         [0.74509805, 0.62352943, 0.854902  ],
         ...,
         [0.76470596, 0.627451  , 0.83921576],
         [0.63529414, 0.47058827, 0.7176471 ],
         [0.7137255 , 0.5176471 , 0.80392164]],

        [[0.43529415, 0.34901962, 0.6313726 ],
         [0.38431376, 0.27450982, 0.5647059 ],
         [0.32941177, 0.18823531, 0.4784314 ],
         ...,
         [0.7843138 , 0.6313726 , 0.8862746 ],
         [0.7411765 , 0.54509807, 0.8313726 ],
         [0.7019608 , 0.47058827, 0.8000001 ]],

        [[0.53333336, 0.41960788, 0.7490196 ],
         [0.3803922 , 0.24705884, 0.58431375],
         [0.50980395, 0.36078432, 0.69411767],
         ...,
         [0.47450984, 0.30980393, 0.5921569 ],
         [0.68235296, 0.4666667 , 0.7843138 ],
         [0.8117648 , 0.5529412 , 0.9058824 ]]],


       [[[0.5411765 , 0.3254902 , 0.72156864],
         [0.80392164, 0.64705884, 0.92549026],
         [0.85098046, 0.77647066, 0.90196085],
         ...,
         [0.2784314 , 0.18823531, 0.5411765 ],
         [0.25490198, 0.10196079, 0.4666667 ],
         [0.3529412 , 0.13333334, 0.5137255 ]],

        [[0.2509804 , 0.02745098, 0.43921572],
         [0.7019608 , 0.5294118 , 0.83921576],
         [0.9568628 , 0.85098046, 1.        ],
         ...,
         [0.31764707, 0.21960786, 0.5764706 ],
         [0.3529412 , 0.19215688, 0.56078434],
         [0.3921569 , 0.16862746, 0.54901963]],

        [[0.28235295, 0.0627451 , 0.49803925],
         [0.4666667 , 0.2784314 , 0.6156863 ],
         [1.        , 0.8745099 , 1.        ],
         ...,
         [0.32156864, 0.19215688, 0.54901963],
         [0.38431376, 0.20784315, 0.57254905],
         [0.5176471 , 0.29411766, 0.6745098 ]],

        ...,

        [[0.83921576, 0.454902  , 0.7803922 ],
         [0.93725497, 0.5647059 , 0.8862746 ],
         [0.8588236 , 0.50980395, 0.8235295 ],
         ...,
         [0.54901963, 0.2509804 , 0.5019608 ],
         [0.54901963, 0.18431373, 0.45882356],
         [0.5921569 , 0.1764706 , 0.4666667 ]],

        [[0.5568628 , 0.22352943, 0.5411765 ],
         [0.58431375, 0.25882354, 0.57254905],
         [0.73333335, 0.41176474, 0.73333335],
         ...,
         [0.7058824 , 0.3529412 , 0.6431373 ],
         [0.6313726 , 0.22352943, 0.5176471 ],
         [0.7176471 , 0.25882354, 0.5647059 ]],

        [[0.65882355, 0.36862746, 0.6784314 ],
         [0.7803922 , 0.48627454, 0.8078432 ],
         [0.47450984, 0.18823531, 0.50980395],
         ...,
         [0.7254902 , 0.3254902 , 0.6509804 ],
         [0.62352943, 0.18039216, 0.49411768],
         [0.627451  , 0.13333334, 0.454902  ]]],


       [[[0.81568635, 0.6901961 , 0.8862746 ],
         [0.92549026, 0.77647066, 0.9607844 ],
         [0.8862746 , 0.70980394, 0.87843144],
         ...,
         [0.69803923, 0.32156864, 0.7058824 ],
         [0.8117648 , 0.39607847, 0.7803922 ],
         [0.69803923, 0.30980393, 0.6745098 ]],

        [[0.7607844 , 0.6       , 0.82745105],
         [0.7686275 , 0.58431375, 0.8000001 ],
         [0.86666673, 0.65882355, 0.86274517],
         ...,
         [0.6784314 , 0.28627452, 0.6745098 ],
         [0.6784314 , 0.28627452, 0.6627451 ],
         [0.6117647 , 0.25882354, 0.6117647 ]],

        [[0.7254902 , 0.5019608 , 0.7803922 ],
         [0.7490196 , 0.5137255 , 0.78823537],
         [0.8235295 , 0.5803922 , 0.8352942 ],
         ...,
         [0.6509804 , 0.24313727, 0.6313726 ],
         [0.7960785 , 0.427451  , 0.7960785 ],
         [0.6431373 , 0.32941177, 0.67058825]],

        ...,

        [[0.7725491 , 0.42352945, 0.74509805],
         [0.79215693, 0.48235297, 0.8000001 ],
         [0.6431373 , 0.3647059 , 0.6784314 ],
         ...,
         [0.5764706 , 0.23137257, 0.5686275 ],
         [0.65882355, 0.3372549 , 0.6666667 ],
         [0.62352943, 0.32156864, 0.6431373 ]],

        [[0.6627451 , 0.3372549 , 0.6431373 ],
         [0.7490196 , 0.4666667 , 0.76470596],
         [0.63529414, 0.38431376, 0.68235296],
         ...,
         [0.7372549 , 0.43529415, 0.7490196 ],
         [0.6509804 , 0.35686275, 0.6784314 ],
         [0.6313726 , 0.34509805, 0.6627451 ]],

        [[0.8235295 , 0.5254902 , 0.81568635],
         [0.7490196 , 0.48235297, 0.7686275 ],
         [0.6156863 , 0.3921569 , 0.67058825],
         ...,
         [0.9960785 , 0.7176471 , 1.        ],
         [0.9176471 , 0.6392157 , 0.95294124],
         [0.77647066, 0.49803925, 0.8117648 ]]]], dtype=float32)>]

In [68]:
print(train_x.shape)

(13, 96, 96, 3)
