# DiYT: Do it Yourself Transformer

### A visual transformer by Alessandro Massari and Matteo Pelliccione developed for Vision & Perceptron exam 2024-2025 - MSc AI and Robotics - Sapienza Università di Roma


Complete description of the project could be find at the following GitHub: 
https://github.com/alessandromassari/vision-project-diyt

## Import

In [None]:
%cd /kaggle/working/
%rm  -r /kaggle/working/

In [None]:
# Install packages
!pip install pytorch-msssim 
!pip install thop

# IMPORT libraries
import kagglehub
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import random
import sys
import shutil
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from concurrent.futures import ThreadPoolExecutor as ThPE
from sklearn.model_selection import train_test_split
from torch.nn import Parameter
from torch import ones, zeros, tanh
from einops import repeat
import math
from pytorch_msssim import ssim
from sklearn.metrics import f1_score, roc_auc_score, roc_curve, precision_recall_curve
import seaborn as sns
from thop import profile

In [None]:
# Check if GPU is available and return the device
def hardware_check(): 
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"GPU is available!")
        print(f"  -> GPU - {torch.cuda.get_device_name()}")
        print(f"  -> Total Memory: {torch.cuda.get_device_properties().total_memory / 1024**3:.2f} GB")
    else:
        device = torch.device("cpu")
        print("GPU is not available, using CPU.")
        print("\nCPU Information:")

        # Fix the backslash issue in the command string by escaping it
        cpu_model = os.popen("cat /proc/cpuinfo | grep \"model name\" | uniq").read().strip()
        print(f"CPU Model: {cpu_model}")
        print(f"Number of CPU cores: {os.cpu_count()}")

    return device

device = hardware_check()
print(f"\nUsing {device} for computation")

## Utils

Here all the utility functions used in this project. Rember to run them all.

* Data augmentation function
* Show dataset instances 
* Reconstruction error image plotting and saving if requested
* Get Ground truth mask function
* Plotter training loss
* AUPRO custom implementation
* plot_metric_curve

In [None]:
# Data Augmentation utility function
# this fun will apply differents transformations according to the class label
# Some transformations are specific for certain classes

def augment_image(source_image_path, dest_path, class_name, n_aug=1):

  image = Image.open(source_image_path)
  base_name = os.path.splitext(os.path.basename(source_image_path))[0]

  # all class safe transformations
  crop_resize = transforms.Compose([transforms.CenterCrop(size=200),
            transforms.Resize((256,256))])
  transf = [transforms.RandomRotation(degrees=30),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.8),
            crop_resize
  ]
  # specific classes transformations
  if class_name in ['carpet', 'grid', 'leather', 'wood', 'screw', 'tile']:
    transf.append(transforms.RandomHorizontalFlip(p=0.9))
    
  for i in range(n_aug):
    for j, t in enumerate(transf):
      aug_image = t(image)
      aug_image_name = f"{base_name}_aug_{i}_{j}.png"
      aug_image.save(os.path.join(dest_path, aug_image_name))

# ------------------------------------------------------------------------------------ #

# Show one instance from each class in a grid
def show_one_instance_per_class_grid(ds_path, folder='train', instance_type='good', grid_size=(3, 5), index_in_class=0):

    # Get a list of samples folders (classes)
    samples_folders = sorted([f for f in os.listdir(ds_path) if os.path.isdir(os.path.join(ds_path, f))])

    num_rows, num_cols = grid_size
    num_subplots = num_rows * num_cols

    # Limit the number of classes to display based on the grid size
    classes_to_display = samples_folders[:num_subplots]

    if not classes_to_display:
        print(f"No class folders found in {ds_path}")
        return

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3))
    # Flatten axes array for easy iteration
    axes = axes.flatten()

    fig.suptitle(f"One Sample per Class (from {folder}/{instance_type})", fontsize=14)

    displayed_count = 0
    for i, class_name in enumerate(classes_to_display):
        if displayed_count >= num_subplots:
            break # Stop if the grid is full

        # Construct the path to the desired folder for the current class
        if folder == 'train' and instance_type == 'good':
             source_folder = os.path.join(ds_path, class_name, folder, instance_type)
        elif folder == 'test' or folder == 'ground_truth':
            # For test/ground_truth, instance_type is usually the defect name or 'good'
            source_folder = os.path.join(ds_path, class_name, folder, instance_type)
        else:
             print(f"Warning: Unsupported folder '{folder}' and instance_type '{instance_type}' combination for path construction.")
             axes[i].set_title(f"{class_name}\n(Path Error)")
             axes[i].axis("off")
             displayed_count += 1
             continue


        if not os.path.exists(source_folder):
            print(f"Folder not found for class {class_name}: {source_folder}")
            axes[i].set_title(f"{class_name}\n(Folder Not Found)")
            axes[i].axis("off")
            displayed_count += 1
            continue


        image_files = sorted([f for f in os.listdir(source_folder) if f.lower().endswith(('.bmp', '.png', '.jpg', '.jpeg'))])

        if not image_files:
            print(f"No image files found in {source_folder}")
            axes[i].set_title(f"{class_name}\n(No Images)")
            axes[i].axis("off")
            displayed_count += 1
            continue

        if not 0 <= index_in_class < len(image_files):
            print(f"Warning: Index {index_in_class} out of bounds for class {class_name}. Using index 0.")
            img_to_display_path = os.path.join(source_folder, image_files[0])
        else:
            img_to_display_path = os.path.join(source_folder, image_files[index_in_class])

        try:
            image = Image.open(img_to_display_path).convert("RGB") # Ensure RGB format
            resized_image = image.resize((224, 224))
            axes[i].imshow(resized_image)
            axes[i].set_title(f"{class_name}\n({os.path.basename(img_to_display_path)})", fontsize=10)
            axes[i].axis("off")
            displayed_count += 1
        except Exception as e:
            print(f"Error loading or processing image for class {class_name} ({img_to_display_path}): {e}")
            axes[i].set_title(f"{class_name}\n(Load Error)")
            axes[i].axis("off")
            displayed_count += 1 


    # Hide any unused subplots if there are fewer classes than grid spots
    for j in range(displayed_count, len(axes)):
        axes[j].axis("off")


    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
    plt.show()

# ------------------------------------------------------------------------------------ #

# Utility function to plot reconstruction error in different classes 
# Visualize real, heatmap pixel-wise with ground truth mask and save picture if asve_dir specified

def visualize_with_gt(model,
                      dataloaders,
                      class_name,
                      device,
                      hyperparams,
                      root_folder,    # es. './mvtec_ad_validation' o './mvtec_ad_test'
                      num_images=3,
                      save_dir=None):   # pass save_dir as arg to save the image

    model.eval()
    val_loader = dataloaders[class_name]["val"]
    p = hyperparams['patch_size']
    H = W = hyperparams['image_size']

    with torch.no_grad():
        for batch_idx, (inputs, paths) in enumerate(val_loader):
            inputs = inputs.to(device)
            pred, _ = model(inputs)

            # Ricostruzione full-image
            B, N, _ = pred.shape
            npd = H // p
            recon = (pred.view(B, npd, npd, p, p, -1)
                         .permute(0,5,1,3,2,4)
                         .contiguous()
                         .view(B, -1, H, W))

            # Error map pixel-wise
            error_maps = ((recon - inputs)**2).mean(dim=1).cpu().numpy()

            for i in range(min(num_images, B)):
                img = inputs[i].cpu().permute(1,2,0).numpy()
                err = error_maps[i]

                # === calcolo percorso ground-truth ===
                # e.g. './mvtec_ad_validation/cable/poke_insulation/007.png'
                path = paths[i]
                fname = os.path.basename(path)
                parts = path.split('/')
                # posizionati su .../<root_folder>/<class_name>/<defect_type>/<file>.png
                # trova indice di class_name
                
                is_good_sample = 'good' in parts
                gt_mask = None # Initialize ground truth mask as None

                if not is_good_sample:
                    # Try to construct and load ground truth mask path for non-'good' samples
                    try:
                        # Find the class_name index to get defect type
                        class_idx = parts.index(class_name)
                        # The folder after class_name in the path is usually the defect type
                        defect = parts[class_idx + 1]
                        base_fname = os.path.splitext(parts[-1])[0]
                        # Construct the potential mask path
                        gt_path = os.path.join(
                            root_folder, # Use the passed root folder
                            class_name,
                            'ground_truth', # Ground truth masks are in this folder
                            defect,       # Under the defect type folder
                            f"{base_fname}_mask.png" # Mask filename convention
                        )

                        if os.path.exists(gt_path):
                            gt_mask = Image.open(gt_path)

                    except (ValueError, IndexError, FileNotFoundError) as e:
                        # Handle cases where path structure is unexpected or mask not found
                        print(f"Could not load ground truth mask for {path}: {e}")
                        gt_mask = None # Ensure gt_mask is None if loading fails
                    except Exception as e:
                         print(f"An unexpected error occurred loading mask for {path}: {e}")
                         gt_mask = None


                #  Visualizzazione 
                fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                # Use defect type or 'good' in the title
                display_type = defect if not is_good_sample else 'good'
                fig.suptitle(f"Class: {class_name} - Type: {display_type} - Image: {fname}", fontsize=14, weight='bold')

                # Original Image
                axes[0].imshow(img)
                axes[0].set_title("Original Image")
                axes[0].axis("off")
                
                # Error Map
                im1 = axes[1].imshow(err, cmap='viridis') # Use a suitable colormap for error maps
                axes[1].set_title("Error Map")
                axes[1].axis("off")
                fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

                # Ground Truth Mask or Placeholder
                if gt_mask is not None: 
                     im2 = axes[2].imshow(gt_mask, cmap='gray') # Use grayscale for masks
                     axes[2].set_title("Ground Truth Mask")
                     axes[2].axis("off")
                else:
                     # Display a black image if no ground truth mask (for 'good' samples or missing masks)
                     axes[2].imshow(np.zeros((H, W), dtype=np.uint8), cmap='gray')
                     axes[2].set_title("No Ground Truth Mask")
                     axes[2].axis("off")

                plt.tight_layout(rect=[0, 0.03, 1, 0.95])
                plt.show()

                # save all the showed images in the directory passes as arg
                if save_dir:
                    os.makedirs(save_dir, exist_ok=True)
                    fig.savefig(os.path.join(save_dir, f"{class_name}_sample_{batch_idx}_{i}.png"))

# ------------------------------------------------------------------------------------ #

def get_gt_mask_path(img_path, root_folder="./mvtec_ad_test"):
    # Ensure root_folder is used correctly
    base_fname = os.path.splitext(os.path.basename(img_path))[0]

    # Extract class_name and defect type from the image path
    # The path structure is expected to be like:
    # <root_folder>/<class_name>/<defect_type>/<image_file>
    parts = img_path.split('/')
    # Find the index of the root_folder in the path
    try:
        root_idx = parts.index(os.path.basename(root_folder))
    except ValueError:
        print(f"Error: Root folder '{os.path.basename(root_folder)}' not found in image path '{img_path}'")
        return None # Cannot construct the path if root is not found

    # class_name is expected to be one level deeper than the root_folder
    if root_idx + 1 < len(parts):
        class_name = parts[root_idx + 1]
    else:
        print(f"Error: Cannot extract class name from path '{img_path}'")
        return None

    # defect type is expected to be two levels deeper than the root_folder
    if root_idx + 2 < len(parts):
         defect = parts[root_idx + 2]
    else:
         print(f"Error: Cannot extract defect type from path '{img_path}'")
         return None


    # Construct the ground truth path
    # The ground truth masks are in <root_folder>/<class_name>/ground_truth/<defect_type>/<base_fname>_mask.png
    # The root folder for GT masks should be the same as the input root_folder
    gt_path = os.path.join(
        root_folder,
        class_name,
        "ground_truth",
        defect,
        f"{base_fname}_mask.png"
    )
    return gt_path

# ------------------------------------------------------------------------------------ #

# Plot losses in training
# Use as: plot_training_losses(training_losses) pass a dictionary as argument

def plot_training_losses(training_losses):
    plt.figure(figsize=(10, 6))

    # Palette Pantone personalizzata
    colors = [
        "#822433",
        "#006778",  # PMS 3155 EC
        "#70A489",  # PMS 556 EC
        "#00B3BE",  # PMS 7466 EC
        "#AAC9B6",  # PMS 558 EC
        "#AAA38E",  # PMS 7536 EC
        "#D7D3C7",  # PMS 7534 EC
        "#C54C00",  # PMS 1525 EC
        "#F69240",  # PMS 715 EC
        "#C79900",  # PMS 117 EC
        "#D7A900",  # PMS 110 EC
        "#A6BCC6",  # PMS 5435 EC
        "#D3DEE4",  # PMS 642 EC
    ]

    for idx, (class_name, losses) in enumerate(training_losses.items()):
        color = colors[idx % len(colors)]  # Ciclo se le classi superano i colori disponibili
        plt.plot(losses, label=class_name, color=color)

    plt.title("Training Loss per Epoch for Each Class")
    plt.xlabel("Epoch")
    plt.ylabel("Training Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

# ------------------------------------------------------------------------------------ #

# AUPRO custom implementation 
def compute_aupro(pred_mask, gt_mask, max_fpr=0.3):
    gt_mask = gt_mask.astype(np.uint8).ravel()
    pred_mask = pred_mask.ravel()
    precision, recall, thresholds = precision_recall_curve(gt_mask, pred_mask)
    fpr = 1 - precision

    # Clip to max_fpr
    mask = fpr <= max_fpr
    if mask.sum() == 0:
        return 0.0

    return np.trapz(recall[mask], fpr[mask])

# ------------------------------------------------------------------------------------ #
# used in validation for plotting and saving, if required, metric curves
def plot_metric_curve(
    epochs, 
    values_dict, 
    title="Metric Trend During Finetuning", 
    ylabel="Score", 
    save_path=None
):
    
    plt.figure(figsize=(10, 5))
    
    colors = ['#822433', '#006778']
    for idx, (label, values) in enumerate(values_dict.items()):
        color = colors[idx % len(colors)]
        plt.plot(epochs, values, marker='o', label=label, color=color)

    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.ylim(0.0, 1.05)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()

    if save_path:
        plt.show()
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

## Data

In [None]:
# Define dataset path here
path = '/kaggle/input/mvtec-ad'
# path = kagglehub.dataset_download("ipythonx/mvtec-ad")
print("Path to dataset files:", path)

In [None]:
# Dataset structure 🌳📂
# Is formed by groundtruth directory containing masked pictures, 
# train dir containing only good pieces and test directories filled with images with different defects

def print_directory_tree(path, prefix="", is_last=True, max_folders=4, 
                         max_test_files=4, max_files_in_subfolders=4, is_root=False):
    # Define the connector for the current item
    connector = "└── " if is_last else "├── "

    # Print the current item (full path if it's the root)
    if is_root:
        print(path)
    else:
        print(prefix + connector + os.path.basename(path))

    # Prefix for child items
    new_prefix = prefix + ("    " if is_last else "│   ")

    if os.path.isdir(path):
        # List all files and directories in the current folder
        items = sorted([os.path.join(path, item) for item in os.listdir(path)])
        dirs = [item for item in items if os.path.isdir(item)]
        files = [item for item in items if os.path.isfile(item)]

        # Special logic for folders containing "train" and "ground_truth"
        if "train" in path.lower() or "ground_truth" in path.lower():
            # Limit to max_folders subfolders
            show_dirs = dirs[:max_folders]
            if len(dirs) > max_folders:
                show_dirs.append("...")  # Placeholder for remaining folders

            # For each subfolder to show
            for i, item in enumerate(show_dirs):
                is_last_item = i == len(show_dirs) - 1

                if item == "...":
                    print(new_prefix + "└── ...")
                else:
                    # Print the subfolder
                    folder_name = os.path.basename(item)
                    print(new_prefix + ("└── " if is_last_item else "├── ") + folder_name)

                    folder_prefix = new_prefix + ("    " if is_last_item else "│   ")

                    # For the first two subfolders, show max_files_in_subfolders files
                    if i < 2:
                        subdir_items = sorted([os.path.join(item, subitem) for subitem in os.listdir(item)])
                        subdir_files = [f for f in subdir_items if os.path.isfile(f)]

                        # Limit the number of files to show
                        show_files = subdir_files[:max_files_in_subfolders]
                        if len(subdir_files) > max_files_in_subfolders:
                            show_files.append("...")

                        for j, file in enumerate(show_files):
                            is_last_file = j == len(show_files) - 1
                            if file == "...":
                                print(folder_prefix + "└── ...")
                            else:
                                print(folder_prefix + ("└── " if is_last_file else "├── ") + os.path.basename(file))

        # Special logic for folders containing "test"
        elif "test" in path.lower():
            # Limit to
            show_dirs = dirs[:max_folders]
            if len(files) > max_folders:
                show_dirs.append("...") #PLaceholder for remaining folders

            # Print all files
            for i, item in enumerate(show_dirs):
              is_last_item = i == len(show_dirs) - 1
              if item == "...":
                print(new_prefix + "└── ...")
              else:
                # Print the subfolder
                folder_name = os.path.basename(item)
                print(new_prefix + ("└── " if is_last_item else "├── ") + folder_name)
                folder_prefix = new_prefix + ("    " if is_last_item else "│   ")

                if i < 2:
                    # For the first two subfolders, show some files
                    subdir_items = sorted([os.path.join(item, subitem) for subitem in os.listdir(item)])
                    subdir_files = [f for f in subdir_items if os.path.isfile(f)]


          # For other folders, show everything normally
        else:
            # Print all subfolders and files
            all_items = dirs + files
            for i, item in enumerate(all_items):
                is_last_item = i == len(all_items) - 1
                print_directory_tree(item, new_prefix, is_last_item, max_folders, max_test_files, max_files_in_subfolders)


# Print directory tree
print("📂 mvtec Anomaly Detection - Directory Tree:\n")
print_directory_tree(path, is_root=True)

In [None]:
# Show instances from original mvtec-ad dataset
show_one_instance_per_class_grid(path, folder='train', instance_type='good', grid_size=(3, 5), index_in_class=0)

### ⚠️ **Choose only one from the following two code snippet! Read the initial comments!** ⚠️

In [None]:
# ----- This code is really time consuming, instead of run it install the augmneted dataset version in next snippet ----- #

# run once in your environment, then you should be able to use the other faster snippet
# split dataset and augment train set
# paralellize jobs using multiple threads

classes = os.listdir(path)
classes = [item for item in classes if not item.endswith('.txt')]
print(classes)

label_list = []

print(f"")
# define the new training set folder
train_folder = './mvtec_ad_train/'

def process_image(file, source_path, destination_dir, class_name, n_aug):
    if not file.lower().endswith(('.png', '.jpg', '.jpeg')):
        return
    source_file_path = os.path.join(source_path, file)
    destination_file_path = os.path.join(destination_dir, file)
    try:
        shutil.copy(source_file_path, destination_file_path)
        augment_image(source_file_path, destination_dir, class_name, n_aug=n_aug)
    except Exception as e:
        print(f"Error processing {file}: {e}")

max_workers = 8
n_aug = 1
# copy samples in the new training set
for class_name in tqdm(classes, desc = "processing"):

    # define source directories
    source_path = os.path.join(path, class_name, 'train', 'good')
    # define destination directory as train_subfolder/class
    destination_dir = os.path.join(train_folder, class_name)
    os.makedirs(destination_dir, exist_ok=True)

    files = os.listdir(source_path)

    with ThPE(max_workers=max_workers) as executor:
        for file in files:
            executor.submit(process_image, file, source_path, destination_dir, class_name, n_aug)

# print the new directory tree
print_directory_tree(train_folder, is_root=True)


# ------------ Visualize a sample in both original and augmented version ------------ #
# Use it only if you want to visulize an image and its transformed version
def show_augmented(folder_path, class_name, n_augmented=3):

  class_path = os.path.join(folder_path, class_name)

  all_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg'))]

  or_files = [f for f in all_files if "_aug_" not in f]
  original_img_name = or_files[8]
  original_img_path = os.path.join(class_path, original_img_name)
  original_img = Image.open(original_img_path)

  base_original_name = os.path.splitext(original_img_name)[0]
  aug_imgs = [f for f in all_files if f.startswith(base_original_name + "_aug_")]
  aug_imgs = sorted(aug_imgs)[:n_augmented]


  fig, axes = plt.subplots(1, 1 + len(aug_imgs), figsize=(4*(1 + len(aug_imgs)), 4))
  fig.suptitle(f"Original + Augmentations - Class: {class_name}", fontsize=14, weight='bold')

  axes = np.ravel(axes)

  axes[0].imshow(original_img)
  axes[0].set_title("Original")
  axes[0].axis("off")

  for i, aug_name in enumerate(aug_imgs):
      aug_img = Image.open(os.path.join(class_path, aug_name))
      axes[i + 1].imshow(aug_img)
      axes[i + 1].set_title(f"Augmented {i+1}")
      axes[i + 1].axis("off")

  plt.tight_layout()
  plt.show()

train_folder = './mvtec_ad_train/'  # o '/kaggle/working/mvtec_ad_train/'
show_augmented(train_folder, class_name='wood')

In [None]:
# This is the alternative to previous code snippet! don't run if you've already ran the previous one! 

# Augmented dataset directly from Kaggle (is the result of previous code snippet but ran by us)
aug_trainset_folder = '/kaggle/input/mvtec-ad-augmneted-trainset/mvtec_ad_train' 
classes = sorted(os.listdir(aug_trainset_folder))
print("Augmented trainset classes: ", classes)

# print the directory tree of the augmeted dataset
print_directory_tree(aug_trainset_folder, is_root=True)

In [None]:
# split orginal dataset and do the same for validation and test 
validation_folder = './mvtec_ad_validation/'
test_folder = './mvtec_ad_test/'

#20% validation 80% test
split_ratio = 0.2

for class_name in classes:

  # define source directories: here we've subclasses too
  source_path = os.path.join(path, class_name, 'test')
  sub_folders = os.listdir(source_path)
  for s in sub_folders:
  
    source_sub_path = os.path.join(source_path, s)
    files = [f for f in os.listdir(source_sub_path)
              if f.lower().endswith(('.png','.jpg','.bmp'))]
    # split using sklearn train_test_split function, random_state = seed
    test_files, val_files = train_test_split(files, test_size=split_ratio,random_state=46,shuffle=True)

    # validation set
    dst_val_sub = os.path.join(validation_folder, class_name, s)
    os.makedirs(dst_val_sub, exist_ok=True)
    for f in val_files:
        shutil.copy(os.path.join(source_sub_path,f), os.path.join(dst_val_sub,f))

        mask_filename = os.path.splitext(f)[0] + '_mask.png'
        source_mask_path = os.path.join(path, class_name, 'ground_truth', s, mask_filename)

        if os.path.exists(source_mask_path):
            dst_mask_sub = os.path.join(validation_folder, class_name, 'ground_truth', s)
            os.makedirs(dst_mask_sub, exist_ok=True)
            shutil.copy(source_mask_path, os.path.join(dst_mask_sub, mask_filename))
      
    # test set
    dst_test_sub = os.path.join(test_folder, class_name, s)
    os.makedirs(dst_test_sub, exist_ok=True)
    for f in test_files:
        shutil.copy(os.path.join(source_sub_path,f), os.path.join(dst_test_sub,f))
        
        mask_filename = os.path.splitext(f)[0] + '_mask.png'
        source_mask_path = os.path.join(path, class_name, 'ground_truth', s, mask_filename)

        if os.path.exists(source_mask_path):
            dst_mask_sub = os.path.join(test_folder, class_name, 'ground_truth', s)
            os.makedirs(dst_mask_sub, exist_ok=True)
            shutil.copy(source_mask_path, os.path.join(dst_mask_sub, mask_filename))

# print directory trees - uncomment following threee lines if you wanna see them
            
print_directory_tree(test_folder, is_root=True)
print(f"")
print_directory_tree(validation_folder, is_root=True)

### Data Loader ...

In [None]:
# e.g. classes = ['bottle', 'cable']
# return a dictionary of dataloaders with keys train, val, test 
# new class to handle validation set structure (same structure of test set)
class extract_images_dir(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform
        self.samples = []

        for dirpath, _, filenames in os.walk(root_dir):
            # exclude ground_truth from validation and test images
            if 'ground_truth' in dirpath.split('/'):
                continue
                
            for f in sorted(filenames):
                if f.endswith((".png", ".jpg", ".jpeg")):
                    img_path = os.path.join(dirpath, f)
                    # double check to exclude '_mask' objects - should be a redundant check
                    if '_mask' in os.path.splitext(img_path)[0].lower():
                        continue
                    self.samples.append(img_path)
                
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, img_path
        
def create_class_dataloader(classes,
    trainset_folder = '/kaggle/input/mvtec-ad-augmneted-trainset/mvtec_ad_train',
    validation_folder = './mvtec_ad_validation/',
    test_folder = './mvtec_ad_test/',
    batch_size=32):

    
    transform = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor()
    ])

    # we load each class separatly so a dictionary of dataloaders
    dataloaders = {}   
    
    # iterate for each class c in classes
    for c in classes:
        
        print(f"Creating dataloader for class: {c}")
        
        class_trainpath = os.path.join(trainset_folder, c)
        class_validpath = os.path.join(validation_folder, c)
        class_testpath  = os.path.join(test_folder, c)
        
        # training set
        train_set =   extract_images_dir(class_trainpath, transform=transform)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)
     
        # validation set it's a little bit more complex
        validation_set =  extract_images_dir(class_validpath, transform=transform)
        val_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=False, num_workers=4)

        # test set 
        test_set =  extract_images_dir(class_testpath, transform=transform)
        test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)
        
        # update with new class the dataloaders dictionary
        dataloaders[c] = {
            "train": train_loader,
            "val": val_loader,
            "test": test_loader,
        }
    return dataloaders

## Network

In [None]:
# here the model paramerers

#seed
torch.manual_seed(42)
torch.cuda.manual_seed(42)
random.seed(42)

hyperparameters = {
    'image_size' : 256,   # prev val: 224
    'patch_size' :  16,   # 256/32^2 = 64 # of patches
    'in_channels': 3,
    'embed_dim' : 512,    # prev val: 512
    'num_heads' : 8,      # prev val: 16 - 8
    'depth_enc' : 16,     # prev val: 16
    'depth_dec' : 2,      # prev val: 8 - 2
    'mlp_dim'   : 512,    # prev val: 1024
    'dropout_rate' : 0.1,
    'init_alpha' : 0.5,
    'mask_ratio' : 0.75,
    'dec_embed_dim': 256  # prev: 512
}

In [None]:
# DyT normalization layer definition

class DyT(nn.Module):
    def __init__(self, C, init_alpha=0.1):
        super().__init__()
        self.alpha = Parameter(ones(1) * init_alpha)
        self.gamma = Parameter(ones(C))
        self.beta = Parameter(zeros(C))
        
    def forward(self, x):
        x = tanh(self.alpha * x)
        return self.gamma * x + self.beta
        

In [None]:
# New Feature Aggregration module 

class FeatureAggregationModule(nn.Module):
    def __init__(self, embed_dim, patch_size, image_size, intermediate_layers):
        super().__init__()
        self.patch_size = patch_size
        self.image_size = image_size
        self.intermediate_layers = intermediate_layers
        self.embed_dim = embed_dim

        # The target spatial dimensions for the aggregated features
        self.target_h = self.target_w = image_size // patch_size
        self.num_target_patches = self.target_h * self.target_w

        # Linear layers to project features from different encoder layers to a common dimension
        self.projection_layers = nn.ModuleList([
            nn.Linear(embed_dim, embed_dim) for _ in intermediate_layers
        ])

        # A final layer to combine features
        # The input channels to the combiner will be the sum of the projected feature dimensions
        self.combiner = nn.Conv2d(len(intermediate_layers) * embed_dim, embed_dim, kernel_size=1)

    # Modified forward to accept intermediate features (list of (features, ids_keep) tuples) and original N
    def forward(self, intermediate_features_with_ids, original_N):
        # intermediate_features_with_ids is a list of (tensor, ids_keep) tuples
        # tensor shape is (B, len_keep, D), ids_keep shape is (B, len_keep)

        batch_aggregated_features_spatial = []

        for b in range(intermediate_features_with_ids[0][0].size(0)): # Iterate through batch
            batch_resized_features = []
            for i, (features, ids_keep) in enumerate(intermediate_features_with_ids):
                # features: (B, len_keep_i, D), ids_keep: (B, len_keep_i)
                features_b = features[b] # (len_keep_i, D)
                ids_keep_b = ids_keep[b] # (len_keep_i)

                # Create a tensor for the full grid (N, D) and scatter the kept features
                full_grid_features = torch.zeros(original_N, self.embed_dim, device=features_b.device, dtype=features_b.dtype)
                # Use a boolean mask to select valid indices in ids_keep_b (not padding value)
                valid_indices_mask = ids_keep_b != -1 # Assuming -1 is the padding value

                # Scatter the features using only valid indices
                full_grid_features[ids_keep_b[valid_indices_mask]] = features_b[valid_indices_mask] # Scatter features to original positions


                # Reshape full grid features from (N, D) to spatial (D, H/P, W/P)
                # Need to ensure N is equal to self.num_target_patches
                if original_N != self.num_target_patches:
                     raise ValueError(f"Original number of patches ({original_N}) does not match target grid size ({self.num_target_patches}).")

                features_spatial = full_grid_features.transpose(0, 1).view(self.embed_dim, self.target_h, self.target_w) # (D, H/P, W/P)


                # Project features (optional) - Project after reshaping to spatial
                features_projected = self.projection_layers[i](features_spatial.permute(1, 2, 0)).permute(2, 0, 1) # (D, H/P, W/P)

                batch_resized_features.append(features_projected)

            # Concatenate features along the channel dimension for the current batch item
            combined_features_b = torch.cat(batch_resized_features, dim=0) # (len(intermediate_layers) * D, H/P, W/P)
            batch_aggregated_features_spatial.append(combined_features_b)

        # Stack the aggregated features for the batch
        aggregated_features_spatial = torch.stack(batch_aggregated_features_spatial, dim=0) # (B, len(intermediate_layers) * D, H/P, W/P)

        # Combine features using the combiner layer
        aggregated_features = self.combiner(aggregated_features_spatial) # (B, embed_dim, H/P, W/P)

        return aggregated_features

In [None]:
# Building ViT with DyT as normalizer

# Building ViT
def random_masking(x, mask_ratio):
    B, N, D = x.shape
    len_keep = int(N * (1 - mask_ratio))

    # Assume the patches are arranged in a square grid (e.g. sqrt(N) x sqrt(N))
    grid_size = int(math.sqrt(N))
    assert grid_size * grid_size == N, "N should be a perfect square for block masking"

    x_masked_list = []
    mask_list = []
    ids_restore_list = []
    ids_keep_list = []
    ids_masked_list = []

    for b in range(B):
        # Create a 2D mask: 1 = masked, 0 = keep
        mask_2d = torch.ones(grid_size, grid_size, device=x.device)

        # Determine number of patches to unmask
        num_keep = len_keep

        while mask_2d.sum() > (N - num_keep):
            # Randomly choose a block size
            block_size = torch.randint(low=2, high=grid_size // 2 + 1, size=(1,)).item()
            top = torch.randint(0, grid_size - block_size + 1, size=(1,)).item()
            left = torch.randint(0, grid_size - block_size + 1, size=(1,)).item()

            # Unmask a block (set to 0)
            mask_2d[top:top+block_size, left:left+block_size] = 0

        # Flatten mask
        mask_flat = mask_2d.flatten()
        ids_keep = torch.nonzero(mask_flat == 0, as_tuple=False).squeeze(1)
        ids_masked = torch.nonzero(mask_flat == 1, as_tuple=False).squeeze(1)

        # Shuffle ids to simulate "restore" operation
        ids_all = torch.cat([ids_keep, ids_masked], dim=0)
        ids_restore = torch.argsort(ids_all)

        # Apply mask
        x_b = x[b:b+1]
        x_masked_b = torch.gather(x_b, dim=1, index=ids_keep.unsqueeze(0).unsqueeze(-1).repeat(1, 1, D))

        # Create full mask
        mask_b = torch.ones([1, N], device=x.device, dtype=torch.bool) # Use boolean mask
        mask_b[0, ids_keep] = False # False means keep
        mask_b = torch.gather(mask_b, dim=1, index=ids_restore.unsqueeze(0))

        x_masked_list.append(x_masked_b.squeeze(0)) # Remove batch dim for padding
        mask_list.append(mask_b.squeeze(0))
        ids_restore_list.append(ids_restore)
        ids_keep_list.append(ids_keep)
        ids_masked_list.append(ids_masked)


    # Pad tensors to the maximum length in the batch
    max_len_masked = max(len(t) for t in x_masked_list)
    max_len_restore = max(len(t) for t in ids_restore_list)
    max_len_keep = max(len(t) for t in ids_keep_list)
    max_len_masked_ids = max(len(t) for t in ids_masked_list)

    x_masked_padded = torch.stack([F.pad(t, (0, 0, 0, max_len_masked - t.size(0))) for t in x_masked_list], dim=0)
    mask_padded = torch.stack([F.pad(t, (0, max_len_restore - t.size(0)), value=True) for t in mask_list], dim=0) # Pad mask with True (masked)
    ids_restore_padded = torch.stack([F.pad(t, (0, max_len_restore - t.size(0)), value=-1) for t in ids_restore_list], dim=0) # Pad with -1 or another indicator

    # Need padded versions of ids_keep and ids_masked for the decoder reconstruction
    ids_keep_padded = torch.stack([F.pad(t, (0, max_len_keep - t.size(0)), value=-1) for t in ids_keep_list], dim=0)
    ids_masked_padded = torch.stack([F.pad(t, (0, max_len_masked_ids - t.size(0)), value=-1) for t in ids_masked_list], dim=0)

    # Return padded tensors and the padded ids for reconstruction
    return x_masked_padded, mask_padded, ids_restore_padded, ids_keep_padded, ids_masked_padded


# Sinuisodal patch embedding to
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.image_size = image_size

        # Projection layer with initialization
        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

        num_patches = (image_size // patch_size) ** 2
        self.num_patches = num_patches
        self.grid_size = image_size // patch_size

        # Learnable positional embeddings
        self.learnable_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        nn.init.trunc_normal_(self.learnable_pos_embed, std=0.02)

        # Sinusoidal positional embeddings
        self.sinusoidal_pos_embed = self._build_sinusoidal_pos_embed(num_patches, embed_dim)

    def _build_sinusoidal_pos_embed(self, num_patches, embed_dim):
        grid_size = int(math.sqrt(num_patches))
        coords_h = torch.arange(grid_size)
        coords_w = torch.arange(grid_size)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
        coords_flatten = torch.flatten(coords, 1) # (2, num_patches)

        pe = torch.zeros(num_patches, embed_dim)
        position = coords_flatten.transpose(0, 1) # (num_patches, 2)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))

        pe[:, 0::2] = torch.sin(position[:, 0:1] * div_term)
        pe[:, 1::2] = torch.cos(position[:, 0:1] * div_term)
        pe = pe.unsqueeze(0) # (1, num_patches, embed_dim)
        return pe.to(self.proj.weight.device) # Move to the same device as proj weight


    def forward(self, x: torch.Tensor):
        B = x.size(0)
        x = self.proj(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)

        # Add both learnable and sinusoidal positional embeddings
        learnable_pe = self.learnable_pos_embed.to(x.device)
        sinusoidal_pe = self.sinusoidal_pos_embed.expand(B, -1, -1).to(x.device)

        x = x + learnable_pe + sinusoidal_pe

        return x

class MLP(nn.Module):
   def __init__(self, in_features, hidden_features, dropout_rate):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.dropout2 = nn.Dropout(dropout_rate)

        self.init_weights()

   def init_weights(self):
        # Xavier initialization for better convergence
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)

   def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.dropout2(x)
        return x

# Encoder
class EncoderLayer(nn.Module):
  def __init__(self, embed_dim, num_heads, mlp_dim, dropout_rate):
        super().__init__()
        # Pre-norm architecture with DyT
        self.norm1 = DyT(embed_dim, init_alpha=0.1)
        self.attn = nn.MultiheadAttention(
            embed_dim, num_heads,
            dropout=dropout_rate,
            batch_first=True
        )
        self.norm2 = DyT(embed_dim, init_alpha=0.1)
        self.mlp = MLP(embed_dim, mlp_dim, dropout_rate)

  def forward(self, x):
        # Pre-norm
        norm_x = self.norm1(x)
        attn_out, _ = self.attn(norm_x, norm_x, norm_x)
        x = x + attn_out

        norm_x = self.norm2(x)
        mlp_out = self.mlp(norm_x)
        x = x + mlp_out
        return x

# Decoder
class DecoderLayer(nn.Module):
    def __init__(self, dec_embed_dim, num_heads, mlp_dim, dropout_rate, enc_embed_dim):
        super().__init__()
        self.norm1 = DyT(dec_embed_dim, init_alpha=0.1)
        self.attn = nn.MultiheadAttention(
            dec_embed_dim, num_heads,
            dropout=dropout_rate,
            batch_first=True
        )

        # cross-attention
        self.norm2 = DyT(dec_embed_dim, init_alpha=0.1)
        # Ensure cross-attention key/value projection matches encoder embed dim
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=dec_embed_dim, # Query dimension is decoder embed dim
            num_heads=num_heads,
            dropout=dropout_rate,
            batch_first=True,
            kdim=enc_embed_dim,  # Key dimension is encoder embed dim
            vdim=enc_embed_dim   # Value dimension is encoder embed dim
        )


        # mlp layer
        self.norm3 = DyT(dec_embed_dim, init_alpha=0.1) # Renamed to norm3
        self.mlp = MLP(dec_embed_dim, mlp_dim, dropout_rate) # Use dec_embed_dim for MLP

    def forward(self, x, encoder_output_cross_attention):

        # self attention
        norm_x = self.norm1(x)
        attn_out, _ = self.attn(norm_x, norm_x, norm_x)
        x = x + attn_out


        # cross attention on encoder output
        norm_x = self.norm2(x)
        # Pass encoder_output_cross_attention as key and value
        cross_attn_out, _ = self.cross_attn(norm_x, encoder_output_cross_attention, encoder_output_cross_attention)
        x = x + cross_attn_out

        # mlp layer
        norm_x = self.norm3(x) # Use norm3
        mlp_out = self.mlp(norm_x)
        x = x + mlp_out
        return x

In [None]:
class MAE(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim,
                 num_heads, depth_enc, depth_dec, mlp_dim, dropout_rate, init_alpha, mask_ratio, dec_embed_dim=256):
        super().__init__()

        # Patch embedding with better initialization
        self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)

        # Store attributes used in training loop later
        self.patch_size = patch_size
        self.image_size = image_size
        self.in_channels = in_channels

        # Encoder with ModuleList (or Sequantial)
        self.encoder = nn.ModuleList([
            EncoderLayer(embed_dim, num_heads, mlp_dim, dropout_rate)
            for _ in range(depth_enc)
        ])
        self.encoder_norm = DyT(embed_dim, init_alpha=init_alpha)

        # Decoder-specific projection
        self.decoder_embed = nn.Linear(embed_dim, dec_embed_dim)
        nn.init.xavier_uniform_(self.decoder_embed.weight)
        nn.init.zeros_(self.decoder_embed.bias)

        # Mask token with proper initialization
        self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
        nn.init.trunc_normal_(self.mask_token, std=0.02)

        # Decoder
        self.decoder = nn.ModuleList([
            DecoderLayer(dec_embed_dim, num_heads, mlp_dim, dropout_rate, enc_embed_dim=embed_dim)
            for _ in range(depth_dec)
        ])
        self.decoder_norm = DyT(dec_embed_dim, init_alpha=init_alpha)

        # Prediction head with proper initialization
        self.decoder_pred = nn.Linear(dec_embed_dim, patch_size * patch_size * in_channels)
        nn.init.xavier_uniform_(self.decoder_pred.weight)
        nn.init.zeros_(self.decoder_pred.bias)

        # NEW PART - MAT PYRAMID
        self.intermediate_layers = [depth_enc // 3, 2 * depth_enc // 3, depth_enc]
        self.feature_aggregator = FeatureAggregationModule(embed_dim, patch_size, image_size, self.intermediate_layers)

        aggregated_patches_h = aggregated_patches_w = self.image_size // self.patch_size
        num_aggregated_patches = aggregated_patches_h * aggregated_patches_w
        # final linear layer
        self.aggregated_feature_proj = nn.Linear(embed_dim, embed_dim)

        # Apply weight initialization
        self.apply(self.init_weights)

    def init_weights(self, module):

        if isinstance(module, nn.Linear):
            if hasattr(module, 'weight') and module.weight is not None:
                if module.weight.shape[0] == module.weight.shape[1]:  # Square matrix (like in attention)
                    nn.init.xavier_uniform_(module.weight)
                else:
                    nn.init.xavier_normal_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.MultiheadAttention):
            # MultiheadAttention has its own initialization, but we can fine-tune it
            if hasattr(module, 'in_proj_weight') and module.in_proj_weight is not None:
                nn.init.xavier_uniform_(module.in_proj_weight)
            if hasattr(module, 'out_proj.weight'):
                nn.init.xavier_uniform_(module.out_proj.weight)

    def forward(self, x, mask_ratio=0.75):
        x = self.patch_embed(x)
        B, N, D = x.shape

        # use masking only in training - no masking for test
        if self.training:
            x_masked, mask, ids_restore, ids_keep, ids_masked_original_order = random_masking(x, mask_ratio)
        else:
            x_masked = x
            mask=torch.zeros(B,N, device=x.device, dtype=torch.bool)
            ids_keep = torch.arange(N, device=x.device).unsqueeze(0).repeat(B, 1)
            ids_restore = torch.arange(N, device=x.device).unsqueeze(0).repeat(B, 1)


        intermediate_features = []
        x_encoder = x_masked
        for i, layer in enumerate(self.encoder):
            x_encoder = layer(x_encoder)
            if (i + 1) in self.intermediate_layers:
                intermediate_features.append((x_encoder, ids_keep))

        encoder_output_last_layer = self.encoder_norm(x_encoder)

        aggregated_features_spatial = self.feature_aggregator(intermediate_features, N)

        B, D_agg, H_agg, W_agg = aggregated_features_spatial.shape
        aggregated_features_flat = aggregated_features_spatial.flatten(2).transpose(1, 2)
        aggregated_features_projected = self.aggregated_feature_proj(aggregated_features_flat)

        len_keep_actual = encoder_output_last_layer.size(1)

        # Project encoder output to decoder dimension before concatenation
        encoder_output_projected = self.decoder_embed(encoder_output_last_layer)

        mask_tokens = self.mask_token.repeat(B, N - len_keep_actual, 1)
        x_ = torch.cat([encoder_output_projected, mask_tokens], dim=1)

        # expanding restoration
        ids_restore_expanded = ids_restore.unsqueeze(-1).repeat(1, 1, self.decoder_embed.out_features) # Use decoder embed dim

        # Pad x_ if its length is less than the maximum length in ids_restore_expanded
        max_len_restore = ids_restore_expanded.size(1)
        if x_.size(1) < max_len_restore:
            padding_length = max_len_restore - x_.size(1)
            x_ = F.pad(x_, (0, 0, 0, padding_length)) # Pad along the sequence dimension


        x_ = torch.gather(x_, dim=1, index=ids_restore_expanded)


        for layer in self.decoder:
            x_ = layer(x_, aggregated_features_projected)
        x_ = self.decoder_norm(x_)

        pred = self.decoder_pred(x_) # shape (B, N, patch_size**2 * in_channels)

        return pred, mask

## Sanity check

Verify:
* Batch type and shapes
* I/O shape
* Model size and number of parameters
* FLOPs required by the model 

### ⚠️ **To have a correct training procedure you MUST run this section of code, it's more than just a check.** ⚠️

In [None]:
def mae_sanity_check(model, dataloader, device, hyperparameters):
    model = model.to(device)
    model.eval()
    print(f"----- Model Sanity check started ----")
    
    # no gradients here, its's a check not a training :)
    with torch.no_grad():
        for _, sample in enumerate(dataloader):
            if isinstance(sample, (list, tuple)):
                input_sample = sample[0]  # sample = (image, label), prendiamo l'immagine
            else:
                input_sample = sample  # solo immagini

            input_sample = input_sample.to(device)

            # === SANITY CHECK ===
            assert input_sample.dtype == torch.float32, "Input tensor should be float32"
            assert input_sample.size(1) == hyperparameters['in_channels'], f"Expected {hyperparameters['in_channels']} input channels"
            print("Input check: OK")

            # === MODEL FORWARD ===
            pred, mask = model(input_sample)

            # pred: [B, N, patch_dim]
            assert pred.dtype == torch.float32, "Prediction tensor should be float32"
            assert pred.size(0) == input_sample.size(0), "Batch size mismatch between input and prediction"
            assert pred.size(2) == (hyperparameters['patch_size'] ** 2) * hyperparameters['in_channels'], \
                f"Each patch prediction should have size patch^2 * in_channels"

            print("Output prediction shape check: OK")

            break  # controlliamo un solo batch

    # Evaluate # of parameters and size
    total_params = sum(p.numel() for p in model.parameters())
    size_in_mb = total_params * 4 / 1024 / 1024
    print(f"Model size: {size_in_mb:.2f} MB ({total_params/ 1e6:.2f} M params)")
    print("Model forward sanity check: PASSED 🎯")
    print("")

    # FLOPs evaluation
    #dummy_input =  torch.randn(1, hyperparameters['in_channels'], hyperparameters['image_size'], hyperparameters['image_size']).to(device)
    flops, params = profile(model, inputs=(input_sample, ))
    print(f"Model FLOPs: {flops / 1e9:.2f} GigaFLOPs") # Print FLOPs in GigaFLOPs

In [49]:
classes = ['capsule','carpet','grid','hazelnut','leather']# write here the classes you want to train and test! 
dataloaders = create_class_dataloader(classes=classes,batch_size=32)
model = MAE(**hyperparameters)
val_loader_to_check = dataloaders['hazelnut']['val']
mae_sanity_check(model, val_loader_to_check, device, hyperparameters)

Creating dataloader for class: capsule
Creating dataloader for class: carpet
Creating dataloader for class: hazelnut
Creating dataloader for class: leather
----- Model Sanity check started ----
Input check: OK
Output prediction shape check: OK
Model size: 120.65 MB (31.63 M params)
Model forward sanity check: PASSED 🎯

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
Model FLOPs: 73.69 GigaFLOPs


## Training

In [None]:
# from an image to patches - patchification process
def target_to_patches(pred, target, patch_size, in_channels):

    B, N, _ = pred.shape
    
    patches = target.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    patches = patches.contiguous().view(B, in_channels, -1, patch_size, patch_size)
    patches = patches.permute(0, 2, 1, 3, 4).contiguous().view(B, N, -1)  # (B, N, P*P*C)
    
    return patches

In [None]:
# DEFINE LOSS FUNCTIONS

# Define the training loss function taking in account both MSE and SSIM loss
def mae_loss_with_ssim(pred, target, mask, patch_size, in_channels):
    """
    pred: (B, N, P*P*C)
    target: (B, C, H, W)
    mask: (B, N)
    """
    
    B, N, _ = pred.shape
    device = pred.device

    target_patches = target_to_patches(pred, target, patch_size, in_channels)
    
    # PATCH-BASED MSE (su patch mascherate)
    mse_loss = (pred - target_patches) ** 2

    if mask.sum() > 0:
        mask = mask.unsqueeze(-1).type_as(mse_loss)  # (B, N, 1)
        mse_loss = (mse_loss * mask).sum() / (mask.sum() * mse_loss.size(-1))
    else:
        mse_loss = mse_loss.mean()

    # === SSIM LOSS (immagine ricostruita vs originale) ===
    # Ricostruisci immagine da pred
    H = W = target.shape[2]
    recon = pred.view(B, H // patch_size, W // patch_size, patch_size, patch_size, in_channels)
    recon = recon.permute(0, 5, 1, 3, 2, 4).contiguous()
    recon = recon.view(B, in_channels, H, W)

    ssim_loss = 1 - ssim(recon, target, data_range=1.0, size_average=True)  # Assumendo immagini normalizzate [0,1]

    # Total loss
    total_loss = 0.8 * mse_loss + 0.2 * ssim_loss 

    return total_loss

# Define pixel wise loss used in fine-tuning
def pixelwise_loss_mean(pred, target, patch_size, in_channels):
    """
    pred: (B, N, P*P*C)
    target: (B, C, H, W)
    """
    B, N, _ = pred.shape
    device = pred.device

    H = W = target.shape[2]
    
    # Reconstruct images from patches
    recon = pred.view(B, H // patch_size, W // patch_size, patch_size, patch_size, in_channels)
    recon = recon.permute(0, 5, 1, 3, 2, 4).contiguous()
    recon = recon.view(B, in_channels, H, W)

    # Loss MSE pixel-wise
    return F.mse_loss(recon, target)

In [None]:
 os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
# training per class - one train step

# Define training parameters
BATCH_SIZE= 32  
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-2

#dictionary for per class train loss plot
training_losses = {}

def train_one_ep(model, dataloader, optimizer, criterion, c_name, device):
    
    model.train()
    running_loss = 0.0
    for batch in dataloader:
        ins = batch[0].to(device)
        optimizer.zero_grad()
        pred, mask = model(ins)

        # training and finetuning case: two different losses
        if criterion == pixelwise_loss_mean:
            loss = pixelwise_loss_mean(pred, ins, model.patch_size, model.in_channels)
        else:    
            loss = criterion(pred, ins, mask, model.patch_size, model.in_channels) #model.patch_embed.proj.in_channels)
        
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    return running_loss / len(dataloader)

In [None]:
# Training
# First pre-training, then run the fine-tuning section to get the overall results
EPOCHS = 80
PATIENCE_T = 20

for c_name, dt_loader in dataloaders.items():
    print(f"\nRunning training on class: {c_name} ")
    
    model = MAE(**hyperparameters).to(device)
    train_loader = dt_loader["train"]
    val_loader   = dt_loader["val"]
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    best_tl = 0.25  # just for the development
    early_stop_count = 0
    if c_name not in training_losses :
        training_losses[c_name] = []
    
    # Learning Rate Scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS, eta_min=0.0)
    
    for epoch in range(EPOCHS):
        print(f"CLASS: {c_name} | EPOCH: {epoch+1}/{EPOCHS}")
        train_loss = train_one_ep(model, train_loader, optimizer, mae_loss_with_ssim, c_name, device)
        print(f"Train Loss: {train_loss:.4f}")
        
        training_losses[c_name].append(train_loss)
        scheduler.step()
        
        # Early stopping
        if (train_loss < best_tl):
            best_tl = train_loss
            early_stop_count = 0
            torch.save(model.state_dict(), f"/kaggle/working/{c_name}_pretrained_model.pth")
            print(f" New best model saved for class: {c_name}")
        else:
            early_stop_count += 1
            if early_stop_count >= PATIENCE_T:
                print("Early stopping triggered.")
                break
   
    
    print(f"-- Finished training for class: {c_name} --\n")

print("\n---- Training complete for all classes ----\n")

In [None]:
# Plot losses-epochs graph
plot_training_losses(training_losses)

In [None]:
# Change class_name to visualize different class samples
visualize_with_gt(model, dataloaders, class_name="hazelnut", device=device, 
                  hyperparams=hyperparameters,root_folder='./mvtec_ad_validation', num_images=4, save_dir='./pictures/pretrained')

In [None]:
# evalutation function called for validation and test
def evaluation_fun(model, dataloader, device, patch_size, folder, threshold=None):
    model.eval()

    all_image_scores, all_image_labels = [], []
    good_samples_err = []
    # AUPRO initialization
    all_pixel_maps = []
    all_gt_masks = []

    with torch.no_grad():

        for imgs, paths in dataloader:
            imgs = imgs.to(device)
            # no mask needed in validation and test
            pred, _ = model(imgs)

            # reconstruct the image
            B = imgs.size(0)
            H = W = hyperparameters["image_size"]
            recon = pred.view(B, H // patch_size, W // patch_size, patch_size, patch_size, hyperparameters['in_channels'])
            recon = recon.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, 3, H, W)

            # compute error map - check if 'good' in the path
            err_maps = ((recon - imgs)**2).mean(dim=1)
            image_scores_batch = torch.amax(err_maps, dim=[1, 2])
            batch_scores = image_scores_batch.cpu().numpy()
            all_image_scores.extend(batch_scores)
            all_image_labels += [0 if "good" in p.lower() else 1 for p in paths]

            # used in the threshold
            for score, path in zip(batch_scores, paths):
                if "good" in path.lower():
                    good_samples_err.append(score)

            # get ground truth masks
            for i, path in enumerate(paths):
                    all_pixel_maps.append(err_maps[i].cpu().numpy())
                    # --- Modified section ---
                    # Check if it's an anomaly image before trying to load a ground truth mask
                    if "good" not in path.lower():
                        gt_path = get_gt_mask_path(path, root_folder=folder)
                        if os.path.exists(gt_path):
                            gt_mask = Image.open(gt_path).convert("L")
                            # Resize mask to match image size for pixel-wise comparison
                            gt_mask = np.array(gt_mask.resize((W, H), Image.NEAREST)) / 255.0
                            all_gt_masks.append(gt_mask)
                        else:
                             # If an anomaly image has no GT mask, still add a zero mask for dimension consistency,
                             # but this case should ideally not happen for anomalous samples in MVTec AD.
                             # Print a warning or handle this as an error if necessary.
                             print(f"Warning: Ground truth mask not found for anomaly image: {path}")
                             gt_mask = np.zeros((H, W))
                             all_gt_masks.append(gt_mask)
                    else:
                        # 'good' image case - no ground truth mask available, add a zero mask
                        gt_mask = np.zeros((H,W))
                        all_gt_masks.append(gt_mask)
                    # --- End of modified section ---


        # check if have good samples to evaluate best_threshold
        if len(good_samples_err) == 0:
            raise ValueError("No 'good' samples found in validation set\n")

        all_image_scores = np.array(all_image_scores)
        all_image_labels = np.array(all_image_labels)

        # Validation case: calculate the threshold
        if threshold is None:
            best_threshold = np.percentile(good_samples_err, 95) #np.mean(good_samples_err) + 3 * np.std(good_samples_err)
            best_class_threshold[c_name] = best_threshold
            print(f"Updated Class {c_name} threshold: {best_threshold:.3f}")
        # Test case: use the threshold passed as argument: no cheating allowed!
        elif threshold is not None:
            best_threshold = threshold
        
        # AUC and F1 (image wise)
        auc = roc_auc_score(all_image_labels,  all_image_scores)
        if best_threshold is not None:
             preds = (all_image_scores >= best_threshold).astype(int)
             f1  = f1_score(all_image_labels, preds)
        else:
             # If no threshold is available (e.g., no good samples), set F1 to 0
             f1 = 0.0
             print("F1 score not calculated due to missing threshold.")


        # AUPRO
        aupro_scores = []
        # Iterate through pixel maps and ground truth masks simultaneously
        for pred_map, gtruth_mask in zip(all_pixel_maps, all_gt_masks):
            # Only calculate AUPRO for samples with actual anomaly masks (sum > 0)
            if gtruth_mask.sum() > 0:
                 # Appiattisci la heatmap e la maschera GT
                 scores = pred_map.flatten()
                 masks = gtruth_mask.flatten()

                 # Calcola AUPRO usando roc_auc_score
                 # Ensure there are at least two unique classes (0 and 1) in the mask
                 if len(np.unique(masks)) > 1:
                     aupro = roc_auc_score(masks, scores)
                     aupro_scores.append(aupro)
                 else:
                     # This case should only happen if a mask file exists but is all 0s or all 1s for an anomaly image
                     print(f"Warning: Ground truth mask for AUPRO calculation has only one unique value for a potential anomaly sample.")


        if len(aupro_scores) > 0:
            aupro_mean = np.mean(aupro_scores)
        else:
            # If no anomaly samples with valid GT masks were found after filtering
            aupro_mean = 0.0
            print("Warning: No valid anomaly samples with corresponding GT masks found for AUPRO calculation after filtering.")

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    # ROC graph - Plotting is done regardless of whether threshold/F1 was calculated
    fpr, tpr, _ = roc_curve(all_image_labels, all_image_scores)
    ax[0].plot(fpr, tpr, label=f"AUC={auc:.3f}",color='#822433')
    ax[0].plot([0,1],[0,1],'k--')
    ax[0].set_xlabel("False Positive Rate")
    ax[0].set_ylabel("True Positive Rate")
    ax[0].set_title("ROC Curve")
    ax[0].legend()
    ax[0].grid(True, linestyle='--', alpha=0.6)
    
    # Plot histogram about error distribution - Plotting is done regardless of threshold
    ax[1].hist(all_image_scores[all_image_labels==0], bins=30, alpha=0.9, label='good', color='#006778')
    ax[1].hist(all_image_scores[all_image_labels==1], bins=30, alpha=0.8, label='anomaly', color='#822433')
    ax[1].axvline(best_threshold, color='k', linestyle='--', label=f'thr={best_threshold:.3f}')
    #else:
    # If thresholding failed, maybe indicate this on the plot or in the title
    #     plt.title("Error Distribution (Threshold not available)");
    ax[1].set_xlabel("error value")
    ax[1].legend()
    ax[1].set_title("Error Distribution")
    ax[1].grid(True, linestyle='--', alpha=0.6)

    plt.tight_layout()
    plt.show()
    # Return best_threshold as the fourth value in the tuple, even if it's a fallback or None
    return auc, f1, aupro_mean, best_threshold

In [None]:
# fine-tuning loop with fixed threshold from train set
EPOCHS_FT = 40
LEARNING_RATE_FT = 5e-5
PATIENCE_FT = 20
EVAL_STEP = 5
WEIGHT_DECAY = 1e-2

# class threshold dictionary
best_class_threshold = {}

for c_name, dt_loader in dataloaders.items():
    print(f"\nRunning FINE-TUNING on class: {c_name} ")

    # lists for storing and then plot auc and aupro metrics
    auc_story = []
    aupro_story = []
    epochs_recorded = [] # Store the epoch number when metrics are recorded

    # load the pretrained model
    model = MAE(**hyperparameters).to(device)
    model.load_state_dict(torch.load(f"/kaggle/working/{c_name}_pretrained_model.pth"))
    train_loader = dt_loader["train"]
    val_loader   = dt_loader["val"]

    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE_FT, weight_decay=WEIGHT_DECAY)
    criterion = pixelwise_loss_mean
    best_auc = 0.0
    early_stop_count = 0

    for epoch in range(1, EPOCHS_FT+1):
        model.train()
        ft_loss = train_one_ep(model, train_loader, optimizer, criterion, c_name, device)
        print(f"CLASS: {c_name} | EPOCH: {epoch}/{EPOCHS_FT} | LOSS: {ft_loss:.4f}")

        if epoch % EVAL_STEP == 0:
            # validation
            val_auc, val_f1, val_aupro, val_threshold = evaluation_fun(model, val_loader, device,
                                                   patch_size=hyperparameters['patch_size'], folder="./mvtec_ad_validation")
            # store auc and aupro values
            auc_story.append(val_auc)
            aupro_story.append(val_aupro if val_aupro is not None else 0)
            epochs_recorded.append(epoch) # Store the epoch number

            print(f"\tEPOCH {epoch}/{EPOCHS_FT} Validation metrics: ")
            print(f"Val AUC: {val_auc:.3f} | Val F1: {val_f1:.3f} | Val AUPRO: {val_aupro:.3f}")

            # update threshold dictionary
            best_class_threshold[c_name] = val_threshold

            # early stopping
            if val_auc > best_auc:
                best_auc = val_auc
                early_stop_count = 0
                torch.save(model.state_dict(), f"/kaggle/working/{c_name}_best_model_finetuned.pth")
                print(f" New best finetuned model saved for class: {c_name}")
            else:
                early_stop_count += 1
                if early_stop_count >= PATIENCE_FT:
                    print("Early stopping triggered.")
                    break

    print(f"-- Fine-tuning completed for class: {c_name} --\n")

    # Use the recorded epochs for plotting
    plot_metric_curve(epochs=epochs_recorded,
                      values_dict={"AUC": auc_story,"AUPRO": aupro_story},
                      title=f"Validation Metrics - Class: {c_name}",
                      save_path=f"/kaggle/working/visuals/{c_name}_metrics_plot.png")

print("\n---- Fine-tuning complete for all classes ----\n")
print("Best thresholds per class founded:")
for c, th in best_class_threshold.items():
    print(f"{c}: {th:.3f}")

## Evaluation

Evaluate the model on the test with F1 score, AUC and AUPRO


In [None]:
# Test
print(f"\n---- Test section ----\n")

for c_name, dt_loader in dataloaders.items():
    print(f"\nTesting class: {c_name} \n")
    
    final_model = MAE(**hyperparameters).to(device)
    path_saved_model = f"/kaggle/working/{c_name}_best_model_finetuned.pth"
    model.load_state_dict(torch.load(path_saved_model))

    test_loader = dt_loader["test"]
    # get the class threshold from the thresholds dictionary
    class_threshold = best_class_threshold.get(c_name)
    test_auc, test_f1, test_aupro, _ = evaluation_fun(model=final_model,dataloader=test_loader,device=device,
                                                      patch_size=hyperparameters['patch_size'], folder="./mvtec_ad_test", threshold=class_threshold)
    
    print(f"\nFinal results summary per class: {c_name}")
    print(f"Threshold: {class_threshold:.3f}")
    print(f"Test AUC: {test_auc:.3f}")
    print(f"Test F1: {test_f1:.3f}")
    print(f"Test AUPRO: {test_aupro:.3f}\n")
    print("--------------------------------------")    

print(f"\n---- THIS IS THE END of the Test section ----\n")    

## **EXTRA!**

In [None]:
# Change class_name to visualize different class samples 
visualize_with_gt(model, dataloaders, class_name="hazelnut", device=device, 
                  hyperparams=hyperparameters,root_folder='./mvtec_ad_validation', num_images=10, save_dir='./pictures/finetuned')

In [None]:
# potremmo implementare una parte di training and validation loss plotting come fanno anche loro
print("Dataloaders keys:", dataloaders.keys())