# Imports

In [None]:
from source.dataLoader import load_dataset
from source.utils import save_visualizations, save_results_to_csv, plot_all_categories_with_images, create_lr_scheduler
from source.train import train_model
from source.evaluate import evaluate_model, evaluate_ANOViT
from source.losses import mse_loss, ANOViT_loss
from source.models import ADTR, ADTR_FPN, ANOVit
import config
import torch
from torchvision import transforms
from torch.optim import AdamW
import gc
import os
import kagglehub

# Download Datasets (if needed)

In [None]:
if not os.path.exists(config.MVTEC_ROOT) and config.DATASET_TO_USE == 'mvtec':
    dataset_path = kagglehub.dataset_download(config.MVTEC_KAGGLE_DOWNLOAD_URL)
elif not os.path.exists(config.BTAD_ROOT) and config.DATASET_TO_USE == 'btad':
    dataset_path = kagglehub.dataset_download(config.BTAD_KAGGLE_DOWNLOAD_URL)
    dataset_path += "/BTech_Dataset_transformed"

print(f"Dataset loaded")

# Clean cache

In [None]:
def cleanup():
    print("\n--- Cleaning VRAM ---")
    try:
        del model
        del optimizer
        print("Deleted model and optimizer objects.")
    except NameError:
        print("Model and optimizer objects were not found for deletion (already deleted or out of scope).")

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("PyTorch CUDA cache cleared.")

    gc.collect()
    print("Garbage collection triggered.")

    print("\nVRAM cleaning process complete.")

# Setup/Train/Evaluate model

In [None]:
torch.manual_seed(config.RANDOM_SEED)

if config.DOWNLOAD_DATASET:
    DATASET = dataset_path
else:
    DATASET = config.MVTEC_ROOT if config.DATASET_TO_USE == 'mvtec' else config.BTAD_ROOT

if config.MODEL == 'ANOVit':
    transform_train = transforms.Compose([
        transforms.Resize(config.IMG_SIZE),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomRotation(3),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # put in the range [-1, 1]
    ])
    transform_gt = transforms.Compose([
        transforms.Resize(config.IMG_SIZE),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.Resize(config.IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # put in the range [-1, 1]
    ])
else:
    transform_train = transforms.Compose([
        transforms.Resize(config.IMG_SIZE),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomRotation(3),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.ToTensor(),        
    ])
    
    transform_gt = transforms.Compose([
        transforms.Resize(config.IMG_SIZE),
        transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(config.IMG_SIZE),
        transforms.ToTensor(),
    ])

categories = config.MVTEC_CATEGORIES if config.DATASET_TO_USE == 'mvtec' else config.BTAD_CATEGORIES

for category in categories:
    print(f"Processing category: {category}")
    
    # Load the dataset
    train_loader, val_loader, test_loader = load_dataset(
        main_path=DATASET,
        transform_train=transform_train,
        transform_gt=transform_gt,
        transform_test=transform_test,
        class_selected=category
    )

    print("Instantiate model, optimizer, scheduler, and loss function")
    # select model based on configuration
    if config.MODEL == 'ADTR_FPN':
        model = ADTR_FPN.ADTR_FPN(
            in_channels=512*4, 
            out_channels_fpn=512,
            transformer_dim=768
            ).to(config.DEVICE)
        optimizer = AdamW(model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
        num_epochs = config.EPOCHS
        num_train_steps = len(train_loader) * num_epochs
        warmup_steps = len(train_loader)
        scheduler = create_lr_scheduler(optimizer, num_train_steps, warmup_steps)
        criterion = mse_loss
    elif config.MODEL == 'ADTR':
        model = ADTR.ADTR(use_dyt=config.USE_DYT).to(config.DEVICE)
        optimizer = AdamW(model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
        scheduler = None
        criterion = mse_loss
    elif config.MODEL == 'ANOVit':
        model = ANOVit.ANOVit(
            config.D_MODEL,
            img_size=config.IMG_SIZE,
            patch_size=config.PATCH_SIZE,
            n_channels=config.N_CHANNELS,
            n_heads=config.N_HEADS,
            n_layers=config.N_LAYERS,
            use_DyT=config.USE_DYT
        ).to(config.DEVICE)
        optimizer = AdamW(model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
        scheduler = None
        criterion = ANOViT_loss
    else:
        raise ValueError(f"Model {config.MODEL} is not supported.")
    
    if category == categories[0]:
        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Total trainable parameters: {total_params}")

    if config.LOAD_WEIGHTS:
        print(f"Loading weights for {config.MODEL} in category {category}")
        # check if the checkpoint exists
        if os.path.exists(f"{config.CHECKPOINT_DIR}/{config.MODEL}_{category}.pth"):
            model.load_state_dict(torch.load(f"{config.CHECKPOINT_DIR}/{config.MODEL}_{category}.pth"))
        else:
            print(f"Checkpoint for {config.MODEL} in category {category} does not exist. Training from scratch.")

    # Train the model
    if config.TRAIN_MODEL:
        train_model(model, train_loader,val_loader, optimizer, criterion, scheduler=scheduler)
        model_save_path = os.path.join(config.CHECKPOINT_DIR, f"{config.MODEL}_{category}.pth")
        os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
        torch.save(model.state_dict(), model_save_path)

    # Evaluate the model
    if config.MODEL == 'ANOVit':
        image_auroc, pixel_auroc, pixel_aupr, accuracy, f1 = evaluate_ANOViT(model, test_loader, config.DEVICE)
    else:
        image_auroc, pixel_auroc, pixel_aupr, accuracy, f1 = evaluate_model(model, test_loader)
    
    # Save visualizations
    path_images = save_visualizations(model, test_loader, category)
    
    save_results_to_csv(
        category_name=category,
        image_auroc=image_auroc,
        pixel_auroc=pixel_auroc,
        pixel_aupr=pixel_aupr,
        accuracy=accuracy,
        f1=f1,
        path_images=path_images
    )
    
    cleanup()

# Plot Results

In [None]:
IMG_TO_PLOT = config.IMAGE_TO_PLOT_MVTEC if config.DATASET_TO_USE == 'mvtec' else config.IMAGE_TO_PLOT_BTAD
plot_all_categories_with_images(f"{config.RESULT_FOLDER}/{config.MODEL}_results.csv", img_to_plot=IMG_TO_PLOT, save_path=None)

# Delete Dataset

In [None]:
if config.DOWNLOAD_DATASET and config.DELETE_CACHE_DATASET:
    # Delete KaggleHub cache to free up space
    print("\n--- Deleting KaggleHub cache ---")
    %rm -rf ~/.cache/kagglehub