## Latent decomposition and recombination for controlled hallucination
Parts of this code were inspired by [Chanseok Kang's VAE implementation](https://colab.research.google.com/github/goodboychan/goodboychan.github.io/blob/main/_notebooks/2021-09-14-03-Variational-AutoEncoder-Celeb-A.ipynb), and [TensorFlow's CycleGAN implementation](https://www.tensorflow.org/tutorials/generative/cyclegan).

ssh -i "C:\Users\Thijs\OneDrive\University\Year 3\Thesis\GPU access\privateKey.pem" u529937@aurometalsaurus.uvt.nl

srun --nodes=1 --pty /bin/bash -l

In [2]:
import os, json
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/usr/lib/cuda"
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
import tensorflow as tf
import tensorflow_probability as tfp 
distributor = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


## Data

In [2]:
# Define some constants used throughout the script
INPUT_SHAPE = (256, 256, 3)
FILTERS = [64,128,256,512,512]
BATCH_SIZE = 16
DIST_BATCH_SIZE = BATCH_SIZE//max(len(tf.config.list_physical_devices("GPU")), 1)
LATENT_DIM = 254
PATH = "../../data/vangogh2photo/"

# Define a generator to yield images and its path as a label
def gen(path):
    for file in os.listdir(path):
        image = tf.io.read_file(os.path.join(path, file))
        image = tf.io.decode_image(image, dtype=tf.float32)
        yield image

# Define a function to randomly modify images
def random_jitter(image):
  # Expand and crop
  image = tf.image.resize(image, (286,286))
  image = tf.image.random_crop(image, size=INPUT_SHAPE)
  # Random mirroring
  image = tf.image.random_flip_left_right(image)
  return image

# Load training data of class A
train_A = tf.data.Dataset.from_generator(
    gen,
    args=[os.path.join(PATH, "trainA")],
    output_signature=tf.TensorSpec(shape=INPUT_SHAPE, dtype=tf.float32),
).cache().repeat().map(random_jitter).shuffle(3, reshuffle_each_iteration=True).batch(BATCH_SIZE)
# Load training data of class B
train_B = tf.data.Dataset.from_generator(
    gen,
    args=[os.path.join(PATH, "trainB")],
    output_signature=tf.TensorSpec(shape=INPUT_SHAPE, dtype=tf.float32),
).cache().repeat().map(random_jitter).shuffle(3, reshuffle_each_iteration=True).batch(BATCH_SIZE)
# Zip the training data
train = tf.data.Dataset.zip((train_A, train_B))
train = distributor.experimental_distribute_dataset(train)

# Load test data of class A
test_A = tf.data.Dataset.from_generator(
    gen,
    args=[os.path.join(PATH, "testA")],
    output_signature=(tf.TensorSpec(shape=INPUT_SHAPE, dtype=tf.float32)),
).batch(BATCH_SIZE, drop_remainder=True)
# Load test data fo class B
test_B = tf.data.Dataset.from_generator(
    gen,
    args=[os.path.join(PATH, "testB")],
    output_signature=tf.TensorSpec(shape=INPUT_SHAPE, dtype=tf.float32),
).batch(BATCH_SIZE, drop_remainder=True)
# Zip the test data
test = tf.data.Dataset.zip((test_A, test_B))

## Models

In [3]:
# Define a function that creates a Sequential model consisting of convolutions
def downsampler(nfilters, name=None, strides=(1,1), size=(12,12)):
    return tf.keras.Sequential([
        tf.keras.layers.Dropout(rate=.4),
        tf.keras.layers.Conv2D(filters=nfilters, kernel_size=size, padding="same", strides=strides),
        tf.keras.layers.GroupNormalization(groups=nfilters),
        tf.keras.layers.LeakyReLU()
    ], name=name)

# Define a function that creates a Sequential model consisting of transposed convolutions
def upsampler(nfilters, name=None, strides=(2,2), size=(12,12)):
    return tf.keras.Sequential([
        tf.keras.layers.Dropout(rate=.4),
        tf.keras.layers.Conv2DTranspose(filters=nfilters, kernel_size=size, strides=strides, padding="same"),
        tf.keras.layers.GroupNormalization(groups=nfilters),
        tf.keras.layers.LeakyReLU()
    ], name=name)

# Define a function to return an inception module (defaults to downsampling unless `resize_shape` is given)
def inceptionv2(input_shape, nfilters, strides=(1,1), block=downsampler, name=None):
    inputs = tf.keras.layers.Input(shape=input_shape, batch_size=BATCH_SIZE)
    # Separate into brach 1
    branch1 = block(nfilters=nfilters//4, size=(1,1), strides=strides)(inputs)
    # Separate into brach 2
    branch2 = block(nfilters=nfilters//4, size=(1,1), strides=(1,1))(inputs)
    branch2 = block(nfilters=nfilters//4, size=(6,6), strides=strides)(branch2)
    # Separate into brach 3
    branch3 = block(nfilters=nfilters//4, size=(1,1), strides=(1,1))(inputs)
    branch3 = block(nfilters=nfilters//4, size=(12,12), strides=strides)(branch3)
    # Separate into brach 4
    if block==upsampler:
        branch4 = tf.keras.layers.UpSampling2D(size=(3,3))(inputs)
        cropping = tuple(map(lambda i, j: (i-j)//2, branch4.shape[1:-1], input_shape[:-1]))
        branch4 = tf.keras.layers.Cropping2D(cropping)(branch4)
    else:
        branch4 = tf.keras.layers.MaxPooling2D(pool_size=(3,3), strides=(1,1), padding="same")(inputs)
    branch4 = block(nfilters=nfilters//4, size=(1,1), strides=strides)(branch4)
    # Concatenate into desired dimensions
    outputs = tf.keras.layers.Concatenate()([branch1, branch2, branch3, branch4])
    return tf.keras.Model(inputs=inputs, outputs=outputs, name=name)

# Define a function to create a prior distribution used to condition the latent distribution
def get_prior():
    return tfp.distributions.MultivariateNormalDiag(
        loc=tf.Variable(
            tf.random.normal(shape=(LATENT_DIM,)), 
            trainable=True, 
            name="mu"
        ),
        scale_diag=tfp.util.TransformedVariable(
            tf.Variable(tf.ones(shape=(LATENT_DIM,))), 
            bijector=tfp.bijectors.Softplus(), 
            name="sigma", 
            trainable=True
        )
    )

# Define a function to build the encoders
def build_encoder(name="Encoder"):
    # Get input
    inputs = tf.keras.layers.Input(shape=INPUT_SHAPE, batch_size=BATCH_SIZE, name="Input")
    x = inputs
    # Create some downsampling modules
    for i, nfilters in enumerate(FILTERS):
        # Create identity with specified number of filters
        x_ = downsampler(nfilters, size=(1,1), name=f"Downsampler_{i*2}")(x)
        # Pass through inception modules
        x = downsampler(nfilters=nfilters, size=(3,3), name=f"Downsampler_{i*2+1}")(x)
        x = inceptionv2(input_shape=x.shape[1:], nfilters=nfilters, name=f"Inception_{i}")(x)
        # Add identity and skipp layers
        x = tf.keras.layers.Add(name=f"SumSkips_{i}")([x_, x])
        # Reduce size
        x = tf.keras.layers.AveragePooling2D(pool_size=(2,2), name=f"AveragePooling_{i}")(x)
    # Resize to enable the tfp layer
    x = tf.keras.layers.Reshape((-1,1), name="Reshape")(x)
    cropping = (x.shape[1] - tfp.layers.MultivariateNormalTriL.params_size(LATENT_DIM))/2
    cropping = (int(cropping-cropping%1), int(cropping+cropping%1))
    x = tf.keras.layers.Cropping1D(cropping=cropping, name="Cropping")(x)
    x = tf.keras.layers.Flatten(name="Flatten")(x)
    # Process in a distribution layer
    latent_dist = tfp.layers.MultivariateNormalTriL(
        LATENT_DIM, 
        name="LatentDistribution",
        activity_regularizer=tfp.layers.KLDivergenceRegularizer(get_prior(), use_exact_kl=True),
    )(x)
    return tf.keras.Model(inputs=inputs, outputs=latent_dist, name=name)

# Define a function to build the decoder
def build_decoder():
    # Get input
    style = tf.keras.layers.Input(shape=LATENT_DIM, batch_size=BATCH_SIZE, name="StyleInput")
    content = tf.keras.layers.Input(shape=LATENT_DIM, batch_size=BATCH_SIZE, name="ContentInput")
    # Concatenate latent representations
    x = tf.keras.layers.Concatenate(name="Concatenate")([style, content])
    # Resize to the last downsampling module's shape
    x = tf.keras.layers.Dense(8*8*512, activation="relu", name="Dense")(x)
    x = tf.keras.layers.Reshape((8,8,512), name="Reshape")(x)
    # Revert the downsampling
    for i, nfilters in enumerate(reversed(FILTERS)):
        # Make skip connection
        x_ = upsampler(nfilters, size=(1,1), name=f"Upsampler_{i*2}")(x)
        # Pass input through inception module
        x = upsampler(nfilters=nfilters, size=(3,3), strides=(2,2), name=f"Upsampler_{i*2+1}")(x)
        x = inceptionv2(input_shape=x.shape[1:], block=upsampler, nfilters=nfilters, strides=(1,1), name=f"Inception_{i}")(x)
       # Sum skip and inception
        x = tf.keras.layers.Add(name=f"SumSkips_{i}")([x_, x])
    # Reshape to image format
    x = upsampler(3, strides=(1,1), name=f"Upsampler_{i*2+2}")(x)
    # Fit a Bernoulli to leverage probabilistic reconstruction loss
    x = tf.keras.layers.Flatten(name="Flatten")(x)
    outputs = tfp.layers.IndependentBernoulli(event_shape=INPUT_SHAPE, name="Bernoulli", dtype=tf.float32)(x)
    return tf.keras.Model(inputs=[style, content], outputs=outputs, name="Decoder")

# Build

In [4]:
# Build within distributed scope
with distributor.scope():
    # Models
    encoder_style = build_encoder("StyleEncoder")
    encoder_content = build_encoder("ContentEncoder")
    decoder = build_decoder()
    # Optimizers
    decoder_optimizer = tf.keras.optimizers.Adam(1e-4)
    encoder_style_optimizer = tf.keras.optimizers.Adam(1e-4)
    encoder_content_optimizer = tf.keras.optimizers.Adam(1e-4)

## Training

In [5]:
# Define the loss scaling factors
KL_COEF = RECONSTRUCTION_COEF = CYCLE_COEF = 1.

# Define a function to get reconstructions, translations, and cycles
def training_call(A, B):
    # Get latent representations  (stylistic batches are averaged to enforce style to be whatever is common)
    style_A = encoder_style(A)
    mean_style_A = tf.repeat(tf.reduce_mean(style_A, axis=0, keepdims=True), repeats=DIST_BATCH_SIZE, axis=0)
    style_B = encoder_style(B)
    mean_style_B = tf.repeat(tf.reduce_mean(style_B, axis=0, keepdims=True), repeats=DIST_BATCH_SIZE, axis=0)
    content_A = encoder_content(A)
    content_B = encoder_content(B)

    # Predict reconstructed images
    reconstructed_A = decoder([mean_style_A, content_A])
    reconstructed_B = decoder([mean_style_B, content_B])
    # Cycle latent representations
    style_reconstructed_A = encoder_style(reconstructed_A.mean())
    style_reconstructed_B = encoder_style(reconstructed_B.mean())
    content_reconstructed_A = encoder_content(reconstructed_A.mean())
    content_reconstructed_B = encoder_content(reconstructed_B.mean())

    # Translate images
    translated_A = decoder([mean_style_B, content_A])
    translated_B = decoder([mean_style_A, content_B])
    # Cycle latent style representations (using mean because of gradients)
    style_translated_A = encoder_style(translated_A.mean())  # i.e., the estimated style of B
    style_translated_B = encoder_style(translated_B.mean())  # i.e., the estimated style of A
    content_translated_A = encoder_content(translated_A.mean())
    content_translated_B = encoder_content(translated_B.mean())

    return (
        content_A, 
        content_B, 
        style_A, 
        style_B, 
        reconstructed_A,
        reconstructed_B,
        style_reconstructed_A,
        style_reconstructed_B,
        content_reconstructed_A,
        content_reconstructed_B,
        style_translated_A, 
        style_translated_B, 
        content_translated_A, 
        content_translated_B
    )

# Define a train step
def train_step(data, i):
    with tf.GradientTape(persistent=True) as tape:
        # Split classes
        A, B = data

        # Get latent representations of the input, reconstructions, and translations
        (
        content_A, 
        content_B, 
        style_A, 
        style_B, 
        reconstructed_A,
        reconstructed_B,
        style_reconstructed_A,
        style_reconstructed_B,
        content_reconstructed_A,
        content_reconstructed_B,
        style_translated_A, 
        style_translated_B, 
        content_translated_A, 
        content_translated_B
        ) = training_call(A, B)

        # Calculate reconstruction loss (to enforce viable output)
        reconstruction_loss = -reconstructed_A.log_prob(A)
        reconstruction_loss += -reconstructed_B.log_prob(B)
        # Calculate cycle loss 
        cycle_loss = tf.square(content_A-content_reconstructed_A) + tf.square(content_A-content_translated_A)
        cycle_loss += tf.square(content_B-content_reconstructed_B) + tf.square(content_B-content_translated_B)
        cycle_loss += tf.square(style_A-style_reconstructed_A) + tf.square(style_A-style_translated_B)
        cycle_loss += tf.square(style_B-style_reconstructed_B) + tf.square(style_B-style_translated_A)
        # Calculate kl-losses
        kl_style = tf.add_n(encoder_style.losses)
        kl_content = tf.add_n(encoder_content.losses)
        # Aggregate losses
        loss_decoder = CYCLE_COEF*cycle_loss + tf.reshape((.96**(i/500)+5e-3)*RECONSTRUCTION_COEF*reconstruction_loss, (-1,1))
        loss_encoder_style = RECONSTRUCTION_COEF*reconstruction_loss + KL_COEF*kl_style
        loss_encoder_content = RECONSTRUCTION_COEF*reconstruction_loss + KL_COEF*kl_content

    # Calculate and apply gradients to weights of decoder
    grads_decoder = tape.gradient(loss_decoder, decoder.trainable_variables)
    decoder_optimizer.apply_gradients(zip(grads_decoder, decoder.trainable_variables))
    # Calculate and apply gradients to weights of style encoder
    grads_encoder_style = tape.gradient(loss_encoder_style, encoder_style.trainable_variables)
    encoder_style_optimizer.apply_gradients(zip(grads_encoder_style, encoder_style.trainable_variables))
    # Calculate and apply gradients to weights of style encoder
    grads_encoder_content = tape.gradient(loss_encoder_content, encoder_content.trainable_variables)
    encoder_content_optimizer.apply_gradients(zip(grads_encoder_content, encoder_content.trainable_variables))

    # Return progress
    return {
        "reconstruction_loss":tf.reduce_mean(reconstruction_loss), 
        "cycle_loss":tf.reduce_mean(cycle_loss), 
        "kl_regularization":tf.reduce_mean(kl_style+kl_content)
    }

# Convert to distributed train step
@tf.function
def distributed_train_step(data, i):
    per_replica_losses = distributor.run(train_step, args=(data, i))
    return distributor.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

# Run batches
history = {}
for i, data in enumerate(train):
    # Stop training when appropriate
    if i==1e5:
        break
    # Train the networks
    losses = distributed_train_step(data, tf.constant(i, dtype=tf.float32))
    history[i] = {key: float(value) for key, value in losses.items()}
    # Periodically print progress
    if i%100==0:
        print(
            f"Losses at batch {i}:\n",
            *map(lambda key: f"{key}: {losses[key]}, ", losses)
        )
        json.dump(history, open("./history.json", mode="w"))
    # Periodically save weights
    if (i+1)%100==0:
        decoder.save_weights("./decoder.h5")
        encoder_style.save_weights("./encoder_style.h5")
        encoder_content.save_weights("./encoder_content.h5")

## Visualization
Pre-trained weights are available for download on https://github.com/ValueInvestorThijs/LDR.

In [71]:
# Define the call function for inference (note the slight changes such as calling `mean`)
def inference_call(A, B):
    # Get latent style representations  (batches are averaged to enforce style to be whatever is common)
    style_A = encoder_style(A).mean()
    mean_style_A = tf.repeat(tf.reduce_mean(style_A, axis=0, keepdims=True), repeats=BATCH_SIZE, axis=0)
    style_B = encoder_style(B).mean()
    mean_style_B = tf.repeat(tf.reduce_mean(style_B, axis=0, keepdims=True), repeats=BATCH_SIZE, axis=0)
    # Get latent content representations
    content_A = encoder_content(A).mean()
    content_B = encoder_content(B).mean()
    # Predict reconstructed images
    reconstructed_A = decoder([mean_style_A, content_A]).mean()
    reconstructed_B = decoder([mean_style_B, content_B]).mean()
    # Translate images
    translated_A = decoder([mean_style_B, content_A]).mean()
    translated_B = decoder([mean_style_A, content_B]).mean()

    return style_A, style_B, reconstructed_A, reconstructed_B, translated_A, translated_B

# Load model
decoder.load_weights("./decoder.h5")
encoder_style.load_weights("./encoder_style.h5")
encoder_content.load_weights("./encoder_content.h5")

# Initialize variables for the manova
style_A = tf.constant([[]]*BATCH_SIZE)
style_B = tf.constant([[]]*BATCH_SIZE)
for i, (A, B) in enumerate(test):
    # Get images, reconstructions, and translations
    s_A, s_B, reconstructed_A, reconstructed_B, translated_A, translated_B = inference_call(A, B)
    style_A = tf.concat([style_A, s_A], axis=0)
    style_B = tf.concat([style_B, s_B], axis=0)
    # Save images
    for j in range(BATCH_SIZE):
        tf.io.write_file(f"./samples/reconstructed_A/{i*BATCH_SIZE+j}.png", contents=tf.io.encode_png(tf.cast(reconstructed_A[j]*255, tf.uint8)))
        tf.io.write_file(f"./samples/reconstructed_B/{i*BATCH_SIZE+j}.png", contents=tf.io.encode_png(tf.cast(reconstructed_B[j]*255, tf.uint8)))
        tf.io.write_file(f"./samples/translated_A/{i*BATCH_SIZE+j}.png", contents=tf.io.encode_png(tf.cast(translated_A[j]*255, tf.uint8)))
        tf.io.write_file(f"./samples/translated_B/{i*BATCH_SIZE+j}.png", contents=tf.io.encode_png(tf.cast(translated_B[j]*255, tf.uint8)))

# Plot models
tf.keras.utils.plot_model(model=decoder, dpi=384, to_file="./plots/DecoderArchitecture.png", show_layer_names=False);
tf.keras.utils.plot_model(model=encoder_style, dpi=384, to_file="./plots/EncoderArchitecture.png", show_layer_names=False);
tf.keras.utils.plot_model(model=encoder_style.layers[3], dpi=384, to_file="./plots/InceptionArchitecture.png", show_layer_names=False);
tf.keras.utils.plot_model(model=encoder_style.layers[2], dpi=384, to_file="./plots/DownsamplerArchitecture.png", show_layer_names=False);

# Interpolations
image1 = tf.io.read_file("../../data/vangogh2photo/testA/00673.jpg")
image1 = tf.io.decode_image(image1, dtype=tf.float32)
image1 = tf.expand_dims(image1, axis=0)
style_A = encoder_style(list(test_A.take(1))).mean()
style_A = tf.reduce_mean(style_A, axis=0, keepdims=True)
style_B = encoder_style(list(test_B.take(1))).mean()
style_B = tf.reduce_mean(style_B, axis=0, keepdims=True)
inter1 = style_A*2/3 + style_B*1/3
inter2 = style_A*1/3 + style_B*2/3
content1 = encoder_content(image1).mean()
interpolations1 = decoder([tf.concat([style_A, inter1, inter2, style_B], axis=0), tf.repeat(content1, 4, axis=0)]).mean()
for i in range(len(interpolations1)):
    tf.io.write_file(f"./samples/novel/interpolation1_{i}.png", contents=tf.io.encode_png(tf.cast(interpolations1[i]*255, tf.uint8)))
    tf.io.write_file(f"./samples/novel/interpolation1_{i}.png", contents=tf.io.encode_png(tf.cast(interpolations1[i]*255, tf.uint8)))
    tf.io.write_file(f"./samples/novel/interpolation1_{i}.png", contents=tf.io.encode_png(tf.cast(interpolations1[i]*255, tf.uint8)))
    tf.io.write_file(f"./samples/novel/interpolation1_{i}.png", contents=tf.io.encode_png(tf.cast(interpolations1[i]*255, tf.uint8)))
image2 = tf.io.read_file("../../data/vangogh2photo/testB/2014-08-05 16_20_33.jpg")
image2 = tf.io.decode_image(image2, dtype=tf.float32)
image2 = tf.expand_dims(image2, axis=0)
content2 = encoder_content(image2).mean()
interpolations2 = decoder([tf.concat([style_B, inter2, inter1, style_A], axis=0), tf.repeat(content2, 4, axis=0)]).mean()
for i in range(len(interpolations2)):
    tf.io.write_file(f"./samples/novel/interpolation2_{i}.png", contents=tf.io.encode_png(tf.cast(interpolations2[i]*255, tf.uint8)))
    tf.io.write_file(f"./samples/novel/interpolation2_{i}.png", contents=tf.io.encode_png(tf.cast(interpolations2[i]*255, tf.uint8)))
    tf.io.write_file(f"./samples/novel/interpolation2_{i}.png", contents=tf.io.encode_png(tf.cast(interpolations2[i]*255, tf.uint8)))
    tf.io.write_file(f"./samples/novel/interpolation2_{i}.png", contents=tf.io.encode_png(tf.cast(interpolations2[i]*255, tf.uint8)))

# Generate entirely random images
rand_style = encoder_style.layers[-1].activity_regularizer.distribution_b.sample(3)
rand_content = encoder_content.layers[-1].activity_regularizer.distribution_b.sample(3)
rand = decoder([rand_style, rand_content]).mean()
for i in range(len(rand)):
    tf.io.write_file(f"./samples/novel/rand{i}.png", contents=tf.io.encode_png(tf.cast(rand[i]*255, tf.uint8)))
# Generate random Van Goghs
rand_content_style_A = decoder([tf.repeat(style_A, repeats=3, axis=0), rand_content]).mean()
for i in range(len(rand_content_style_A)):
    tf.io.write_file(f"./samples/novel/rand_content_style_A{i}.png", contents=tf.io.encode_png(tf.cast(rand_content_style_A[i]*255, tf.uint8)))
# Generate random Starry Night
image3 = tf.io.read_file("../../data/vangogh2photo/testB/2014-08-05 16_20_33.jpg")
image3 = tf.io.decode_image(image2, dtype=tf.float32)
image3 = tf.expand_dims(image2, axis=0)
style_A_specific = encoder_style(image2).mean()
rand_content_style_A_specific = decoder([tf.repeat(style_A, repeats=3, axis=0), rand_content]).mean()
for i in range(len(rand_content_style_A)):
    tf.io.write_file(f"./samples/novel/rand_content_style_A{i}.png", contents=tf.io.encode_png(tf.cast(rand_content_style_A[i]*255, tf.uint8)))
# Generate random photographs
rand_content_style_B = decoder([tf.repeat(style_B, repeats=3, axis=0), rand_content]).mean()
for i in range(len(rand_content_style_B)):
    tf.io.write_file(f"./samples/novel/rand_content_style_B{i}.png", contents=tf.io.encode_png(tf.cast(rand_content_style_B[i]*255, tf.uint8)))
# Generate random style using Van Gogh content
rand_style_content_A = decoder([rand_style, tf.repeat(content1, repeats=3, axis=0)]).mean()
for i in range(len(rand_style_content_A)):
    tf.io.write_file(f"./samples/novel/rand_style_content_A{i}.png", contents=tf.io.encode_png(tf.cast(rand_style_content_A[i]*255, tf.uint8)))
# Generate random style using photo content
rand_style_content_B = decoder([rand_style, tf.repeat(content2, repeats=3, axis=0)]).mean()
for i in range(len(rand_style_content_B)):
    tf.io.write_file(f"./samples/novel/rand_style_content_B{i}.png", contents=tf.io.encode_png(tf.cast(rand_style_content_B[i]*255, tf.uint8)))

## Evaluation

In [3]:
import numpy as np
import scipy

# Print Kruskal-Wallis statistics
print("Stylistic latent samples of photographs, H-statistic: ", scipy.stats.kruskal(*style_B))
print("Stylistic latent samples of Van Goghs, H-statistic: ", scipy.stats.kruskal(*style_A))

# Get inception model
inception = tf.keras.applications.InceptionV3(include_top=False, pooling="avg", input_shape=INPUT_SHAPE)
# Get data
reconstructed_A = tf.data.Dataset.from_generator(
    gen,
    args=["./samples/reconstructed_A"],
    output_signature=tf.TensorSpec(shape=INPUT_SHAPE, dtype=tf.float32),
).batch(BATCH_SIZE)
reconstructed_B = tf.data.Dataset.from_generator(
    gen,
    args=["./samples/reconstructed_B"],
    output_signature=tf.TensorSpec(shape=INPUT_SHAPE, dtype=tf.float32),
).batch(BATCH_SIZE)
# Calculate Frechét Inception Distance
def fid(reconstructed, true):
    # Get inception predictions
    inner1 = inception.predict(reconstructed)
    inner2 = inception.predict(true)

    # Calculate FID
    mu1, sigma1 = tf.reduce_mean(inner1, axis=0), np.cov(inner1, rowvar=False)
    mu2, sigma2 = tf.reduce_mean(inner2, axis=0), np.cov(inner2, rowvar=False)
    mse = tf.reduce_sum(tf.square(mu1 - mu2))
    covmean = scipy.linalg.sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real 
    return mse + np.trace(sigma1 + sigma2 - 2.0 * covmean)
print(f"FID paintings: {fid(reconstructed_A, test_A)}")
print(f"FID photographs: {fid(reconstructed_B, test_B)}")