# Efficient Anomaly Detection in Industrial Images using Transformers with Dynamic Tanh

## Imports

In [None]:
from typing import Literal

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
import torch
from torch.utils.data import DataLoader

from modules import data, globals, vtae

## Globals

In [None]:
# Dataset
dataset: Literal['MVTech', 'BTAD'] = 'MVTech'
product: data.ProductType = 'bottle'
resize_dim: int = 550
crop_dim: int = 512

# Model
patch_side: int = 64
latent_channels: int = 8
heads: int = 8
depth: int = 6
caps_per_patch: int = 64
caps_dim: int = 8
caps_iterations: int = 3
ff_dim: int = 1024
mdn_components: int = 150
noise: float = 0.2
loss_weights: tuple[float, float, float] = (5., 0.5, 1.)
lr: float = 1e-4
weight_decay: float = 1e-4
use_dytanh: bool = False

# Training
epochs: int = 400
batch_size: int = 8

## Data

In [None]:
train_loader: DataLoader[tuple[torch.Tensor]]
val_loader: DataLoader[tuple[torch.Tensor]]
test_loader: DataLoader[tuple[torch.Tensor, torch.Tensor]]

train_loader, val_loader, test_loader = data.get_loaders(dataset,
                                                         product,
                                                         crop_dim = (crop_dim, crop_dim),
                                                         resize_dim = (resize_dim, resize_dim),
                                                         batch_size = batch_size
                                                         )

## Network

In [None]:
model: vtae.VTAE = vtae.VTAE(image_shape = (3, crop_dim, crop_dim),
                             patch_shape = (patch_side, patch_side),
                             latent_channels = latent_channels,
                             heads = heads,
                             depth = depth,
                             caps_per_patch = caps_per_patch,
                             caps_dim = caps_dim,
                             caps_iterations = caps_iterations,
                             ff_dim = ff_dim,
                             mdn_components = mdn_components,
                             noise = noise,
                             loss_weights = loss_weights,
                             lr = lr,
                             weight_decay = weight_decay,
                             use_dytanh = use_dytanh
                             )

## Train

In [None]:
# Callbacks
early_stopping: EarlyStopping = EarlyStopping(monitor = 'val_combined_loss', patience = 10)
model_checkpoint: ModelCheckpoint = ModelCheckpoint(monitor = 'val_combined_loss',
                                                    dirpath = globals.CHECKPOINT_DIR / dataset / product,
                                                    save_top_k = 1
                                                    )

# Logger
norm_str: str = "dytanh" if use_dytanh else "layernorm"
logger: TensorBoardLogger = TensorBoardLogger(globals.LOG_DIR, name = f"{dataset}_{product}_{norm_str}")

# Train the model
trainer: Trainer = Trainer(max_epochs = epochs,
                           callbacks = [early_stopping, model_checkpoint],
                           precision = '16-mixed',
                           log_every_n_steps = len(train_loader),
                           logger = logger
                           )
trainer.fit(model, train_loader, val_loader)

## Evaluation

In [None]:
trainer.test(model, test_loader)