# Efficient Anomaly Detection in Industrial Images using Transformers with Dynamic Tanh

## Imports

In [1]:
from typing import Literal

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
import torch
from torch.utils.data import DataLoader

from modules import data, globals, vtae

## Globals

In [2]:
# Dataset
dataset: Literal['MVTech', 'BTAD'] = 'MVTech'
product: data.ProductType = 'hazelnut'
resize_dim: int = 550
crop_dim: int = 512

# Model
patch_side: int = 64
latent_channels: int = 8
heads: int = 8
depth: int = 6
caps_per_patch: int = 64
caps_dim: int = 8
caps_iterations: int = 3
ff_dim: int = 1024
mdn_components: int = 150
noise: float = 0.2
lr: float = 1e-4
weight_decay: float = 1e-4

# Training
epochs: int = 400
batch_size: int = 8

## Data

In [3]:
train_loader: DataLoader[tuple[torch.Tensor]]
val_loader: DataLoader[tuple[torch.Tensor]]
test_loader: DataLoader[tuple[torch.Tensor, torch.Tensor]]

train_loader, val_loader, test_loader = data.get_loaders('MVTech',
                                                         'hazelnut',
                                                         crop_dim = (crop_dim, crop_dim),
                                                         resize_dim = (resize_dim, resize_dim),
                                                         batch_size = batch_size
                                                         )

## Network

In [4]:
model: vtae.VTAE = vtae.VTAE(image_shape = (3, crop_dim, crop_dim),
                             patch_shape = (patch_side, patch_side),
                             latent_channels = latent_channels,
                             heads = heads,
                             depth = depth,
                             caps_per_patch = caps_per_patch,
                             caps_dim = caps_dim,
                             caps_iterations = caps_iterations,
                             ff_dim = ff_dim,
                             mdn_components = mdn_components,
                             noise = noise,
                             lr = lr,
                             weight_decay = weight_decay
                             )



## Train

In [5]:
# Callbacks
early_stopping: EarlyStopping = EarlyStopping(monitor = 'val_combined_loss', patience = 4)
model_checkpoint: ModelCheckpoint = ModelCheckpoint(monitor = 'val_combined_loss',
                                                    dirpath = globals.CHECKPOINT_DIR / dataset / product,
                                                    save_top_k = 1
                                                    )

# Logger
logger: TensorBoardLogger = TensorBoardLogger(globals.LOG_DIR, name = f"{dataset}_{product}")

# Train the model
trainer: Trainer = Trainer(max_epochs = epochs,
                           callbacks = [early_stopping, model_checkpoint],
                           precision = '16-mixed',
                           check_val_every_n_epoch = 5,
                           log_every_n_steps = len(train_loader),
                           logger = logger
                           )
trainer.fit(model, train_loader, val_loader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('AMD Radeon RX 6800') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-06-15 18:59:47.535593: E external/local_xla/xla/stream_executor/plugin_registry.cc:91] Invalid plugin kind specified: FFT
2025-06-15 18:59:47.574598: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-06-15 18:59:47

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/valerio/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
/home/valerio/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

## Evaluation

In [None]:
trainer.test(model, test_loader)