# Introduction
Check the [pdf report](https://github.com/DavideEspositoPelella/Anomaly_Detection-GAN_Based.git/Report) or the [GitHub repository](https://github.com/DavideEspositoPelella/Anomaly_Detection-GAN_Based.git).

The following Colab notebook represent the implementation of the Deep Learning project _**GAN-based Anomaly Detection in Imbalance Problems**_. 

<a href="https://colab.research.google.com/drive/1wHPGH00jiXZUCCsaCI6dreoZdNqTyijK?usp=share_link"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo" height=30px></a>


# Imports
Importing all the libraries and dependencies needed

In [None]:
! pip install --quiet "pytorch-lightning"

In [None]:
import os
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchvision.transforms as T
from torchvision.datasets import FashionMNIST
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
import numpy as np
from sklearn.cluster import KMeans
import random

print("Lightning version:", pl.__version__)
print("PyTorch version:", torch.__version__)
print("Python version:", platform.python_version())


In [None]:
pl.seed_everything(42)

# Dataset and Preprocessing

In [None]:
# Randomly rotate the images by 90, 180, or 270 degrees
class Rotation3:
    def __call__(self, img):
        angle = random.choice([0, 90, 180, 270])
        img = T.functional.rotate(img, angle)
        return img

# DataModule
class AnomalyDetectionDataModule_MNIST(pl.LightningDataModule):
    def __init__(self, batch_size, subsampling, subsamples, normal_class, data_dir="./"):
        super().__init__()

        self.normal_class = normal_class
        self.subsampling = subsampling
        self.subsamples = subsamples

        self.train_transform = T.Compose([
            T.Resize(32),
            T.RandomVerticalFlip(),
            T.RandomHorizontalFlip(),
            T.RandomCrop(30),
            T.Resize(32),
            T.Grayscale(1),  # Use 1 channel (grayscale)
            T.RandomApply([Rotation3()], p=1),
            T.ToTensor(),
            T.Normalize([0.5], [0.5])
        ])
        self.test_transform = T.Compose([
            T.Resize(32),
            T.Grayscale(1),  # Use 1 channel (grayscale)
            T.ToTensor(),
            T.Normalize([0.5], [0.5])
        ])

        self.data_dir = data_dir
        self.dataset_classes = 10
        self.num_classes = 2  # Normal and Anomaly
        self.batch_size = batch_size
        self.num_clusters = self.dataset_classes - 1
        self.subsamples = subsamples
        self.subsampling = subsampling

    # K-means based sampling of anomal classes
    def kmeans_sampling(self, dataset, class_idx, num_samples):
        indices = [i for i in range(len(dataset)) if dataset.targets[i] != class_idx]
        class_data = torch.stack([dataset[i][0] for i in indices])
        kmeans = KMeans(n_clusters=self.num_clusters, random_state=0, n_init=9) #n_init = 9
        kmeans.fit(class_data.view(class_data.size(0), -1).numpy())
        cluster_assignments = kmeans.predict(class_data.view(class_data.size(0), -1).numpy())
        sampled_indices = []
        for cluster_idx in range(self.num_clusters):
            cluster_indices = np.where(cluster_assignments == cluster_idx)[0]
            num_samples_from_cluster = min(num_samples, len(cluster_indices))
            sampled_indices.extend(np.random.choice(cluster_indices, num_samples_from_cluster, replace=False))
        return [indices[idx] for idx in sampled_indices]
    
    def prepare_data(self):
        FashionMNIST(root=self.data_dir, train=True, download=True)

    def setup(self, stage=None):

        if stage == 'fit' or stage is None: # Setup dataset for training phase
            
            all_data = []
            all_labels = []
            normal_data  = []
            anomal_data = []
            normal_class = self.normal_class

            # Load dataset
            train_dataset = FashionMNIST(root=self.data_dir, train=True, transform=self.train_transform)

            # Use 1 class as 'normal' and the other as 'anomal'
            normal_indices = [i for i in range(len(train_dataset)) if train_dataset.targets[i] == normal_class]
            anomal_indices = self.kmeans_sampling(train_dataset, normal_class, num_samples=len(normal_indices))
            
            # Use only a subset of dataset
            if self.subsampling == True:
                normal_indices = normal_indices[:self.subsamples]
                anomal_indices = anomal_indices[:self.subsamples]

            normal_labels = torch.ones(len(normal_indices), dtype=torch.float32)
            anomal_labels = torch.zeros(len(anomal_indices), dtype=torch.float32)

            normal_data.extend([train_dataset[i][0] for i in normal_indices])
            anomal_data.extend([train_dataset[i][0] for i in anomal_indices])

            all_data.extend(normal_data)
            all_data.extend(anomal_data)
            all_labels.extend(normal_labels)
            all_labels.extend(anomal_labels)

            all_data = torch.stack(all_data)
            all_labels = torch.tensor(all_labels)

            # Build the new dataset following requirements
            self.train_dataset = torch.utils.data.TensorDataset(all_data, all_labels)

        if stage == 'test' or stage is None: # Setup dataset for training phase
            
            all_data = []
            all_labels = []
            normal_data  = []
            anomal_data = []
            normal_class = self.normal_class

            # Load test dataset
            test_dataset = FashionMNIST(root=self.data_dir, train=False, transform=self.test_transform)

            normal_indices = [i for i in range(len(test_dataset)) if test_dataset.targets[i] == normal_class]
            anomal_indices = [i for i in range(len(test_dataset)) if test_dataset.targets[i] != normal_class]
            
            if self.subsampling == True:
                normal_indices = normal_indices[:self.subsamples]
                anomal_indices = anomal_indices[:self.subsamples]

            normal_labels = torch.ones(len(normal_indices), dtype=torch.float32)
            anomal_labels = torch.zeros(len(anomal_indices), dtype=torch.float32)

            normal_data.extend([test_dataset[i][0] for i in normal_indices])
            anomal_data.extend([test_dataset[i][0] for i in anomal_indices])

            all_data.extend(normal_data)
            all_data.extend(anomal_data)
            all_labels.extend(normal_labels)
            all_labels.extend(anomal_labels)

            all_data = torch.stack(all_data)
            all_labels = torch.tensor(all_labels)

            # Build the new dataset 
            self.test_dataset = torch.utils.data.TensorDataset(all_data, all_labels)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.anomaly_data, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)


# Models
Utilities and models implemented

## GAN-Based

In [None]:

class ENCODER(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, stride=2, padding=1)  #out 16x16
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1) #out 8x8
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1) #out 4x4
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1) #out 2x2
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        return x

class DECODER(pl.LightningModule):
    def __init__(self):
        super().__init__()   

        self.tr_conv1 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1) #out 4x4
        self.tr_conv2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1) #out 8x8
        self.tr_conv3 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1) #out 16x16
        self.tr_conv4 = nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=4, stride=2, padding=1) #out 32x32
    
    def forward(self, x):
        x = F.relu(self.tr_conv1(x))
        x = F.relu(self.tr_conv2(x))
        x = F.relu(self.tr_conv3(x))
        x = F.tanh(self.tr_conv4(x))
        return x


class GENERATOR(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.encoder = ENCODER()
        self.decoder = DECODER()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


class DISCRIMINATOR(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, stride=2, padding=1)  #out 32x16x16
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1) #out 64x8x8
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1) #out 128x4x4
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=4, stride=1, padding=1) #out 32x3x3
        self.conv5 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=4, stride=1, padding=1) #out 1x2x2

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        return x
    

class GAN_SOTA(pl.LightningModule):
    def __init__(self, latent_dim=100, lr=0.0002, display=False, b1=0.5, b2=0.999):
        super().__init__()
        self.lr = lr
        self.display = display

        self.save_hyperparameters()
        self.G = GENERATOR()
        self.Enc = self.G.encoder
        self.D_norm = DISCRIMINATOR()
        self.D_anom = DISCRIMINATOR()

        self.automatic_optimization = False
        self.test_step_outputs = [[], []]  # two dataloaders

    
    def forward(self, x):
        return self.G(x)
    
    def print_generator_parameters(self):
        for name, param in self.G.named_parameters():
            print(f"Generator Parameter Name: {name}")
            print(f"Parameter Value:\n{param}")

    def patch_loss(self, X, generated_X, n=3):
        # Calculate the L1 loss for each patch
        patch_errors = torch.abs(X - generated_X)
        patch_errors_reshaped = patch_errors.view(patch_errors.size(0), patch_errors.size(1), -1)
        mean_patch_errors = patch_errors_reshaped.mean(dim=-1)
        # Select the top n patches
        sorted_patch_indices = torch.argsort(mean_patch_errors, dim=-1, descending=True)
        top_patch_indices = sorted_patch_indices[:, :n]
        # Calculate the average of the top n patch errors
        top_patch_errors = torch.gather(mean_patch_errors, dim=-1, index=top_patch_indices)
        avg_top_patch_errors = top_patch_errors.mean()
        return avg_top_patch_errors
    
    def show_batch_image(self, batch):
        X, Y = batch
        image = X[0].detach().cpu().numpy()
        plt.figure(figsize=(5, 5))
        plt.title("Label: {}".format(Y[0]))
        plt.imshow(image[0], cmap='gray')  # Assuming the input image is single-channel (gray)
        plt.axis('off')
        plt.show()


    #The generator is trained to output 1 from normal data and 0 from anomaly data
    def training_step(self, batch, batch_idx):
        g_opt, d_norm_opt, d_anom_opt = self.optimizers()

        X, Y = batch
        batch_size = X.shape[0]

        real_label = torch.ones((batch_size, 1, 2, 2), device=self.device)
        fake_label = torch.zeros((batch_size, 1, 2, 2), device=self.device)

        errD_anomal = 0
        errD_normal = 0

        if (Y.squeeze() == 1):
            generated_X = self.G(X)
            discriminated_X = self.D_norm(X)
            discriminated_G = self.D_norm(generated_X)
            encoded_X = self.G.encoder(X)

            #------------------------#
            # Optimize Discriminator #
            d_norm_opt.zero_grad()

            # NORM ADV LOSS
            Norm_adv_loss = (((discriminated_X - real_label)**2) + ((discriminated_G - fake_label)**2))

            errD_normal = (Norm_adv_loss.mean())

            errD = errD_normal
            self.manual_backward(errD_normal.mean())
            d_norm_opt.step()

            #--------------------#
            # Optimize Generator #
            g_opt.zero_grad()

            generated_X = self.G(X)
            discriminated_X = self.D_norm(X)
            discriminated_G = self.D_norm(generated_X)
            encoded_X = self.G.encoder(X)

            # L1 RECONSTRUCTION ERROR
            l1_loss = F.l1_loss(X, generated_X)

            # PATCH L1 LOSS
            patch_loss = self.patch_loss(X, generated_X)

            # LATENT VECTOR LOSS
            latent_loss = F.l1_loss(encoded_X, self.G.encoder(generated_X))

            # ADVERSARIAL LOSS
            norm_adv_loss = ((discriminated_G - real_label)**2)

            errG_normal = patch_loss*(1.5) + (norm_adv_loss)*(0.5) + (latent_loss)*(0.5) + (l1_loss)*(1.5)
            errG = errG_normal.mean()
            self.manual_backward(errG_normal.mean())
            g_opt.step()

            self.log_dict({"g_loss": torch.tensor(errG).mean(), "d_normal_loss": torch.tensor(errD).mean()}, prog_bar=True)

        else:
            generated_X = self.G(X)
            discriminated_X = self.D_anom(X)
            discriminated_G = self.D_anom(generated_X)
        
            # Optimize Discriminator #
            d_anom_opt.zero_grad()

            # ANOM ADV LOSS
            Anom_adv_loss = ((((discriminated_X - real_label)**2) + ((discriminated_G - fake_label)**2)))
            
            errD_anomal = (Anom_adv_loss.mean())

            errD = errD_anomal
            self.manual_backward(errD_anomal.mean())
            d_anom_opt.step()

            # Optimize Generator #
            g_opt.zero_grad()
            generated_X = self.G(X)
            discriminated_X = self.D_anom(X)
            discriminated_G = self.D_anom(generated_X)
            encoded_X = self.G.encoder(X)

            # ANOM ADVERS LOSS
            anom_adv_loss = ((discriminated_G - fake_label)**2)

            # ABC LOSS
            abc_loss = -torch.log(1 - torch.exp(-F.l1_loss(generated_X, X)))
           
            errG_anomal = anom_adv_loss*(1) + (abc_loss)*(0.5)
            errG = errG_anomal.mean()
            self.manual_backward(errG_anomal.mean())
            g_opt.step()

            self.log_dict({"g_loss": errG, "d_anomal_loss": errD}, prog_bar=True)
            
        return {"g_loss": errG, "d_norm_loss": errD_normal, "d_anom_loss": errD_anomal}
    
    def test_step(self, batch, batch_idx):
        X, Y = batch

        generated_X = self.G(X)
        
        # Calculate L1 reconstruction error instead of MSE
        reconstruction_error = F.mse_loss(generated_X, X, reduction='none')
        reconstruction_error = reconstruction_error.view(reconstruction_error.size(0), -1).mean(dim=1)  # Calculate mean error over pixels
        self.test_step_outputs[0].append(reconstruction_error)
        self.test_step_outputs[1].append(Y)

        if self.display:
            print("generated_X:", generated_X)
            print("reconstruction_error:", reconstruction_error)
            self.show_batch_image(batch)
        
        return {"reconstruction_error": reconstruction_error, "true_labels": Y}

    def on_test_epoch_end(self):
        all_reconstruction_errors = torch.cat(self.test_step_outputs[0])
        all_true_labels = torch.cat(self.test_step_outputs[1])

        anomaly_threshold = 0.5
        predicted_labels = (all_reconstruction_errors < anomaly_threshold).float().cpu().numpy()
        true_labels = all_true_labels.squeeze().float().cpu().numpy()
        
        # Calculate AUROC
        auroc = roc_auc_score(true_labels, predicted_labels)

        self.log("auroc", (auroc), prog_bar=True)

        # Clear the test step outputs after processing them
        self.test_step_outputs = [[], []]

    def on_epoch_end(self):
        self.log('g_loss_epoch', self.trainer.logged_metrics['g_loss'].mean(), prog_bar=True)
        self.log('d_anomal_loss_epoch', self.trainer.logged_metrics['d_anomal_loss'].mean(), prog_bar=True)
        self.log('d_normal_loss_epoch', self.trainer.logged_metrics['d_normal_loss'].mean(), prog_bar=True)


    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.G.parameters(), lr=self.lr, betas=(b1, b2))
        opt_d_norm = torch.optim.Adam(self.D_norm.parameters(), lr=self.lr, betas=(b1, b2))
        opt_d_anom = torch.optim.Adam(self.D_anom.parameters(), lr=self.lr, betas=(b1, b2))
        return opt_g, opt_d_norm, opt_d_anom


## Random Guessing

In [None]:
class RandomGuessing:
    def __init__(self):
        self.predictions = []
        self.true_labels = []

    def predict(self, dataloader):
        self.predictions = []
        self.true_labels = []

        for batch in dataloader:
            inputs, labels = batch
            random_scores = torch.rand(inputs.size(0))
            
            self.predictions.append(random_scores)
            self.true_labels.append(labels)

        self.predictions = torch.cat(self.predictions).cpu().numpy()
        self.true_labels = torch.cat(self.true_labels).cpu().numpy()

        return self.predictions

    def compute_auroc(self):
        # Calculate AUROC
        auroc = roc_auc_score(self.true_labels, self.predictions)
        return auroc

## Autoencoder

In [None]:

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.enc = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU()
        )
        self.dec = nn.Sequential(
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.ReLU()
        )

    def forward(self, x):
        encode = self.enc(x)
        decode = self.dec(encode)
        return decode

class AE_based(pl.LightningModule):
    def __init__(self, latent_dim=100, lr=0.0002, b1=0.5, b2=0.999):
        super(AE_based, self).__init__()
        self.lr = lr
        self.save_hyperparameters()
        self.AE = Autoencoder()

    def forward(self, x):
        return self.AE(x)
        
    def training_step(self, batch, batch_idx):
        X, Y = batch
        batch_size = X.shape[0]

        generated_X = self.AE(X)

        loss = F.mse_loss(X, generated_X).mean()

        self.log_dict({"train loss": loss}, prog_bar=True)

    def test_step(self, batch, batch_idx):
        X, Y = batch
        batch_size = X.shape[0]

        generated_X = self.AE(X)

        loss = F.mse_loss(X, generated_X).mean()

        self.log_dict({"test loss": loss}, prog_bar=True)

    def configure_optimizers(self):
        lr = self.lr
        b1 = self.b1
        b2 = self.b2
        opt = torch.optim.AdamW(self.parameters(), lr=lr, betas=(b1, b2))


# Train
Here follow the training sections of the models. In order to properly train a model, execute in the following order the respective cells:
1) Import
2) Dataset 
3) Model

## GAN

In [None]:
# Utilities
checkpoint_path = './GAN/checkpoints'
os.makedirs(checkpoint_path, exist_ok=True)

checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_path,  
    filename="GAN_{epoch}_ckpt",  
    monitor='auc',  
    mode='min',     
    save_last=True, 
    save_top_k=-1,  
)


In [None]:
# Hyperparameters
BATCH_SIZE = 1
epochs = 15
learning_rate = 0.0001
subsampling = True
subsamples = 1000
accelerator_enabled = True
display = False

In [None]:
# Initialize dataset
data_module = AnomalyDetectionDataModule(batch_size=BATCH_SIZE, 
                                         subsampling=subsampling, 
                                         subsamples=subsamples,
                                         normal_class=0)


In [None]:
# Initialize models and trainers
if torch.cuda.is_available() and accelerator_enabled: 
  print("GPU")
  model = GAN_SOTA(BATCH_SIZE, lr=learning_rate, display=display).to("cuda")
  trainer = pl.Trainer(accelerator="cuda",
                       max_epochs=epochs,
                       callbacks=[TQDMProgressBar(),
                                  checkpoint_callback]
                       )
else : 
  print("CPU")
  model = GAN_SOTA(BATCH_SIZE, lr=learning_rate, display=display).to("cpu")
  trainer = pl.Trainer(accelerator="cpu",
                       max_epochs=epochs, 
                       callbacks=[TQDMProgressBar(),
                                  checkpoint_callback]
                       )


In [None]:
# Train and save
trainer.fit(model, data_module)
trainer.save_checkpoint("/GAN/checkpoints/GAN.ckpt")
torch.save(model.state_dict(), "./GAN/GAN_{samples}samples_{epochs}ep")

# Testing
test_results = trainer.test(model, data_module)
print(test_results)

## Random guessing

In [None]:
# Create an instance of the AnomalyDetectionDataModule
data_module = AnomalyDetectionDataModule(batch_size=64, subsampling=True, subsamples=1000)

# Setup the data module (prepare datasets)
data_module.setup()

# Create dataloaders
train_dataloader = data_module.train_dataloader()
test_dataloader = data_module.test_dataloader()

random_guesser = RandomGuessing()
test_predictions = random_guesser.predict(test_dataloader)
print(test_predictions)

## Autoencoder

In [None]:
# Utilities
checkpoint_path = './AE/checkpoints'
os.makedirs(checkpoint_path, exist_ok=True)

checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_path,  
    filename="AE_{epoch}_ckpt",  
    monitor='auc',  
    mode='min',     
    save_last=True, 
    save_top_k=-1,  
)


In [None]:
# Hyperparameters
BATCH_SIZE = 1
epochs = 15
learning_rate = 0.0001
subsampling = True
subsamples = 1000
accelerator_enabled = True
display = False

In [None]:
# Initialize dataset
data_module = AnomalyDetectionDataModule(batch_size=BATCH_SIZE, 
                                         subsampling=subsampling, 
                                         subsamples=subsamples,
                                         normal_class=0)

In [None]:
# Initialize model and trainer

if torch.cuda.is_available() and accelerator_enabled: 
  print("GPU")
  model = AE_based(BATCH_SIZE, lr=learning_rate, display=display).to("cuda")
  trainer = pl.Trainer(accelerator="cuda",
                       max_epochs=epochs,
                       callbacks=[TQDMProgressBar(),
                                  checkpoint_callback]
                       )
else : 
  print("CPU")
  model = AE_based(BATCH_SIZE, lr=learning_rate, display=display).to("cpu")
  trainer = pl.Trainer(accelerator="cpu",
                       max_epochs=epochs, 
                       callbacks=[TQDMProgressBar(),
                                  checkpoint_callback]
                       )


In [None]:
# Trainer
trainer.fit(model, datamodule=data_module)
trainer.save_checkpoint("/AE/checkpoints/GAN.ckpt")
torch.save(model.state_dict(), "./AE/AE_{samples}samples_{epochs}ep")

# Testing
test_results = trainer.test(model, data_module)
print(test_results)