<a href="https://colab.research.google.com/github/YB-Sung/DXIC_Lab_13_DL4_Auto-Encoder/blob/main/Auto_Encoder_for_Image_Anomaly_Detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [ LG전자 H&A DX Intensive Course - Auto-Encoder for Anomaly Detection ]

Auto-Encoder를 활용한 image anomaly detection

# Import modules

In [None]:
from glob import glob
import os
import random
import numpy as np
import cv2
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.metrics import roc_auc_score

import matplotlib.pyplot as plt
import seaborn as sns

custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", rc=custom_params)

# Functions

In [None]:
def torch_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
    # CUDA randomness
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    np.random.seed(random_seed)
    random.seed(random_seed)
    os.environ['PYTHONHASHSEED'] = str(random_seed)


def train(
    model, dataloader, criterion, optimizer, log_interval: int, device: str) -> list:

    total_loss = []

    model.train()
    for i, (inputs, _, _) in enumerate(dataloader):

        # convert device
        inputs = inputs.to(device)

        # model outputs
        outputs = model(inputs)

        # loss
        loss = criterion(inputs, outputs).mean()
        total_loss.append(loss.item())

        # calculate gradients
        loss.backward()

        # update model weights
        optimizer.step()
        optimizer.zero_grad()

        # log learning history
        if i % log_interval == 0 or (i+1) == len(dataloader):
            print(f"{'TRAIN':5s} [{i+1:5d}/{len(dataloader):5d}] loss: {np.mean(total_loss):.4f}")

    # average loss
    avg_loss = np.mean(total_loss)

    return avg_loss

def test(
    model, dataloader, criterion, log_interval: int, device: str) -> list:

    # for image-level auroc
    total_loss_img = []
    total_targets = []
    # for pixel-level auroc
    total_masks = []
    total_loss_pixel = []

    torch_seed(223)
    model.eval()

    with torch.no_grad():
        for i, (inputs, masks, targets) in enumerate(dataloader):
            # get masks
            total_masks.append(masks.numpy())

            # get targets
            total_targets.extend(targets.tolist())

            # convert device
            inputs = inputs.to(device)

            # model outputs
            outputs = model(inputs)

            # loss
            loss = criterion(inputs, outputs)
            total_loss_img.extend(loss.flatten(start_dim=1).max(dim=1)[0].cpu().tolist())
            total_loss_pixel.append(loss.max(dim=1)[0].cpu().numpy())

            # log learning history
            if i % log_interval == 0 or (i+1) == len(dataloader):
                print(f"{'TSET':5s} [{i+1:5d}/{len(dataloader):5d}] loss: {np.mean(total_loss_img):.4f}")

    # image-level auroc
    auroc_img = roc_auc_score(total_targets, total_loss_img)

    # pixel-level auroc
    total_loss_pixel = np.vstack(total_loss_pixel).reshape(-1)
    total_masks = np.vstack(total_masks).reshape(-1)
    auroc_pixel = roc_auc_score(total_masks, total_loss_pixel)

    # return
    return auroc_img, auroc_pixel


def fit(
    model, trainloader, testloader, criterion, optimizer,
    epochs: int, log_interval: int, device: str) -> list:

    train_history = []
    test_history_auroc_img = []
    test_history_auroc_pixel = []

    # fitting model
    for i in range(epochs):
        print(f'\nEpoch: [{i+1}/{epochs}]')
        train_loss = train(
            model        = model,
            dataloader   = trainloader,
            criterion    = criterion,
            optimizer    = optimizer,
            log_interval = log_interval,
            device       = device
        )

        test_auroc_img, test_auroc_pixel = test(
            model        = model,
            dataloader   = testloader,
            criterion    = criterion,
            log_interval = log_interval,
            device       = device
        )

        print(f'\nTest AUROC-image: {test_auroc_img:.4f}, AUROC-pixel: {test_auroc_pixel:.4f}')

        # show results
        with torch.no_grad():
            test_category = testloader.dataset.category
            fig, ax = plt.subplots(2, len(test_category), figsize=(2*len(test_category), 5))

            file_list_cat = list(map(lambda x: x.split('/')[-2], testloader.dataset.file_list))

            for i, c in enumerate(test_category):
                # select image per category
                idx = np.where(np.array(file_list_cat) == c)[0][0]
                img, mask, _ = testset[idx]

                # inference
                output = model(img.unsqueeze(0).to(device)).cpu()[0]

                # show image
                ax[0, i].imshow(img.permute(1,2,0))
                ax[1, i].imshow(output.permute(1,2,0))

                # axis off
                ax[0, i].axis('off')
                ax[1, i].axis('off')

                # set title
                ax[0, i].set_title(f"{c}\nimage")
                ax[1, i].set_title(f"{c}\nreconstruction")
            plt.tight_layout()
            plt.show()

        # stack history
        train_history.append(train_loss)
        test_history_auroc_img.append(test_auroc_img)
        test_history_auroc_pixel.append(test_auroc_pixel)

    return train_history, test_history_auroc_img, test_history_auroc_pixel


def figure(
    all_train_history: list, all_test_history_auroc_img: list,
    all_test_history_auroc_pixel: list, all_exp_name: list) -> None:

    fig, ax = plt.subplots(1, 3, figsize=(15,5))

    # train line plot
    for i, (train_h, exp_name) in enumerate(zip(all_train_history, all_exp_name)):
        sns.lineplot(
            x     = range(1, len(train_h)+1),
            y     = train_h,
            label = exp_name,
            ax    = ax[0]
        )

    # test image-level AUROC lineplot
    for i, (test_h, exp_name) in enumerate(zip(all_test_history_auroc_img, all_exp_name)):
        sns.lineplot(
            x     = range(1, len(test_h)+1),
            y     = test_h,
            label = exp_name,
            ax    = ax[1]
        )

    # test pixel-level AUROC lineplot
    for i, (test_h, exp_name) in enumerate(zip(all_test_history_auroc_pixel, all_exp_name)):
        sns.lineplot(
            x     = range(1, len(test_h)+1),
            y     = test_h,
            label = exp_name,
            ax    = ax[2]
        )

    # set y axis label
    ax[0].set_ylabel('MSE Loss')
    ax[1].set_ylabel('AUROC(image-level)')
    ax[2].set_ylabel('AUROC(pixel-level)')

    # set x axis label
    ax[0].set_xlabel('Epochs')
    ax[1].set_xlabel('Epochs')
    ax[2].set_xlabel('Epochs')

    # set title
    ax[0].set_title('Train loss history')
    ax[1].set_title('Test AUROC(image-level) history')
    ax[2].set_title('Test AUROC(pixel-level) history')

    # set y value limit
    max_train = np.max(all_train_history)

    ax[0].set_ylim(0, max_train+0.01)
    ax[1].set_ylim(0, 1)
    ax[2].set_ylim(0, 1)

    # set legend
    ax[0].legend(loc='upper left')
    ax[1].legend(loc='upper left')
    ax[2].legend(loc='upper left')
    plt.tight_layout()
    plt.show()

# Configuration for experiments

In [None]:
class Config:
    # dataset 관련 parameters
    datadir = './data'
    target = 'bottle'
    image_size = [224, 224]

    # training 관련 parameters
    epochs = 20
    batch_size = 8
    test_batch_size = 128
    learning_rate = 0.01
    num_workers = 2
    log_interval = 2000

    # device
    device = 'cuda'

    # seed
    seed = 223

cfg = Config()

# Load dataset and dataloader

**Download data**
- MVTec AD [ [link](https://www.mvtec.com/company/research/datasets/mvtec-ad) ]

In [None]:
!wget -P './data' 'https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420937370-1629951468/bottle.tar.xz'
!tar Jxvf ./data/bottle.tar.xz -C ./data

```bash
./data/bottle
├── ground_truth
│   ├── broken_large
│   ├── broken_small
│   └── contamination
├── license.txt
├── readme.txt
├── test
│   ├── broken_large
│   ├── broken_small
│   ├── contamination
│   └── good
└── train
    └── good
```

In [None]:
print('[ trainset ]')
print(f"train good images: {len(glob(os.path.join(cfg.datadir, cfg.target, 'train/good/*')))}")
print('\n[ testset ]')
testdir = os.path.join(cfg.datadir, cfg.target, 'test')
for name in os.listdir(testdir):
    print(f"test {name} images: {len(glob(os.path.join(testdir, name, '*')))}")

In [None]:
class MVTecAD(Dataset):
    def __init__(
        self, datadir: str, target: str, train: bool,
        img_size: list, transform: transforms.Compose):

        self.datadir = os.path.join(datadir, target)
        self.train = train

        self.category = os.listdir(os.path.join(self.datadir, 'train' if train else 'test'))

        self.img_size = img_size
        self.transform = transform

        self.file_list = glob(os.path.join(self.datadir, 'train' if self.train else 'test', '*/*'))


    def __getitem__(self, idx):
        file_path = self.file_list[idx]

        # image
        img = cv2.imread(file_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, dsize=self.img_size)

        # target
        target = 0 if 'good' in file_path else 1

        # mask
        if 'good' in file_path:
            mask = np.zeros(self.img_size, dtype=np.float32)
        else:
            mask = cv2.imread(
                file_path.replace('test','ground_truth').replace('.png','_mask.png'),
                cv2.IMREAD_GRAYSCALE
            )
            mask = cv2.resize(mask, dsize=self.img_size).astype(bool).astype(int)

        img = self.transform(img)
        mask = torch.Tensor(mask).to(torch.int64)

        return img, mask, target

    def __len__(self):
        return len(self.file_list)

In [None]:
# define dataset and dataloader
trainset = MVTecAD(
    datadir   = cfg.datadir,
    target    = cfg.target,
    img_size  = cfg.image_size,
    transform = transforms.ToTensor(),
    train     = True
)

testset = MVTecAD(
    datadir   = cfg.datadir,
    target    = cfg.target,
    img_size  = cfg.image_size,
    transform = transforms.ToTensor(),
    train     = False
)

trainloader = DataLoader(trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
testloader = DataLoader(testset, batch_size=cfg.test_batch_size, shuffle=False, num_workers=cfg.num_workers)

In [None]:
fig, ax = plt.subplots(2, len(testset.category), figsize=(2*len(testset.category), 5))

file_list_cat = list(map(lambda x: x.split('/')[-2], testset.file_list))

for i, c in enumerate(testset.category):
    idx = np.where(np.array(file_list_cat) == c)[0][0]
    img, mask, _ = testset[idx]
    ax[0, i].imshow(img.permute(1,2,0))
    ax[1, i].imshow(mask, cmap='gray')

    # axis off
    ax[0, i].axis('off')
    ax[1, i].axis('off')

    # set title
    ax[0, i].set_title(c)
plt.tight_layout()
plt.show()

# Convolutional Auto-Encoder

In [None]:
class ConvolutionalAutoEncoder(nn.Module):
    def __init__(self, input_dim: int, dims: list):
        super().__init__()

        dims = [input_dim] + dims

        self.enc = nn.Sequential(*self.build_layer(dims=dims))
        self.dec = nn.Sequential(*self.build_layer(dims=dims[::-1], up=True))
        self.output = nn.Conv2d(
            in_channels  = input_dim,
            out_channels = input_dim,
            kernel_size  = 3,
            padding      = 1
        )

    def build_layer(self, dims, up=False):
        layer = []

        for i in range(1, len(dims)):
            if up:
                layer_i = [
                    nn.ConvTranspose2d(
                        in_channels  = dims[i-1],
                        out_channels = dims[i],
                        kernel_size  = 2,
                        stride       = 2
                    ),
                    nn.ReLU()
                ]
            else:
                layer_i = [
                    nn.Conv2d(
                        in_channels  = dims[i-1],
                        out_channels = dims[i],
                        kernel_size  = 3,
                        padding      = 1),
                    nn.ReLU(),
                    nn.MaxPool2d(kernel_size=2, stride=2)
                ]

            layer.extend(layer_i)

        return layer

    def encoder(self, x):
        out = self.enc(x)

        return out

    def decoder(self, out):
        out = self.dec(out)
        out = self.output(out)
        out = F.sigmoid(out)

        return out

    def forward(self, x):
        out = self.encoder(x)
        out = self.decoder(out)

        return out

## CAE - shallow

In [None]:
torch_seed(cfg.seed)
cae_shallow = ConvolutionalAutoEncoder(input_dim=3, dims=[32, 64])
cae_shallow.to(cfg.device)
print('load Convolutional Auto-Encoder')
print('The number of model parameters: ',sum([p.numel() for p in cae_shallow.parameters()]))

# set reduction to none
criterion = nn.MSELoss(reduction='none')
optimizer = Adam(cae_shallow.parameters(), lr=cfg.learning_rate)

In [None]:
torch_seed(cfg.seed)
train_history_cae_shallow, test_history_auroc_img_cae_shallow, test_history_auroc_pixel_cae_shallow = fit(
    model        = cae_shallow,
    trainloader  = trainloader,
    testloader   = testloader,
    criterion    = criterion,
    optimizer    = optimizer,
    epochs       = cfg.epochs,
    log_interval = cfg.log_interval,
    device       = cfg.device
)

In [None]:
all_train_history = [train_history_cae_shallow]
all_test_history_auroc_img = [test_history_auroc_img_cae_shallow]
all_test_history_auroc_pixel = [test_history_auroc_pixel_cae_shallow]
all_exp_name = ['CAE shallow']

figure(
    all_train_history            = all_train_history,
    all_test_history_auroc_img   = all_test_history_auroc_img,
    all_test_history_auroc_pixel = all_test_history_auroc_pixel,
    all_exp_name                 = all_exp_name
)

## CAE - deep

In [None]:
torch_seed(cfg.seed)
cae_deep = ConvolutionalAutoEncoder(input_dim=3, dims=[32, 64, 128, 256, 512])
cae_deep.to(cfg.device)
print('load Convolutional Auto-Encoder')
print('The number of model parameters: ',sum([p.numel() for p in cae_deep.parameters()]))

# set reduction to none
criterion = nn.MSELoss(reduction='none')
optimizer = Adam(cae_deep.parameters(), lr=cfg.learning_rate)

In [None]:
torch_seed(cfg.seed)
train_history_cae_deep, test_history_auroc_img_cae_deep, test_history_auroc_pixel_cae_deep = fit(
    model        = cae_deep,
    trainloader  = trainloader,
    testloader   = testloader,
    criterion    = criterion,
    optimizer    = optimizer,
    epochs       = cfg.epochs,
    log_interval = cfg.log_interval,
    device       = cfg.device
)

In [None]:
all_train_history.append(train_history_cae_deep)
all_test_history_auroc_img.append(test_history_auroc_img_cae_deep)
all_test_history_auroc_pixel.append(test_history_auroc_pixel_cae_deep)
all_exp_name.append('CAE deep')

figure(
    all_train_history            = all_train_history,
    all_test_history_auroc_img   = all_test_history_auroc_img,
    all_test_history_auroc_pixel = all_test_history_auroc_pixel,
    all_exp_name                 = all_exp_name
)

# Experiment results

In [None]:
auroc_img_list = [
    test_history_auroc_img_cae_shallow[-1],
    test_history_auroc_img_cae_deep[-1]
]

auroc_pixel_list = [
    test_history_auroc_pixel_cae_shallow[-1],
    test_history_auroc_pixel_cae_deep[-1]
]

pd.DataFrame({
    'Model'       : ['CAE shallow','CAE deep'],
    'AUROC(image)': auroc_img_list,
    'AUROC(pixel)': auroc_pixel_list
}).round(4)

# Anomaly score distribution

In [None]:
# category file list
file_list_cat = list(map(lambda x: x.split('/')[-2], testset.file_list))

# loss function
criterion = nn.MSELoss(reduction='none')

# inference
total_loss = []

cae_deep.eval()
with torch.no_grad():
    for (inputs, _, _) in testloader:
        inputs = inputs.to(cfg.device)
        outputs = cae_deep(inputs)
        loss = criterion(inputs, outputs)
        total_loss.extend(loss.flatten(start_dim=1).max(dim=1)[0].cpu().tolist())

In [None]:
sns.displot(
    x      = total_loss,
    hue    = file_list_cat,
    kind   = 'kde',
    fill   = True,
    aspect = 2
)
plt.title('Anomaly score distribution')
plt.show()

# Anomaly visualization

In [None]:
# category file list
file_list_cat = list(map(lambda x: x.split('/')[-2], testset.file_list))

# set row name
row_name = ['image', 'mask', 'anomaly region', 'image x anomaly']

fig, ax = plt.subplots(4, len(testset.category), figsize=(2*len(testset.category), 8))

# loss function
criterion = nn.MSELoss(reduction='none')

cae_deep.eval()
for i, c in enumerate(testset.category):
    # get index per category
    idx = np.where(np.array(file_list_cat) == c)[0][2]

    # get image and mask
    img, mask, _ = testset[idx]

    # get loss output
    with torch.no_grad():
        output = cae_deep(img.unsqueeze(0).to(cfg.device))[0].cpu()
        loss = criterion(img, output).max(dim=0)[0]

    # scaling
    loss = (loss-loss.min()) / (loss.max()-loss.min())

    # show image
    ax[0, i].imshow(img.permute(1,2,0))
    ax[1, i].imshow(mask, cmap='gray')
    ax[2, i].imshow(loss, cmap='gray')
    ax[3, i].imshow(img.permute(1,2,0)*0.5 + (loss*0.5).unsqueeze(-1))

    # axis off
    ax[0, i].axis('off')
    ax[1, i].axis('off')
    ax[2, i].axis('off')
    ax[3, i].axis('off')

    # set title
    for r_idx, ax_r in enumerate(ax[:, i]):
        ax_r.set_title(f'[{c}]\n{row_name[r_idx]}' if r_idx==0 else row_name[r_idx])

plt.tight_layout()
plt.show()