# STEP 2: EfficientAD-M per Anomaly Detection - UN MODELLO PER OGNI CONNETTORE

Questo notebook addestra **9 modelli EfficientAD-M separati**, uno per ogni connettore (conn1, conn2, ..., conn9).

Ogni modello avr√† il suo threshold specifico calcolato solo sui dati OK del rispettivo connettore.

**EfficientAD-M** √® un metodo di anomaly detection basato su Teacher-Student architecture:
- **Teacher**: ResNet18 pre-addestrato su ImageNet (congelato)
- **Student**: ResNet18 non pre-addestrato (addestrato su OK)
- **Anomaly score**: differenza tra feature teacher e student

Funziona meglio degli autoencoder quando le anomalie sono strutturali/geometriche (come nel nostro caso).

**NOTA**: Le immagini sono gi√† in grayscale e normalizzate nel preprocessing. Le carichiamo come RGB per compatibilit√† con ResNet pre-addestrato.


## Setup


In [None]:
# Setup: Clona repository GitHub e monta Google Drive per i dati
import os
from pathlib import Path

# Opzione 1: Clona da GitHub (consigliato per sviluppo)
# Sostituisci con il tuo repository URL
GITHUB_REPO = "https://github.com/Giovanni000/Project-Work.git"  # ‚ö†Ô∏è MODIFICA QUESTO!
REPO_DIR = "/content/project"

# Clona repository (se non esiste gi√†)
if not Path(REPO_DIR).exists():
    !git clone {GITHUB_REPO} {REPO_DIR}
else:
    os.chdir(REPO_DIR)
    !git pull

# Cambia directory al repository
os.chdir(REPO_DIR)
# Se il clone crea una sottocartella, entra dentro
subdirs = [d for d in Path(REPO_DIR).iterdir() if d.is_dir() and not d.name.startswith('.')]
if len(subdirs) == 1:
    os.chdir(subdirs[0])

print(f"Repository directory: {os.getcwd()}")

# Opzione 2: Monta Google Drive solo per i dati (immagini)
from google.colab import drive
drive.mount('/content/drive')

# Path ai dati su Drive
DATA_ROOT = Path("/content/drive/MyDrive/Project Work/Data")
print(f"Data directory: {DATA_ROOT}")

# ‚ö†Ô∏è IMPORTANTE: Le immagini su Drive sono LENTE da caricare durante il training!
# Se il training √® troppo lento, considera di copiare le immagini in locale prima

# Import necessari
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
import pandas as pd
from pathlib import Path
import numpy as np
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt

# Seed per riproducibilit√†
torch.manual_seed(42)
np.random.seed(42)

# Verifica device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"VRAM: {gpu_memory:.1f} GB")
    if "T4" in gpu_name:
        print("‚úÖ Tesla T4 rilevata - Parametri ottimizzati per questa GPU")


## Verifica/Crea Dataset CSV


In [None]:
# Se data/dataset.csv non esiste, crealo automaticamente
import pandas as pd
from pathlib import Path

dataset_csv = Path("data/dataset.csv")

if not dataset_csv.exists():
    print("‚ö†Ô∏è  data/dataset.csv non trovato. Creazione automatica...")
    
    # Carica features_labeled.csv
    features_csv = Path("features_labeled.csv")
    if not features_csv.exists():
        # Prova path alternativo
        features_csv = Path("/content/project/features_labeled.csv")
    
    if not features_csv.exists():
        raise FileNotFoundError(f"features_labeled.csv non trovato in {features_csv}")
    
    print(f"Leggendo CSV: {features_csv}...")
    df = pd.read_csv(features_csv)
    print(f"CSV caricato: {len(df)} righe")
    
    # Aggrega PARTIAL OCCLUSION con OCCLUSION
    df['label_merged'] = df['label'].replace('PARTIAL OCCLUSION', 'OCCLUSION')
    
    # Costruisci path immagini (su Drive)
    DRIVE_DATA_BASE = "/content/drive/MyDrive/Project Work/Data"
    df['image_path'] = df.apply(
        lambda row: f"{DRIVE_DATA_BASE}/connectors/{row['connector_name']}/{row['filename']}",
        axis=1
    )
    
    # Verifica esistenza immagini
    print("Verificando esistenza immagini...")
    existing = []
    for idx, path in enumerate(df['image_path']):
        if Path(path).exists():
            existing.append(idx)
    
    print(f"Immagini trovate: {len(existing)}/{len(df)}")
    
    if len(existing) == 0:
        print("‚ö†Ô∏è  PROBLEMA: Nessuna immagine trovata!")
        print("Verifica che Drive sia montato e che i path siano corretti.")
    
    # Filtra solo immagini esistenti
    df_valid = df.iloc[existing].copy()
    
    # Prepara CSV finale
    output_df = df_valid[['image_path', 'label_merged', 'connector_name']].copy()
    output_df.rename(columns={'label_merged': 'label'}, inplace=True)
    
    # Crea cartella data se non esiste
    dataset_csv.parent.mkdir(parents=True, exist_ok=True)
    
    # Salva
    output_df.to_csv(dataset_csv, index=False)
    print(f"‚úÖ Dataset preparato: {len(output_df)} righe")
    print(f"Distribuzione label:")
    print(output_df['label'].value_counts())
else:
    print(f"‚úÖ Dataset trovato: {dataset_csv}")
    df_check = pd.read_csv(dataset_csv)
    print(f"  Righe: {len(df_check)}")
    print(f"  Colonne: {list(df_check.columns)}")


## Configurazione EfficientAD-M


In [None]:
# ============================================================================
# CONFIGURAZIONE EFFICIENTAD-M
# ============================================================================

# Dimensioni immagine (coerente con il progetto: 128x128)
IMG_SIZE = 128

# Parametri training
BATCH_SIZE = 32  # Ottimizzato per Tesla T4
NUM_EPOCHS = 25  # EfficientAD-M converge velocemente
LEARNING_RATE = 1e-4

# Threshold
THRESHOLD_MULTIPLIER = 2.5  # Threshold = mu + THRESHOLD_MULTIPLIER * sigma

# DataLoader
NUM_WORKERS = 2  # Ottimizzato per Tesla T4

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Multi-layer feature configuration
FEATURE_LAYERS = ["layer2", "layer3", "layer4"]  # Layers to use for anomaly detection
FEATURE_LAYER_WEIGHTS = {
    "layer2": 1.0,
    "layer3": 1.0,
    "layer4": 1.0,
}

# Spatial mask configuration
SPATIAL_MASK_MIN_WEIGHT = 0.5  # Minimum weight in spatial mask
SPATIAL_MASK_MAX_WEIGHT = 2.0  # Maximum weight in spatial mask

# Robust score configuration
TOP_K_PERCENT = 0.01  # Top 1% of anomaly values for robust score

# Debug/Visualization configuration
DEBUG_CONNECTOR = "conn1"  # Connector to use for visualization/debugging

print(f"Configurazione EfficientAD-M:")
print(f"  IMG_SIZE: {IMG_SIZE}")
print(f"  BATCH_SIZE: {BATCH_SIZE}")
print(f"  NUM_EPOCHS: {NUM_EPOCHS}")
print(f"  LEARNING_RATE: {LEARNING_RATE}")
print(f"  THRESHOLD_MULTIPLIER: {THRESHOLD_MULTIPLIER}")
print(f"  FEATURE_LAYERS: {FEATURE_LAYERS}")
print(f"  TOP_K_PERCENT: {TOP_K_PERCENT}")


## Dataset Class (solo OK, filtrato per connettore)


In [None]:
class EfficientADDatasetPerConnector(Dataset):
    """
    Dataset PyTorch per EfficientAD-M di un singolo connettore.
    Contiene solo immagini OK del connettore specificato.
    
    NOTA: Le immagini sono gi√† in grayscale e normalizzate nel preprocessing.
    Le carichiamo come RGB per compatibilit√† con ResNet pre-addestrato.
    """
    
    def __init__(self, csv_path, connector_name, transform=None):
        """
        Args:
            csv_path: Path al CSV con colonne 'image_path', 'label', 'connector_name'
            connector_name: Nome del connettore (es. 'conn1', 'conn2', ...)
            transform: Trasformazioni da applicare alle immagini
        """
        df = pd.read_csv(csv_path)
        # Filtra solo OK del connettore specificato
        self.df = df[(df['label'] == 'OK') & (df['connector_name'] == connector_name)].copy().reset_index(drop=True)
        self.transform = transform
        
        print(f"Dataset EfficientAD per {connector_name}: {len(self.df)} immagini OK")
        if len(self.df) == 0:
            raise ValueError(f"Nessuna immagine OK trovata per {connector_name}!")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row['image_path']
        
        # Carica immagine (ottimizzato: evita lazy loading)
        try:
            # Le immagini sono gi√† grayscale nel preprocessing, ma le carichiamo come RGB
            # per compatibilit√† con ResNet pre-addestrato
            image = Image.open(image_path).convert('RGB')
            image.load()  # Forza caricamento completo
        except Exception as e:
            print(f"Errore caricamento {image_path}: {e}")
            # Fallback: immagine nera
            image = Image.new('RGB', (IMG_SIZE, IMG_SIZE), (0, 0, 0))
        
        if self.transform:
            image = self.transform(image)
        
        return image


In [None]:
def compute_spatial_weight_mask_for_connector(connector_name, csv_path="data/dataset.csv",
                                              img_size=IMG_SIZE, save_dir="models",
                                              min_weight=SPATIAL_MASK_MIN_WEIGHT, 
                                              max_weight=SPATIAL_MASK_MAX_WEIGHT):
    """
    For the given connector, compute a spatial weight mask W[h, w] from all OK images:
    - load all OK images for that connector
    - resize to (img_size, img_size)
    - convert to grayscale or keep RGB but reduce to a single channel statistic
    - compute per-pixel mean and std over the OK stack
    - transform std into a weight map:
        - lower std -> higher weight (stable regions)
        - higher std -> lower weight (variable regions)
    - normalize and clamp to [min_weight, max_weight]
    - save the final W as a .npy file: spatial_mask_{connector_name}.npy in `save_dir`
    - return W as a numpy array [H, W]
    """
    import torch.nn.functional as F
    
    # Check if mask already exists
    models_dir = Path(save_dir)
    models_dir.mkdir(exist_ok=True)
    mask_path = models_dir / f"spatial_mask_{connector_name}.npy"
    
    if mask_path.exists():
        print(f"  ‚úÖ Spatial mask gi√† esistente per {connector_name}, skip computation")
        return np.load(mask_path)
    
    print(f"  üìä Computing spatial weight mask for {connector_name}...")
    
    # Read CSV and filter OK images for this connector
    df = pd.read_csv(csv_path)
    ok_df = df[(df['label'] == 'OK') & (df['connector_name'] == connector_name)].copy()
    
    if len(ok_df) == 0:
        raise ValueError(f"Nessuna immagine OK trovata per {connector_name}!")
    
    print(f"    Caricando {len(ok_df)} immagini OK...")
    
    # Load and stack all OK images
    images_stack = []
    for idx, row in tqdm(ok_df.iterrows(), total=len(ok_df), desc="  Loading images"):
        image_path = row['image_path']
        try:
            # Load image and convert to RGB
            img = Image.open(image_path).convert('RGB')
            # Resize to img_size
            img = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
            # Convert to numpy array and normalize to [0, 1]
            img_array = np.array(img, dtype=np.float32) / 255.0
            # Convert to grayscale (average over channels)
            img_gray = np.mean(img_array, axis=2)  # [H, W]
            images_stack.append(img_gray)
        except Exception as e:
            print(f"    ‚ö†Ô∏è  Errore caricamento {image_path}: {e}")
            continue
    
    if len(images_stack) == 0:
        raise ValueError(f"Nessuna immagine valida caricata per {connector_name}!")
    
    # Stack all images: [N, H, W]
    X = np.stack(images_stack, axis=0)
    print(f"    Stack shape: {X.shape}")
    
    # Compute per-pixel standard deviation
    std_map = X.std(axis=0)  # [H, W]
    
    # Normalize std_map to [0, 1]
    std_min, std_max = std_map.min(), std_map.max()
    if std_max - std_min < 1e-8:
        # All values are the same, use uniform weights
        std_norm = np.zeros_like(std_map)
    else:
        std_norm = (std_map - std_min) / (std_max - std_min + 1e-8)
    
    # Convert to weights: lower std -> higher weight
    # Formula: weights = 1.0 / (1.0 + std_norm)
    # This gives weights in [0.5, 1.0], then we rescale to [min_weight, max_weight]
    weights = 1.0 / (1.0 + std_norm)
    
    # Rescale to [min_weight, max_weight]
    w_min, w_max = weights.min(), weights.max()
    if w_max - w_min < 1e-8:
        # Uniform weights
        weights = np.ones_like(weights) * ((min_weight + max_weight) / 2.0)
    else:
        weights = min_weight + (weights - w_min) * (max_weight - min_weight) / (w_max - w_min + 1e-8)
    
    # Save mask
    np.save(mask_path, weights)
    print(f"  ‚úÖ Spatial mask salvato in: {mask_path}")
    print(f"    Weight range: [{weights.min():.3f}, {weights.max():.3f}]")
    
    return weights


def load_spatial_weight_mask(connector_name, target_size, save_dir="models", device=DEVICE):
    """
    Load spatial_mask_{connector_name}.npy, resize it to target_size = (Hf, Wf),
    and return a tensor of shape [1, 1, Hf, Wf] on the given device.
    
    Args:
        connector_name: Name of the connector
        target_size: Target size (Hf, Wf) for the mask
        save_dir: Directory where masks are saved
        device: Device to place the tensor on
    
    Returns:
        mask: Tensor [1, 1, Hf, Wf] on device
    """
    import torch.nn.functional as F
    
    models_dir = Path(save_dir)
    mask_path = models_dir / f"spatial_mask_{connector_name}.npy"
    
    if not mask_path.exists():
        raise FileNotFoundError(f"Spatial mask not found: {mask_path}")
    
    # Load mask [H, W]
    mask = np.load(mask_path)
    
    # Convert to torch tensor [1, 1, H, W]
    mask_tensor = torch.from_numpy(mask).float().unsqueeze(0).unsqueeze(0)
    
    # Resize to target_size using bilinear interpolation
    if mask_tensor.shape[2:] != target_size:
        mask_tensor = F.interpolate(
            mask_tensor, 
            size=target_size, 
            mode='bilinear', 
            align_corners=False
        )
    
    # Move to device
    mask_tensor = mask_tensor.to(device)
    
    return mask_tensor


## Dataset Class (solo OK, filtrato per connettore)


In [None]:
def compute_fused_anomaly_map(teacher_feats, student_feats,
                              spatial_mask_fullres,
                              feature_layers=FEATURE_LAYERS,
                              feature_layer_weights=FEATURE_LAYER_WEIGHTS,
                              device=DEVICE):
    """
    Compute fused 2D anomaly map from multi-layer features with spatial mask.
    This is the intermediate step before reducing to a scalar score.
    
    Args:
        teacher_feats: Dict[layer_name -> tensor [B, C, Hf, Wf]]
        student_feats: Dict[layer_name -> tensor [B, C, Hf, Wf]]
        spatial_mask_fullres: Weight map at image resolution [1, 1, H, W]
        feature_layers: List of layer names to use
        feature_layer_weights: Dict of weights for each layer
        device: Device for computation
    
    Returns:
        fused_map: Tensor [B, 1, Href, Wref] - fused anomaly map before scalar aggregation
    """
    import torch.nn.functional as F
    
    # Determine reference spatial resolution (use layer2 as it has highest resolution)
    ref_layer = feature_layers[0]  # Use first layer as reference
    ref_feat = teacher_feats[ref_layer]
    _, _, Href, Wref = ref_feat.shape
    
    fused = None
    
    # Process each layer
    for layer_name in feature_layers:
        t = teacher_feats[layer_name]  # [B, C, Hf, Wf]
        s = student_feats[layer_name]  # [B, C, Hf, Wf]
        
        # Compute squared difference
        diff = (t - s) ** 2  # [B, C, Hf, Wf]
        
        # Average over channels
        amap = diff.mean(dim=1, keepdim=True)  # [B, 1, Hf, Wf]
        
        # Get layer spatial dimensions
        _, _, Hf, Wf = amap.shape
        
        # Resize spatial mask to layer resolution
        mask_resized = F.interpolate(
            spatial_mask_fullres,
            size=(Hf, Wf),
            mode='bilinear',
            align_corners=False
        )  # [1, 1, Hf, Wf]
        
        # Apply spatial mask (broadcasting over batch)
        amap = amap * mask_resized  # [B, 1, Hf, Wf]
        
        # Upsample to reference resolution
        if (Hf, Wf) != (Href, Wref):
            amap = F.interpolate(
                amap,
                size=(Href, Wref),
                mode='bilinear',
                align_corners=False
            )  # [B, 1, Href, Wref]
        
        # Accumulate into fused map with layer weight
        layer_weight = feature_layer_weights.get(layer_name, 1.0)
        if fused is None:
            fused = layer_weight * amap
        else:
            fused = fused + layer_weight * amap
    
    return fused  # [B, 1, Href, Wref]


def compute_anomaly_score_from_features(teacher_feats, student_feats,
                                        spatial_mask_fullres,
                                        feature_layers=FEATURE_LAYERS,
                                        feature_layer_weights=FEATURE_LAYER_WEIGHTS,
                                        topk_percent=TOP_K_PERCENT,
                                        device=DEVICE):
    """
    Compute robust anomaly score from multi-layer features with spatial mask.
    
    Steps:
        - For each layer in feature_layers:
            * compute squared difference (t-s)^2
            * average over channels -> anomaly map per layer [B, Hf, Wf]
            * resize spatial mask to [Hf, Wf] and multiply element-wise
        - Upsample each layer anomaly map to the highest spatial resolution among the layers
        - Fuse layers (weighted sum)
        - For each image in the batch, flatten the fused map and compute a robust score:
            * take top-k percentile of anomaly values and average them
    
    Args:
        teacher_feats: Dict[layer_name -> tensor [B, C, Hf, Wf]]
        student_feats: Dict[layer_name -> tensor [B, C, Hf, Wf]]
        spatial_mask_fullres: Weight map at image resolution [1, 1, H, W]
        feature_layers: List of layer names to use
        feature_layer_weights: Dict of weights for each layer
        topk_percent: Top-k percentile for robust score (e.g. 0.01 = top 1%)
        device: Device for computation
    
    Returns:
        scores: Tensor [B] with one scalar anomaly score per image
    """
    # Get fused map
    fused = compute_fused_anomaly_map(
        teacher_feats, student_feats,
        spatial_mask_fullres,
        feature_layers=feature_layers,
        feature_layer_weights=feature_layer_weights,
        device=device
    )  # [B, 1, Href, Wref]
    
    # Robust scalar aggregation: top-k percentile
    B, _, Href, Wref = fused.shape
    fused_flat = fused.view(B, -1)  # [B, Href*Wref]
    
    # Compute top-k
    k = max(1, int(topk_percent * fused_flat.size(1)))  # e.g. top 1%
    topk_vals, _ = torch.topk(fused_flat, k=k, dim=1)
    scores = topk_vals.mean(dim=1)  # [B]
    
    return scores


## Modelli Teacher & Student


In [None]:
class Teacher(nn.Module):
    """
    Teacher: ResNet18 pre-addestrato su ImageNet (congelato).
    Usato per estrarre feature di riferimento multi-layer.
    """
    def __init__(self):
        super(Teacher, self).__init__()
        # Carica ResNet18 pre-addestrato
        base = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        
        # Decomponi ResNet18 in layer separati per estrarre feature multi-layer
        self.stem = nn.Sequential(
            base.conv1,
            base.bn1,
            base.relu,
            base.maxpool
        )
        self.layer1 = base.layer1
        self.layer2 = base.layer2
        self.layer3 = base.layer3
        self.layer4 = base.layer4
        
        # Congela tutti i parametri (no training)
        for param in self.parameters():
            param.requires_grad = False
    
    def forward(self, x):
        """
        Args:
            x: Tensor [B, 3, H, W]
        Returns:
            features: Dict con feature maps multi-layer:
                - "layer2": [B, 128, H/8, W/8]
                - "layer3": [B, 256, H/16, W/16]
                - "layer4": [B, 512, H/32, W/32]
        """
        x = self.stem(x)
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        
        return {
            "layer2": x2,
            "layer3": x3,
            "layer4": x4,
        }


class Student(nn.Module):
    """
    Student: ResNet18 NON pre-addestrato (pesi random).
    Addestrato per imitare le feature del Teacher su immagini OK.
    Estrae feature multi-layer come il Teacher.
    """
    def __init__(self):
        super(Student, self).__init__()
        # Carica ResNet18 SENZA pesi pre-addestrati
        base = resnet18(weights=None)
        
        # Decomponi ResNet18 in layer separati (stessa struttura del Teacher)
        self.stem = nn.Sequential(
            base.conv1,
            base.bn1,
            base.relu,
            base.maxpool
        )
        self.layer1 = base.layer1
        self.layer2 = base.layer2
        self.layer3 = base.layer3
        self.layer4 = base.layer4
    
    def forward(self, x):
        """
        Args:
            x: Tensor [B, 3, H, W]
        Returns:
            features: Dict con feature maps multi-layer:
                - "layer2": [B, 128, H/8, W/8]
                - "layer3": [B, 256, H/16, W/16]
                - "layer4": [B, 512, H/32, W/32]
        """
        x = self.stem(x)
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        
        return {
            "layer2": x2,
            "layer3": x3,
            "layer4": x4,
        }


## Funzione Training per Connettore


In [None]:
def train_efficientad_per_connector(connector_name, csv_path="data/dataset.csv",
                                     batch_size=32,
                                     num_epochs=20,
                                     learning_rate=1e-4,
                                     device=None):
    """
    Addestra un modello EfficientAD-M per un singolo connettore.
    
    Args:
        connector_name: Nome del connettore (es. 'conn1')
        csv_path: Path al CSV del dataset
        batch_size: Dimensione del batch
        num_epochs: Numero di epoche
        learning_rate: Learning rate
        device: Device (cuda/cpu)
    
    Returns:
        teacher: Modello Teacher (congelato)
        student: Modello Student addestrato
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"\n{'='*60}")
    print(f"Training EfficientAD-M per {connector_name}")
    print(f"{'='*60}")
    
    # Trasformazioni (normalizzazione ImageNet per ResNet pre-addestrato)
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Dataset (solo OK del connettore specificato)
    dataset = EfficientADDatasetPerConnector(csv_path, connector_name, transform=transform)
    
    # DataLoader (ottimizzato per Colab)
    train_loader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=NUM_WORKERS,
        pin_memory=True if device.type == 'cuda' else False,
        prefetch_factor=2,
        persistent_workers=False
    )
    
    # Modelli
    teacher = Teacher().to(device)
    student = Student().to(device)
    
    # Teacher in modalit√† eval e congelato
    teacher.eval()
    for param in teacher.parameters():
        param.requires_grad = False
    
    # Compute spatial mask if not exists
    try:
        compute_spatial_weight_mask_for_connector(
            connector_name, 
            csv_path=csv_path,
            img_size=IMG_SIZE,
            save_dir="models"
        )
    except Exception as e:
        print(f"  ‚ö†Ô∏è  Errore computing spatial mask: {e}")
        print(f"  Continuo senza spatial mask...")
    
    # Load spatial mask at full resolution for training
    try:
        spatial_mask_fullres = load_spatial_weight_mask(
            connector_name,
            target_size=(IMG_SIZE, IMG_SIZE),
            save_dir="models",
            device=device
        )  # [1, 1, IMG_SIZE, IMG_SIZE]
    except Exception as e:
        print(f"  ‚ö†Ô∏è  Errore loading spatial mask: {e}")
        print(f"  Usando mask uniforme...")
        spatial_mask_fullres = torch.ones(1, 1, IMG_SIZE, IMG_SIZE, device=device)
    
    # Optimizer (solo per Student)
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)
    
    # Training loop with multi-layer loss and spatial mask
    print(f"\nTraining Student per imitare Teacher su immagini OK (multi-layer + spatial mask)...")
    student.train()
    import torch.nn.functional as F
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images = images.to(device)
            
            # Feature Teacher (congelato, no grad) - multi-layer dict
            with torch.no_grad():
                teacher_feats = teacher(images)  # Dict
            
            # Feature Student (addestrato) - multi-layer dict
            student_feats = student(images)  # Dict
            
            # Multi-layer loss with spatial mask
            total_loss = 0.0
            for layer_name in FEATURE_LAYERS:
                t = teacher_feats[layer_name]  # [B, C, Hf, Wf]
                s = student_feats[layer_name]  # [B, C, Hf, Wf]
                
                # Get layer spatial dimensions
                _, _, Hf, Wf = t.shape
                
                # Resize spatial mask to layer resolution
                mask_resized = F.interpolate(
                    spatial_mask_fullres,
                    size=(Hf, Wf),
                    mode='bilinear',
                    align_corners=False
                )  # [1, 1, Hf, Wf]
                
                # Compute difference
                diff = (t - s) ** 2  # [B, C, Hf, Wf]
                
                # Apply spatial mask (broadcast over batch and channels)
                diff = diff * mask_resized  # [B, C, Hf, Wf]
                
                # Layer loss (mean over all dimensions)
                layer_loss = diff.mean()
                
                # Weighted sum
                layer_weight = FEATURE_LAYER_WEIGHTS.get(layer_name, 1.0)
                total_loss += layer_weight * layer_loss
            
            # Backward solo su Student
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            running_loss += total_loss.item()
        
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")
    
    # Salva modello Student
    models_dir = Path("models")
    models_dir.mkdir(exist_ok=True)
    model_path = models_dir / f"efficientad_student_{connector_name}.pth"
    torch.save(student.state_dict(), model_path)
    print(f"\n‚úÖ Modello Student salvato in: {model_path}")
    
    return teacher, student


In [None]:
def calculate_threshold_per_connector(teacher, student, connector_name, csv_path="data/dataset.csv", 
                                      threshold_multiplier=2.5, device=None):
    """
    Calcola il threshold per anomaly detection per un singolo connettore.
    
    Threshold = mu + threshold_multiplier * sigma, dove mu e sigma sono media e std
    degli anomaly score su tutte le immagini OK del connettore specificato.
    
    Anomaly score = differenza tra feature Teacher e Student (max su feature map).
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"\nCalcolo threshold per {connector_name}...")
    
    # Trasformazioni (stesse del training)
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Dataset (solo OK del connettore)
    dataset = EfficientADDatasetPerConnector(csv_path, connector_name, transform=transform)
    
    # DataLoader
    loader = DataLoader(
        dataset, 
        batch_size=1,  # Batch size 1 per calcolo preciso
        shuffle=False,
        num_workers=0,  # Evita problemi con multiprocessing
        pin_memory=False
    )
    
    # Load spatial mask at full resolution
    try:
        spatial_mask_fullres = load_spatial_weight_mask(
            connector_name,
            target_size=(IMG_SIZE, IMG_SIZE),
            save_dir="models",
            device=device
        )  # [1, 1, IMG_SIZE, IMG_SIZE]
    except Exception as e:
        print(f"  ‚ö†Ô∏è  Errore loading spatial mask: {e}")
        print(f"  Usando mask uniforme...")
        spatial_mask_fullres = torch.ones(1, 1, IMG_SIZE, IMG_SIZE, device=device)
    
    # Calcola score per tutte le immagini OK usando robust score
    scores = []
    
    teacher.eval()
    student.eval()
    
    with torch.no_grad():
        for images in tqdm(loader, desc=f"Calcolo score {connector_name}"):
            images = images.to(device)
            
            # Feature Teacher e Student (multi-layer dict)
            teacher_feats = teacher(images)  # Dict
            student_feats = student(images)  # Dict
            
            # Compute robust anomaly score
            batch_scores = compute_anomaly_score_from_features(
                teacher_feats,
                student_feats,
                spatial_mask_fullres,
                feature_layers=FEATURE_LAYERS,
                feature_layer_weights=FEATURE_LAYER_WEIGHTS,
                topk_percent=TOP_K_PERCENT,
                device=device
            )  # [B]
            
            score = batch_scores[0].cpu().item()
            scores.append(score)
    
    scores = np.array(scores)
    mu = np.mean(scores)
    sigma = np.std(scores)
    threshold = mu + threshold_multiplier * sigma
    
    print(f"  Score OK - mean: {mu:.6f}, std: {sigma:.6f}")
    print(f"  Threshold ({threshold_multiplier}*sigma): {threshold:.6f}")
    
    # Salva threshold
    models_dir = Path("models")
    threshold_path = models_dir / f"efficientad_threshold_{connector_name}.npy"
    np.save(threshold_path, threshold)
    print(f"  ‚úÖ Threshold salvato in: {threshold_path}")
    
    return threshold


## Training Loop - Tutti i Connettori


In [None]:
# Training per tutti i 9 connettori
connectors = [f"conn{i}" for i in range(1, 10)]

trained_models = {}

for connector_name in connectors:
    try:
        # Training
        teacher, student = train_efficientad_per_connector(
            connector_name=connector_name,
            csv_path="data/dataset.csv",
            batch_size=BATCH_SIZE,
            num_epochs=NUM_EPOCHS,
            learning_rate=LEARNING_RATE,
            device=DEVICE
        )
        
        trained_models[connector_name] = (teacher, student)
        
        # Calcolo threshold
        threshold = calculate_threshold_per_connector(
            teacher=teacher,
            student=student,
            connector_name=connector_name,
            csv_path="data/dataset.csv",
            threshold_multiplier=THRESHOLD_MULTIPLIER,
            device=DEVICE
        )
        
    except Exception as e:
        print(f"\n‚ùå Errore durante training di {connector_name}: {e}")
        import traceback
        traceback.print_exc()
        continue

print(f"\n{'='*60}")
print(f"‚úÖ Training completato per {len(trained_models)} connettori")
print(f"{'='*60}")


## Funzione di Caricamento Modello


In [None]:
def load_efficientad_model(connector_name, device=None):
    """
    Carica un modello EfficientAD-M addestrato per un connettore.
    Include anche il caricamento della spatial mask.
    
    Args:
        connector_name: Nome del connettore (es. 'conn1')
        device: Device (cuda/cpu)
    
    Returns:
        teacher: Modello Teacher (multi-layer)
        student: Modello Student (multi-layer)
        threshold: Threshold per anomaly detection
        spatial_mask_path: Path alla spatial mask (o None se non esiste)
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    models_dir = Path("models")
    model_path = models_dir / f"efficientad_student_{connector_name}.pth"
    threshold_path = models_dir / f"efficientad_threshold_{connector_name}.npy"
    spatial_mask_path = models_dir / f"spatial_mask_{connector_name}.npy"
    
    if not model_path.exists():
        raise FileNotFoundError(f"Modello non trovato: {model_path}")
    
    if not threshold_path.exists():
        raise FileNotFoundError(f"Threshold non trovato: {threshold_path}")
    
    # Carica Teacher (sempre lo stesso, pre-addestrato, multi-layer)
    teacher = Teacher().to(device)
    teacher.eval()
    
    # Carica Student (multi-layer)
    student = Student().to(device)
    student.load_state_dict(torch.load(model_path, map_location=device))
    student.eval()
    
    # Carica threshold
    threshold = np.load(threshold_path)
    
    # Check if spatial mask exists
    spatial_mask_exists = spatial_mask_path.exists()
    
    print(f"‚úÖ Modello EfficientAD-M caricato per {connector_name}")
    print(f"  Threshold: {threshold:.6f}")
    print(f"  Spatial mask: {'‚úÖ' if spatial_mask_exists else '‚ùå'}")
    
    return teacher, student, threshold, spatial_mask_path if spatial_mask_exists else None


## Visualization Utilities


In [None]:
def visualize_spatial_mask(connector_name=DEBUG_CONNECTOR, models_dir="models"):
    """
    Load spatial_mask_{connector_name}.npy from models_dir
    and show it with matplotlib:
    - grayscale image
    - with a colorbar
    """
    models_path = Path(models_dir)
    mask_path = models_path / f"spatial_mask_{connector_name}.npy"
    
    if not mask_path.exists():
        print(f"‚ùå Spatial mask not found: {mask_path}")
        print(f"   Run training first to compute the mask.")
        return
    
    mask = np.load(mask_path)
    
    plt.figure(figsize=(10, 8))
    plt.imshow(mask, cmap="viridis")
    plt.colorbar(label="Weight")
    plt.title(f"Spatial Weight Mask for {connector_name}")
    plt.xlabel("Width (pixels)")
    plt.ylabel("Height (pixels)")
    plt.tight_layout()
    plt.show()
    
    print(f"‚úÖ Spatial mask visualized for {connector_name}")
    print(f"   Shape: {mask.shape}")
    print(f"   Weight range: [{mask.min():.3f}, {mask.max():.3f}]")


In [None]:
def visualize_average_anomaly_map_for_connector(
    connector_name=DEBUG_CONNECTOR,
    csv_path="data/dataset.csv",
    max_ok_samples=50
):
    """
    For the given connector:
    - load up to max_ok_samples images with label == "OK"
    - for each image:
        * compute the fused anomaly map (before reducing to a scalar score),
          using the EXISTING pipeline: teacher features, student features,
          spatial mask, multi-layer fusion
        * upsample the fused anomaly map to image resolution
    - average all these maps -> average anomaly heatmap
    - visualize:
        * show the average anomaly map as a heatmap with matplotlib
    """
    import torch.nn.functional as F
    
    print(f"üìä Computing average anomaly map for {connector_name}...")
    
    # Load model
    try:
        teacher, student, threshold, spatial_mask_path = load_efficientad_model(connector_name, device=DEVICE)
    except Exception as e:
        print(f"‚ùå Error loading model: {e}")
        return
    
    # Load spatial mask
    try:
        spatial_mask_fullres = load_spatial_weight_mask(
            connector_name,
            target_size=(IMG_SIZE, IMG_SIZE),
            save_dir="models",
            device=DEVICE
        )
    except Exception as e:
        print(f"‚ö†Ô∏è  Error loading spatial mask: {e}")
        print(f"   Using uniform mask...")
        spatial_mask_fullres = torch.ones(1, 1, IMG_SIZE, IMG_SIZE, device=DEVICE)
    
    # Load OK images
    df = pd.read_csv(csv_path)
    ok_df = df[(df['label'] == 'OK') & (df['connector_name'] == connector_name)].copy()
    
    if len(ok_df) == 0:
        print(f"‚ùå No OK images found for {connector_name}")
        return
    
    # Limit to max_ok_samples
    ok_df = ok_df.head(max_ok_samples)
    print(f"   Processing {len(ok_df)} OK images...")
    
    # Transformations
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Accumulate fused maps
    teacher.eval()
    student.eval()
    all_fused_maps = []
    
    with torch.no_grad():
        for idx, row in tqdm(ok_df.iterrows(), total=len(ok_df), desc="  Computing maps"):
            image_path = row['image_path']
            
            try:
                # Load and preprocess image
                img = Image.open(image_path).convert('RGB')
                img_tensor = transform(img).unsqueeze(0).to(DEVICE)
                
                # Get features
                teacher_feats = teacher(img_tensor)
                student_feats = student(img_tensor)
                
                # Compute fused anomaly map
                fused_map = compute_fused_anomaly_map(
                    teacher_feats,
                    student_feats,
                    spatial_mask_fullres,
                    feature_layers=FEATURE_LAYERS,
                    feature_layer_weights=FEATURE_LAYER_WEIGHTS,
                    device=DEVICE
                )  # [1, 1, Href, Wref]
                
                # Upsample to image resolution
                fused_upsampled = F.interpolate(
                    fused_map,
                    size=(IMG_SIZE, IMG_SIZE),
                    mode='bilinear',
                    align_corners=False
                )  # [1, 1, IMG_SIZE, IMG_SIZE]
                
                all_fused_maps.append(fused_upsampled.cpu().numpy())
                
            except Exception as e:
                print(f"  ‚ö†Ô∏è  Error processing {image_path}: {e}")
                continue
    
    if len(all_fused_maps) == 0:
        print(f"‚ùå No valid maps computed")
        return
    
    # Average all maps
    avg_map = np.mean(all_fused_maps, axis=0)[0, 0]  # [IMG_SIZE, IMG_SIZE]
    
    # Visualize
    plt.figure(figsize=(10, 8))
    plt.imshow(avg_map, cmap="jet")
    plt.colorbar(label="Average Anomaly Score")
    plt.title(f"Average Anomaly Map for {connector_name} (OK images, n={len(all_fused_maps)})")
    plt.xlabel("Width (pixels)")
    plt.ylabel("Height (pixels)")
    plt.tight_layout()
    plt.show()
    
    print(f"‚úÖ Average anomaly map computed and visualized")
    print(f"   Map shape: {avg_map.shape}")
    print(f"   Score range: [{avg_map.min():.6f}, {avg_map.max():.6f}]")


In [None]:
def visualize_example_ok_ko_heatmaps(
    connector_name=DEBUG_CONNECTOR,
    csv_path="data/dataset.csv",
    ko_labels=("KO",),
    models_dir="models"
):
    """
    - Select one OK image for this connector (label == "OK")
    - Select one KO image for this connector (label in ko_labels),
      ignoring occlusion-related labels (like 'OCCLUSION' / 'PARTIAL OCCLUSION').
    - For each selected image:
        * load the original RGB image
        * compute fused anomaly map with the current pipeline
        * upsample the map to the original image resolution
        * normalize to [0,1]
        * visualize:
            - original image
            - anomaly heatmap alone
            - overlay: original image + semi-transparent anomaly heatmap
    """
    import torch.nn.functional as F
    
    print(f"üìä Visualizing OK/KO heatmaps for {connector_name}...")
    
    # Load model
    try:
        teacher, student, threshold, spatial_mask_path = load_efficientad_model(connector_name, device=DEVICE)
    except Exception as e:
        print(f"‚ùå Error loading model: {e}")
        return
    
    # Load spatial mask
    try:
        spatial_mask_fullres = load_spatial_weight_mask(
            connector_name,
            target_size=(IMG_SIZE, IMG_SIZE),
            save_dir="models",
            device=DEVICE
        )
    except Exception as e:
        print(f"‚ö†Ô∏è  Error loading spatial mask: {e}")
        print(f"   Using uniform mask...")
        spatial_mask_fullres = torch.ones(1, 1, IMG_SIZE, IMG_SIZE, device=DEVICE)
    
    # Load dataset
    df = pd.read_csv(csv_path)
    connector_df = df[df['connector_name'] == connector_name].copy()
    
    # Select one OK image
    ok_df = connector_df[connector_df['label'] == 'OK']
    if len(ok_df) == 0:
        print(f"‚ùå No OK images found for {connector_name}")
        return
    ok_row = ok_df.iloc[0]
    ok_path = ok_row['image_path']
    
    # Select one KO image (ignore OCCLUSION)
    ko_df = connector_df[connector_df['label'].isin(ko_labels)]
    if len(ko_df) == 0:
        print(f"‚ö†Ô∏è  No KO images found for {connector_name}, showing only OK")
        ko_path = None
    else:
        ko_row = ko_df.iloc[0]
        ko_path = ko_row['image_path']
    
    # Transformations
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    teacher.eval()
    student.eval()
    
    def process_image(image_path, label):
        """Process a single image and return original + anomaly map."""
        # Load original RGB image
        img_original = Image.open(image_path).convert('RGB')
        img_array = np.array(img_original)
        
        # Preprocess for model
        img_tensor = transform(img_original).unsqueeze(0).to(DEVICE)
        
        with torch.no_grad():
            # Get features
            teacher_feats = teacher(img_tensor)
            student_feats = student(img_tensor)
            
            # Compute fused anomaly map
            fused_map = compute_fused_anomaly_map(
                teacher_feats,
                student_feats,
                spatial_mask_fullres,
                feature_layers=FEATURE_LAYERS,
                feature_layer_weights=FEATURE_LAYER_WEIGHTS,
                device=DEVICE
            )  # [1, 1, Href, Wref]
            
            # Upsample to original image resolution
            h_orig, w_orig = img_array.shape[:2]
            fused_upsampled = F.interpolate(
                fused_map,
                size=(h_orig, w_orig),
                mode='bilinear',
                align_corners=False
            )  # [1, 1, h_orig, w_orig]
            
            # Convert to numpy and normalize to [0, 1]
            anomaly_map = fused_upsampled[0, 0].cpu().numpy()
            anomaly_map_norm = (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min() + 1e-8)
        
        return img_array, anomaly_map_norm
    
    # Process OK image
    print(f"  Processing OK image: {Path(ok_path).name}")
    ok_img, ok_map = process_image(ok_path, "OK")
    
    # Process KO image if available
    if ko_path:
        print(f"  Processing KO image: {Path(ko_path).name}")
        ko_img, ko_map = process_image(ko_path, "KO")
    
    # Visualize
    if ko_path:
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    else:
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        axes = axes.reshape(1, -1)
    
    # OK row
    axes[0, 0].imshow(ok_img)
    axes[0, 0].set_title(f"OK Image: {Path(ok_path).name}")
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(ok_map, cmap="jet")
    axes[0, 1].set_title("OK Anomaly Heatmap")
    axes[0, 1].axis('off')
    
    axes[0, 2].imshow(ok_img)
    axes[0, 2].imshow(ok_map, cmap="jet", alpha=0.5)
    axes[0, 2].set_title("OK Overlay")
    axes[0, 2].axis('off')
    
    # KO row (if available)
    if ko_path:
        axes[1, 0].imshow(ko_img)
        axes[1, 0].set_title(f"KO Image: {Path(ko_path).name}")
        axes[1, 0].axis('off')
        
        axes[1, 1].imshow(ko_map, cmap="jet")
        axes[1, 1].set_title("KO Anomaly Heatmap")
        axes[1, 1].axis('off')
        
        axes[1, 2].imshow(ko_img)
        axes[1, 2].imshow(ko_map, cmap="jet", alpha=0.5)
        axes[1, 2].set_title("KO Overlay")
        axes[1, 2].axis('off')
    
    plt.suptitle(f"Anomaly Heatmaps for {connector_name}", fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print(f"‚úÖ Heatmaps visualized for {connector_name}")


## Example Debug/Visualization Calls

The following functions can be called manually for debugging and interpretation:


In [None]:
# Example debug calls (run manually after training):
# 
# # Visualize spatial mask for DEBUG_CONNECTOR
# visualize_spatial_mask()
#
# # Visualize average anomaly map over OK images
# visualize_average_anomaly_map_for_connector()
#
# # Visualize example OK/KO heatmaps
# visualize_example_ok_ko_heatmaps()
#
# Note: Change DEBUG_CONNECTOR at the top if you want to visualize a different connector
