In [13]:
import os
import logging
import time
import json
from pathlib import Path
import glob
import warnings
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import torch.backends.cudnn as cudnn
from torch.nn import functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
# Setup logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Suppress specific FutureWarning for GradScaler
warnings.filterwarnings("ignore", category=FutureWarning)
# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
class Config:
    # Dataset parameters
    dataset_path = "./mvtec_anomaly_detection"
    categories = [
        "bottle", "cable", "capsule", "carpet", "grid",
        "hazelnut", "leather", "metal_nut", "pill", "screw",
        "tile", "toothbrush", "transistor", "wood", "zipper"
    ]
    material_properties = {
        "bottle": ["transparent", "rigid", "container", "curved surface"],
        "cable": ["flexible", "elongated", "connector", "insulated"],
        "capsule": ["small", "cylindrical", "pharmaceutical", "colorful"],
        "carpet": ["textured", "flat", "fabric", "patterned"],
        "grid": ["regular pattern", "structured", "geometric", "repetitive"],
        "hazelnut": ["organic", "natural", "edible", "oval shaped"],
        "leather": ["textured", "flexible", "natural material", "durable"],
        "metal_nut": ["metallic", "threaded", "hardware", "circular"],
        "pill": ["medical", "small", "uniform", "pharmaceutical"],
        "screw": ["metallic", "threaded", "hardware", "cylindrical"],
        "tile": ["flat", "hard", "ceramic", "uniform"],
        "toothbrush": ["plastic", "bristled", "handheld", "hygienic"],
        "transistor": ["electronic", "semiconductor", "small", "technical"],
        "wood": ["natural", "grain patterned", "organic", "fibrous"],
        "zipper": ["metal", "plastic", "fastener", "interlocking"]
    }
    generic_anomaly_types = [
        "scratched", "broken", "deformed", "contaminated", 
        "discolored", "cracked", "damaged", "misshapen",
        "stained", "perforated", "corroded", "bent"
    ]
    specific_anomaly_types = {
        "bottle": ["dented", "leaking", "chipped", "malformed neck"],
        "cable": ["exposed wire", "frayed", "kinked", "connector damage"],
        "capsule": ["crushed", "empty", "incomplete filling", "color variation"],
        "carpet": ["worn", "torn", "faded", "stained"],
        "grid": ["distorted", "misaligned", "broken lines", "irregular spacing"],
        "hazelnut": ["moldy", "shriveled", "discolored", "pest damaged"],
        "leather": ["scratched", "stained", "wrinkled", "uneven texture"],
        "metal_nut": ["rusted", "cracked", "thread damage", "deformed"],
        "pill": ["chipped", "split", "faded", "stained"],
        "screw": ["stripped head", "bent", "rusted", "shortened"],
        "tile": ["chipped", "cracked", "rough surface", "uneven color"],
        "toothbrush": ["bent bristles", "missing bristles", "discolored", "malformed handle"],
        "transistor": ["bent pins", "cracked casing", "missing components", "burn marks"],
        "wood": ["splintered", "rotted", "knotted", "warped"],
        "zipper": ["missing teeth", "detached", "stuck", "bent teeth"]
    }
    target_size = (224, 224)  # Image size for model input
    batch_size = 32
    num_epochs = 10
    learning_rate = 1e-4
    weight_decay = 1e-5
    clip_model_name = "openai/clip-vit-base-patch32"
    embedding_dim = 512
    margin = 0.5
    temperature = 0.07
    anomaly_threshold = 0.5
    vis_save_dir = "visualization_results"
    checkpoint_dir = "checkpoints"
    normal_prompt_templates = [
        "a photo of normal {category}",
        "a normal {category} without defects",
        "a pristine {category} in perfect condition",
        "a flawless {category} with no anomalies"
    ]
    anomaly_prompt_templates = [
        "a photo of {category} with {anomaly_type}",
        "a defective {category} showing {anomaly_type}",
        "a {category} that has {anomaly_type} areas",
        "a damaged {category} with visible {anomaly_type}"
    ]
    material_prompt_templates = [
        "a {material_property} {category} in normal condition",
        "a {material_property} {category} without defects"
    ]
    material_anomaly_templates = [
        "a {material_property} {category} with {anomaly_type}",
        "a {material_property} {category} showing {anomaly_type}"
    ]
    use_fp16 = True  # Use mixed precision training
    num_workers = 0  # Set to 0 for debugging DataLoader issues
class MVTecDataset(Dataset):
    def __init__(self, root_dir, category, is_train=True, transform=None):
        self.root_dir = root_dir
        self.category = category
        self.is_train = is_train
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.mask_paths = []
        self.defect_types = []
        logger.info(f"{'Training' if is_train else 'Testing'} dataset for category: {category}")
        if is_train:
            train_good_path = os.path.join(root_dir, category, "train", "good")
            logger.info(f"Searching for good samples in: {train_good_path}")
            self.image_paths = sorted(glob.glob(os.path.join(train_good_path, "*.png")))
            logger.info(f"Found {len(self.image_paths)} training images.")
            if len(self.image_paths) == 0:
                raise ValueError(f"No training images found for category '{category}'")
            self.labels = [0] * len(self.image_paths)
            self.mask_paths = [None] * len(self.image_paths)
            self.defect_types = ["none"] * len(self.image_paths)
        else:
            test_good_path = os.path.join(root_dir, category, "test", "good")
            logger.info(f"Searching for test good samples in: {test_good_path}")
            good_paths = sorted(glob.glob(os.path.join(test_good_path, "*.png")))
            self.image_paths.extend(good_paths)
            self.labels.extend([0] * len(good_paths))
            self.defect_types.extend(["none"] * len(good_paths))
            defect_types = [
                d for d in os.listdir(os.path.join(root_dir, category, "test"))
                if d != "good" and os.path.isdir(os.path.join(root_dir, category, "test", d))
            ]
            logger.info(f"Found defect types for {category}: {defect_types}")
            for defect_type in defect_types:
                defect_path = os.path.join(root_dir, category, "test", defect_type)
                defect_paths = sorted(glob.glob(os.path.join(defect_path, "*.png")))
                logger.info(f"Found {len(defect_paths)} defective samples for {defect_type}")
                self.image_paths.extend(defect_paths)
                self.labels.extend([1] * len(defect_paths))
                self.defect_types.extend([defect_type] * len(defect_paths))
                for img_path in defect_paths:
                    img_filename = os.path.basename(img_path)
                    mask_path = os.path.join(
                        root_dir, category, "ground_truth", defect_type, img_filename
                    )
                    mask_with_suffix = os.path.splitext(mask_path)[0] + "_mask.png"
                    if os.path.exists(mask_path):
                        self.mask_paths.append(mask_path)
                    elif os.path.exists(mask_with_suffix):
                        self.mask_paths.append(mask_with_suffix)
                    else:
                        self.mask_paths.append(None)
                        logger.warning(f"No mask found for {img_path}")
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        try:
            image_path = self.image_paths[idx]
            image = Image.open(image_path).convert("RGB")
            assert self.transform is not None, "Transform must be provided!"
            image = self.transform(image)
            label = self.labels[idx]
            defect_type = self.defect_types[idx]
            mask_path = self.mask_paths[idx]
            mask = None
            if mask_path and os.path.exists(mask_path):
                mask = Image.open(mask_path).convert("L")
                mask = transforms.ToTensor()(mask)
            else:
                mask = torch.zeros((1, *Config.target_size))  # Placeholder
            return {
                "image": image,
                "label": label,
                "mask": mask,
                "image_path": image_path,
                "defect_type": defect_type
            }
        except Exception as e:
            logger.error(f"Error loading sample {idx} from {self.image_paths[idx]}: {str(e)}")
            raise
class PromptBasedAnomalyDetector(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-base-patch32"):
        super(PromptBasedAnomalyDetector, self).__init__()
        self.clip_model = CLIPModel.from_pretrained(clip_model_name)
        self.tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)
        self.text_encoder = self.clip_model.text_model
        self.vision_encoder = self.clip_model.vision_model
        self.text_proj = self.clip_model.text_projection
        self.vision_proj = self.clip_model.visual_projection
        for param in self.clip_model.parameters():
            param.requires_grad = False
        num_normal_templates = len(Config.normal_prompt_templates)
        num_anomaly_templates = len(Config.anomaly_prompt_templates)
        num_material_templates = len(Config.material_prompt_templates)
        num_material_anomaly_templates = len(Config.material_anomaly_templates)
        self.normal_prompt_embeddings = nn.Parameter(
            torch.randn(num_normal_templates, self.text_encoder.config.hidden_size, device=device)
        )
        self.anomaly_prompt_embeddings = nn.Parameter(
            torch.randn(num_anomaly_templates, self.text_encoder.config.hidden_size, device=device)
        )
        self.material_prompt_embeddings = nn.Parameter(
            torch.randn(num_material_templates, self.text_encoder.config.hidden_size, device=device)
        )
        self.material_anomaly_embeddings = nn.Parameter(
            torch.randn(num_material_anomaly_templates, self.text_encoder.config.hidden_size, device=device)
        )
        self.normal_weights = nn.Parameter(torch.ones(num_normal_templates, device=device))
        self.anomaly_weights = nn.Parameter(torch.ones(num_anomaly_templates, device=device))
        self.material_weights = nn.Parameter(torch.ones(num_material_templates, device=device))
        self.material_anomaly_weights = nn.Parameter(torch.ones(num_material_anomaly_templates, device=device))
        self.vit_layer = self.vision_encoder
        self.anomaly_localization_head = nn.Sequential(
            nn.Conv2d(self.vision_encoder.config.hidden_size, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 1, kernel_size=1)
        ).to(device)
    def encode_text(self, text):
        text_inputs = self.tokenizer(text, padding=True, return_tensors="pt").to(device)
        text_outputs = self.text_encoder(**text_inputs)
        text_embeddings = text_outputs.last_hidden_state[:, 0, :]
        text_embeddings = self.text_proj(text_embeddings)
        return text_embeddings
    def encode_image(self, image):
        vision_outputs = self.vision_encoder(image)
        image_embeddings = vision_outputs.last_hidden_state[:, 0, :]
        image_embeddings = self.vision_proj(image_embeddings)
        patch_embeddings = vision_outputs.last_hidden_state[:, 1:, :]
        batch_size = image.shape[0]
        patch_size = 32
        h = w = image.shape[2] // patch_size
        patch_features = patch_embeddings.reshape(batch_size, h, w, -1).permute(0, 3, 1, 2)
        return image_embeddings, patch_features
    def get_category_specific_anomalies(self, category):
        if category in Config.specific_anomaly_types:
            return Config.specific_anomaly_types[category] + Config.generic_anomaly_types
        return Config.generic_anomaly_types
    def get_material_properties(self, category):
        return Config.material_properties.get(category, [])
    def generate_prompts(self, category_name):
        normal_embeddings = []
        normal_weights = F.softmax(self.normal_weights, dim=0)
        for i, template in enumerate(Config.normal_prompt_templates):
            prompt = template.format(category=category_name)
            embedding = self.encode_text([prompt])
            normal_embeddings.append(embedding + self.normal_prompt_embeddings[i])
        material_weights = F.softmax(self.material_weights, dim=0)
        for i, template in enumerate(Config.material_prompt_templates):
            for prop in self.get_material_properties(category_name):
                prompt = template.format(category=category_name, material_property=prop)
                embedding = self.encode_text([prompt])
                normal_embeddings.append(embedding + self.material_prompt_embeddings[i])
        anomaly_embeddings = []
        anomaly_weights = F.softmax(self.anomaly_weights, dim=0)
        anomaly_types = self.get_category_specific_anomalies(category_name)
        for i, template in enumerate(Config.anomaly_prompt_templates):
            for anomaly_type in anomaly_types:
                prompt = template.format(category=category_name, anomaly_type=anomaly_type)
                embedding = self.encode_text([prompt])
                anomaly_embeddings.append(embedding + self.anomaly_prompt_embeddings[i])
        material_anomaly_weights = F.softmax(self.material_anomaly_weights, dim=0)
        for i, template in enumerate(Config.material_anomaly_templates):
            for prop in self.get_material_properties(category_name):
                for anomaly_type in anomaly_types[:2]:
                    prompt = template.format(category=category_name, material_property=prop, anomaly_type=anomaly_type)
                    embedding = self.encode_text([prompt])
                    anomaly_embeddings.append(embedding + self.material_anomaly_embeddings[i])
        normal_embedding = torch.cat(normal_embeddings).mean(dim=0, keepdim=True) if normal_embeddings else self.encode_text([f"a photo of normal {category_name}"])
        anomaly_embedding = torch.cat(anomaly_embeddings).mean(dim=0, keepdim=True) if anomaly_embeddings else self.encode_text([f"a photo of {category_name} with anomaly"])
        return normal_embedding, anomaly_embedding
    def forward(self, image, category_name):
        image_embedding, patch_features = self.encode_image(image)
        normal_embedding, anomaly_embedding = self.generate_prompts(category_name)
        normal_similarity = F.cosine_similarity(image_embedding, normal_embedding)
        anomaly_similarity = F.cosine_similarity(image_embedding, anomaly_embedding)
        anomaly_map = self.anomaly_localization_head(patch_features)
        anomaly_map = F.interpolate(anomaly_map, size=Config.target_size, mode='bilinear', align_corners=False)
        return {
            'normal_similarity': normal_similarity,
            'anomaly_similarity': anomaly_similarity,
            'anomaly_score': anomaly_similarity - normal_similarity,
            'anomaly_map': anomaly_map,
            'image_embedding': image_embedding,
            'normal_embedding': normal_embedding,
            'anomaly_embedding': anomaly_embedding,
            'patch_features': patch_features
        }
def train_model(model, category, train_loader, val_loader=None, num_epochs=10):
    optimizer = optim.AdamW(model.parameters(), lr=Config.learning_rate, weight_decay=Config.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    scaler = GradScaler() if Config.use_fp16 else None
    best_val_loss = float('inf')
    augmentation = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.1, contrast=0.1)
    ])
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch in progress_bar:
            images = batch['image'].to(device, non_blocking=True)
            augmented_images = augmentation(images)
            with autocast(enabled=Config.use_fp16):
                outputs = model(images, category)
                aug_outputs = model(augmented_images, category)
                contrastive_loss = F.relu(Config.margin - (outputs['anomaly_similarity'] - outputs['normal_similarity'])).mean()
                consistency_loss = F.mse_loss(outputs['image_embedding'], aug_outputs['image_embedding'])
                loss = contrastive_loss + 0.5 * consistency_loss
            if Config.use_fp16:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            optimizer.zero_grad()
            train_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item(), "contrastive": contrastive_loss.item(), "consistency": consistency_loss.item()})
        avg_train_loss = train_loss / len(train_loader)
        logger.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}")
        scheduler.step()
        if val_loader:
            val_loss = validate_model(model, category, val_loader)
            logger.info(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}")
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                save_checkpoint(model, optimizer, epoch, val_loss, f"{Config.checkpoint_dir}/{category}_best_model.pth")
        else:
            save_checkpoint(model, optimizer, epoch, avg_train_loss, f"{Config.checkpoint_dir}/{category}_epoch_{epoch+1}.pth")
    save_checkpoint(model, optimizer, num_epochs - 1, avg_train_loss, f"{Config.checkpoint_dir}/{category}_final_model.pth")
    return model
def validate_model(model, category, val_loader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            images = batch['image'].to(device, non_blocking=True)
            outputs = model(images, category)
            loss = F.relu(Config.margin - (outputs['anomaly_similarity'] - outputs['normal_similarity'])).mean()
            total_loss += loss.item()
    return total_loss / len(val_loader)
def save_checkpoint(model, optimizer, epoch, loss, filename):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss
    }, filename)
    logger.info(f"Checkpoint saved to {filename}")
def evaluate_model(model, category, test_loader):
    model.eval()
    image_scores = []
    image_labels = []
    pixel_scores = []
    pixel_labels = []
    defect_specific_scores = {}
    os.makedirs(os.path.join(Config.vis_save_dir, category), exist_ok=True)
    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader, desc=f"Evaluating {category}")):
            images = batch['image'].to(device, non_blocking=True)
            labels = batch['label'].numpy()
            masks = batch['mask'].numpy()
            image_paths = batch['image_path']
            defect_types = batch['defect_type']
            outputs = model(images, category)
            anomaly_scores = outputs['anomaly_score'].cpu().numpy()
            anomaly_maps = outputs['anomaly_map'].squeeze(1).cpu().numpy()
            image_scores.extend(anomaly_scores)
            image_labels.extend(labels)
            for idx, defect_type in enumerate(defect_types):
                if defect_type not in defect_specific_scores:
                    defect_specific_scores[defect_type] = {"scores": [], "labels": []}
                defect_specific_scores[defect_type]["scores"].append(anomaly_scores[idx])
                defect_specific_scores[defect_type]["labels"].append(labels[idx])
            for b in range(images.shape[0]):
                if labels[b] == 1 and masks[b].sum() > 0:
                    pixel_scores.append(anomaly_maps[b].flatten())
                    pixel_labels.append(masks[b].flatten())
    from sklearn.metrics import roc_auc_score
    image_roc_auc = roc_auc_score(image_labels, image_scores)
    pixel_roc_auc = roc_auc_score([item for sublist in pixel_labels for item in sublist],
                                  [item for sublist in pixel_scores for item in sublist])
    logger.info(f"Image-level ROC AUC: {image_roc_auc:.4f}")
    logger.info(f"Pixel-level ROC AUC: {pixel_roc_auc:.4f}")
    defect_specific_aucs = {}
    for defect_type, data in defect_specific_scores.items():
        if defect_type != "none" and len(set(data["labels"])) > 1:
            try:
                defect_specific_aucs[defect_type] = roc_auc_score(data["labels"], data["scores"])
                logger.info(f"  {defect_type} AUC: {defect_specific_aucs[defect_type]:.4f}")
            except:
                logger.warning(f"Could not calculate AUC for defect type {defect_type}")
    metrics = {
        'category': category,
        'image_roc_auc': float(image_roc_auc),
        'pixel_roc_auc': float(pixel_roc_auc),
        'defect_specific_aucs': {k: float(v) for k, v in defect_specific_aucs.items()}
    }
    with open(os.path.join(Config.vis_save_dir, f"{category}_metrics.json"), 'w') as f:
        json.dump(metrics, f, indent=4)
    return metrics
def get_transforms():
    train_transform = transforms.Compose([
        transforms.Resize(Config.target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    test_transform = transforms.Compose([
        transforms.Resize(Config.target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return train_transform, test_transform
def main():
    start_time = time.time()
    logger.info("Starting improved anomaly detection pipeline...")
    if not torch.cuda.is_available():
        logger.warning("CUDA is not available. Using CPU, which will be much slower!")
    all_metrics = {}
    os.makedirs(Config.vis_save_dir, exist_ok=True)
    os.makedirs(Config.checkpoint_dir, exist_ok=True)
    for category in Config.categories:
        logger.info(f"Processing category: {category}")
        train_transform, test_transform = get_transforms()
        try:
            train_dataset = MVTecDataset(Config.dataset_path, category, is_train=True, transform=train_transform)
            test_dataset = MVTecDataset(Config.dataset_path, category, is_train=False, transform=test_transform)
            train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True,
                                      num_workers=0, pin_memory=True, persistent_workers=False)
            test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False,
                                     num_workers=0, pin_memory=True, persistent_workers=False)
            model = PromptBasedAnomalyDetector(Config.clip_model_name).to(device)
            logger.info(f"Model device: {next(model.parameters()).device}")
            logger.info(f"Training model for category: {category}")
            model = train_model(model, category, train_loader)
            logger.info(f"Evaluating model for category: {category}")
            metrics = evaluate_model(model, category, test_loader)
            all_metrics[category] = metrics
        except Exception as e:
            logger.error(f"Failed to process category {category}: {str(e)}")
            continue
    overall_metrics = {
        "overall_avg_image_auc": sum(m['image_roc_auc'] for m in all_metrics.values()) / len(all_metrics) if all_metrics else 0,
        "overall_avg_pixel_auc": sum(m['pixel_roc_auc'] for m in all_metrics.values()) / len(all_metrics) if all_metrics else 0,
        "categories": list(all_metrics.keys()),
        "category_performance": {cat: {"image_auc": metrics["image_roc_auc"], "pixel_auc": metrics["pixel_roc_auc"]} for cat, metrics in all_metrics.items()}
    }
    defect_type_metrics = {}
    for category, metrics in all_metrics.items():
        if "defect_specific_aucs" in metrics:
            for defect_type, auc in metrics["defect_specific_aucs"].items():
                if defect_type not in defect_type_metrics:
                    defect_type_metrics[defect_type] = []
                defect_type_metrics[defect_type].append(auc)
    overall_metrics["defect_type_performance"] = {
        defect_type: sum(aucs)/len(aucs) for defect_type, aucs in defect_type_metrics.items() if aucs
    }
    with open(os.path.join(Config.vis_save_dir, "overall_metrics.json"), 'w') as f:
        json.dump(overall_metrics, f, indent=4)
    elapsed_time = (time.time() - start_time) / 60.0
    logger.info(f"Pipeline completed in {elapsed_time:.2f} minutes.")
if __name__ == "__main__":
    if torch.cuda.device_count() > 1:
        logger.info(f"Using {torch.cuda.device_count()} GPUs!")
    cudnn.benchmark = True
    main()

2025-05-12 07:35:18,821 - __main__ - INFO - Using device: cuda
2025-05-12 07:35:18,826 - __main__ - INFO - Starting improved anomaly detection pipeline...
2025-05-12 07:35:18,827 - __main__ - INFO - Processing category: bottle
2025-05-12 07:35:18,828 - __main__ - INFO - Training dataset for category: bottle
2025-05-12 07:35:18,829 - __main__ - INFO - Searching for good samples in: ./mvtec_anomaly_detection\bottle\train\good
2025-05-12 07:35:18,831 - __main__ - INFO - Found 209 training images.
2025-05-12 07:35:18,832 - __main__ - INFO - Testing dataset for category: bottle
2025-05-12 07:35:18,832 - __main__ - INFO - Searching for test good samples in: ./mvtec_anomaly_detection\bottle\test\good
2025-05-12 07:35:18,834 - __main__ - INFO - Found defect types for bottle: ['broken_large', 'broken_small', 'contamination']
2025-05-12 07:35:18,835 - __main__ - INFO - Found 20 defective samples for broken_large
2025-05-12 07:35:18,837 - __main__ - INFO - Found 22 defective samples for broken_sm

KeyboardInterrupt: 

In [23]:
import os
import logging
import time
import json
from pathlib import Path
import glob
import warnings
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import torch.backends.cudnn as cudnn
from torch.nn import functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score
import numpy as np

# Setup logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=FutureWarning)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

class Config:
    # Dataset parameters
    dataset_path = "./mvtec_anomaly_detection"
    categories = [
        "bottle", "cable", "capsule", "carpet", "grid",
        "hazelnut", "leather", "metal_nut", "pill", "screw",
        "tile", "toothbrush", "transistor", "wood", "zipper"
    ]
    material_properties = {
        "bottle": ["transparent", "rigid", "container", "curved surface"],
        "cable": ["flexible", "elongated", "connector", "insulated"],
        "capsule": ["small", "cylindrical", "pharmaceutical", "colorful"],
        "carpet": ["textured", "flat", "fabric", "patterned"],
        "grid": ["regular pattern", "structured", "geometric", "repetitive"],
        "hazelnut": ["organic", "natural", "edible", "oval shaped"],
        "leather": ["textured", "flexible", "natural material", "durable"],
        "metal_nut": ["metallic", "threaded", "hardware", "circular"],
        "pill": ["medical", "small", "uniform", "pharmaceutical"],
        "screw": ["metallic", "threaded", "hardware", "cylindrical"],
        "tile": ["flat", "hard", "ceramic", "uniform"],
        "toothbrush": ["plastic", "bristled", "handheld", "hygienic"],
        "transistor": ["electronic", "semiconductor", "small", "technical"],
        "wood": ["natural", "grain patterned", "organic", "fibrous"],
        "zipper": ["metal", "plastic", "fastener", "interlocking"]
    }
    generic_anomaly_types = [
        "scratched", "broken", "deformed", "contaminated",
        "discolored", "cracked", "damaged", "misshapen",
        "stained", "perforated", "corroded", "bent"
    ]
    specific_anomaly_types = {
        "bottle": ["dented", "leaking", "chipped", "malformed neck"],
        "cable": ["exposed wire", "frayed", "kinked", "connector damage"],
        "capsule": ["crushed", "empty", "incomplete filling", "color variation"],
        "carpet": ["worn", "torn", "faded", "stained"],
        "grid": ["distorted", "misaligned", "broken lines", "irregular spacing"],
        "hazelnut": ["moldy", "shriveled", "discolored", "pest damaged"],
        "leather": ["scratched", "stained", "wrinkled", "uneven texture"],
        "metal_nut": ["rusted", "cracked", "thread damage", "deformed"],
        "pill": ["chipped", "split", "faded", "stained"],
        "screw": ["stripped head", "bent", "rusted", "shortened"],
        "tile": ["chipped", "cracked", "rough surface", "uneven color"],
        "toothbrush": ["bent bristles", "missing bristles", "discolored", "malformed handle"],
        "transistor": ["bent pins", "cracked casing", "missing components", "burn marks"],
        "wood": ["splintered", "rotted", "knotted", "warped"],
        "zipper": ["missing teeth", "detached", "stuck", "bent teeth"]
    }
    target_size = (224, 224)
    batch_size = 32
    num_epochs = 10
    learning_rate = 1e-4
    weight_decay = 1e-5
    clip_model_name = "openai/clip-vit-base-patch32"
    embedding_dim = 512
    margin = 0.5
    temperature = 0.07
    anomaly_threshold = 0.5
    vis_save_dir = "visualization_results"
    checkpoint_dir = "checkpoints"

    normal_prompt_templates = [
        "a photo of normal {category}",
        "a normal {category} without defects",
        "a pristine {category} in perfect condition",
        "a flawless {category} with no anomalies"
    ]

    anomaly_prompt_templates = [
        "a photo of {category} with {anomaly_type}",
        "a defective {category} showing {anomaly_type}",
        "a {category} that has {anomaly_type} areas",
        "a damaged {category} with visible {anomaly_type}"
    ]

    use_fp16 = True
    num_workers = 0


class MVTecDataset(Dataset):
    def __init__(self, root_dir, category, is_train=True, transform=None):
        self.root_dir = root_dir
        self.category = category
        self.is_train = is_train
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.mask_paths = []
        self.defect_types = []

        logger.info(f"{'Training' if is_train else 'Testing'} dataset for category: {category}")
        if is_train:
            train_good_path = os.path.join(root_dir, category, "train", "good")
            logger.info(f"Searching for good samples in: {train_good_path}")
            self.image_paths = sorted(glob.glob(os.path.join(train_good_path, "*.png")))
            logger.info(f"Found {len(self.image_paths)} training images.")
            if len(self.image_paths) == 0:
                raise ValueError(f"No training images found for category '{category}'")
            self.labels = [0] * len(self.image_paths)
            self.mask_paths = [None] * len(self.image_paths)
            self.defect_types = ["none"] * len(self.image_paths)
        else:
            test_good_path = os.path.join(root_dir, category, "test", "good")
            logger.info(f"Searching for test good samples in: {test_good_path}")
            good_paths = sorted(glob.glob(os.path.join(test_good_path, "*.png")))
            self.image_paths.extend(good_paths)
            self.labels.extend([0] * len(good_paths))
            self.defect_types.extend(["none"] * len(good_paths))

            defect_types = [
                d for d in os.listdir(os.path.join(root_dir, category, "test"))
                if d != "good" and os.path.isdir(os.path.join(root_dir, category, "test", d))
            ]
            logger.info(f"Found defect types for {category}: {defect_types}")

            for defect_type in defect_types:
                defect_path = os.path.join(root_dir, category, "test", defect_type)
                defect_paths = sorted(glob.glob(os.path.join(defect_path, "*.png")))
                logger.info(f"Found {len(defect_paths)} defective samples for {defect_type}")
                self.image_paths.extend(defect_paths)
                self.labels.extend([1] * len(defect_paths))
                self.defect_types.extend([defect_type] * len(defect_paths))

                for img_path in defect_paths:
                    img_filename = os.path.basename(img_path)
                    mask_path = os.path.join(
                        root_dir, category, "ground_truth", defect_type, img_filename
                    )
                    mask_with_suffix = os.path.splitext(mask_path)[0] + "_mask.png"
                    if os.path.exists(mask_path):
                        self.mask_paths.append(mask_path)
                    elif os.path.exists(mask_with_suffix):
                        self.mask_paths.append(mask_with_suffix)
                    else:
                        self.mask_paths.append(None)
                        logger.warning(f"No mask found for {img_path}. Using placeholder.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            image_path = self.image_paths[idx]
            image = Image.open(image_path).convert("RGB")
            assert self.transform is not None, "Transform must be provided!"
            image = self.transform(image)
            label = self.labels[idx]
            defect_type = self.defect_types[idx]
            mask_path = self.mask_paths[idx]
            mask = None
            if mask_path and os.path.exists(mask_path):
                mask = Image.open(mask_path).convert("L")
                mask = transforms.Resize(Config.target_size)(mask)
                mask = transforms.ToTensor()(mask)
            else:
                mask = torch.zeros((1, *Config.target_size))  # Placeholder
            return {
                "image": image,
                "label": label,
                "mask": mask,
                "image_path": image_path,
                "defect_type": defect_type
            }
        except Exception as e:
            logger.error(f"Error loading sample {idx} from {self.image_paths[idx]}: {str(e)}")
            raise


class PromptLearner(nn.Module):
    def __init__(self, dim, prompt_len=5):
        super().__init__()
        self.prefix = nn.Parameter(torch.randn(1, prompt_len, dim))
        self.suffix = nn.Parameter(torch.randn(1, prompt_len, dim))

    def forward(self, text_emb):
        prefix = self.prefix.expand(text_emb.shape[0], -1, -1)
        suffix = self.suffix.expand(text_emb.shape[0], -1, -1)
        return torch.cat([prefix, text_emb, suffix], dim=1)


class PromptBasedAnomalyDetector(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-base-patch32"):
        super(PromptBasedAnomalyDetector, self).__init__()
        self.clip_model = CLIPModel.from_pretrained(clip_model_name)
        self.tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)
        self.text_encoder = self.clip_model.text_model
        self.vision_encoder = self.clip_model.vision_model
        self.text_proj = self.clip_model.text_projection
        self.vision_proj = self.clip_model.visual_projection

        for param in self.clip_model.parameters():
            param.requires_grad = False

        self.prompt_learner = PromptLearner(self.text_encoder.config.hidden_size)

        self.normal_prompt_embeddings = nn.ParameterList([
            nn.Parameter(torch.randn(1, self.text_encoder.config.hidden_size, device=device))
            for _ in range(len(Config.normal_prompt_templates))
        ])

        self.anomaly_prompt_embeddings = nn.ParameterList([
            nn.Parameter(torch.randn(1, self.text_encoder.config.hidden_size, device=device))
            for _ in range(len(Config.anomaly_prompt_templates))
        ])

        self.vit_layer = self.vision_encoder
        self.anomaly_localization_head = nn.Sequential(
            nn.Conv2d(self.vision_encoder.config.hidden_size, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 1, kernel_size=1)
        ).to(device)

    def encode_text(self, text):
        text_inputs = self.tokenizer(text, padding=True, return_tensors="pt").to(device)
        text_outputs = self.text_encoder(**text_inputs)
        text_embeddings = text_outputs.last_hidden_state[:, 0, :]
        text_embeddings = self.text_proj(text_embeddings)
        return text_embeddings

    def encode_image(self, image):
        vision_outputs = self.vision_encoder(image)
        image_embeddings = vision_outputs.last_hidden_state[:, 0, :]
        image_embeddings = self.vision_proj(image_embeddings)
        patch_embeddings = vision_outputs.last_hidden_state[:, 1:, :]
        batch_size = image.shape[0]
        patch_size = 32
        h = w = image.shape[2] // patch_size
        patch_features = patch_embeddings.reshape(batch_size, h, w, -1).permute(0, 3, 1, 2)
        return image_embeddings, patch_features

    def get_category_specific_anomalies(self, category):
        return Config.specific_anomaly_types.get(category, []) + Config.generic_anomaly_types

    def generate_prompts(self, category_name):
        normal_prompts = [tpl.format(category=category_name) for tpl in Config.normal_prompt_templates]
        normal_embeddings = [prompt_emb + self.encode_text([prompt]) for prompt, prompt_emb in zip(normal_prompts, self.normal_prompt_embeddings)]
        normal_embedding = torch.stack(normal_embeddings).mean(dim=0)

        anomaly_types = self.get_category_specific_anomalies(category_name)
        anomaly_prompts = [tpl.format(category=category_name, anomaly_type=at) for tpl in Config.anomaly_prompt_templates for at in anomaly_types]
        anomaly_embeddings = [prompt_emb + self.encode_text([prompt]) for prompt, prompt_emb in zip(anomaly_prompts[:4], self.anomaly_prompt_embeddings)]
        anomaly_embedding = torch.stack(anomaly_embeddings).mean(dim=0)

        return normal_embedding, anomaly_embedding

def forward(self, image, category):
    image_embedding, patch_features = self.encode_image(image)

    # Vision-to-Vision prompting: Add learnable visual prompts to patch embeddings
    patched_with_prompts = self.prompt_learner(patch_features)

    # Compute similarity between image embedding and prompted embeddings
    normal_similarity = F.cosine_similarity(image_embedding, self.normal_embedding)
    anomaly_similarity = F.cosine_similarity(image_embedding, self.anomaly_embedding)

    # Generate anomaly map using semantic concatenation in localization head
    anomaly_map = self.anomaly_localization_head(patch_features, patched_with_prompts)
    anomaly_map = F.interpolate(anomaly_map, size=Config.target_size, mode='bilinear', align_corners=False)

    # Normalize anomaly map to [0, 1] using sigmoid
    anomaly_score_map = torch.sigmoid(anomaly_map).squeeze(1)  # Shape: [B, H, W]
    
    # Global anomaly score per image
    anomaly_score = anomaly_score_map.mean(dim=(1, 2))  # Shape: [B]

    return {
        'normal_similarity': normal_similarity,
        'anomaly_similarity': anomaly_similarity,
        'anomaly_score': anomaly_score,         # Image-level anomaly score
        'anomaly_map': anomaly_score_map,       # Pixel-level anomaly map [B, H, W]
        'image_embedding': image_embedding
    }

def contrastive_eam_loss(normal_sim, anomaly_sim, labels, margin=0.5):
    logits = torch.stack([normal_sim, anomaly_sim], dim=1)
    targets = labels.long()
    ce_loss = F.cross_entropy(logits / Config.temperature, targets)
    alignment_loss = F.relu(Config.margin - (anomaly_sim - normal_sim)).mean()
    return ce_loss + alignment_loss


def visualize_results(image, anomaly_map, mask, score, path, category, defect_type):
    plt.figure(figsize=(10, 5))
    
    # Denormalize image
    image = image.permute(1, 2, 0).cpu().numpy()
    image = (image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406])  # Denormalize
    image = np.clip(image, 0, 1)
    
    plt.subplot(1, 3, 1)
    plt.title(f"Image\nDefect Type: {defect_type}")
    plt.imshow(image)
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.title(f"Anomaly Map\n(Score: {score:.2f})")
    plt.imshow(anomaly_map, cmap='jet',  vmin=0, vmax=1)
    plt.colorbar()
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.title("Ground Truth Mask")
    plt.imshow(mask[0].numpy(), cmap='gray')
    plt.axis('off')

    plt.tight_layout()
    save_path = os.path.join(Config.vis_save_dir, category)
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, os.path.basename(path)))
    plt.close()


def compute_metrics(image_scores, image_labels, pixel_scores, pixel_labels):
    image_roc_auc = roc_auc_score(image_labels, image_scores)
    pixel_roc_auc = roc_auc_score([item for sublist in pixel_labels for item in sublist],
                                   [item for sublist in pixel_scores for item in sublist])

    pixel_preds = (pixel_roc_auc > Config.anomaly_threshold).astype(int)
    pixel_f1 = f1_score([item for sublist in pixel_labels for item in sublist], pixel_preds)
    pixel_acc = accuracy_score([item for sublist in pixel_labels for item in sublist], pixel_preds)

    return {
        'image_roc_auc': float(image_roc_auc),
        'pixel_roc_auc': float(pixel_roc_auc),
        'pixel_f1': float(pixel_f1),
        'pixel_accuracy': float(pixel_acc)
    }


def train_model(model, category, train_loader):
    optimizer = optim.AdamW(model.parameters(), lr=Config.learning_rate, weight_decay=Config.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.num_epochs)
    scaler = GradScaler() if Config.use_fp16 else None

    augmentation = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.1, contrast=0.1)
    ])

    best_val_loss = float('inf')

    for epoch in range(Config.num_epochs):
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.num_epochs}")
        for batch in progress_bar:
            images = batch['image'].to(device)
            labels = batch['label'].float().to(device)
            augmented_images = augmentation(images)

            with autocast(enabled=Config.use_fp16):
                outputs = model(images, category)
                aug_outputs = model(augmented_images, category)

                loss = contrastive_eam_loss(outputs['normal_similarity'], outputs['anomaly_similarity'], labels)
                consistency_loss = F.mse_loss(outputs['anomaly_map'], aug_outputs['anomaly_map'])
                total_loss = loss + 0.5 * consistency_loss

            if Config.use_fp16:
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                total_loss.backward()
                optimizer.step()

            optimizer.zero_grad()
            train_loss += total_loss.item()
            progress_bar.set_postfix({"loss": total_loss.item()})

        avg_train_loss = train_loss / len(train_loader)
        logger.info(f"Epoch {epoch+1}/{Config.num_epochs}, Train Loss: {avg_train_loss:.4f}")
        scheduler.step()

        save_checkpoint(model, optimizer, epoch, avg_train_loss, f"{Config.checkpoint_dir}/{category}_epoch_{epoch+1}.pth")

    return model


def evaluate_model(model, category, test_loader):
    model.eval()
    image_scores = []
    image_labels = []
    pixel_scores = []
    pixel_labels = []
    defect_specific_scores = {}

    os.makedirs(os.path.join(Config.vis_save_dir, category), exist_ok=True)

    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader, desc=f"Evaluating {category}")):
            images = batch['image'].to(device)
            labels = batch['label'].numpy()
            masks = batch['mask'].numpy()
            image_paths = batch['image_path']
            defect_types = batch['defect_type']

            outputs = model(images, category)
            anomaly_scores = outputs['anomaly_score'].cpu().numpy()
            anomaly_maps = outputs['anomaly_map'].squeeze(1).cpu().numpy()

            image_scores.extend(anomaly_scores)
            image_labels.extend(labels)

            for idx, defect_type in enumerate(defect_types):
                if defect_type not in defect_specific_scores:
                    defect_specific_scores[defect_type] = {"scores": [], "labels": []}
                defect_specific_scores[defect_type]["scores"].append(anomaly_scores[idx])
                defect_specific_scores[defect_type]["labels"].append(labels[idx])

            for b in range(images.shape[0]):
                if labels[b] == 1 and masks[b].sum() > 0:
                    pixel_scores.append(anomaly_maps[b].flatten())
                    pixel_labels.append(masks[b].flatten().numpy())

                visualize_results(images[b], anomaly_maps[b], masks[b], anomaly_scores[b], image_paths[b], category, defect_types[b])

    metrics = compute_metrics(image_scores, image_labels, pixel_scores, pixel_labels)
    metrics['category'] = category

    defect_specific_aucs = {}
    for defect_type, data in defect_specific_scores.items():
        if defect_type != "none" and len(set(data["labels"])) > 1:
            try:
                defect_specific_aucs[defect_type] = roc_auc_score(data["labels"], data["scores"])
            except:
                pass

    metrics['defect_specific_aucs'] = {k: float(v) for k, v in defect_specific_aucs.items()}

    with open(os.path.join(Config.vis_save_dir, f"{category}_metrics.json"), 'w') as f:
        json.dump(metrics, f, indent=4)

    return metrics


def save_checkpoint(model, optimizer, epoch, loss, filename):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss
    }, filename)
    logger.info(f"Checkpoint saved to {filename}")


def main():
    start_time = time.time()
    logger.info("Starting improved anomaly detection pipeline...")
    if not torch.cuda.is_available():
        logger.warning("CUDA is not available. Using CPU, which will be much slower!")
    all_metrics = {}
    os.makedirs(Config.vis_save_dir, exist_ok=True)
    os.makedirs(Config.checkpoint_dir, exist_ok=True)

    for category in Config.categories:
        logger.info(f"Processing category: {category}")
        train_transform = transforms.Compose([
            transforms.Resize(Config.target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        test_transform = transforms.Compose([
            transforms.Resize(Config.target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        try:
            train_dataset = MVTecDataset(Config.dataset_path, category, is_train=True, transform=train_transform)
            test_dataset = MVTecDataset(Config.dataset_path, category, is_train=False, transform=test_transform)

            train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True,
                                     num_workers=0, pin_memory=True, persistent_workers=False)
            test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False,
                                    num_workers=0, pin_memory=True, persistent_workers=False)

            model = PromptBasedAnomalyDetector(Config.clip_model_name).to(device)
            logger.info(f"Model device: {next(model.parameters()).device}")
            logger.info(f"Training model for category: {category}")
            model = train_model(model, category, train_loader)
            logger.info(f"Evaluating model for category: {category}")
            metrics = evaluate_model(model, category, test_loader)
            all_metrics[category] = metrics
        except Exception as e:
            logger.error(f"Failed to process category {category}: {str(e)}")
            continue

    overall_metrics = {
        "overall_avg_image_auc": sum(m['image_roc_auc'] for m in all_metrics.values()) / len(all_metrics) if all_metrics else 0,
        "overall_avg_pixel_auc": sum(m['pixel_roc_auc'] for m in all_metrics.values()) / len(all_metrics) if all_metrics else 0,
        "categories": list(all_metrics.keys()),
        "category_performance": {cat: {"image_auc": metrics["image_roc_auc"], "pixel_auc": metrics["pixel_roc_auc"]} for cat, metrics in all_metrics.items()}
    }

    defect_type_metrics = {}
    for category, metrics in all_metrics.items():
        if "defect_specific_aucs" in metrics:
            for defect_type, auc in metrics["defect_specific_aucs"].items():
                if defect_type not in defect_type_metrics:
                    defect_type_metrics[defect_type] = []
                defect_type_metrics[defect_type].append(auc)

    overall_metrics["defect_type_performance"] = {
        defect_type: sum(aucs)/len(aucs) for defect_type, aucs in defect_type_metrics.items() if aucs
    }

    with open(os.path.join(Config.vis_save_dir, "overall_metrics.json"), 'w') as f:
        json.dump(overall_metrics, f, indent=4)

    elapsed_time = (time.time() - start_time) / 60.0
    logger.info(f"Pipeline completed in {elapsed_time:.2f} minutes.")


if __name__ == "__main__":
    if torch.cuda.device_count() > 1:
        logger.info(f"Using {torch.cuda.device_count()} GPUs!")
    cudnn.benchmark = True
    main()

2025-05-12 08:50:58,588 - __main__ - INFO - Using device: cuda
2025-05-12 08:50:58,593 - __main__ - INFO - Starting improved anomaly detection pipeline...
2025-05-12 08:50:58,596 - __main__ - INFO - Processing category: bottle
2025-05-12 08:50:58,597 - __main__ - INFO - Training dataset for category: bottle
2025-05-12 08:50:58,598 - __main__ - INFO - Searching for good samples in: ./mvtec_anomaly_detection\bottle\train\good
2025-05-12 08:50:58,600 - __main__ - INFO - Found 209 training images.
2025-05-12 08:50:58,601 - __main__ - INFO - Testing dataset for category: bottle
2025-05-12 08:50:58,601 - __main__ - INFO - Searching for test good samples in: ./mvtec_anomaly_detection\bottle\test\good
2025-05-12 08:50:58,603 - __main__ - INFO - Found defect types for bottle: ['broken_large', 'broken_small', 'contamination']
2025-05-12 08:50:58,604 - __main__ - INFO - Found 20 defective samples for broken_large
2025-05-12 08:51:01,090 - __main__ - INFO - Found 22 defective samples for broken_sm

KeyboardInterrupt: 

In [39]:
import os
import logging
import time
import json
from pathlib import Path
import glob
import warnings
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import torch.backends.cudnn as cudnn
from torch.nn import functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score
import numpy as np

# Setup logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=FutureWarning)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

class Config:
    dataset_path = "./data"
    categories = ["bottle"]
    target_size = (224, 224)
    batch_size = 16
    num_epochs = 30
    learning_rate = 1e-4
    weight_decay = 1e-5
    clip_model_name = "openai/clip-vit-base-patch32"
    embedding_dim = 512
    margin = 0.5
    temperature = 0.07
    anomaly_threshold = 0.5
    pixel_weight = 0.5
    vis_save_dir = "visualization_results"
    checkpoint_dir = "checkpoints"
    use_fp16 = True
    num_workers = 0

    # Anomaly types
    generic_anomaly_types = ["scratched", "broken", "contaminated", "discolored"]
    specific_anomaly_types = {
        "bottle": ["dented", "leaking", "chipped"],
        "cable": ["frayed", "kinked"],
        "capsule": ["crushed", "color variation"],
        "carpet": ["worn", "torn"],
        "grid": ["distorted", "misaligned"]
    }

class MVTecDataset(Dataset):
    def __init__(self, root_dir, category, is_train=True, transform=None):
        self.root_dir = root_dir
        self.category = category
        self.is_train = is_train
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.mask_paths = []
        self.defect_types = []

        if is_train:
            train_good_path = os.path.join(root_dir, category, "train", "good")
            self.image_paths = sorted(glob.glob(os.path.join(train_good_path, "*.png")))
            self.labels = [0] * len(self.image_paths)
            self.mask_paths = [None] * len(self.image_paths)
            self.defect_types = ["none"] * len(self.image_paths)
        else:
            test_good_path = os.path.join(root_dir, category, "test", "good")
            good_paths = sorted(glob.glob(os.path.join(test_good_path, "*.png")))
            self.image_paths.extend(good_paths)
            self.labels.extend([0] * len(good_paths))
            self.defect_types.extend(["none"] * len(good_paths))
            self.mask_paths.extend([None] * len(good_paths))
            defect_types = [
                d for d in os.listdir(os.path.join(root_dir, category, "test"))
                if d != "good" and os.path.isdir(os.path.join(root_dir, category, "test", d))
            ]
            for defect_type in defect_types:
                defect_path = os.path.join(root_dir, category, "test", defect_type)
                defect_paths = sorted(glob.glob(os.path.join(defect_path, "*.png")))
                self.image_paths.extend(defect_paths)
                self.labels.extend([1] * len(defect_paths))
                self.defect_types.extend([defect_type] * len(defect_paths))
                for img_path in defect_paths:
                    img_filename = os.path.basename(img_path)
                    mask_path = os.path.join(
                        root_dir, category, "ground_truth", defect_type, img_filename
                    )
                    mask_with_suffix = os.path.splitext(mask_path)[0] + "_mask.png"
                    if os.path.exists(mask_path):
                        self.mask_paths.append(mask_path)
                    elif os.path.exists(mask_with_suffix):
                        self.mask_paths.append(mask_with_suffix)
                    else:
                        self.mask_paths.append(None)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        label = self.labels[idx]
        defect_type = self.defect_types[idx]
        mask = None

        if self.mask_paths[idx] and os.path.exists(self.mask_paths[idx]):
            mask = Image.open(self.mask_paths[idx]).convert("L")
            mask = transforms.Resize(Config.target_size)(mask)
            mask = transforms.ToTensor()(mask)
        else:
            mask = torch.zeros((1, *Config.target_size))  # Default empty mask

        return {
            "image": image,
            "label": label,
            "mask": mask,
            "image_path": image_path,
            "defect_type": defect_type
        }

class VisionPromptLearner(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.embed_dim = clip_model.vision_model.config.hidden_size
        self.prompt_length = 5
        self.visual_prompts = nn.Parameter(torch.randn(self.prompt_length, self.embed_dim))

    def forward(self, patch_embeddings):
        prompts = self.visual_prompts.expand(patch_embeddings.shape[0], -1, -1)
        return torch.cat([prompts, patch_embeddings], dim=1)

class AnomalyLocalizationHead(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.fusion_layer = nn.Sequential(
            nn.Conv2d(embed_dim * 2, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 1, kernel_size=1)
        )

    def forward(self, patch_features, prompts):
        batch_size, seq_len, embed_dim = prompts.shape
        h = w = int(patch_features.shape[1]**0.5)
        
        # Take mean of prompts across sequence dimension
        prompts_global = prompts.mean(dim=1)  # (batch_size, embed_dim)
        # Expand to (batch_size, embed_dim, 1, 1) and tile to (h, w)
        prompts_expanded = prompts_global.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, h, w)
        
        # Reshape patch_features to (batch_size, embed_dim, h, w)
        patch_reshaped = patch_features.view(batch_size, embed_dim, h, w)
        
        # Concatenate along channel dimension
        fused = torch.cat([patch_reshaped, prompts_expanded], dim=1)
        return self.fusion_layer(fused)

class PromptBasedAnomalyDetector(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-base-patch32"):
        super().__init__()
        self.clip_model = CLIPModel.from_pretrained(clip_model_name)
        self.vision_encoder = self.clip_model.vision_model
        self.vision_proj = self.clip_model.visual_projection
        self.prompt_learner = VisionPromptLearner(self.clip_model)
        self.anomaly_localization_head = AnomalyLocalizationHead(self.vision_encoder.config.hidden_size)

        for param in self.clip_model.parameters():
            param.requires_grad = False

    def encode_image(self, image):
        vision_outputs = self.vision_encoder(image)
        image_embeddings = vision_outputs.last_hidden_state[:, 0, :]
        image_embeddings = self.vision_proj(image_embeddings)
        patch_embeddings = vision_outputs.last_hidden_state[:, 1:, :]
        return image_embeddings, patch_embeddings

    def forward(self, image, category):
        image_embedding, patch_embeddings = self.encode_image(image)
        patched_with_prompts = self.prompt_learner(patch_embeddings)
        anomaly_map = self.anomaly_localization_head(patch_embeddings, patched_with_prompts)
        anomaly_map = F.interpolate(anomaly_map, size=Config.target_size, mode='bilinear', align_corners=False)
        anomaly_map = anomaly_map.squeeze(1)

        # ✅ Apply sigmoid to normalize between [0, 1]
        normalized_anomaly_map = torch.sigmoid(anomaly_map)
        anomaly_score = normalized_anomaly_map.mean(dim=(1, 2))

        return {
            'anomaly_score': anomaly_score,
            'anomaly_map': normalized_anomaly_map,  # ✅ Already in [0, 1]
            'image_embedding': image_embedding
        }
def contrastive_eam_loss(anomaly_scores, labels, anomaly_maps, masks):
    # Image-level loss
    logits = anomaly_scores.unsqueeze(1)
    targets = labels.float().unsqueeze(1)
    image_loss = F.binary_cross_entropy_with_logits(logits, targets)

    # Pixel-level loss
    if masks is not None and masks.shape == anomaly_maps.shape:
        # If masks are (B, H, W), expand to (B, 1, H, W) if necessary
        if masks.dim() == 3:
            masks = masks.unsqueeze(1)  # For BCEWithLogitsLoss which expects channel dim
        elif masks.dim() == 4 and masks.shape[1] != 1:
            masks = masks[:, 0].unsqueeze(1)  # Take first channel if multi-channel

        pixel_loss = F.binary_cross_entropy_with_logits(anomaly_maps, masks.float())
        total_loss = image_loss + Config.pixel_weight * pixel_loss
    else:
        total_loss = image_loss

    return total_loss

def visualize_results(image, anomaly_map, mask, score, path, category, defect_type):
    plt.figure(figsize=(10, 5))
    # Denormalize image
    image = image.permute(1, 2, 0).cpu().numpy()
    image = (image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406])
    image = np.clip(image, 0, 1)

    plt.subplot(1, 3, 1)
    defect_title = defect_type.capitalize() if defect_type != "none" else "Normal"
    plt.title(f"Image\nDefect Type: {defect_title}")
    plt.imshow(image)
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.title(f"Anomaly Map\n(Score: {score:.2f})")
    plt.imshow(anomaly_map, cmap='jet',  vmin=0, vmax=1)
    plt.colorbar()
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.title("Ground Truth Mask")
    if mask is not None and mask.sum() > 0:
        plt.imshow(mask[0], cmap='gray')  # ✅ No .numpy() needed
    else:
        plt.text(0.5, 0.5, "No Mask", ha='center', va='center')
    plt.axis('off')

    plt.tight_layout()
    save_path = os.path.join(Config.vis_save_dir, category)
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, os.path.basename(path)))
    plt.close()

def compute_metrics(image_scores, image_labels, pixel_scores, pixel_labels):
    # Image-level ROC-AUC
    image_roc_auc = 0.0
    if len(set(image_labels)) > 1:
        image_roc_auc = roc_auc_score(image_labels, image_scores)

    # Pixel-level metrics
    pixel_roc_auc = 0.0
    pixel_f1 = 0.0
    pixel_acc = 0.0

    if len(pixel_scores) > 0 and len(pixel_labels) > 0:
        all_pixel_scores = np.concatenate(pixel_scores)
        all_pixel_labels = np.concatenate(pixel_labels)

        # Skip if only one class exists
        if len(np.unique(all_pixel_labels)) < 2:
            logger.warning("Only one class found in pixel labels. Skipping pixel-level AUC.")
        else:
            # Normalize scores to [0, 1] using sigmoid
            all_pixel_scores = 1 / (1 + np.exp(-all_pixel_scores))  # Sigmoid normalization
            try:
                pixel_roc_auc = roc_auc_score(all_pixel_labels, all_pixel_scores)
            except ValueError as e:
                logger.error(f"Failed to compute pixel AUC: {str(e)}")

        pixel_preds = (all_pixel_scores > Config.anomaly_threshold).astype(int)
        pixel_f1 = f1_score(all_pixel_labels, pixel_preds, zero_division=0)
        pixel_acc = accuracy_score(all_pixel_labels, pixel_preds)

    return {
        'image_roc_auc': float(image_roc_auc),
        'pixel_roc_auc': float(pixel_roc_auc),
        'pixel_f1': float(pixel_f1),
        'pixel_accuracy': float(pixel_acc)
    }

def train_model(model, category, train_loader):
    optimizer = optim.AdamW(model.parameters(), lr=Config.learning_rate, weight_decay=Config.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.num_epochs)
    scaler = GradScaler() if Config.use_fp16 else None
    augmentation = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.1, contrast=0.1)
    ])

    for epoch in range(Config.num_epochs):
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.num_epochs}")
        for batch in progress_bar:
            images = batch['image'].to(device)
            labels = batch['label'].float().to(device)
            masks = batch['mask'].to(device)
            augmented_images = augmentation(images)

            with autocast(enabled=Config.use_fp16):
                outputs = model(images, category)
                aug_outputs = model(augmented_images, category)
                loss = contrastive_eam_loss(
                outputs['anomaly_score'], 
                labels, 
                outputs['anomaly_map'], 
                batch['mask'].to(device) 
            )
                consistency_loss = F.mse_loss(outputs['anomaly_map'], aug_outputs['anomaly_map'])
                total_loss = loss + 0.5 * consistency_loss

            if Config.use_fp16:
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                total_loss.backward()
                optimizer.step()
            optimizer.zero_grad()
            train_loss += total_loss.item()
            progress_bar.set_postfix({"loss": total_loss.item()})
        avg_train_loss = train_loss / len(train_loader)
        logger.info(f"Epoch {epoch+1}/{Config.num_epochs}, Train Loss: {avg_train_loss:.4f}")
        scheduler.step()
        save_checkpoint(model, optimizer, epoch, avg_train_loss, f"{Config.checkpoint_dir}/{category}_epoch_{epoch+1}.pth")
    return model

def evaluate_model(model, category, test_loader):
    model.eval()
    image_scores = []
    image_labels = []
    pixel_scores = []
    pixel_labels = []
    defect_specific_scores = {}

    os.makedirs(os.path.join(Config.vis_save_dir, category), exist_ok=True)

    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Evaluating {category}"):
            images = batch['image'].to(device)
            labels = batch['label'].cpu().numpy()
            masks = batch['mask'].cpu().numpy()
            image_paths = batch['image_path']
            defect_types = batch['defect_type']

            outputs = model(images, category)
            anomaly_scores = outputs['anomaly_score'].cpu().numpy()
            anomaly_maps = outputs['anomaly_map'].cpu().numpy()

            image_scores.extend(anomaly_scores)
            image_labels.extend(labels)

            for idx, defect_type in enumerate(defect_types):
                if defect_type not in defect_specific_scores:
                    defect_specific_scores[defect_type] = {"scores": [], "labels": []}
                defect_specific_scores[defect_type]["scores"].append(anomaly_scores[idx])
                defect_specific_scores[defect_type]["labels"].append(labels[idx])

            for b in range(images.shape[0]):
                if labels[b] == 1 and masks[b].sum() > 0:
                    pixel_scores.append(anomaly_maps[b].flatten())
                    pixel_labels.append(masks[b].flatten().astype(int))
                visualize_results(images[b], anomaly_maps[b], masks[b], anomaly_scores[b], image_paths[b], category, defect_types[b])

    metrics = compute_metrics(image_scores, image_labels, pixel_scores, pixel_labels)
    metrics['category'] = category

    defect_specific_aucs = {}
    for defect_type, data in defect_specific_scores.items():
        if defect_type != "none" and len(set(data["labels"])) > 1:
            try:
                defect_specific_aucs[defect_type] = roc_auc_score(data["labels"], data["scores"])
            except:
                pass

    metrics['defect_specific_aucs'] = {k: float(v) for k, v in defect_specific_aucs.items()}

    with open(os.path.join(Config.vis_save_dir, f"{category}_metrics.json"), 'w') as f:
        json.dump(metrics, f, indent=4)

    return metrics

def save_checkpoint(model, optimizer, epoch, loss, filename):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss
    }, filename)
    logger.info(f"Checkpoint saved to {filename}")

def main():
    start_time = time.time()
    logger.info("Starting improved anomaly detection pipeline...")
    if not torch.cuda.is_available():
        logger.warning("CUDA is not available. Using CPU, which will be much slower!")

    all_metrics = {}
    os.makedirs(Config.vis_save_dir, exist_ok=True)
    os.makedirs(Config.checkpoint_dir, exist_ok=True)

    for category in Config.categories:
        logger.info(f"Processing category: {category}")
        train_transform = transforms.Compose([
            transforms.Resize(Config.target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        test_transform = transforms.Compose([
            transforms.Resize(Config.target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        train_dataset = MVTecDataset(Config.dataset_path, category, is_train=True, transform=train_transform)
        test_dataset = MVTecDataset(Config.dataset_path, category, is_train=False, transform=test_transform)

        train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=0, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0, pin_memory=True)

        model = PromptBasedAnomalyDetector(Config.clip_model_name).to(device)
        logger.info(f"Training model for category: {category}")
        model = train_model(model, category, train_loader)
        logger.info(f"Evaluating model for category: {category}")
        metrics = evaluate_model(model, category, test_loader)
        all_metrics[category] = metrics

    overall_metrics = {
        "overall_avg_image_auc": sum(m['image_roc_auc'] for m in all_metrics.values()) / len(all_metrics),
        "overall_avg_pixel_auc": sum(m['pixel_roc_auc'] for m in all_metrics.values()) / len(all_metrics),
        "categories": list(all_metrics.keys()),
        "category_performance": {cat: {"image_auc": m["image_roc_auc"], "pixel_auc": m["pixel_roc_auc"]} for cat, m in all_metrics.items()}
    }

    with open(os.path.join(Config.vis_save_dir, "overall_metrics.json"), 'w') as f:
        json.dump(overall_metrics, f, indent=4)

    elapsed_time = (time.time() - start_time) / 60.0
    logger.info(f"Pipeline completed in {elapsed_time:.2f} minutes.")

if __name__ == "__main__":
    main()

2025-05-12 09:53:16,511 - __main__ - INFO - Using device: cuda
2025-05-12 09:53:16,515 - __main__ - INFO - Starting improved anomaly detection pipeline...
2025-05-12 09:53:16,516 - __main__ - INFO - Processing category: bottle
2025-05-12 09:53:17,918 - __main__ - INFO - Training model for category: bottle
Epoch 1/30: 100%|██████████| 14/14 [00:06<00:00,  2.31it/s, loss=0.861]
2025-05-12 09:53:23,981 - __main__ - INFO - Epoch 1/30, Train Loss: 0.9008
2025-05-12 09:53:25,715 - __main__ - INFO - Checkpoint saved to checkpoints/bottle_epoch_1.pth
Epoch 2/30: 100%|██████████| 14/14 [00:05<00:00,  2.44it/s, loss=0.837]
2025-05-12 09:53:31,449 - __main__ - INFO - Epoch 2/30, Train Loss: 0.8470
2025-05-12 09:53:34,318 - __main__ - INFO - Checkpoint saved to checkpoints/bottle_epoch_2.pth
Epoch 3/30: 100%|██████████| 14/14 [00:05<00:00,  2.40it/s, loss=0.824]
2025-05-12 09:53:40,166 - __main__ - INFO - Epoch 3/30, Train Loss: 0.8296
2025-05-12 09:53:43,045 - __main__ - INFO - Checkpoint saved t

In [43]:
import os
import logging
import time
import json
from pathlib import Path
import glob
import warnings
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import torch.backends.cudnn as cudnn
from torch.nn import functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_recall_curve
import numpy as np

# Setup logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=FutureWarning)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

class Config:
    dataset_path = "./mvtec_anomaly_detection"
    categories = ["bottle"]
    target_size = (224, 224)
    batch_size = 16
    num_epochs = 50
    learning_rate = 1e-4
    weight_decay = 1e-5
    clip_model_name = "openai/clip-vit-base-patch32"
    embedding_dim = 512
    margin = 0.5
    temperature = 0.07
    anomaly_threshold = 0.5
    pixel_weight = 0.5
    vis_save_dir = "visualization_results"
    checkpoint_dir = "checkpoints"
    use_fp16 = True
    num_workers = 0

    # Anomaly types
    generic_anomaly_types = ["scratched", "broken", "contaminated", "discolored"]
    specific_anomaly_types = {
        "bottle": ["dented", "leaking", "chipped"],
        "cable": ["frayed", "kinked"],
        "capsule": ["crushed", "color variation"],
        "carpet": ["worn", "torn"],
        "grid": ["distorted", "misaligned"]
    }

class MVTecDataset(Dataset):
    def __init__(self, root_dir, category, is_train=True, transform=None):
        self.root_dir = root_dir
        self.category = category
        self.is_train = is_train
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.mask_paths = []
        self.defect_types = []

        if is_train:
            train_good_path = os.path.join(root_dir, category, "train", "good")
            self.image_paths = sorted(glob.glob(os.path.join(train_good_path, "*.png")))
            self.labels = [0] * len(self.image_paths)
            self.mask_paths = [None] * len(self.image_paths)
            self.defect_types = ["none"] * len(self.image_paths)
        else:
            test_good_path = os.path.join(root_dir, category, "test", "good")
            good_paths = sorted(glob.glob(os.path.join(test_good_path, "*.png")))
            self.image_paths.extend(good_paths)
            self.labels.extend([0] * len(good_paths))
            self.defect_types.extend(["none"] * len(good_paths))
            self.mask_paths.extend([None] * len(good_paths))

            defect_types = [
                d for d in os.listdir(os.path.join(root_dir, category, "test"))
                if d != "good" and os.path.isdir(os.path.join(root_dir, category, "test", d))
            ]
            for defect_type in defect_types:
                defect_path = os.path.join(root_dir, category, "test", defect_type)
                defect_paths = sorted(glob.glob(os.path.join(defect_path, "*.png")))
                self.image_paths.extend(defect_paths)
                self.labels.extend([1] * len(defect_paths))
                self.defect_types.extend([defect_type] * len(defect_paths))
                for img_path in defect_paths:
                    img_filename = os.path.basename(img_path)
                    mask_path = os.path.join(
                        root_dir, category, "ground_truth", defect_type, img_filename
                    )
                    mask_with_suffix = os.path.splitext(mask_path)[0] + "_mask.png"
                    if os.path.exists(mask_path):
                        self.mask_paths.append(mask_path)
                    elif os.path.exists(mask_with_suffix):
                        self.mask_paths.append(mask_with_suffix)
                    else:
                        self.mask_paths.append(None)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        label = self.labels[idx]
        defect_type = self.defect_types[idx]
        mask = None
        if self.mask_paths[idx] and os.path.exists(self.mask_paths[idx]):
            mask = Image.open(self.mask_paths[idx]).convert("L")
            mask = transforms.Resize(Config.target_size)(mask)
            mask = transforms.ToTensor()(mask)
        else:
            mask = torch.zeros((1, *Config.target_size))  # Default empty mask
        return {
            "image": image,
            "label": label,
            "mask": mask,
            "image_path": image_path,
            "defect_type": defect_type
        }

class VisionPromptLearner(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.embed_dim = clip_model.vision_model.config.hidden_size
        self.prompt_length = 10  # Increased from 5 to 10
        self.visual_prompts = nn.Parameter(torch.randn(self.prompt_length, self.embed_dim))

    def forward(self, patch_embeddings):
        prompts = self.visual_prompts.expand(patch_embeddings.shape[0], -1, -1)
        return torch.cat([prompts, patch_embeddings], dim=1)

class AnomalyLocalizationHead(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.fusion_layer = nn.Sequential(
            nn.Conv2d(embed_dim * 2, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, patch_features, prompts):
        batch_size, seq_len, embed_dim = prompts.shape
        h = w = int(patch_features.shape[1]**0.5)

        prompts_global = prompts.mean(dim=1)  # (batch_size, embed_dim)
        prompts_expanded = prompts_global.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, h, w)
        patch_reshaped = patch_features.view(batch_size, embed_dim, h, w)
        fused = torch.cat([patch_reshaped, prompts_expanded], dim=1)
        return self.fusion_layer(fused)

class PromptBasedAnomalyDetector(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-base-patch32"):
        super().__init__()
        self.clip_model = CLIPModel.from_pretrained(clip_model_name)
        self.vision_encoder = self.clip_model.vision_model
        self.vision_proj = self.clip_model.visual_projection
        self.prompt_learner = VisionPromptLearner(self.clip_model)
        self.anomaly_localization_head = AnomalyLocalizationHead(self.vision_encoder.config.hidden_size)

        for param in self.clip_model.parameters():
            param.requires_grad = False

    def encode_image(self, image):
        vision_outputs = self.vision_encoder(image)
        image_embeddings = vision_outputs.last_hidden_state[:, 0, :]
        image_embeddings = self.vision_proj(image_embeddings)
        patch_embeddings = vision_outputs.last_hidden_state[:, 1:, :]
        return image_embeddings, patch_embeddings

    def forward(self, image, category):
        image_embedding, patch_embeddings = self.encode_image(image)
        patched_with_prompts = self.prompt_learner(patch_embeddings)
        anomaly_map = self.anomaly_localization_head(patch_embeddings, patched_with_prompts)
        anomaly_map = F.interpolate(anomaly_map, size=Config.target_size, mode='bilinear', align_corners=False)
        anomaly_map = anomaly_map.squeeze(1)
        anomaly_score = anomaly_map.mean(dim=(1, 2))
        return {
            'anomaly_score': anomaly_score,
            'anomaly_map': anomaly_map,
            'image_embedding': image_embedding
        }

def dice_loss(preds, targets, smooth=1e-6):
    preds = preds.contiguous().view(-1)
    targets = targets.contiguous().view(-1)
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum()
    return 1 - (2 * intersection + smooth) / (union + smooth)

def contrastive_eam_loss(anomaly_scores, labels, anomaly_maps, masks):
    logits = anomaly_scores.unsqueeze(1)
    targets = labels.float().unsqueeze(1)
    image_loss = F.binary_cross_entropy_with_logits(logits, targets)

    if masks is not None and masks.shape == anomaly_maps.shape:
        if masks.dim() == 3:
            masks = masks.unsqueeze(1)
        elif masks.dim() == 4 and masks.shape[1] != 1:
            masks = masks[:, 0].unsqueeze(1)

        pixel_loss_bce = F.binary_cross_entropy_with_logits(anomaly_maps, masks.float())
        pixel_loss_dice = dice_loss(torch.sigmoid(anomaly_maps), masks.float())
        total_loss = image_loss + Config.pixel_weight * (pixel_loss_bce + pixel_loss_dice)
    else:
        total_loss = image_loss

    return total_loss

def visualize_results(image, anomaly_map, mask, score, path, category, defect_type):
    plt.figure(figsize=(10, 5))
    image = image.permute(1, 2, 0).cpu().numpy()
    image = (image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406])
    image = np.clip(image, 0, 1)

    plt.subplot(1, 3, 1)
    defect_title = defect_type.capitalize() if defect_type != "none" else "Normal"
    plt.title(f"Image\nDefect Type: {defect_title}")
    plt.imshow(image)
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.title(f"Anomaly Map\n(Score: {score:.2f})")
    plt.imshow(anomaly_map, cmap='jet', vmin=0, vmax=1)
    plt.colorbar()
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.title("Ground Truth Mask")
    if mask is not None and mask.sum() > 0:
        plt.imshow(mask[0], cmap='gray')
    else:
        plt.text(0.5, 0.5, "No Mask", ha='center', va='center')
    plt.axis('off')

    plt.tight_layout()
    save_path = os.path.join(Config.vis_save_dir, category)
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, os.path.basename(path)))
    plt.close()

def compute_metrics(image_scores, image_labels, pixel_scores, pixel_labels):
    image_roc_auc = 0.0
    if len(set(image_labels)) > 1:
        image_roc_auc = roc_auc_score(image_labels, image_scores)

    pixel_roc_auc = 0.0
    pixel_f1 = 0.0
    pixel_acc = 0.0
    best_threshold = Config.anomaly_threshold

    if len(pixel_scores) > 0 and len(pixel_labels) > 0:
        all_pixel_scores = np.concatenate(pixel_scores)
        all_pixel_labels = np.concatenate(pixel_labels)

        if len(np.unique(all_pixel_labels)) < 2:
            logger.warning("Only one class found in pixel labels. Skipping pixel-level AUC.")
        else:
            try:
                pixel_roc_auc = roc_auc_score(all_pixel_labels, all_pixel_scores)
                precisions, recalls, thresholds = precision_recall_curve(all_pixel_labels, all_pixel_scores)
                f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
                best_idx = np.argmax(f1_scores)
                best_threshold = thresholds[best_idx]
                pixel_preds = (all_pixel_scores > best_threshold).astype(int)
                pixel_f1 = f1_score(all_pixel_labels, pixel_preds, zero_division=0)
                pixel_acc = accuracy_score(all_pixel_labels, pixel_preds)
            except Exception as e:
                logger.error(f"Failed to compute metrics: {str(e)}")

    return {
        'image_roc_auc': float(image_roc_auc),
        'pixel_roc_auc': float(pixel_roc_auc),
        'pixel_f1': float(pixel_f1),
        'pixel_accuracy': float(pixel_acc),
        'best_threshold': float(best_threshold)
    }

def train_model(model, category, train_loader):
    optimizer = optim.AdamW(model.parameters(), lr=Config.learning_rate, weight_decay=Config.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.num_epochs)
    scaler = GradScaler() if Config.use_fp16 else None

    augmentation = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2)
    ])

    for epoch in range(Config.num_epochs):
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.num_epochs}")
        for batch in progress_bar:
            images = batch['image'].to(device)
            labels = batch['label'].float().to(device)
            masks = batch['mask'].to(device)

            augmented_images = augmentation(images)

            with autocast(enabled=Config.use_fp16):
                outputs = model(images, category)
                aug_outputs = model(augmented_images, category)

                loss = contrastive_eam_loss(
                    outputs['anomaly_score'],
                    labels,
                    outputs['anomaly_map'],
                    batch['mask'].to(device)
                )
                consistency_loss = F.mse_loss(outputs['anomaly_map'], aug_outputs['anomaly_map'])
                total_loss = loss + 0.5 * consistency_loss

            if Config.use_fp16:
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                total_loss.backward()
                optimizer.step()

            optimizer.zero_grad()
            train_loss += total_loss.item()
            progress_bar.set_postfix({"loss": total_loss.item()})

        avg_train_loss = train_loss / len(train_loader)
        logger.info(f"Epoch {epoch+1}/{Config.num_epochs}, Train Loss: {avg_train_loss:.4f}")
        scheduler.step()
        save_checkpoint(model, optimizer, epoch, avg_train_loss, f"{Config.checkpoint_dir}/{category}_epoch_{epoch+1}.pth")

    return model

def evaluate_model(model, category, test_loader):
    model.eval()
    image_scores = []
    image_labels = []
    pixel_scores = []
    pixel_labels = []
    defect_specific_scores = {}

    os.makedirs(os.path.join(Config.vis_save_dir, category), exist_ok=True)

    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Evaluating {category}"):
            images = batch['image'].to(device)
            labels = batch['label'].cpu().numpy()
            masks = batch['mask'].cpu().numpy()
            image_paths = batch['image_path']
            defect_types = batch['defect_type']

            outputs = model(images, category)
            anomaly_scores = outputs['anomaly_score'].cpu().numpy()
            anomaly_maps = outputs['anomaly_map'].cpu().numpy()

            image_scores.extend(anomaly_scores)
            image_labels.extend(labels)

            for idx, defect_type in enumerate(defect_types):
                if defect_type not in defect_specific_scores:
                    defect_specific_scores[defect_type] = {"scores": [], "labels": []}
                defect_specific_scores[defect_type]["scores"].append(anomaly_scores[idx])
                defect_specific_scores[defect_type]["labels"].append(labels[idx])

            for b in range(images.shape[0]):
                if labels[b] == 1 and masks[b].sum() > 0:
                    pixel_scores.append(anomaly_maps[b].flatten())
                    pixel_labels.append(masks[b].flatten().astype(int))
                visualize_results(images[b], anomaly_maps[b], masks[b], anomaly_scores[b], image_paths[b], category, defect_types[b])

    metrics = compute_metrics(image_scores, image_labels, pixel_scores, pixel_labels)
    metrics['category'] = category

    defect_specific_aucs = {}
    for defect_type, data in defect_specific_scores.items():
        if defect_type != "none" and len(set(data["labels"])) > 1:
            try:
                defect_specific_aucs[defect_type] = roc_auc_score(data["labels"], data["scores"])
            except:
                pass

    metrics['defect_specific_aucs'] = {k: float(v) for k, v in defect_specific_aucs.items()}
    with open(os.path.join(Config.vis_save_dir, f"{category}_metrics.json"), 'w') as f:
        json.dump(metrics, f, indent=4)

    return metrics

def save_checkpoint(model, optimizer, epoch, loss, filename):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss
    }, filename)
    logger.info(f"Checkpoint saved to {filename}")

def main():
    start_time = time.time()
    logger.info("Starting improved anomaly detection pipeline...")

    if not torch.cuda.is_available():
        logger.warning("CUDA is not available. Using CPU, which will be much slower!")

    all_metrics = {}
    os.makedirs(Config.vis_save_dir, exist_ok=True)
    os.makedirs(Config.checkpoint_dir, exist_ok=True)

    for category in Config.categories:
        logger.info(f"Processing category: {category}")

        train_transform = transforms.Compose([
            transforms.Resize(Config.target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        test_transform = transforms.Compose([
            transforms.Resize(Config.target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        train_dataset = MVTecDataset(Config.dataset_path, category, is_train=True, transform=train_transform)
        test_dataset = MVTecDataset(Config.dataset_path, category, is_train=False, transform=test_transform)

        train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=0, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0, pin_memory=True)

        model = PromptBasedAnomalyDetector(Config.clip_model_name).to(device)

        logger.info(f"Training model for category: {category}")
        model = train_model(model, category, train_loader)

        logger.info(f"Evaluating model for category: {category}")
        metrics = evaluate_model(model, category, test_loader)
        all_metrics[category] = metrics

    overall_metrics = {
        "overall_avg_image_auc": sum(m['image_roc_auc'] for m in all_metrics.values()) / len(all_metrics),
        "overall_avg_pixel_auc": sum(m['pixel_roc_auc'] for m in all_metrics.values()) / len(all_metrics),
        "overall_avg_pixel_f1": sum(m['pixel_f1'] for m in all_metrics.values()) / len(all_metrics),
        "categories": list(all_metrics.keys()),
        "category_performance": {
            cat: {
                "image_auc": m["image_roc_auc"],
                "pixel_auc": m["pixel_roc_auc"],
                "pixel_f1": m["pixel_f1"]
            } for cat, m in all_metrics.items()
        }
    }

    with open(os.path.join(Config.vis_save_dir, "overall_metrics.json"), 'w') as f:
        json.dump(overall_metrics, f, indent=4)

    elapsed_time = (time.time() - start_time) / 60.0
    logger.info(f"Pipeline completed in {elapsed_time:.2f} minutes.")

if __name__ == "__main__":
    main()

2025-05-12 11:25:12,625 - __main__ - INFO - Using device: cuda
2025-05-12 11:25:12,630 - __main__ - INFO - Starting improved anomaly detection pipeline...
2025-05-12 11:25:12,831 - __main__ - INFO - Processing category: bottle
2025-05-12 11:25:26,053 - __main__ - INFO - Training model for category: bottle
Epoch 1/50: 100%|██████████| 14/14 [00:09<00:00,  1.42it/s, loss=0.829]
2025-05-12 11:25:35,911 - __main__ - INFO - Epoch 1/50, Train Loss: 0.8636
2025-05-12 11:25:37,462 - __main__ - INFO - Checkpoint saved to checkpoints/bottle_epoch_1.pth
Epoch 2/50: 100%|██████████| 14/14 [00:05<00:00,  2.51it/s, loss=0.809]
2025-05-12 11:25:43,038 - __main__ - INFO - Epoch 2/50, Train Loss: 0.8174
2025-05-12 11:25:46,046 - __main__ - INFO - Checkpoint saved to checkpoints/bottle_epoch_2.pth
Epoch 3/50: 100%|██████████| 14/14 [00:05<00:00,  2.47it/s, loss=0.798]
2025-05-12 11:25:51,725 - __main__ - INFO - Epoch 3/50, Train Loss: 0.8032
2025-05-12 11:25:53,362 - __main__ - INFO - Checkpoint saved t

In [45]:
import os
import logging
import time
import json
from pathlib import Path
import glob
import warnings
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import torch.backends.cudnn as cudnn
from torch.nn import functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_recall_curve
import numpy as np

# Setup logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=FutureWarning)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

class Config:
    dataset_path = "./data"
    categories = ["bottle"]
    target_size = (224, 224)
    batch_size = 16
    num_epochs = 30
    learning_rate = 1e-4
    weight_decay = 1e-5
    clip_model_name = "openai/clip-vit-base-patch32"
    embedding_dim = 768
    margin = 0.5
    temperature = 0.07
    anomaly_threshold = 0.5
    pixel_weight = 0.5
    vis_save_dir = "visualization_results"
    checkpoint_dir = "checkpoints"
    use_fp16 = True
    num_workers = 0

    # Hyperparameters
    lambda_prompt = 0.5
    lambda_eam = 0.3
    lambda_align = 0.2
    prompt_learning_rate = 1e-4

    # Anomaly types
    generic_anomaly_types = ["scratched", "broken", "contaminated", "discolored"]
    specific_anomaly_types = {
        "bottle": ["dented", "leaking", "chipped"],
        "cable": ["frayed", "kinked"],
        "capsule": ["crushed", "color variation"],
        "carpet": ["worn", "torn"],
        "grid": ["distorted", "misaligned"]
    }

class MVTecDataset(Dataset):
    def __init__(self, root_dir, category, is_train=True, transform=None):
        self.root_dir = root_dir
        self.category = category
        self.is_train = is_train
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.mask_paths = []
        self.defect_types = []

        if is_train:
            train_good_path = os.path.join(root_dir, category, "train", "good")
            self.image_paths = sorted(glob.glob(os.path.join(train_good_path, "*.png")))
            self.labels = [0] * len(self.image_paths)
            self.mask_paths = [None] * len(self.image_paths)
            self.defect_types = ["none"] * len(self.image_paths)
        else:
            test_good_path = os.path.join(root_dir, category, "test", "good")
            good_paths = sorted(glob.glob(os.path.join(test_good_path, "*.png")))
            self.image_paths.extend(good_paths)
            self.labels.extend([0] * len(good_paths))
            self.defect_types.extend(["none"] * len(good_paths))
            self.mask_paths.extend([None] * len(good_paths))

            defect_types = [
                d for d in os.listdir(os.path.join(root_dir, category, "test"))
                if d != "good" and os.path.isdir(os.path.join(root_dir, category, "test", d))
            ]
            for defect_type in defect_types:
                defect_path = os.path.join(root_dir, category, "test", defect_type)
                defect_paths = sorted(glob.glob(os.path.join(defect_path, "*.png")))
                self.image_paths.extend(defect_paths)
                self.labels.extend([1] * len(defect_paths))
                self.defect_types.extend([defect_type] * len(defect_paths))
                for img_path in defect_paths:
                    img_filename = os.path.basename(img_path)
                    mask_path = os.path.join(
                        root_dir, category, "ground_truth", defect_type, img_filename
                    )
                    mask_with_suffix = os.path.splitext(mask_path)[0] + "_mask.png"
                    if os.path.exists(mask_path):
                        self.mask_paths.append(mask_path)
                    elif os.path.exists(mask_with_suffix):
                        self.mask_paths.append(mask_with_suffix)
                    else:
                        self.mask_paths.append(None)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        label = self.labels[idx]
        defect_type = self.defect_types[idx]

        mask = None
        if self.mask_paths[idx] and os.path.exists(self.mask_paths[idx]):
            mask = Image.open(self.mask_paths[idx]).convert("L")
            mask = transforms.Resize(Config.target_size)(mask)
            mask = transforms.ToTensor()(mask)
        else:
            mask = torch.zeros((1, *Config.target_size))  # Default empty mask

        return {
            "image": image,
            "label": label,
            "mask": mask,
            "image_path": image_path,
            "defect_type": defect_type
        }

class VisionPromptLearner(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.embed_dim = clip_model.vision_model.config.hidden_size
        self.prompt_length = 5
        self.normal_prompts = nn.Parameter(torch.randn(self.prompt_length, self.embed_dim))
        self.anomaly_prompts = nn.Parameter(torch.randn(self.prompt_length, self.embed_dim))

    def forward(self, patch_embeddings):
        prompts = torch.cat([self.normal_prompts, self.anomaly_prompts], dim=0)
        prompts = prompts.expand(patch_embeddings.shape[0], -1, -1)
        return torch.cat([prompts, patch_embeddings], dim=1)

class AnomalyLocalizationHead(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.fusion_layer = nn.Sequential(
            nn.Conv2d(embed_dim * 2, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 1, kernel_size=1)
        )

    def forward(self, patch_features, prompts):
        batch_size, seq_len, embed_dim = prompts.shape
        h = w = int(patch_features.shape[1]**0.5)
        prompts_global = prompts.mean(dim=1)
        prompts_expanded = prompts_global.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, h, w)
        patch_reshaped = patch_features.view(batch_size, embed_dim, h, w)
        fused = torch.cat([patch_reshaped, prompts_expanded], dim=1)
        return self.fusion_layer(fused)

class PromptBasedAnomalyDetector(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-base-patch32"):
        super().__init__()
        self.clip_model = CLIPModel.from_pretrained(clip_model_name)
        self.vision_encoder = self.clip_model.vision_model
        self.vision_proj = self.clip_model.visual_projection
        self.prompt_learner = VisionPromptLearner(self.clip_model)
        self.anomaly_localization_head = AnomalyLocalizationHead(self.vision_encoder.config.hidden_size)

        for param in self.clip_model.parameters():
            param.requires_grad = False

    def encode_image(self, image):
        vision_outputs = self.vision_encoder(image)
        image_embeddings = vision_outputs.last_hidden_state[:, 0, :]
        image_embeddings = self.vision_proj(image_embeddings)
        patch_embeddings = vision_outputs.last_hidden_state[:, 1:, :]
        return image_embeddings, patch_embeddings

    def forward(self, image, category):
        image_embedding, patch_embeddings = self.encode_image(image)
        patched_with_prompts = self.prompt_learner(patch_embeddings)
        anomaly_map = self.anomaly_localization_head(patch_embeddings, patched_with_prompts)
        anomaly_map = F.interpolate(anomaly_map, size=Config.target_size, mode='bilinear', align_corners=False)
        anomaly_map = anomaly_map.squeeze(1)
        normalized_anomaly_map = torch.sigmoid(anomaly_map)
        anomaly_score = normalized_anomaly_map.mean(dim=(1, 2))

        return {
            'anomaly_score': anomaly_score,
            'anomaly_map': normalized_anomaly_map,
            'image_embedding': image_embedding,
            'patch_embeddings': patch_embeddings
        }

def contrastive_prompt_loss(normal_prompts, anomaly_prompts, cls_embeddings, labels):
    cls_embeddings = F.normalize(cls_embeddings, dim=1)
    normal_proto = F.normalize(normal_prompts.mean(dim=0), dim=0)
    anomaly_proto = F.normalize(anomaly_prompts.mean(dim=0), dim=0)

    sim_to_normal = torch.matmul(cls_embeddings, normal_proto)
    sim_to_anomaly = torch.matmul(cls_embeddings, anomaly_proto)

    pos_loss = (1 - sim_to_normal[labels == 0]).mean()
    neg_loss = (1 + sim_to_anomaly[labels == 1]).mean()

    return pos_loss + neg_loss

def eam_loss(normal_prompts, anomaly_prompts, margin=Config.margin):
    normal_proto = normal_prompts.mean(dim=0)
    anomaly_proto = anomaly_prompts.mean(dim=0)
    distance = torch.norm(normal_proto - anomaly_proto, p=2)
    return torch.relu(margin - distance)

def prompt_alignment_loss(learned_prompts, text_embeddings):
    learned_prompts = F.normalize(learned_prompts, dim=1)
    text_embeddings = F.normalize(text_embeddings, dim=1)
    cos_sim = torch.mm(learned_prompts, text_embeddings.T)
    return (1 - torch.diag(cos_sim)).mean()

def dice_loss(preds, targets, smooth=1e-6):
    preds = preds.contiguous().view(-1)
    targets = targets.contiguous().view(-1)
    intersection = (preds * targets).sum()
    return 1 - (2. * intersection + smooth) / (preds.sum() + targets.sum() + smooth)

def compute_vad_map(patch_embeddings):
    prototype = patch_embeddings.mean(dim=0)
    distances = torch.cdist(patch_embeddings, prototype.unsqueeze(0))
    return distances.squeeze(-1)

def compute_optimal_threshold(image_scores, image_labels):
    precision, recall, thresholds = precision_recall_curve(image_labels, image_scores)
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
    best_idx = np.argmax(f1_scores)
    return thresholds[best_idx]

def visualize_results(image, anomaly_map, mask, score, path, category, defect_type):
    plt.figure(figsize=(10, 5))
    image = image.permute(1, 2, 0).cpu().numpy()
    image = (image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406])
    image = np.clip(image, 0, 1)

    plt.subplot(1, 3, 1)
    defect_title = defect_type.capitalize() if defect_type != "none" else "Normal"
    plt.title(f"Image\nDefect Type: {defect_title}")
    plt.imshow(image)
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.title(f"Anomaly Map\n(Score: {score:.2f})")
    plt.imshow(anomaly_map, cmap='jet', vmin=0, vmax=1)
    plt.colorbar()
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.title("Ground Truth Mask")
    if mask is not None and mask.sum() > 0:
        plt.imshow(mask[0], cmap='gray')
    else:
        plt.text(0.5, 0.5, "No Mask", ha='center', va='center')
    plt.axis('off')

    plt.tight_layout()
    save_path = os.path.join(Config.vis_save_dir, category)
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, os.path.basename(path)))
    plt.close()

def compute_metrics(image_scores, image_labels, pixel_scores, pixel_labels):
    image_roc_auc = 0.0
    if len(set(image_labels)) > 1:
        image_roc_auc = roc_auc_score(image_labels, image_scores)

    pixel_roc_auc = 0.0
    pixel_f1 = 0.0
    pixel_acc = 0.0

    if len(pixel_scores) > 0 and len(pixel_labels) > 0:
        all_pixel_scores = np.concatenate(pixel_scores)
        all_pixel_labels = np.concatenate(pixel_labels)

        if len(np.unique(all_pixel_labels)) < 2:
            logger.warning("Only one class found in pixel labels. Skipping pixel-level AUC.")
        else:
            all_pixel_scores = 1 / (1 + np.exp(-all_pixel_scores))
            try:
                pixel_roc_auc = roc_auc_score(all_pixel_labels, all_pixel_scores)
            except ValueError as e:
                logger.error(f"Failed to compute pixel AUC: {str(e)}")

        pixel_preds = (all_pixel_scores > Config.anomaly_threshold).astype(int)
        pixel_f1 = f1_score(all_pixel_labels, pixel_preds, zero_division=0)
        pixel_acc = accuracy_score(all_pixel_labels, pixel_preds)

    return {
        'image_roc_auc': float(image_roc_auc),
        'pixel_roc_auc': float(pixel_roc_auc),
        'pixel_f1': float(pixel_f1),
        'pixel_accuracy': float(pixel_acc)
    }

def train_model(model, category, train_loader):
    optimizer = optim.AdamW([
        {'params': model.prompt_learner.normal_prompts, 'lr': Config.prompt_learning_rate},
        {'params': model.prompt_learner.anomaly_prompts, 'lr': Config.prompt_learning_rate},
        {'params': model.anomaly_localization_head.parameters(), 'lr': Config.learning_rate}
    ], weight_decay=Config.weight_decay)

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.num_epochs)
    scaler = GradScaler() if Config.use_fp16 else None

    augmentation = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.1, contrast=0.1)
    ])

    tokenizer = CLIPTokenizer.from_pretrained(Config.clip_model_name)
    text_encoder = CLIPModel.from_pretrained(Config.clip_model_name).text_model.to(device)

    def get_text_embeddings(texts):
        inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
        text_features = text_encoder(**inputs).last_hidden_state[:, 0, :]
        return F.normalize(text_features, dim=-1)

    anomaly_texts = Config.specific_anomaly_types.get(category, [])
    if not anomaly_texts:
        anomaly_texts = Config.generic_anomaly_types[:3]

    text_embeddings = get_text_embeddings(anomaly_texts)

    for epoch in range(Config.num_epochs):
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.num_epochs}")

        for batch in progress_bar:
            images = batch['image'].to(device)
            labels = batch['label'].float().to(device)
            masks = batch['mask'].to(device)

            augmented_images = augmentation(images)

            with autocast(enabled=Config.use_fp16):
                outputs = model(images, category)
                aug_outputs = model(augmented_images, category)

                # Compute losses
                bce_loss = F.binary_cross_entropy_with_logits(outputs['anomaly_score'], labels)
                if masks.dim() == 3:
                    masks = masks.unsqueeze(1)
                pixel_loss = Config.pixel_weight * dice_loss(outputs['anomaly_map'], masks.float())

                contrastive_loss = Config.lambda_prompt * contrastive_prompt_loss(
                    model.prompt_learner.normal_prompts,
                    model.prompt_learner.anomaly_prompts,
                    outputs['image_embedding'],
                    labels.long()
                )

                eam = Config.lambda_eam * eam_loss(
                    model.prompt_learner.normal_prompts,
                    model.prompt_learner.anomaly_prompts
                )

                alignment = Config.lambda_align * prompt_alignment_loss(
                    model.prompt_learner.anomaly_prompts,
                    text_embeddings
                )

                consistency_loss = F.mse_loss(outputs['anomaly_map'], aug_outputs['anomaly_map'])

                total_loss = bce_loss + pixel_loss + contrastive_loss + eam + alignment + 0.5 * consistency_loss

            if Config.use_fp16:
                scaler.scale(total_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                total_loss.backward()
                optimizer.step()
            optimizer.zero_grad()

            train_loss += total_loss.item()
            progress_bar.set_postfix({"loss": total_loss.item()})

        avg_train_loss = train_loss / len(train_loader)
        logger.info(f"Epoch {epoch+1}/{Config.num_epochs}, Train Loss: {avg_train_loss:.4f}")
        scheduler.step()

        save_checkpoint(model, optimizer, epoch, avg_train_loss, f"{Config.checkpoint_dir}/{category}_epoch_{epoch+1}.pth")

    return model

def evaluate_model(model, category, test_loader):
    model.eval()
    image_scores = []
    image_labels = []
    pixel_scores = []
    pixel_labels = []
    defect_specific_scores = {}

    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Evaluating {category}"):
            images = batch['image'].to(device)
            labels = batch['label'].cpu().numpy()
            masks = batch['mask'].cpu().numpy()
            image_paths = batch['image_path']
            defect_types = batch['defect_type']

            outputs = model(images, category)
            anomaly_scores = outputs['anomaly_score'].cpu().numpy()
            anomaly_maps = outputs['anomaly_map'].cpu().numpy()

            image_scores.extend(anomaly_scores)
            image_labels.extend(labels)

            for idx, defect_type in enumerate(defect_types):
                if defect_type not in defect_specific_scores:
                    defect_specific_scores[defect_type] = {"scores": [], "labels": []}
                defect_specific_scores[defect_type]["scores"].append(anomaly_scores[idx])
                defect_specific_scores[defect_type]["labels"].append(labels[idx])

            for b in range(images.shape[0]):
                if labels[b] == 1 and masks[b].sum() > 0:
                    pixel_scores.append(anomaly_maps[b].flatten())
                    pixel_labels.append(masks[b].flatten().astype(int))
                visualize_results(images[b], anomaly_maps[b], masks[b], anomaly_scores[b], image_paths[b], category, defect_types[b])

    optimal_threshold = compute_optimal_threshold(image_scores, image_labels)
    Config.anomaly_threshold = optimal_threshold

    metrics = compute_metrics(image_scores, image_labels, pixel_scores, pixel_labels)
    metrics['category'] = category
    metrics['optimal_threshold'] = float(optimal_threshold)

    defect_specific_aucs = {}
    for defect_type, data in defect_specific_scores.items():
        if defect_type != "none" and len(set(data["labels"])) > 1:
            try:
                defect_specific_aucs[defect_type] = roc_auc_score(data["labels"], data["scores"])
            except:
                pass

    metrics['defect_specific_aucs'] = {k: float(v) for k, v in defect_specific_aucs.items()}
    with open(os.path.join(Config.vis_save_dir, f"{category}_metrics.json"), 'w') as f:
        json.dump(metrics, f, indent=4)

    return metrics

def save_checkpoint(model, optimizer, epoch, loss, filename):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss
    }, filename)
    logger.info(f"Checkpoint saved to {filename}")

def main():
    start_time = time.time()
    logger.info("Starting improved anomaly detection pipeline...")

    if not torch.cuda.is_available():
        logger.warning("CUDA is not available. Using CPU, which will be much slower!")

    all_metrics = {}
    os.makedirs(Config.vis_save_dir, exist_ok=True)
    os.makedirs(Config.checkpoint_dir, exist_ok=True)

    for category in Config.categories:
        logger.info(f"Processing category: {category}")
        train_transform = transforms.Compose([
            transforms.Resize(Config.target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        test_transform = transforms.Compose([
            transforms.Resize(Config.target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        train_dataset = MVTecDataset(Config.dataset_path, category, is_train=True, transform=train_transform)
        test_dataset = MVTecDataset(Config.dataset_path, category, is_train=False, transform=test_transform)

        train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=0, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0, pin_memory=True)

        model = PromptBasedAnomalyDetector(Config.clip_model_name).to(device)
        logger.info(f"Training model for category: {category}")
        model = train_model(model, category, train_loader)
        logger.info(f"Evaluating model for category: {category}")
        metrics = evaluate_model(model, category, test_loader)
        all_metrics[category] = metrics

    overall_metrics = {
        "overall_avg_image_auc": sum(m['image_roc_auc'] for m in all_metrics.values()) / len(all_metrics),
        "overall_avg_pixel_auc": sum(m['pixel_roc_auc'] for m in all_metrics.values()) / len(all_metrics),
        "categories": list(all_metrics.keys()),
        "category_performance": {cat: {"image_auc": m["image_roc_auc"], "pixel_auc": m["pixel_roc_auc"]} for cat, m in all_metrics.items()}
    }

    with open(os.path.join(Config.vis_save_dir, "overall_metrics.json"), 'w') as f:
        json.dump(overall_metrics, f, indent=4)

    elapsed_time = (time.time() - start_time) / 60.0
    logger.info(f"Pipeline completed in {elapsed_time:.2f} minutes.")

if __name__ == "__main__":
    main()

2025-05-12 12:01:18,534 - __main__ - INFO - Using device: cuda
2025-05-12 12:01:18,539 - __main__ - INFO - Starting improved anomaly detection pipeline...
2025-05-12 12:01:18,540 - __main__ - INFO - Processing category: bottle
2025-05-12 12:01:30,294 - __main__ - INFO - Training model for category: bottle
Epoch 1/30:   0%|          | 0/14 [00:00<?, ?it/s]


RuntimeError: size mismatch, got input (16), mat (16x512), vec (768)