In [None]:
# Import libraries and modules
import datetime
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import tensorflow as tf
print(np.__version__)
print(tf.__version__)

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

from proganomaly_modules.training_module.trainer import custom_layers

from proganomaly_modules.inference_module import image_utils
from proganomaly_modules.inference_module import gan_inference
from proganomaly_modules.inference_module import inference_inputs

## Prediction

### Set batch size.

In [None]:
batch_size_Z = 8
batch_size_query_images = 8

### Get Z.

In [None]:
tf.random.set_seed(1234)
Z = tf.random.normal(shape=(batch_size_Z, 512), mean=0.0, stddev=1.0, dtype=tf.float32, seed=1)

### Get query images.

In [None]:
# dataset_name = "normal_skin"
dataset_name = "bach_breast"
# dataset_name = "mnist"
# dataset_name = "cifar10_car"
# dataset_name = "cifar10"
# dataset_name = "celeba_hq"
if dataset_name == "normal_skin":
    size = 1024
    block_idx = int(math.log(size, 2)) - 2
    dataset = inference_inputs.read_dataset(
        file_pattern="gs://.../TF_DIR/{0}/{0}.svs.{1}.tfrecords".format(
            slide_name, 8 - block_idx
        ),
        batch_size=batch_size_query_images,
        block_idx=block_idx,
        params={
            "use_multiple_resolution_records": True,
            "tf_record_example_schema": [
                {
                    "name": "image/encoded",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "str"
                },
                {
                    "name": "image/name",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "str"
                },
                {
                    "name": "image/width",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "int"
                },
                {
                    "name": "image/height",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "int"
                },
                {
                    "name": "image/rescale_factor",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "int"
                },
                {
                    "name": "image/rescale_factor",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "int"
                }
            ],
            "image_feature_name": "image/encoded",
            "image_encoding": "png",
            "image_predownscaled_height": 1024,
            "image_predownscaled_width": 1024,
            "image_depth": 3,
            "label_feature_name": "",
            "input_fn_autotune": False,
            "generator_projection_dims": [4, 4, 512]
        }
    )().take(1)

    for batch in dataset:
        numpy_batch = {k: v.numpy() for k, v in batch.items()}
elif dataset_name == "bach_breast":
    size = 512
    block_idx = int(math.log(size, 2)) - 2
    dataset = inference_inputs.read_dataset(
        file_pattern="gs://.../BACH/train/{0}/{0}_L{1}.tfrecords".format(
            "{slide_name}", 8 - block_idx
        ),
        batch_size=batch_size_query_images,
        block_idx=block_idx,
        params={
            "use_multiple_resolution_records": True,
            "tf_record_example_schema": [
                {
                    "name": "image/encoded",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "str"
                },
                {
                    "name": "image/name",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "str"
                },
                {
                    "name": "image/width",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "int"
                },
                {
                    "name": "image/height",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "int"
                }
            ],
            "image_feature_name": "image/encoded",
            "image_encoding": "png",
            "image_predownscaled_height": 512,
            "image_predownscaled_width": 512,
            "image_depth": 3,
            "label_feature_name": "",
            "input_fn_autotune": False,
            "generator_projection_dims": [4, 4, 512]
        }
    )().take(1)

    for batch in dataset:
        numpy_batch = {k: v.numpy() for k, v in batch.items()}
elif dataset_name == "mnist":
    size = 32
    dataset = training_inputs.mnist_dataset(
        batch_size=batch_size_query_images,
        block_idx=3,
        params={
            "input_fn_autotune": False
        },
        training=False
    )().take(1)

    for batch in dataset:
        features, label = batch
        numpy_batch = {
            k: image_utils.descale_images(v.numpy())
            for k, v in features.items()
        }
elif dataset_name == "cifar10_car":
    size = 32
    dataset = inference_inputs.read_dataset(
        file_pattern="data/cifar10_car/test_{0}x{0}_0.tfrecord".format(size),
        batch_size=batch_size_query_images,
        block_idx=3,
        params={
            "use_multiple_resolution_records": False,
            "tf_record_example_schema": [
                {
                    "name": "image_raw",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "str"
                },
                {
                    "name": "label",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "int"
                }
            ],
            "image_feature_name": "image_raw",
            "image_encoding": "raw",
            "image_predownscaled_height": 32,
            "image_predownscaled_width": 32,
            "image_depth": 3,
            "label_feature_name": "label",
            "input_fn_autotune": False,
            "generator_projection_dims": [4, 4, 512]
        }
    )().take(1)

    for batch in dataset:
        features, label = batch
        numpy_batch = {k: v.numpy() for k, v in features.items()}
elif dataset_name == "cifar10":
    size = 32
    dataset = inference_inputs.read_dataset(
        file_pattern="gs://.../data/cifar10/test_{0}x{0}_0.tfrecord".format(
            size
        ),
        batch_size=batch_size_query_images,
        block_idx=3,
        params={
            "use_multiple_resolution_records": False,
            "tf_record_example_schema": [
                {
                    "name": "image_raw",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "str"
                },
                {
                    "name": "label",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "int"
                }
            ],
            "image_feature_name": "image_raw",
            "image_encoding": "raw",
            "image_predownscaled_height": 32,
            "image_predownscaled_width": 32,
            "image_depth": 3,
            "label_feature_name": "label",
            "input_fn_autotune": False,
            "generator_projection_dims": [4, 4, 512]
        }
    )().take(1)

    for batch in dataset:
        features, label = batch
        numpy_batch = {k: v.numpy() for k, v in features.items()}
elif dataset_name == "celeba_hq":
    size = 1024
    dataset = inference_inputs.read_dataset(
        file_pattern="gs://.../data/celeba_hq/train-00000-of-00080",
        batch_size=batch_size_query_images,
        block_idx=8,
        params={
            "use_multiple_resolution_records": False,
            "tf_record_example_schema": [
                {
                    "name": "image_raw",
                    "type": "FixedLen",
                    "shape": [],
                    "dtype": "str"
                }
            ],
            "image_feature_name": "image_raw",
            "image_encoding": "jpeg",
            "image_predownscaled_height": 1024,
            "image_predownscaled_width": 1024,
            "image_depth": 3,
            "label_feature_name": "",
            "input_fn_autotune": False,
            "generator_projection_dims": [4, 4, 512]
        }
    )().take(1)

    for batch in dataset:
        features = batch
        numpy_batch = {k: v.numpy() for k, v in features.items()}

query_images = numpy_batch["image"]

print(query_images.shape)
image_utils.plot_images(images=query_images, depth=3, num_rows=8)

### Plot exports.

In [None]:
predictions_by_growth = gan_inference.plot_all_exports_by_architecture(
    Z=Z,
    query_images=query_images,
    exports_on_gcs=True,
    export_start_idx=0,
    export_end_idx=17,
    max_size=1024,
    only_output_growth_set={i for i in range(9)},
    num_rows=1,
    generator_architecture="berg",
    overrides={
        "output_dir": "gs://.../trained_models/experiment",

        "export_all_growth_phases": False,

        "output_generated_images": True,
        "output_encoded_generated_images": True,

        "output_query_images": True,

        "output_query_encoded_images": True,

        "output_query_anomaly_images_sigmoid": True,
        "output_query_anomaly_images_linear": True,

        "output_query_mahalanobis_distances": True,
        "output_query_mahalanobis_distance_images_sigmoid": True,
        "output_query_mahalanobis_distance_images_linear": True,

        "output_query_pixel_anomaly_flags": True,

        "output_query_anomaly_scores": False,
        "output_query_anomaly_flags": False
    }
)

In [None]:
predictions_by_growth = gan_inference.plot_all_exports_by_architecture(
    Z=None,
    query_images=query_images,
    exports_on_gcs=False,
    export_start_idx=0,
    export_end_idx=17,
    max_size=1024,
    only_output_growth_set={i for i in range(9)},
    num_rows=1,
    generator_architecture="GANomaly",
    overrides={
        "output_dir": "gs://.../trained_models/experiment",

        "export_all_growth_phases": False,

        "output_query_images": True,

        "output_query_gen_encoded_images": True,
        "output_query_enc_encoded_images": True,

        "output_query_anomaly_images_sigmoid": True,
        "output_query_anomaly_images_linear": True,

        "output_query_mahalanobis_distances": True,
        "output_query_mahalanobis_distance_images_sigmoid": True,
        "output_query_mahalanobis_distance_images_linear": True,

        "output_query_pixel_anomaly_flags": True,

        "output_query_anomaly_scores": False,
        "output_query_anomaly_flags": False
    }
)