### DP Divergence Estimation Tutorial


In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from daml.datasets import DamlDataset
from daml.metrics.divergence import HP_MST

tf.random.set_seed(108)

#### Loading in data
Let's start by loading in tensorflow's MNIST dataset,
then we will examine it.


In [None]:
# Load in the mnist dataset from tensorflow datasets
(images, ds_info) = tfds.load(
    "mnist",
    split="train",
    with_info=True,
)  # type: ignore
tfds.visualization.show_examples(images,ds_info)
images = images.shuffle(images.cardinality())
images = [i["image"].numpy() for i in list(images.take(5000))]
images = np.array(images)

In [None]:
print("Number of samples: ", len(images))
print("Image shape:", images[0].shape)

#### Calculate initial divergence
Let's calculate the DP divergence between the first 2500 images and the second 2500 images from this sample.


In [None]:
im1 = images[0:2500]
im2 = images[2500:]

In [None]:
metric = HP_MST()
div = metric.evaluate(
    DamlDataset(im1.reshape((2500, -1))),
    DamlDataset(im2.reshape((2500, -1)))
)
print(div)

We estimate that the DP divergence between these (identically distributed) images sets is, as expected, close to 0.


#### Loading in corrupted data
Now let's load in a corrupted mnist dataset.


In [None]:
corr_images,ds_info = tfds.load("mnist_corrupted/translate",
split="train",
with_info=True,
)  # type: ignore  # type: ignore
tfds.visualization.show_examples(corr_images,ds_info)
corr_images = corr_images.shuffle(corr_images.cardinality())
corr_images = [i["image"].numpy() for i in list(corr_images.take(2500))]
corr_images = np.array(corr_images)

In [None]:
np.shape(corr_images)

#### Calculate corrupted divergence
Now lets calculate the DP Divergence between this corrupted dataset and the original images


In [None]:
div = metric.evaluate(
    DamlDataset(im1.reshape((2500, -1))),
    DamlDataset(corr_images.reshape((2500, -1)))
)
print(div)

We conclude that the translated MNIST images are significantly different from the original images.
