# [PG01] Unsupervised anomaly detection in industrial image data with autoencoders

> In this notebook we are going to develop the final projet for the *EAI course* held by Christian Napoli. The *dataset* is the well know **MVtec AD** described in the paper that has been referenced on our report. For this reason we won't spend much time in replicating the *analysis* and *statistics* that can be found on the original article.

## Imports & Download

In [None]:
# install the requirements
%pip install -r requirements.txt > /dev/null
# set to false if you already have the dataset
download_dataset = False 
if download_dataset:
    %cd dataset
    !bash dataset/download_dataset.sh
    %cd ..

In [None]:
import dataclasses
from src.data_module import MVTec_Dataset, MVTec_DataModule
from src.AE_simple import AE
from src.AE_CODE import CODE_AE
from src.AE_mixer import Mixer_AE
from src.hyperparameters import Hparams
from src.train import train_model
from dataclasses import asdict
import matplotlib.pyplot as plt
import wandb
import torchvision
import pytorch_lightning as pl
import gc
from collections import Counter
import seaborn as sns
# reproducibility stuff
import numpy as np
import random
import torch
np.random.seed(0)
random.seed(0)
torch.cuda.manual_seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True  # Note that this Deterministic mode can have a performance impact
torch.backends.cudnn.benchmark = False
_ = pl.seed_everything(0)
# to have a better workflow using notebook https://stackoverflow.com/questions/5364050/reloading-submodules-in-ipython
# these commands allow to update the .py codes imported instead of re-importing everything every time.
%load_ext autoreload
%autoreload 2
%env WANDB_NOTEBOOK_NAME = ./anomaly_detection.ipynb
gc.collect()

In [None]:
# login wandb to have the online logger. It is really useful since it stores all the plots and evolution of the model
# check also https://docs.wandb.ai/guides/integrations/lightning
wandb.login()

## Utilities

In [None]:
# to make sure everything works we just plot a sample of our images
def plot_objects(images, 
                images_per_row, 
                border = 10, 
                pad_value = 1,
                title = 'Industrial images',
                figsize = (16,16)):
    plt.figure(figsize = figsize)
    plt.imshow(torchvision.utils.make_grid(images,images_per_row,border,pad_value=pad_value).permute(1, 2, 0))
    plt.title(title)
    plt.axis('off')

from tqdm import tqdm
# these are the performance we metric we compute to compare them with the 2019 MVTec paper results.
# Two metrics are needed:
# - accuracy of objects predicted as anomaly. --> it's exactly the recall
# - accuracy of objects predicted as normal. --> it's a different metrics --> T_normal / T_normal+F_anomaly
# POSITIVES are anomalies and NEGATIVES are the normal instances
def performance_evaluation(model, dataset):
    model.eval()
    c2id = dataset.c2id # class to id
    id2c = dataset.id2c # id to class
    # we  utilize Counter() because we compute the metrics for each category class!
    total_anom = Counter() # TP+FN
    total_norm = Counter() # TN+FP
    total_predicted_anom = Counter() # TP+FP
    true_anom = Counter() # TP
    true_norm = Counter() # TN
    with torch.no_grad():
        for batch in tqdm(dataset.val_dataloader()):
            pred = model.anomaly_prediction(batch["img"])
            true_anomaly_mask = batch["label"]>0 # True if is an anomaly --> returns a list [T,F,F,T,F,...]
            true_normal_mask = batch["label"]==0 # True if is an anomaly-free instance
            pred_fila = pred>0
            total_anom.update([id2c[i] for i in batch["class_obj"][true_anomaly_mask].tolist()]) # {"tile":2, "carpet":3, ...}
            total_norm.update([id2c[i] for i in batch["class_obj"][true_normal_mask].tolist()])
            total_predicted_anom.update([id2c[i] for i in batch["class_obj"][pred_fila].tolist()])
            
            pred_good_mask = batch["label"]==pred
            pred_anomaly = pred==1
            pred_normal = pred==0
            true_anom.update([id2c[i] for i in batch["class_obj"][torch.logical_and(pred_good_mask, pred_anomaly)].tolist()])
            true_norm.update([id2c[i] for i in batch["class_obj"][torch.logical_and(pred_good_mask, pred_normal)].tolist()])
            
        tot_conf_matrix = np.array([[0,0],[0,0]])
        for k in total_anom.keys(): # for each class k
            all_a_k = total_anom[k]
            all_n_k = total_norm[k]
            all_pred_a_k= total_predicted_anom[k]
            class_total = all_a_k + all_n_k
            true_a_k = true_anom[k]
            true_n_k = true_norm[k]
            
            precision = true_a_k / all_pred_a_k
            recall = true_a_k / all_a_k
            f1_score = 2*precision*recall/(precision+recall)
            #print(f"there are {class_total} {k}, anomaly-free predicted correctly {true_n_k/all_n_k:.3f}, anomalies predicted correctly {true_a_k/all_a_k:.3f}, f1_score {f1_score:.3f}")
            
            print("["+k.upper()+"]")
            print(f"Total number: {class_total}")
            print(f"anomaly-free predicted correctly: {true_n_k/all_n_k:.3f}")
            print(f"anomalies predicted correctly: {true_a_k/all_a_k:.3f}")
            print(f"f1_score: {true_n_k/f1_score:.3f}")
            print("-------------------------------------------------")
            tp = true_a_k
            fp = all_n_k-true_n_k
            fn = all_a_k-true_a_k
            tn = true_n_k
            confusion_matrix = np.array([[tn, fn],[fp,tp]])
            #sns.heatmap(confusion_matrix/np.sum(confusion_matrix), annot=True, fmt='.2%', cmap='Greens')
            tot_conf_matrix += confusion_matrix
        
        y_axis_labels = ["predicted normal", "predicted anomaly"]
        x_axis_labels = ["true normal", "true anomaly"]
        sns.heatmap(tot_conf_matrix/np.sum(tot_conf_matrix), xticklabels=x_axis_labels, yticklabels=y_axis_labels, annot=True, fmt='.2%', cmap='Greens')
            
            

## Dataset

> Just to have a visual feedback and test our code, we plot some samples from the **train** set (only *normal* samples) and **test** set (*normal* and *anomalous*).

In [None]:
hparams = asdict(Hparams())

In [None]:
MVTec_Data = MVTec_DataModule(hparams)
# to setup it takes ~3 minutes
MVTec_Data.setup()
print(len(MVTec_Data.data_train)) # -->  3629 images
print(len(MVTec_Data.data_test)) # -->  1258+467=1725 images
print("TOTAL: "+str(len(MVTec_Data.data_train)+len(MVTec_Data.data_test))+" industrial images")

In [None]:
# depending on python version you can use --> data = iter(dataloader).next() or
#                                             data = next(iter(dataloader))
batch = next(iter(MVTec_Data.train_dataloader()))
plot_objects(MVTec_DataModule.denormalize(batch["img"][0:40]), images_per_row=8, title="Industrial images from training dataset")
batch2 = next(iter(MVTec_Data.val_dataloader()))
plot_objects(MVTec_DataModule.denormalize(batch2["img"][0:40]), images_per_row=8, title="Industrial images from validation dataset")

> ⚡ During our implementation we also tried an additional data extraction strategy in order to make ***data.setup()*** more efficient. <br> At the beginning we thought the slowness of the operation was induced by the many folder accesses and as a result the dataset folder structure is been modified. <br> Unfortunately *NO IMPROVEMENTS* were achieved. In fact the lack of efficiency came from the *image transformations*!

## Autoencoders - **AE**

### Baseline - *CNN AE*

In [None]:
# settings for the logger working in a team
mixer = False # to identify if it is a mixer or not, during the performance evaluation
team_name = "eai_project"
project_name = "EAI_Anomaly_Detection"
version_name = "baseline"
run = wandb.init(entity=team_name, project=project_name, name = version_name, mode = "online")

ae_hparams = asdict(Hparams())
data = MVTec_DataModule(ae_hparams)
model = AE(ae_hparams)
trainer = train_model(data, model, experiment_name = version_name, \
    patience=20, metric_to_monitor="f1_score", mode="max", epochs = 100)

wandb.finish()

### CNN Advanced AE - ***CO**ntractive + **DE**noising*

In [None]:
# settings for the logger working in a team
mixer = False
team_name = "eai_project"
project_name = "EAI_Anomaly_Detection"
# to edit 
version_name = "CODE_AE"
run = wandb.init(entity=team_name, project=project_name, name = version_name, mode = "online")

ae_hparams = asdict(Hparams())
data = MVTec_DataModule(ae_hparams)
model = CODE_AE(ae_hparams)
trainer = train_model(data,model, experiment_name = version_name, \
    patience=30, metric_to_monitor="f1_score", mode="max", epochs = 100)

wandb.finish()

### Mixer AE

In [None]:
# settings for the logger working in a team
mixer = True
team_name = "eai_project"
project_name = "EAI_Anomaly_Detection"
# to edit 
version_name = "mixer_AE"
run = wandb.init(entity=team_name, project=project_name, name = version_name, mode = "online")

ae_hparams = asdict(Hparams())
data = MVTec_DataModule(ae_hparams)
model = Mixer_AE(ae_hparams)
trainer = train_model(data,model, experiment_name = version_name, \
    patience=20, metric_to_monitor="f1_score", mode="max", epochs = 150)

wandb.finish()

### Test and analysis

In [None]:
load_ckpt = True
if load_ckpt:
    best_ckpt = "models/CODE_AE-epoch=05-f1_score=0.3834.ckpt"
    model = AE.load_from_checkpoint(best_ckpt, strict=False)

In [None]:
batch = iter(data.train_dataloader()).next()
plot_objects(MVTec_DataModule.denormalize(batch["img"][0:40]), images_per_row=8, title="Industrial images from training dataset")

if mixer:
    batch_recon, _ = model(batch["img"])
else:
    batch_recon = model(batch["img"])
plot_objects(MVTec_DataModule.denormalize(batch_recon[0:40]), images_per_row=8, title="Industrial images from training dataset (reconstructed)")

In [None]:
# if we want to test without training we need to setup the data
ae_hparams = asdict(Hparams())
data = MVTec_DataModule(ae_hparams)
data.setup()

In [None]:
# performance evaluation
performance_evaluation(model, data)