# Efficient Anomaly Detection in Industrial Images using Transformers with Dynamic Tanh

## Imports

In [None]:
import time
from typing import Literal

from lightning.fabric.utilities.throughput import measure_flops
import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger
from thop import profile
import torch
from torch.utils.data import DataLoader

from modules import data, globals, vtae

## Globals

In [None]:
# Dataset
dataset: Literal['MVTech', 'BTAD'] = 'MVTech'
product: data.ProductType = 'bottle'
resize_dim: int = 550
crop_dim: int = 512

# Model
patch_side: int = 64
latent_channels: int = 8
heads: int = 8
depth: int = 2
caps_per_patch: int = 32
caps_dim: int = 8
caps_iterations: int = 1
ff_dim: int = 512
mdn_components: int = 50
noise: float = 0.2
loss_weights: tuple[float, float, float] = (5., 0.5, 1.)
lr: float = 1e-4
weight_decay: float = 1e-4
use_dytanh: bool = True

# Training
epochs: int = 400
batch_size: int = 8

In [None]:
# Run name
run_name: str = 'vtae_dytanh' if use_dytanh else 'vtae_ln'

In [None]:
pl.seed_everything(42, workers = True)

## Data

In [None]:
train_loader: DataLoader[tuple[torch.Tensor]]
val_loader: DataLoader[tuple[torch.Tensor]]
test_loader: DataLoader[tuple[torch.Tensor, torch.Tensor]]

train_loader, val_loader, test_loader = data.get_loaders(dataset,
                                                         product,
                                                         crop_dim = (crop_dim, crop_dim),
                                                         resize_dim = (resize_dim, resize_dim),
                                                         batch_size = batch_size
                                                         )

## Network

In [None]:
model: vtae.VTAE = vtae.VTAE(image_shape = (3, crop_dim, crop_dim),
                             patch_shape = (patch_side, patch_side),
                             latent_channels = latent_channels,
                             heads = heads,
                             depth = depth,
                             caps_per_patch = caps_per_patch,
                             caps_dim = caps_dim,
                             caps_iterations = caps_iterations,
                             ff_dim = ff_dim,
                             mdn_components = mdn_components,
                             noise = noise,
                             loss_weights = loss_weights,
                             lr = lr,
                             weight_decay = weight_decay,
                             use_dytanh = use_dytanh
                             )

## Train

In [None]:
# Logger
logger: TensorBoardLogger = TensorBoardLogger(globals.LOG_DIR / dataset / product, name = run_name)

# Train the model
trainer: pl.Trainer = pl.Trainer(max_epochs = epochs,
                                 precision = 'bf16-mixed',
                                 log_every_n_steps = len(train_loader),
                                 logger = logger,
                                 enable_checkpointing = False   # In order to not alter the times, save only at the end
                                 )
trainer.fit(model, train_loader, val_loader)

In [None]:
# Save the model
trainer.save_checkpoint(globals.CHECKPOINT_DIR / dataset / product / f'{run_name}.ckpt')

## Evaluation

In [None]:
# Reload the model with the best threshold
threshold: float = 0.5  # placeholder
model = vtae.VTAE.load_from_checkpoint(globals.CHECKPOINT_DIR / dataset / product / f'{run_name}.ckpt', threshold = threshold)

In [None]:
# Test the model
trainer = pl.Trainer(logger = False)
trainer.test(model, test_loader)

In [None]:
# Get the inference time

# Prepare for inference
model.eval()
x: torch.Tensor = next(iter(test_loader))[0]

# Cuda setup
if torch.cuda.is_available():
    model.cuda()
    x = x.cuda()

with torch.no_grad():
    # Warm-up
    for _ in range(5):
        model(x)

    # Measure inference time
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    start_time: float = time.time()

    for _ in range(100):
        model(x)

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    total_time: float = time.time() - start_time

print(f"Mean inference time: {total_time * 10:.2f} ms")

In [None]:
# Get the flops
flops: float = measure_flops(model, lambda : model(x))
print(f"GFLOPS: {flops / 1e9:.2f}")

# Get the MACs and parameters
macs: int
params: int
macs, params = profile(model, inputs = (x,), verbose = False)   # type: ignore
print(f"GMACs: {macs / 1e9:.2f}")
print(f"Parameters: {params/1e6:.2f} M")