# DP Divergence Estimation Tutorial


In [None]:
try:
  import google.colab
  !pip install -q daml
except:
  pass

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
import numpy as np
import tensorflow_datasets as tfds

from daml.metrics.divergence import HP_MST

## Loading in data
Let's start by loading in tensorflow's MNIST dataset,
then we will examine it.


In [None]:
# Load in the mnist dataset from tensorflow datasets
images, ds_info = tfds.load(
    "mnist",
    split="train",
    with_info=True,
)
tfds.visualization.show_examples(images, ds_info)
images = images.shuffle(images.cardinality())
images = [i["image"].numpy() for i in list(images.take(5000))]
images = np.array(images)

In [None]:
print("Number of samples: ", len(images))
print("Image shape:", images[0].shape)

## Calculate initial divergence
Let's calculate the DP divergence between the first 2500 images and the second 2500 images from this sample.


In [None]:
data_a = images[0:2500].reshape((2500, -1))
data_b = images[2500:].reshape((2500, -1))

In [None]:
metric = HP_MST(data_a=data_a, data_b=data_b)
div = metric.evaluate()
print(div)

We estimate that the DP divergence between these (identically distributed) images sets is, as expected, at or close to 0.

## Loading in corrupted data
Now let's load in a corrupted mnist dataset.


In [None]:
corrupted, ds_info = tfds.load(
    "mnist_corrupted/translate",
    split="train",
    with_info=True,
)
tfds.visualization.show_examples(corrupted, ds_info)
corrupted = corrupted.shuffle(corrupted.cardinality())
corrupted = [i["image"].numpy() for i in list(corrupted.take(2500))]
corrupted = np.array(corrupted)

In [None]:
print("Number of corrupted samples: ", len(corrupted))
print("Corrupted image shape:", corrupted[0].shape)

## Calculate corrupted divergence
Now lets calculate the DP Divergence between this corrupted dataset and the original images


In [None]:
metric.data_b = corrupted.reshape((2500, -1))
div = metric.evaluate()
print(div)

We conclude that the translated MNIST images are significantly different from the original images.
