# [PG01] Unsupervised anomaly detection in industrial image data with autoencoders

In this notebook we are going to develop the final projet for the EAI course held by Christian Napoli.

## Imports & Download

In [None]:
# install the requirements
%pip install -r requirements.txt > /dev/null
# set to false if you already have the dataset
download_dataset = False 
if download_dataset:
    %cd dataset
    !bash dataset/download_dataset.sh
    %cd ..

In [None]:
import dataclasses
from src.data_module import MVTec_Dataset, MVTec_DataModule
from src.AE_simple import AE
from src.hyperparameters import Hparams
from src.train import train_model
from dataclasses import asdict
import matplotlib.pyplot as plt
import wandb
import torchvision
import pytorch_lightning as pl
import gc
# reproducibility stuff
import numpy as np
import random
import torch
np.random.seed(0)
random.seed(0)
torch.cuda.manual_seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True  # Note that this Deterministic mode can have a performance impact
torch.backends.cudnn.benchmark = False
_ = pl.seed_everything(0)
# to have a better workflow using notebook https://stackoverflow.com/questions/5364050/reloading-submodules-in-ipython
# these commands allow to update the .py codes imported instead of re-importing everything every time.
%load_ext autoreload
%autoreload 2
%env WANDB_NOTEBOOK_NAME = ./anomaly_detection.ipynb
gc.collect()

In [None]:
# login wandb to have the online logger. It is really useful since it stores all the plots and evolution of the model
# check also https://docs.wandb.ai/guides/integrations/lightning
wandb.login()

## Dataset test

Just to have a visual feedback and test our code, we plot some samples from the training set (only not anomalous samples) and test set (normal and anomalous)

In [None]:
hparams = asdict(Hparams())

In [None]:
MVTec_Data = MVTec_DataModule(hparams)
# to setup it takes ~2.5 minutes
MVTec_Data.setup()

In [None]:
# to make sure everything works we just plot a sample of our images
def plot_objects(images, 
                images_per_row, 
                border = 10, 
                pad_value = 1,
                title = 'Industrial images',
                figsize = (16,16)):
    plt.figure(figsize = figsize)
    plt.imshow(torchvision.utils.make_grid(images,images_per_row,border,pad_value=pad_value).permute(1, 2, 0))
    plt.title(title)
    plt.axis('off')

In [None]:
batch = iter(MVTec_Data.train_dataloader()).next()
plot_objects(MVTec_DataModule.denormalize(batch["img"][0:40]), images_per_row=8, title="Industrial images from training dataset")
batch2 = iter(MVTec_Data.val_dataloader()).next()
plot_objects(MVTec_DataModule.denormalize(batch2["img"][0:40]), images_per_row=8, title="Industrial images from validation dataset")

todo some statistics on the pixels of the training dataset to do a proper normalization

## Autoencoders

### CNN-AE l2 deterministic

In [None]:
# settings for the logger working in a team
team_name = "eai_project"
project_name = "EAI_Anomaly_Detection"
# to edit 
version_name = "Simple_AE_04"
run = wandb.init(entity=team_name, project=project_name, name = version_name, mode = "offline")

ae_hparams = asdict(Hparams())
data = MVTec_DataModule(ae_hparams)
model = AE(ae_hparams)
trainer = train_model(data,model, experiment_name = version_name, patience=20, metric_to_monitor="accuracy", mode="max", epochs = 150)

wandb.finish()

In [None]:
best_ckpt = "models/Simple_AE_01-epoch=71-avg_val_loss=0.0278.ckpt"
model_best = AE.load_from_checkpoint(best_ckpt, strict=False)

In [None]:
batch = iter(data.train_dataloader()).next()
plot_objects(MVTec_DataModule.denormalize(batch["img"][0:40]), images_per_row=8, title="Industrial images from training dataset")

batch_recon = model_best(batch["img"])
plot_objects(MVTec_DataModule.denormalize(batch_recon[0:40]), images_per_row=8, title="Industrial images from training dataset (reconstructed)")