# Anomaly Detection: MNIST vs. TF Flowers
The following `Jupyter Notebook` explores the use of *anomaly detection*: first training a simple *autoencoder* (the fully connected `MinNDAE` model), and exploring the *reconstruction error*.

## Setup
Need to get the necessary packages ...

In [None]:
# check for colab
if "google.colab" in str(get_ipython()):
  # install colab dependencies
  !pip install git+https://github.com/DiogenesAnalytics/autoencoder

## Get MNIST Data
Wille use `keras.datasets` to get the `MNIST` dataset, and then do some *normalizing* and *reshaping* to prepare it for the *autoencoder*.

In [None]:
# get necessary libs for data/preprocessing
import tensorflow as tf
from keras.datasets import mnist

# load the data
(x_train, _), (x_test, _) = mnist.load_data()

# preprocess the data (normalize)
x_train = x_train.astype("float32") / 255.
x_test = x_test.astype("float32") / 255.

# add grayscale dimension
x_train = tf.expand_dims(x_train, axis=-1)
x_test = tf.expand_dims(x_test, axis=-1)

# convert to tf datasets
train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train))
test_ds = tf.data.Dataset.from_tensor_slices((x_test, x_test))

# set a few params
BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

# update with batch/buffer size
train_ds = train_ds.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_ds = test_ds.batch(BATCH_SIZE)

## Get tf_flowers Data
The [TensorFlow Flowers](https://www.tensorflow.org/datasets/catalog/tf_flowers) dataset first needs to be downloaded, and then preprocessed.

In [None]:
# libs for tf flowers data
import keras
import pathlib

# data location
DATASET_URL = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"

# download, get path, and convert to pathlib obj
TF_FLOWERS_DATA_DIR = pathlib.Path(
    keras.utils.get_file("flower_photos", origin=DATASET_URL, untar=True, cache_dir="./data/keras")
)

In [None]:
# get keras image dataset util func
from keras.utils import image_dataset_from_directory

# create normalization func
def normalize(x):
    return x / 255.

# use keras util to load raw images into tensorflow.data.Dataset
anomalous_data = image_dataset_from_directory(
  TF_FLOWERS_DATA_DIR,
  labels=None,
  color_mode="grayscale",
  validation_split=None,
  shuffle=True,
  subset=None,
  seed=42,
  image_size=(28, 28),
  batch_size=3670,
).map(normalize)

## Autoencoder Training
Finally the *autoencoder* can be trained ...

In [None]:
# get libs for training ae
from autoencoder.model.minimal import MinNDAE, MinNDParams

# seupt config
config = MinNDParams(
    l0={"input_shape": (28, 28, 1)},
    l2={"units": 32 * 1},
    l3={"units": 28 * 28 * 1},
    l4={"target_shape": (28, 28, 1)},
)

# get ae instance
autoencoder = MinNDAE(config)

# check network topology
autoencoder.summary()

In [None]:
# get code for callbacks and custom loss function
from autoencoder.training import build_anomaly_loss_function
from keras.callbacks import EarlyStopping

# create callback
early_stop_callback = EarlyStopping(monitor="val_anomaly_diff", patience=2)

# get custom loss func
custom_loss = build_anomaly_loss_function(next(iter(anomalous_data)), autoencoder)

# compile ae
autoencoder.compile(
    optimizer="adam",
    loss=custom_loss,
    metrics=[custom_loss],
)

# begin model fit
autoencoder.fit(
    x=train_ds,
    epochs=10**2,
    validation_data=test_ds,
    callbacks=[early_stop_callback],
)

In [None]:
# view training loss
autoencoder.training_history()

## Reconstruction Error Distribution
Now let us take peak into this dataset and see how well the *autoencoder* is working as an *anomaly detector* (i.e. how **low** vs. how **high** the *reconstruction* error is for the training and anomalous datasets respectively).

In [None]:
# get custom anomaly detection class
from autoencoder.data.anomaly import AnomalyDetector

# get mnist instance
mnist_recon_error = AnomalyDetector(autoencoder, test_ds, axis=(1, 2, 3))

# calculate recon error
mnist_recon_error.calculate_error()

In [None]:
# get tf flowers instance
tfflower_recon_error = AnomalyDetector(autoencoder, anomalous_data)

# calculate recon error
tfflower_recon_error.calculate_error()

In [None]:
# turn on interactive plot
%matplotlib widget

In [None]:
# now compare recon error distributions
mnist_recon_error.histogram(
    "MNIST Anomaly Detection Using TF Flowers: MinNDAE",
    label="mnist",
    bins=[100, 100],
    additional_data=[tfflower_recon_error], 
    additional_labels=["tf_flowers"],
)