# Convolutional AutoEncoder: Visual Inspection

This notebook provides visual insights on the effect of the convolutional autoencoder trained to carry out anomaly detection on the MASATI (v2) dataset.

Raúl Barba Rojas developed this notebook as part of his Master Thesis *Visual Anomaly Detection in Satellite Imagery* directed by Jorge Díez Peláez and co-directed by José Luis Espinosa Aranda.

## Libraries

The following python modules are required for the notebook to run smoothly:

In [1]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch

from aerial_anomaly_detection.models.autoencoder import AutoEncoder
from aerial_anomaly_detection.datasets import DataLoader
from aerial_anomaly_detection.datasets.masati_v2 import MASATIv2

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Data

Let us load the MASATI (v2) dataset using the code developed as part of the `aerial_anomaly_detection` package:

In [2]:
batch_size = 8
val_dataset = MASATIv2.load(r'..\..\data\processed\MASATI-v2', 'val')
val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = True)

## AutoEncoder

Let us load the trained AutoEncoder model:

In [3]:
autoencoder = AutoEncoder(1000, 32, 32).to(device)
autoencoder.load_state_dict(torch.load(r'..\..\models\AutoEncoder\autoencoder_l1000_w32_h32_bs256_lr001.pth', weights_only = True))

<All keys matched successfully>

## Visual inspection

Last, let us carry out the visual inspection of the autoencoder on validation data for simply one given batch:

In [4]:
num_batches = 5
(output_folder := Path('visual_output')).mkdir(exist_ok = True, parents = True)

autoencoder.eval()
with torch.inference_mode():
    for sample_name, X, _ in val_dataloader:
        X = X.to(device)
        y_pred = autoencoder(X).detach().cpu().numpy()
        fig, axs = plt.subplots(ncols = batch_size, nrows = 2, figsize = (10, 7))

        for idx_sample in range(batch_size):
            sample_img = (((X[idx_sample,...].cpu().numpy() + 1) / 2) * 255).astype(np.uint8).transpose(1, 2, 0)
            sample_pred = (((y_pred[idx_sample,...] + 1) / 2) * 255).astype(np.uint8).transpose(1, 2, 0)

            axs[0, idx_sample].imshow(sample_img)
            axs[0, idx_sample].axis(False)
            axs[0, idx_sample].set_xticks([])
            axs[0, idx_sample].set_yticks([])

            axs[1, idx_sample].imshow(sample_pred)
            axs[1, idx_sample].axis(False)
            axs[1, idx_sample].set_xticks([])
            axs[1, idx_sample].set_yticks([])

        fig.text(s = 'Original', fontsize=16, x = 0.44, y=0.8)
        fig.text(s = 'AutoEncoder', fontsize=16, x = 0.42, y=0.425)
        plt.tight_layout()
        plt.savefig(output_folder / f'batch_{num_batches}.png')
        plt.close()
        num_batches -= 1
        if num_batches == 0:
            break