In [1]:
import torch
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, ConcatDataset
from sklearn.model_selection import train_test_split
from torchsummary import summary
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

import os
from torchvision.transforms import ToPILImage
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import wandb
from kaggle_secrets import UserSecretsClient
import pytorch_lightning as pl
from dataclasses import dataclass
from typing import Tuple
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

In [2]:
@dataclass
class Configure:
    """
    Args:
        data_dir : The path to the directory where the MNIST dataset is stored. Defaults to the value of
            the 'PATH_DATASETS' environment variable or '.' if not set.

        batch_size : The batch size to use during training. Defaults to 256 if a GPU is available,
            or 64 otherwise.

        max_epochs : The maximum number of epochs to train the model for. Defaults to 3.

        accelerator : The accelerator to use for training. Can be one of "cpu", "gpu", "tpu", "ipu", "auto".

        devices : The number of devices to use for training. Defaults to 1.
    """

    data_dir : str = '/kaggle/input/mvtec-ad/bottle/train'
    # writing_dir : str = '/kaggle/working'
    save_dir : str = '/kaggle/working/augmented_data'
    test_dir : str = '/kaggle/input/mvtec-ad/bottle/test'
    batch_size : int = 32 if torch.cuda.is_available() else 8
    max_epochs : int = 51
    accelerator : str = 'auto'
    devices : int = 1

config = Configure()

In [10]:
class AnomalyDetector(pl.LightningModule):

    def __init__(self, data_dir : str = config.data_dir, # writing_dir : str = config.writing_dir, 
                 save_dir : str = config.save_dir,
                 test_dir : str = config.test_dir, accelerator : str = config.accelerator, 
                 learning_rate : float = 0.01, alpha : float = 0.55, beta : float = 0.0, 
                 gamma : float = 0.45, validation_size : float = 0.6, 
                 batch_size : int = config.batch_size, num_workers : int = 2, 
                 weight_decay : float = 0.0, threshold_percentile : int = 95):

        super().__init__()

        self.data_dir = data_dir   
        # self.writing_dir = writing_dir  
        self.save_dir = save_dir
        self.test_dir = test_dir
        self.learning_rate = learning_rate
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.val_size = validation_size
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.weight_decay = weight_decay
        self.threshold_percentile = threshold_percentile
        if accelerator == 'auto':
            self.my_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.my_device = torch.device(accelerator)

        self.val_accuracy = Accuracy(task='multiclass', num_classes=2)
        self.test_accuracy = Accuracy(task='multiclass', num_classes=2)
        self.pixel_errors_list = []
        self.pixel_threshold = 0
        self.example_input_image = None
        self.example_recon_image = None

        self.transform = torchvision.transforms.Compose([
            transforms.Resize((256,256)),
            transforms.ToTensor()
        ])

        self.augment_transform = torchvision.transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        ])
        
        self.Encoder = nn.Sequential(
            
            nn.Conv2d(3, 8, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),   
            nn.Dropout(p=.3),                        # (3x256x256) -> (8x128x128)

            nn.Conv2d(8, 16, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(p=.3),                       # (8x128x128) -> (16x64x64)

            nn.Conv2d(16, 32, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),   
            nn.Dropout(p=.3),                      # (16x64x64) -> (32x32x32)

            nn.Conv2d(32, 32, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(p=.3),                      # (32x32x32) -> (32x16x16)
            
            nn.Flatten()
        )
    
        self.Bottleneck = nn.Sequential(
            
            nn.Linear(32*16*16, 1024),
            nn.ReLU(),
            # nn.Dropout(p=.3),
            
            nn.Linear(1024, 1024),
            nn.ReLU(),
            # nn.Dropout(p=.3),
            
            nn.Linear(1024, 32*16*16),
            nn.ReLU(),
            # nn.Dropout(p=.3),
            
            nn.Unflatten(-1,(32,16,16))
        )

        self.Decoder = nn.Sequential(
            
            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1, output_padding=0),
            nn.ReLU(),
            # nn.Dropout(p=.3),       # (32x16x16) -> (32x32x32)
            
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, output_padding=0),
            nn.ReLU(),
            # nn.Dropout(p=.3),       # (32x32x32) -> (16x64x64)

            nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1, output_padding=0),
            nn.ReLU(),
            # nn.Dropout(p=.3),       # (16x64x64) -> (8x128x128)

            nn.ConvTranspose2d(8, 3, kernel_size=4, stride=2, padding=1, output_padding=0),
            nn.Sigmoid()
            # nn.ReLU(),
            # nn.Dropout(p=.3)        # (8x128x128) -> (3x256x256)
        )

    def forward(self, x : torch.tensor) -> torch.tensor:
        x = self.Encoder(x)
        x = self.Bottleneck(x)
        x = self.Decoder(x)
        return x

    def data_augmentation(self) -> None:
        num_augs = 5  # Number of augmented versions per image
        dataset = datasets.ImageFolder(root=self.data_dir)
        # Save augmentations
        for idx, (img, label) in tqdm(enumerate(dataset), total=len(dataset)):
            class_name = dataset.classes[label]
            input_path = dataset.imgs[idx][0]
            # Output class folder
            class_out_dir = os.path.join(self.save_dir, class_name)
            os.makedirs(class_out_dir, exist_ok=True)

            for i in range(num_augs):
               aug_img = self.augment_transform(img)
               aug_img = ToPILImage()(transforms.ToTensor()(aug_img))  # Reconvert safely to PIL
               base_name = os.path.basename(input_path).split('.')[0]
               save_path = os.path.join(class_out_dir, f"{base_name}_aug{i}.jpg")
               aug_img.save(save_path)

    def compute_loss(self, reconstruction, actual, reduction : str = 'none'):
        
        batch_size = actual.size(0)
        ssim_loss_fn = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.my_device)
        ssim_vals = torch.zeros(batch_size, device = self.my_device)

        for i in range(batch_size):
            ssim_vals[i] = 1 - ssim_loss_fn(reconstruction[i].unsqueeze(0), actual[i].unsqueeze(0))
            
        mse_vals = torch.mean(nn.functional.mse_loss(reconstruction, actual, reduction='none'), dim=[1,2,3])
        l1_vals = torch.mean(nn.functional.l1_loss(reconstruction, actual, reduction='none'), dim=[1,2,3])
        
        total = self.alpha * ssim_vals + self.beta * mse_vals + self.gamma * l1_vals
        
        if reduction == 'mean':
            return total.mean()
        return total


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
        scheduler = {
            'scheduler' : torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5, factor=0.5),
            'monitor' : 'Validation accuracy',
            'interval' : 'epoch',
            'frequency' : 1
        }
        return {'optimizer':optimizer, 'lr_scheduler':scheduler}

    def prepare_data(self) -> None:
        torchvision.datasets.ImageFolder(root = self.data_dir, transform = self.transform)
        torchvision.datasets.ImageFolder(root = self.save_dir, transform = self.transform)
        torchvision.datasets.ImageFolder(root = self.test_dir, transform=self.transform)

    def setup(self, stage : str = None) -> None:

        raw_dataset = torchvision.datasets.ImageFolder(root = self.data_dir, transform = self.transform)
        augment_dataset = torchvision.datasets.ImageFolder(root = self.save_dir, transform = self.transform)
        self.train_dataset = ConcatDataset([raw_dataset, augment_dataset])

        label_remap = lambda y: 0 if y < 3 else 1
        dataset = torchvision.datasets.ImageFolder(root=self.test_dir, transform=self.transform, target_transform=label_remap)
        self.test_dataset, self.val_dataset = train_test_split(dataset, test_size=self.val_size)

    def training_step(self, batch : Tuple[torch.tensor, torch.tensor], batch_idx : int) -> torch.tensor:
        actual, _ = batch
        reconstructed = self(actual)

        if batch_idx==0 and self.example_input_image == None:
            self.example_input_image = actual[0].detach().cpu()
            self.example_recon_image = reconstructed[0].detach().cpu()
        
        loss = self.compute_loss(reconstructed, actual, reduction='mean')
        self.log('Training loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch : Tuple[torch.tensor, torch.tensor], batch_idx : int) -> None:
        actual, y = batch
        reconstructed = self(actual)
        pixel_errors = (reconstructed - actual) ** 2
        self.pixel_errors_list.append(pixel_errors.cpu())
        loss = self.compute_loss(reconstructed, actual)
        flags = (loss < self.pixel_threshold).int()
        self.val_accuracy.update(flags, y)
        self.log('Validation accuracy', self.val_accuracy, prog_bar=True)

    def test_step(self, batch : Tuple[torch.tensor, torch.tensor], batch_idx : int) -> None:
        actual, y = batch
        reconstructed = self(actual)
        loss = self.compute_loss(reconstructed, actual, reduction='none')
        flags = (loss < self.pixel_threshold).int()
        self.test_accuracy.update(flags, y)

    def on_train_epoch_start(self):
        self.example_input_image = None
        self.example_recon_image = None
    
    def on_train_epoch_end(self):
        if self.example_input_image is not None:

            input_img_np = self.example_input_image.permute(1,2,0).numpy()  # (C,H,W) -> (H,W,C)
            recon_img_np = self.example_recon_image.permute(1,2,0).numpy() 

            input_img = wandb.Image(input_img_np, caption='Input Image')
            recon_img = wandb.Image(recon_img_np, caption='Reconstructed Image')

            self.logger.experiment.log({
                'Input vs Reconstruction' : [input_img, recon_img],
                'Epoch' : self.current_epoch
            })
    
    def on_validation_epoch_end(self):
        all_pixel_errors = torch.cat(self.pixel_errors_list, dim=0)
        all_pixel_errors = all_pixel_errors.numpy().flatten()

        self.pixel_threshold = np.percentile(all_pixel_errors, self.threshold_percentile)
        self.pixel_errors_list.clear()
    
    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)

In [12]:
user_secrets = UserSecretsClient()
secret_value = user_secrets.get_secret("wandb_api_key")
os.environ["WANDB_API_KEY"] = secret_value
wandb.login()

True

In [13]:
early_stop_callback = EarlyStopping(
    monitor = 'Validation accuracy',
    min_delta = 0.01,
    patience = 5,
    verbose = 1,
    mode = 'max'
)

checkpoint_callback = ModelCheckpoint(dirpath = '/kaggle/working/checkpoints', save_top_k=1, monitor = 'Validation accuracy', mode = 'max')

In [14]:
wandb_logger = WandbLogger(project="Anomaly Detector")
model = AnomalyDetector()
trainer = pl.Trainer(
        accelerator=config.accelerator,
        devices=config.devices,
        max_epochs=config.max_epochs,
        logger=wandb_logger,
        callbacks = [early_stop_callback, checkpoint_callback]
)

In [7]:
model.data_augmentation()

100%|██████████| 209/209 [01:39<00:00,  2.10it/s]


In [15]:
trainer.fit(model)

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /kaggle/working/checkpoints exists and is not empty.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (40) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [16]:
model.pixel_threshold

0.037598684430122375

In [17]:
trainer.test(model, dataloaders=model.val_dataloader(), verbose=1)

Testing: |          | 0/? [00:00<?, ?it/s]

[{}]

In [18]:
model.test_accuracy.compute()

tensor(0.8800)

In [19]:
wandb.finish()

0,1
Epoch,▁▂▄▅▇█
Training loss,█▃▁▁
Validation accuracy,█▁▁███
epoch,▁▂▂▄▄▅▅▇▇█
trainer/global_step,▁▁▁▂▂▃▄▄▅▅▅▇▇▇██

0,1
Epoch,5.0
Training loss,0.15551
Validation accuracy,0.88
epoch,5.0
trainer/global_step,239.0
