# [PG01] Unsupervised anomaly detection in industrial image data with autoencoders

> In this notebook we are going to develop the final projet for the *EAI course* held by Christian Napoli. The *dataset* is the well know **MVtec AD** described in the paper that has been referenced on our report. For this reason we won't spend much time in replicating the *analysis* and *statistics* that can be found on the original article.

## Imports & Download

In [None]:
# install the requirements
%pip install -r requirements.txt > /dev/null
# set to false if you already have the dataset
download_dataset = False 
if download_dataset:
    %cd dataset
    !bash dataset/download_dataset.sh
    %cd ..

In [None]:
import dataclasses
from src.data_module import MVTec_Dataset, MVTec_DataModule
from src.AE_simple import AE
from src.AE_CODE import CODE_AE
from src.hyperparameters import Hparams
from src.train import train_model
from dataclasses import asdict
import matplotlib.pyplot as plt
import wandb
import torchvision
import pytorch_lightning as pl
import gc
# reproducibility stuff
import numpy as np
import random
import torch
np.random.seed(0)
random.seed(0)
torch.cuda.manual_seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True  # Note that this Deterministic mode can have a performance impact
torch.backends.cudnn.benchmark = False
_ = pl.seed_everything(0)
# to have a better workflow using notebook https://stackoverflow.com/questions/5364050/reloading-submodules-in-ipython
# these commands allow to update the .py codes imported instead of re-importing everything every time.
%load_ext autoreload
%autoreload 2
%env WANDB_NOTEBOOK_NAME = ./anomaly_detection.ipynb
gc.collect()

In [None]:
# login wandb to have the online logger. It is really useful since it stores all the plots and evolution of the model
# check also https://docs.wandb.ai/guides/integrations/lightning
wandb.login()

## Utilities

In [None]:
# to make sure everything works we just plot a sample of our images
def plot_objects(images, 
                images_per_row, 
                border = 10, 
                pad_value = 1,
                title = 'Industrial images',
                figsize = (16,16)):
    plt.figure(figsize = figsize)
    plt.imshow(torchvision.utils.make_grid(images,images_per_row,border,pad_value=pad_value).permute(1, 2, 0))
    plt.title(title)
    plt.axis('off')

# todo evaluate performance on the different input classes to understand which is performing better

## Dataset

> Just to have a visual feedback and test our code, we plot some samples from the **train** set (only *normal* samples) and **test** set (*normal* and *anomalous*).

In [None]:
hparams = asdict(Hparams())

In [None]:
MVTec_Data = MVTec_DataModule(hparams)
# to setup it takes ~3 minutes
MVTec_Data.setup()
print(len(MVTec_Data.data_train)) # -->  3629 images
print(len(MVTec_Data.data_test)) # -->  1258+467=1725 images
print("TOTAL: "+str(len(MVTec_Data.data_train)+len(MVTec_Data.data_test))+" industrial images")

In [None]:
# depending on python version you can use --> data = iter(dataloader).next() or
#                                             data = next(iter(dataloader))
batch = next(iter(MVTec_Data.train_dataloader()))
plot_objects(MVTec_DataModule.denormalize(batch["img"][0:40]), images_per_row=8, title="Industrial images from training dataset")
batch2 = next(iter(MVTec_Data.val_dataloader()))
plot_objects(MVTec_DataModule.denormalize(batch2["img"][0:40]), images_per_row=8, title="Industrial images from validation dataset")

> ⚡ During our implementation we also tried an additional data extraction strategy in order to make ***data.setup()*** more efficient. <br> At the beginning we thought the slowness of the operation was induced by the many folder accesses and as a result the dataset folder structure is been modified. <br> Unfortunately *NO IMPROVEMENTS* were achieved. In fact the lack of efficiency came from the *image transformations*!

## Autoencoders - **AE**

### Baseline - *CNN AE*

In [None]:
# settings for the logger working in a team
team_name = "eai_project"
project_name = "EAI_Anomaly_Detection"
version_name = "baseline_1"
run = wandb.init(entity=team_name, project=project_name, name = version_name, mode = "offline")

ae_hparams = asdict(Hparams())
data = MVTec_DataModule(ae_hparams)
model = AE(ae_hparams)
trainer = train_model(data, model, ae_hparams["batch_size"], experiment_name = version_name, patience=20, metric_to_monitor="auroc", mode="max", epochs = 100)

wandb.finish()

### CNN Advanced AE - ***CO**ntractive + **DE**noising*

In [None]:
# settings for the logger working in a team
team_name = "eai_project"
project_name = "EAI_Anomaly_Detection"
# to edit 
version_name = "advanced_AE_2"
run = wandb.init(entity=team_name, project=project_name, name = version_name, mode = "offline")

ae_hparams = asdict(Hparams())
data = MVTec_DataModule(ae_hparams)
model = CODE_AE(ae_hparams)
trainer = train_model(data,model, experiment_name = version_name, patience=20, metric_to_monitor="f1_score", mode="max", epochs = 150)

wandb.finish()

### Test and analysis

In [None]:
# best_ckpt = "models/Simple_AE_01-epoch=71-avg_val_loss=0.0278.ckpt"
# model = AE.load_from_checkpoint(best_ckpt, strict=False)

In [None]:
batch = iter(data.train_dataloader()).next()
plot_objects(MVTec_DataModule.denormalize(batch["img"][0:40]), images_per_row=8, title="Industrial images from training dataset")

batch_recon = model(batch["img"])
plot_objects(MVTec_DataModule.denormalize(batch_recon[0:40]), images_per_row=8, title="Industrial images from training dataset (reconstructed)")