<a href="https://www.kaggle.com/code/shashankroy568/mvtec6?scriptVersionId=262855021" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# MVTec Anomaly Detection - Fixed Anomalib Installation
# Designed for Kaggle GPU P100 environment
import os
import sys
import warnings
import subprocess
warnings.filterwarnings('ignore')

print("🔧 Starting robust environment setup...")

def install_package(package_name, import_name=None, extra_args=""):
    """Install package with proper error handling"""
    try:
        # Try importing first
        if import_name:
            __import__(import_name)
            print(f"✅ {package_name} already available")
            return True
        
        print(f"Installing {package_name}...")
        cmd = f"pip install {package_name} {extra_args}"
        result = subprocess.run(cmd.split(), capture_output=True, text=True)
        
        if result.returncode == 0:
            print(f"✅ {package_name} installed successfully")
            return True
        else:
            print(f"⚠️ Warning installing {package_name}: {result.stderr}")
            return False
            
    except Exception as e:
        print(f"❌ Error with {package_name}: {e}")
        return False

# Step 1: Install all dependencies step by step
print("📦 Installing dependencies...")

# Core dependencies first
install_package("python-dotenv", "dotenv", "--quiet")
install_package("opencv-python", "cv2", "--quiet")
install_package("Pillow", "PIL", "--quiet --upgrade")

# Try different anomalib installation strategies
print("\n🔧 Installing anomalib with multiple strategies...")
strategies = [
    ("anomalib", "--quiet --no-deps --upgrade"),
    ("anomalib", "--quiet --force-reinstall"),
    ("anomalib==1.0.1", "--quiet"),
    ("git+https://github.com/openvinotoolkit/anomalib.git", "--quiet")
]

anomalib_installed = False
for package, args in strategies:
    print(f"Trying: pip install {package} {args}")
    if install_package(package, None, args):
        # Test import after each installation
        try:
            import anomalib
            print(f"✅ Anomalib successfully imported!")
            anomalib_installed = True
            break
        except Exception as e:
            print(f"⚠️ Installation succeeded but import failed: {e}")
            continue

if not anomalib_installed:
    print("⚠️ Anomalib installation issues detected. Using manual approach...")

# Step 2: Import required libraries with comprehensive fallbacks
print("\n📚 Importing libraries...")

# Standard imports
try:
    import kagglehub
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from pathlib import Path
    import torch
    from PIL import Image
    import cv2
    from glob import glob
    import json
    from sklearn.metrics import roc_auc_score, roc_curve
    from sklearn.preprocessing import MinMaxScaler
    import torchvision.transforms as transforms
    from torch.utils.data import Dataset, DataLoader
    
    print("✅ Standard libraries imported")
except ImportError as e:
    print(f"❌ Standard import error: {e}")
    sys.exit(1)

# Anomalib imports with multiple fallback options
ANOMALIB_VERSION = None
anomalib_components = {}

print("🔍 Detecting anomalib configuration...")

# Strategy 1: Try latest anomalib API
try:
    from anomalib import TaskType
    from anomalib.data.image.mvtec import MVTecDataModule
    from anomalib.models.image.padim import Padim, PadimLightningModule
    from anomalib.engine import Engine
    
    ANOMALIB_VERSION = "v1.0+"
    anomalib_components = {
        'datamodule_class': MVTecDataModule,
        'model_class': Padim,
        'engine_class': Engine
    }
    print("✅ Anomalib v1.0+ API detected")
    
except ImportError:
    # Strategy 2: Try older anomalib API
    try:
        from anomalib.data import MVTec
        from anomalib.models.padim import Padim, PadimLightningModule
        from anomalib.utils.callbacks import get_callbacks
        
        ANOMALIB_VERSION = "v0.7+"
        anomalib_components = {
            'datamodule_class': MVTec,
            'model_class': Padim,
            'callbacks_fn': get_callbacks
        }
        print("✅ Anomalib v0.7+ API detected")
        
    except ImportError:
        # Strategy 3: Try even older API
        try:
            from anomalib.data.mvtec import MVTecDataset
            from anomalib.models.padim.lightning_model import PadimLightningModule
            
            ANOMALIB_VERSION = "legacy"
            anomalib_components = {
                'dataset_class': MVTecDataset,
                'model_class': PadimLightningModule
            }
            print("✅ Anomalib legacy API detected")
            
        except ImportError:
            print("⚠️ No anomalib API detected - will use manual implementation")
            ANOMALIB_VERSION = "manual"

# Step 3: Manual PaDiM implementation as ultimate fallback
class ManualPaDiM:
    """Manual PaDiM implementation when anomalib fails"""
    
    def __init__(self, backbone='resnet18', layers=['layer1', 'layer2', 'layer3']):
        self.backbone_name = backbone
        self.layer_names = layers
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"🧠 Manual PaDiM initialized with {backbone} on {self.device}")
        
        # Load pretrained backbone
        if backbone == 'resnet18':
            import torchvision.models as models
            self.backbone = models.resnet18(pretrained=True)
            self.backbone.eval()
            self.backbone.to(self.device)
        
        self.feature_extractor = {}
        self._register_hooks()
        
    def _register_hooks(self):
        """Register hooks for specified layers"""
        def hook_fn(name):
            def hook(module, input, output):
                self.feature_extractor[name] = output
            return hook
        
        # Register hooks for specified layers
        for name, module in self.backbone.named_modules():
            if name in self.layer_names:
                module.register_forward_hook(hook_fn(name))
    
    def extract_features(self, images):
        """Extract features from images"""
        self.feature_extractor.clear()
        
        with torch.no_grad():
            _ = self.backbone(images)
        
        # Concatenate features from all layers
        feature_maps = []
        for layer_name in self.layer_names:
            if layer_name in self.feature_extractor:
                feat = self.feature_extractor[layer_name]
                # Adaptive pool to standard size
                feat = torch.nn.functional.adaptive_avg_pool2d(feat, (28, 28))
                feature_maps.append(feat)
        
        if feature_maps:
            return torch.cat(feature_maps, dim=1)
        else:
            print("⚠️ No features extracted")
            return None
    
    def fit(self, train_loader):
        """Fit the model on training data"""
        print("🔄 Training manual PaDiM...")
        
        all_features = []
        
        for batch_idx, (images, _) in enumerate(train_loader):
            if batch_idx % 10 == 0:
                print(f"Processing batch {batch_idx}...")
            
            images = images.to(self.device)
            features = self.extract_features(images)
            
            if features is not None:
                # Reshape to (batch_size * height * width, features)
                b, c, h, w = features.shape
                features = features.permute(0, 2, 3, 1).reshape(-1, c)
                all_features.append(features.cpu().numpy())
        
        if all_features:
            all_features = np.vstack(all_features)
            
            # Calculate mean and covariance for each spatial location
            self.mean = np.mean(all_features, axis=0)
            self.cov = np.cov(all_features, rowvar=False)
            
            # Add small epsilon to diagonal for numerical stability
            self.cov += np.eye(self.cov.shape[0]) * 1e-6
            
            print(f"✅ Training completed. Feature shape: {all_features.shape}")
            return True
        else:
            print("❌ No features extracted during training")
            return False
    
    def predict(self, test_loader):
        """Predict anomalies on test data"""
        print("🔍 Predicting with manual PaDiM...")
        
        predictions = []
        true_labels = []
        
        for batch_idx, (images, labels) in enumerate(test_loader):
            images = images.to(self.device)
            features = self.extract_features(images)
            
            true_labels.extend(labels.numpy())
            
            if features is not None:
                b, c, h, w = features.shape
                features = features.permute(0, 2, 3, 1).reshape(-1, c).cpu().numpy()
                
                # Calculate Mahalanobis distance
                diff = features - self.mean
                try:
                    inv_cov = np.linalg.pinv(self.cov)
                    distances = np.sum(diff @ inv_cov * diff, axis=1)
                    distances = distances.reshape(b, h, w)
                    
                    # Max pooling to get image-level score
                    image_scores = np.max(distances, axis=(1, 2))
                    predictions.extend(image_scores)
                    
                except Exception as e:
                    print(f"⚠️ Error in distance calculation: {e}")
                    predictions.extend([0.5] * b)
            else:
                predictions.extend([0.5] * len(labels))
        
        return np.array(predictions), np.array(true_labels)

# Step 4: Enhanced dataset exploration
def explore_dataset_structure():
    """Thoroughly explore the MVTec dataset structure"""
    print("\n📂 Downloading and exploring MVTec dataset...")
    
    try:
        # Download dataset
        dataset_path = kagglehub.dataset_download("shashankroy568/mvtec-anomaly-detection")
        print(f"✅ Dataset downloaded to: {dataset_path}")
        
        root_path = Path(dataset_path)
        
        # Print full directory tree (limited depth)
        print(f"\n🌳 Complete directory structure:")
        for root, dirs, files in os.walk(root_path):
            level = root.replace(str(root_path), '').count(os.sep)
            if level < 3:  # Limit depth
                indent = ' ' * 2 * level
                print(f"{indent}📁 {os.path.basename(root)}/")
                if level < 2:  # Show files only at shallow levels
                    subindent = ' ' * 2 * (level + 1)
                    for file in files[:3]:  # Show first 3 files
                        print(f"{subindent}📄 {file}")
                    if len(files) > 3:
                        print(f"{subindent}... and {len(files)-3} more files")
        
        # Look for MVTec categories
        mvtec_categories = [
            'bottle', 'cable', 'capsule', 'carpet', 'grid',
            'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
            'tile', 'toothbrush', 'transistor', 'wood', 'zipper'
        ]
        
        print(f"\n🔍 Searching for MVTec categories...")
        
        # Recursive search for categories
        found_categories = {}
        
        for root, dirs, files in os.walk(root_path):
            for dir_name in dirs:
                if dir_name in mvtec_categories:
                    category_path = Path(root) / dir_name
                    found_categories[dir_name] = category_path
                    print(f"  ✅ Found {dir_name} at: {category_path}")
        
        if found_categories:
            print(f"\n🎯 Found {len(found_categories)} MVTec categories!")
            
            # Test structure of first category
            first_category = list(found_categories.keys())[0]
            first_path = found_categories[first_category]
            
            print(f"\n🔬 Examining structure of '{first_category}':")
            for item in first_path.iterdir():
                if item.is_dir():
                    file_count = len(list(item.rglob("*.*")))
                    print(f"  📁 {item.name}: {file_count} files")
            
            return root_path, found_categories
        
        else:
            print("❌ No MVTec categories found in expected locations")
            print("📋 Available directories:")
            for item in root_path.rglob("*"):
                if item.is_dir():
                    print(f"  📁 {item.relative_to(root_path)}")
            
            return root_path, {}
    
    except Exception as e:
        print(f"❌ Dataset exploration error: {e}")
        return None, {}

# Step 5: FIXED manual dataset class with correct label logic
class MVTecManualDataset(Dataset):
    """Manual MVTec dataset when anomalib fails - FIXED LABEL LOGIC"""
    
    def __init__(self, root_path, category, split='train', transform=None):
        self.root_path = Path(root_path)
        self.category = category
        self.split = split
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        ])
        
        self.samples = self._load_samples()
        print(f"📊 Loaded {len(self.samples)} {split} samples for {category}")
    
    def _load_samples(self):
        """Load all samples for the dataset with CORRECT MVTec labeling"""
        samples = []
        
        # Find category path
        possible_paths = [
            self.root_path / self.category,
            self.root_path / "mvtec_anomaly_detection" / self.category,
            self.root_path / "MVTec" / self.category,
            self.root_path / "mvtec" / self.category
        ]
        
        category_path = None
        for path in possible_paths:
            if path.exists():
                category_path = path
                print(f"✅ Found category at: {category_path}")
                break
        
        if not category_path:
            print(f"❌ Category path not found. Tried:")
            for path in possible_paths:
                print(f"   - {path}")
            return samples
        
        # MVTec Dataset Structure:
        # train/good/ -> Normal samples (label 0)
        # test/good/ -> Normal test samples (label 0)  
        # test/defect_type/ -> Anomaly samples (label 1)
        
        if self.split == 'train':
            # Training: Only normal samples from train/good
            good_path = category_path / 'train' / 'good'
            if good_path.exists():
                print(f"📁 Loading NORMAL training samples from: {good_path}")
                for ext in ['*.png', '*.jpg', '*.jpeg', '*.bmp']:
                    for img_path in good_path.rglob(ext):
                        samples.append((str(img_path), 0))  # Label 0 = Normal
                        
                        if len(samples) <= 3:
                            print(f"   Sample: {img_path.name} -> NORMAL (label 0)")
            else:
                print(f"❌ Training good path not found: {good_path}")
        
        else:  # test split
            # Test: Both normal and anomaly samples
            test_path = category_path / 'test'
            if test_path.exists():
                print(f"📁 Loading test samples from: {test_path}")
                
                # Load normal test samples from test/good
                good_test_path = test_path / 'good'
                if good_test_path.exists():
                    print(f"   📁 Loading NORMAL test samples from: {good_test_path}")
                    for ext in ['*.png', '*.jpg', '*.jpeg', '*.bmp']:
                        for img_path in good_test_path.rglob(ext):
                            samples.append((str(img_path), 0))  # Label 0 = Normal
                            
                            if len([s for s in samples if s[1] == 0]) <= 3:
                                print(f"      Sample: {img_path.name} -> NORMAL (label 0)")
                
                # Load anomaly samples from test/defect_type folders
                for defect_dir in test_path.iterdir():
                    if defect_dir.is_dir() and defect_dir.name != 'good':
                        print(f"   📁 Loading ANOMALY samples from: {defect_dir.name}")
                        for ext in ['*.png', '*.jpg', '*.jpeg', '*.bmp']:
                            for img_path in defect_dir.rglob(ext):
                                samples.append((str(img_path), 1))  # Label 1 = Anomaly
                                
                                if len([s for s in samples if s[1] == 1]) <= 3:
                                    print(f"      Sample: {img_path.name} -> ANOMALY (label 1)")
            else:
                print(f"❌ Test path not found: {test_path}")
        
        # Count samples by label
        normal_count = len([s for s in samples if s[1] == 0])
        anomaly_count = len([s for s in samples if s[1] == 1])
        
        print(f"   ✅ Found {normal_count} NORMAL samples (label 0)")
        print(f"   ✅ Found {anomaly_count} ANOMALY samples (label 1)")
        
        return samples
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"⚠️ Error loading {img_path}: {e}")
            # Return dummy data
            dummy_image = torch.zeros((3, 224, 224))
            return dummy_image, label

# Step 6: Main training pipeline
def run_anomaly_detection_pipeline():
    """Main pipeline with all fallbacks"""
    
    print("🚀 Starting Anomaly Detection Pipeline")
    print("=" * 60)
    
    # Explore dataset
    root_path, categories = explore_dataset_structure()
    
    if not root_path or not categories:
        print("❌ Cannot proceed without dataset")
        return
    
    # Select category
    target_categories = ['bottle', 'metal_nut', 'capsule', 'cable']
    available_targets = [cat for cat in target_categories if cat in categories]
    
    if not available_targets:
        test_category = list(categories.keys())[0]
        print(f"⚠️ No target categories found, using: {test_category}")
    else:
        test_category = available_targets[0]
        print(f"🎯 Using target category: {test_category}")
    
    category_path = categories[test_category]
    
    # Create datasets
    print(f"\n📊 Creating datasets...")
    
    try:
        if ANOMALIB_VERSION and ANOMALIB_VERSION != "manual":
            # Try anomalib dataset
            print(f"Using anomalib {ANOMALIB_VERSION} API...")
            
            if 'datamodule_class' in anomalib_components:
                datamodule = anomalib_components['datamodule_class'](
                    root=str(category_path.parent),
                    category=test_category,
                    image_size=(224, 224),
                    train_batch_size=8,
                    eval_batch_size=8,
                    num_workers=2
                )
                
                # Setup data
                datamodule.setup()
                train_loader = datamodule.train_dataloader()
                test_loader = datamodule.test_dataloader()
                
                print("✅ Anomalib datamodule created successfully")
                use_anomalib = True
                
        else:
            raise Exception("Using manual approach")
            
    except Exception as e:
        print(f"⚠️ Anomalib dataset failed ({e}), using manual approach...")
        use_anomalib = False
        
        # Manual dataset creation with FIXED labeling
        train_dataset = MVTecManualDataset(root_path, test_category, 'train')
        test_dataset = MVTecManualDataset(root_path, test_category, 'test')
        
        # Add safety check for empty datasets
        if len(train_dataset) == 0:
            print(f"❌ No training samples found for {test_category}")
            return

        if len(test_dataset) == 0:
            print(f"❌ No test samples found for {test_category}")
            return

        # Check if we have both normal and anomaly samples for test
        test_labels = [test_dataset.samples[i][1] for i in range(len(test_dataset))]
        unique_labels = set(test_labels)
        
        if len(unique_labels) < 2:
            print(f"⚠️ Warning: Test set only has {unique_labels} labels. AUC calculation may fail.")
            print("This might happen if the dataset structure is different than expected.")
        
        print(f"✅ Dataset ready: {len(train_dataset)} train, {len(test_dataset)} test samples")
        print(f"✅ Test set has {len([l for l in test_labels if l == 0])} normal and {len([l for l in test_labels if l == 1])} anomaly samples")
        
        train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
        test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2)
    
    # Train model
    print(f"\n🧠 Training model...")
    
    try:
        if use_anomalib and 'model_class' in anomalib_components:
            # Use anomalib model
            print("Using anomalib PaDiM model...")
            model = anomalib_components['model_class']()
            
            # Training logic depends on version
            if ANOMALIB_VERSION == "v1.0+":
                engine = anomalib_components['engine_class'](max_epochs=1)
                engine.fit(model=model, datamodule=datamodule)
                results = engine.test(model=model, datamodule=datamodule)
                print(f"✅ Anomalib training completed: {results}")
            else:
                # Use PyTorch Lightning trainer
                import pytorch_lightning as pl
                trainer = pl.Trainer(max_epochs=1, accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=1)
                trainer.fit(model, datamodule)
                results = trainer.test(model, datamodule)
                print(f"✅ Anomalib training completed: {results}")
            
        else:
            # Use manual PaDiM
            print("Using manual PaDiM implementation...")
            model = ManualPaDiM()
            
            # Train
            training_success = model.fit(train_loader)
            
            if training_success:
                # Test
                predictions, true_labels = model.predict(test_loader)
                
                # Calculate metrics
                if len(predictions) == len(true_labels) and len(predictions) > 0:
                    # Check if we have both classes
                    unique_test_labels = set(true_labels)
                    
                    if len(unique_test_labels) >= 2:
                        auc_score = roc_auc_score(true_labels, predictions)
                        print(f"✅ Manual PaDiM AUC Score: {auc_score:.4f}")
                    else:
                        print(f"⚠️ Cannot calculate AUC - only one class present: {unique_test_labels}")
                    
                    # Show prediction distribution
                    normal_indices = true_labels == 0
                    anomaly_indices = true_labels == 1
                    
                    normal_scores = predictions[normal_indices] if normal_indices.any() else np.array([])
                    anomaly_scores = predictions[anomaly_indices] if anomaly_indices.any() else np.array([])
                    
                    print(f"📊 Score Statistics:")
                    if len(normal_scores) > 0:
                        print(f"   Normal samples: {len(normal_scores)}, mean score: {normal_scores.mean():.4f}")
                    if len(anomaly_scores) > 0:
                        print(f"   Anomaly samples: {len(anomaly_scores)}, mean score: {anomaly_scores.mean():.4f}")
                    
                else:
                    print(f"⚠️ Prediction length mismatch: {len(predictions)} vs {len(true_labels)}")
            
    except Exception as e:
        print(f"❌ Training failed: {e}")
        import traceback
        traceback.print_exc()
        return
    
    # Create summary
    print(f"\n" + "=" * 60)
    print("📋 PIPELINE SUMMARY")
    print("=" * 60)
    print(f"✅ Dataset: MVTec Anomaly Detection")
    print(f"✅ Category: {test_category}")
    print(f"✅ Model: PaDiM ({'Anomalib' if use_anomalib else 'Manual'})")
    print(f"✅ Framework: {ANOMALIB_VERSION}")
    print(f"✅ Training: Completed")
    print(f"✅ Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
    
    print(f"\n🎉 Anomaly detection pipeline completed successfully!")

# Execute the pipeline
if __name__ == "__main__":
    run_anomaly_detection_pipeline()



In [None]:
# MVTec Anomaly Detection - Multi-Category Training
# Designed for Kaggle GPU P100 environment
import os
import sys
import warnings
import subprocess
warnings.filterwarnings('ignore')

print("🔧 Starting robust environment setup...")

def install_package(package_name, import_name=None, extra_args=""):
    """Install package with proper error handling"""
    try:
        # Try importing first
        if import_name:
            __import__(import_name)
            print(f"✅ {package_name} already available")
            return True
        
        print(f"Installing {package_name}...")
        cmd = f"pip install {package_name} {extra_args}"
        result = subprocess.run(cmd.split(), capture_output=True, text=True)
        
        if result.returncode == 0:
            print(f"✅ {package_name} installed successfully")
            return True
        else:
            print(f"⚠️ Warning installing {package_name}: {result.stderr}")
            return False
            
    except Exception as e:
        print(f"❌ Error with {package_name}: {e}")
        return False

# Step 1: Install all dependencies step by step
print("📦 Installing dependencies...")

# Core dependencies first
install_package("python-dotenv", "dotenv", "--quiet")
install_package("opencv-python", "cv2", "--quiet")
install_package("Pillow", "PIL", "--quiet --upgrade")

# Try different anomalib installation strategies
print("\n🔧 Installing anomalib with multiple strategies...")
strategies = [
    ("anomalib", "--quiet --no-deps --upgrade"),
    ("anomalib", "--quiet --force-reinstall"),
    ("anomalib==1.0.1", "--quiet"),
    ("git+https://github.com/openvinotoolkit/anomalib.git", "--quiet")
]

anomalib_installed = False
for package, args in strategies:
    print(f"Trying: pip install {package} {args}")
    if install_package(package, None, args):
        # Test import after each installation
        try:
            import anomalib
            print(f"✅ Anomalib successfully imported!")
            anomalib_installed = True
            break
        except Exception as e:
            print(f"⚠️ Installation succeeded but import failed: {e}")
            continue

if not anomalib_installed:
    print("⚠️ Anomalib installation issues detected. Using manual approach...")

# Step 2: Import required libraries with comprehensive fallbacks
print("\n📚 Importing libraries...")

# Standard imports
try:
    import kagglehub
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from pathlib import Path
    import torch
    from PIL import Image
    import cv2
    from glob import glob
    import json
    from sklearn.metrics import roc_auc_score, roc_curve
    from sklearn.preprocessing import MinMaxScaler
    import torchvision.transforms as transforms
    from torch.utils.data import Dataset, DataLoader
    import time
    
    print("✅ Standard libraries imported")
except ImportError as e:
    print(f"❌ Standard import error: {e}")
    sys.exit(1)

# Anomalib imports with multiple fallback options
ANOMALIB_VERSION = None
anomalib_components = {}

print("🔍 Detecting anomalib configuration...")

# Strategy 1: Try latest anomalib API
try:
    from anomalib import TaskType
    from anomalib.data.image.mvtec import MVTecDataModule
    from anomalib.models.image.padim import Padim, PadimLightningModule
    from anomalib.engine import Engine
    
    ANOMALIB_VERSION = "v1.0+"
    anomalib_components = {
        'datamodule_class': MVTecDataModule,
        'model_class': Padim,
        'engine_class': Engine
    }
    print("✅ Anomalib v1.0+ API detected")
    
except ImportError:
    # Strategy 2: Try older anomalib API
    try:
        from anomalib.data import MVTec
        from anomalib.models.padim import Padim, PadimLightningModule
        from anomalib.utils.callbacks import get_callbacks
        
        ANOMALIB_VERSION = "v0.7+"
        anomalib_components = {
            'datamodule_class': MVTec,
            'model_class': Padim,
            'callbacks_fn': get_callbacks
        }
        print("✅ Anomalib v0.7+ API detected")
        
    except ImportError:
        # Strategy 3: Try even older API
        try:
            from anomalib.data.mvtec import MVTecDataset
            from anomalib.models.padim.lightning_model import PadimLightningModule
            
            ANOMALIB_VERSION = "legacy"
            anomalib_components = {
                'dataset_class': MVTecDataset,
                'model_class': PadimLightningModule
            }
            print("✅ Anomalib legacy API detected")
            
        except ImportError:
            print("⚠️ No anomalib API detected - will use manual implementation")
            ANOMALIB_VERSION = "manual"

# Step 3: Manual PaDiM implementation as ultimate fallback
class ManualPaDiM:
    """Manual PaDiM implementation when anomalib fails"""
    
    def __init__(self, backbone='resnet18', layers=['layer1', 'layer2', 'layer3']):
        self.backbone_name = backbone
        self.layer_names = layers
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"🧠 Manual PaDiM initialized with {backbone} on {self.device}")
        
        # Load pretrained backbone
        if backbone == 'resnet18':
            import torchvision.models as models
            self.backbone = models.resnet18(pretrained=True)
            self.backbone.eval()
            self.backbone.to(self.device)
        
        self.feature_extractor = {}
        self._register_hooks()
        
    def _register_hooks(self):
        """Register hooks for specified layers"""
        def hook_fn(name):
            def hook(module, input, output):
                self.feature_extractor[name] = output
            return hook
        
        # Register hooks for specified layers
        for name, module in self.backbone.named_modules():
            if name in self.layer_names:
                module.register_forward_hook(hook_fn(name))
    
    def extract_features(self, images):
        """Extract features from images"""
        self.feature_extractor.clear()
        
        with torch.no_grad():
            _ = self.backbone(images)
        
        # Concatenate features from all layers
        feature_maps = []
        for layer_name in self.layer_names:
            if layer_name in self.feature_extractor:
                feat = self.feature_extractor[layer_name]
                # Adaptive pool to standard size
                feat = torch.nn.functional.adaptive_avg_pool2d(feat, (28, 28))
                feature_maps.append(feat)
        
        if feature_maps:
            return torch.cat(feature_maps, dim=1)
        else:
            print("⚠️ No features extracted")
            return None
    
    def fit(self, train_loader):
        """Fit the model on training data"""
        print("🔄 Training manual PaDiM...")
        
        all_features = []
        
        for batch_idx, (images, _) in enumerate(train_loader):
            if batch_idx % 10 == 0:
                print(f"Processing batch {batch_idx}...")
            
            images = images.to(self.device)
            features = self.extract_features(images)
            
            if features is not None:
                # Reshape to (batch_size * height * width, features)
                b, c, h, w = features.shape
                features = features.permute(0, 2, 3, 1).reshape(-1, c)
                all_features.append(features.cpu().numpy())
        
        if all_features:
            all_features = np.vstack(all_features)
            
            # Calculate mean and covariance for each spatial location
            self.mean = np.mean(all_features, axis=0)
            self.cov = np.cov(all_features, rowvar=False)
            
            # Add small epsilon to diagonal for numerical stability
            self.cov += np.eye(self.cov.shape[0]) * 1e-6
            
            print(f"✅ Training completed. Feature shape: {all_features.shape}")
            return True
        else:
            print("❌ No features extracted during training")
            return False
    
    def predict(self, test_loader):
        """Predict anomalies on test data"""
        print("🔍 Predicting with manual PaDiM...")
        
        predictions = []
        true_labels = []
        
        for batch_idx, (images, labels) in enumerate(test_loader):
            images = images.to(self.device)
            features = self.extract_features(images)
            
            true_labels.extend(labels.numpy())
            
            if features is not None:
                b, c, h, w = features.shape
                features = features.permute(0, 2, 3, 1).reshape(-1, c).cpu().numpy()
                
                # Calculate Mahalanobis distance
                diff = features - self.mean
                try:
                    inv_cov = np.linalg.pinv(self.cov)
                    distances = np.sum(diff @ inv_cov * diff, axis=1)
                    distances = distances.reshape(b, h, w)
                    
                    # Max pooling to get image-level score
                    image_scores = np.max(distances, axis=(1, 2))
                    predictions.extend(image_scores)
                    
                except Exception as e:
                    print(f"⚠️ Error in distance calculation: {e}")
                    predictions.extend([0.5] * b)
            else:
                predictions.extend([0.5] * len(labels))
        
        return np.array(predictions), np.array(true_labels)

# Step 4: Enhanced dataset exploration
def explore_dataset_structure():
    """Thoroughly explore the MVTec dataset structure"""
    print("\n📂 Downloading and exploring MVTec dataset...")
    
    try:
        # Download dataset
        dataset_path = kagglehub.dataset_download("shashankroy568/mvtec-anomaly-detection")
        print(f"✅ Dataset downloaded to: {dataset_path}")
        
        root_path = Path(dataset_path)
        
        # Look for MVTec categories
        mvtec_categories = [
            'bottle', 'cable', 'capsule', 'carpet', 'grid',
            'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
            'tile', 'toothbrush', 'transistor', 'wood', 'zipper'
        ]
        
        print(f"\n🔍 Searching for MVTec categories...")
        
        # Recursive search for categories
        found_categories = {}
        
        for root, dirs, files in os.walk(root_path):
            for dir_name in dirs:
                if dir_name in mvtec_categories:
                    category_path = Path(root) / dir_name
                    found_categories[dir_name] = category_path
                    print(f"  ✅ Found {dir_name} at: {category_path}")
        
        if found_categories:
            print(f"\n🎯 Found {len(found_categories)} MVTec categories!")
            return root_path, found_categories
        
        else:
            print("❌ No MVTec categories found in expected locations")
            return root_path, {}
    
    except Exception as e:
        print(f"❌ Dataset exploration error: {e}")
        return None, {}

# Step 5: FIXED manual dataset class with correct label logic
class MVTecManualDataset(Dataset):
    """Manual MVTec dataset when anomalib fails - FIXED LABEL LOGIC"""
    
    def __init__(self, root_path, category, split='train', transform=None):
        self.root_path = Path(root_path)
        self.category = category
        self.split = split
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        ])
        
        self.samples = self._load_samples()
        print(f"📊 {category}: Loaded {len(self.samples)} {split} samples")
    
    def _load_samples(self):
        """Load all samples for the dataset with CORRECT MVTec labeling"""
        samples = []
        
        # Find category path
        possible_paths = [
            self.root_path / self.category,
            self.root_path / "mvtec_anomaly_detection" / self.category,
            self.root_path / "MVTec" / self.category,
            self.root_path / "mvtec" / self.category
        ]
        
        category_path = None
        for path in possible_paths:
            if path.exists():
                category_path = path
                break
        
        if not category_path:
            print(f"❌ Category path not found for {self.category}")
            return samples
        
        if self.split == 'train':
            # Training: Only normal samples from train/good
            good_path = category_path / 'train' / 'good'
            if good_path.exists():
                for ext in ['*.png', '*.jpg', '*.jpeg', '*.bmp']:
                    for img_path in good_path.rglob(ext):
                        samples.append((str(img_path), 0))  # Label 0 = Normal
        else:  # test split
            # Test: Both normal and anomaly samples
            test_path = category_path / 'test'
            if test_path.exists():
                # Load normal test samples from test/good
                good_test_path = test_path / 'good'
                if good_test_path.exists():
                    for ext in ['*.png', '*.jpg', '*.jpeg', '*.bmp']:
                        for img_path in good_test_path.rglob(ext):
                            samples.append((str(img_path), 0))  # Label 0 = Normal
                
                # Load anomaly samples from test/defect_type folders
                for defect_dir in test_path.iterdir():
                    if defect_dir.is_dir() and defect_dir.name != 'good':
                        for ext in ['*.png', '*.jpg', '*.jpeg', '*.bmp']:
                            for img_path in defect_dir.rglob(ext):
                                samples.append((str(img_path), 1))  # Label 1 = Anomaly
        
        # Count samples by label
        normal_count = len([s for s in samples if s[1] == 0])
        anomaly_count = len([s for s in samples if s[1] == 1])
        
        print(f"   ✅ {self.category} {self.split}: {normal_count} normal, {anomaly_count} anomaly")
        
        return samples
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"⚠️ Error loading {img_path}: {e}")
            # Return dummy data
            dummy_image = torch.zeros((3, 224, 224))
            return dummy_image, label

# Step 6: Single category training function
def train_single_category(root_path, category, results_dict):
    """Train PaDiM on a single category"""
    
    print(f"\n" + "="*50)
    print(f"🚀 TRAINING CATEGORY: {category.upper()}")
    print(f"="*50)
    
    start_time = time.time()
    
    try:
        # Create datasets
        train_dataset = MVTecManualDataset(root_path, category, 'train')
        test_dataset = MVTecManualDataset(root_path, category, 'test')
        
        # Safety checks
        if len(train_dataset) == 0:
            print(f"❌ No training samples found for {category}")
            results_dict[category] = {'status': 'failed', 'reason': 'no_train_data'}
            return
            
        if len(test_dataset) == 0:
            print(f"❌ No test samples found for {category}")
            results_dict[category] = {'status': 'failed', 'reason': 'no_test_data'}
            return
        
        # Check test labels
        test_labels = [test_dataset.samples[i][1] for i in range(len(test_dataset))]
        unique_labels = set(test_labels)
        
        if len(unique_labels) < 2:
            print(f"⚠️ Warning: {category} test set only has {unique_labels} labels")
            results_dict[category] = {'status': 'failed', 'reason': 'insufficient_labels'}
            return
        
        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
        test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2)
        
        print(f"✅ {category}: Dataset ready - {len(train_dataset)} train, {len(test_dataset)} test")
        
        # Train model
        print(f"🧠 Training PaDiM for {category}...")
        model = ManualPaDiM()
        
        training_success = model.fit(train_loader)
        
        if not training_success:
            results_dict[category] = {'status': 'failed', 'reason': 'training_failed'}
            return
        
        # Test model
        predictions, true_labels = model.predict(test_loader)
        
        # Calculate metrics
        if len(predictions) == len(true_labels) and len(predictions) > 0:
            unique_test_labels = set(true_labels)
            
            if len(unique_test_labels) >= 2:
                auc_score = roc_auc_score(true_labels, predictions)
                
                # Calculate additional metrics
                normal_indices = true_labels == 0
                anomaly_indices = true_labels == 1
                
                normal_scores = predictions[normal_indices] if normal_indices.any() else np.array([])
                anomaly_scores = predictions[anomaly_indices] if anomaly_indices.any() else np.array([])
                
                training_time = time.time() - start_time
                
                # Store results
                results_dict[category] = {
                    'status': 'success',
                    'auc_score': auc_score,
                    'train_samples': len(train_dataset),
                    'test_samples': len(test_dataset),
                    'normal_test_samples': len(normal_scores),
                    'anomaly_test_samples': len(anomaly_scores),
                    'normal_mean_score': float(normal_scores.mean()) if len(normal_scores) > 0 else 0,
                    'anomaly_mean_score': float(anomaly_scores.mean()) if len(anomaly_scores) > 0 else 0,
                    'training_time': training_time
                }
                
                print(f"✅ {category}: AUC Score = {auc_score:.4f}")
                print(f"📊 {category}: Normal scores = {normal_scores.mean():.2f}, Anomaly scores = {anomaly_scores.mean():.2f}")
                print(f"⏱️  {category}: Training time = {training_time:.1f}s")
                
            else:
                results_dict[category] = {'status': 'failed', 'reason': 'insufficient_test_labels'}
        else:
            results_dict[category] = {'status': 'failed', 'reason': 'prediction_mismatch'}
            
    except Exception as e:
        print(f"❌ {category}: Training failed - {e}")
        results_dict[category] = {'status': 'failed', 'reason': str(e)}

# Step 7: Multi-category training pipeline
def run_multi_category_pipeline():
    """Train on all target categories"""
    
    print("🚀 Starting Multi-Category Anomaly Detection Pipeline")
    print("=" * 70)
    
    # Explore dataset
    root_path, categories = explore_dataset_structure()
    
    if not root_path or not categories:
        print("❌ Cannot proceed without dataset")
        return
    
    # Target categories for warehouse research
    target_categories = ['bottle', 'metal_nut', 'capsule', 'cable']
    available_targets = [cat for cat in target_categories if cat in categories]
    
    if not available_targets:
        print("❌ No target categories found in dataset")
        return
    
    print(f"\n🎯 Will train on {len(available_targets)} categories: {available_targets}")
    
    # Train each category
    results = {}
    total_start_time = time.time()
    
    for i, category in enumerate(available_targets, 1):
        print(f"\n🔄 Progress: {i}/{len(available_targets)} categories")
        train_single_category(root_path, category, results)
    
    total_time = time.time() - total_start_time
    
    # Generate comprehensive results summary
    print(f"\n" + "="*70)
    print("📋 MULTI-CATEGORY TRAINING RESULTS")
    print("="*70)
    
    successful_trainings = [cat for cat, res in results.items() if res.get('status') == 'success']
    failed_trainings = [cat for cat, res in results.items() if res.get('status') == 'failed']
    
    print(f"✅ Successful: {len(successful_trainings)}/{len(available_targets)} categories")
    print(f"❌ Failed: {len(failed_trainings)}/{len(available_targets)} categories")
    print(f"⏱️  Total time: {total_time:.1f}s")
    
    # Detailed results table
    if successful_trainings:
        print(f"\n📊 DETAILED RESULTS:")
        print("-" * 90)
        print(f"{'Category':<12} {'AUC Score':<10} {'Train':<7} {'Test':<6} {'Normal':<8} {'Anomaly':<8} {'Time':<6}")
        print("-" * 90)
        
        for category in successful_trainings:
            res = results[category]
            print(f"{category:<12} {res['auc_score']:<10.4f} {res['train_samples']:<7} {res['test_samples']:<6} "
                  f"{res['normal_test_samples']:<8} {res['anomaly_test_samples']:<8} {res['training_time']:<6.1f}s")
        
        # Calculate average AUC
        avg_auc = np.mean([results[cat]['auc_score'] for cat in successful_trainings])
        print("-" * 90)
        print(f"{'AVERAGE':<12} {avg_auc:<10.4f}")
        print("-" * 90)
    
    # Failed categories details
    if failed_trainings:
        print(f"\n❌ FAILED CATEGORIES:")
        for category in failed_trainings:
            reason = results[category].get('reason', 'unknown')
            print(f"   {category}: {reason}")
    
    # Research summary
    print(f"\n🎓 RESEARCH SUMMARY:")
    print(f"   Dataset: MVTec Anomaly Detection")
    print(f"   Model: PaDiM (Manual Implementation)")
    print(f"   Categories: {len(successful_trainings)} warehouse-relevant products")
    print(f"   Performance: {avg_auc:.4f} average AUC score" if successful_trainings else "   Performance: No successful trainings")
    print(f"   Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
    
    # Export results to CSV for research paper
    if successful_trainings:
        results_df = pd.DataFrame([
            {
                'category': category,
                'auc_score': results[category]['auc_score'],
                'train_samples': results[category]['train_samples'],
                'test_samples': results[category]['test_samples'],
                'normal_test_samples': results[category]['normal_test_samples'],
                'anomaly_test_samples': results[category]['anomaly_test_samples'],
                'normal_mean_score': results[category]['normal_mean_score'],
                'anomaly_mean_score': results[category]['anomaly_mean_score'],
                'training_time': results[category]['training_time']
            }
            for category in successful_trainings
        ])
        
        results_df.to_csv('mvtec_padim_results.csv', index=False)
        print(f"\n💾 Results exported to: mvtec_padim_results.csv")
    
    print(f"\n🎉 Multi-category training pipeline completed!")
    return results

# Execute the multi-category pipeline
if __name__ == "__main__":
    results = run_multi_category_pipeline()


In [None]:
# MVTec Anomaly Detection - Multi-Category Training with PaDiM
# Designed for Kaggle GPU P100 environment
import os
import sys
import warnings
import subprocess
warnings.filterwarnings('ignore')

print("🔧 Starting robust environment setup...")

def install_package(package_name, import_name=None, extra_args=""):
    """Install package with proper error handling"""
    try:
        # Try importing first
        if import_name:
            __import__(import_name)
            print(f"✅ {package_name} already available")
            return True
        
        print(f"Installing {package_name}...")
        cmd = f"pip install {package_name} {extra_args}"
        result = subprocess.run(cmd.split(), capture_output=True, text=True)
        
        if result.returncode == 0:
            print(f"✅ {package_name} installed successfully")
            return True
        else:
            print(f"⚠️ Warning installing {package_name}: {result.stderr}")
            return False
            
    except Exception as e:
        print(f"❌ Error with {package_name}: {e}")
        return False

# Step 1: Install all dependencies step by step
print("📦 Installing dependencies...")
# Core dependencies first
install_package("python-dotenv", "dotenv", "--quiet")
install_package("opencv-python", "cv2", "--quiet")
install_package("Pillow", "PIL", "--quiet --upgrade")

# Try different anomalib installation strategies
print("\n🔧 Installing anomalib with multiple strategies...")
strategies = [
    ("anomalib", "--quiet --no-deps --upgrade"),
    ("anomalib", "--quiet --force-reinstall"),
    ("anomalib==1.0.1", "--quiet"),
    ("git+https://github.com/openvinotoolkit/anomalib.git", "--quiet")
]

anomalib_installed = False
for package, args in strategies:
    print(f"Trying: pip install {package} {args}")
    if install_package(package, None, args):
        # Test import after each installation
        try:
            import anomalib
            print(f"✅ Anomalib successfully imported!")
            anomalib_installed = True
            break
        except Exception as e:
            print(f"⚠️ Installation succeeded but import failed: {e}")
            continue

if not anomalib_installed:
    print("⚠️ Anomalib installation issues detected. Using manual approach...")

# Step 2: Import required libraries with comprehensive fallbacks
print("\n📚 Importing libraries...")
# Standard imports
try:
    import kagglehub
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from pathlib import Path
    import torch
    from PIL import Image
    import cv2
    from glob import glob
    import json
    from sklearn.metrics import roc_auc_score, roc_curve, precision_score, recall_score, f1_score, confusion_matrix
    from sklearn.preprocessing import MinMaxScaler
    import torchvision.transforms as transforms
    from torch.utils.data import Dataset, DataLoader
    import time
    
    print("✅ Standard libraries imported")
except ImportError as e:
    print(f"❌ Standard import error: {e}")
    sys.exit(1)

# Anomalib imports with multiple fallback options
ANOMALIB_VERSION = None
anomalib_components = {}
print("🔍 Detecting anomalib configuration...")

# Strategy 1: Try latest anomalib API
try:
    from anomalib import TaskType
    from anomalib.data.image.mvtec import MVTecDataModule
    from anomalib.models.image.padim import Padim, PadimLightningModule
    from anomalib.engine import Engine
    
    ANOMALIB_VERSION = "v1.0+"
    anomalib_components = {
        'datamodule_class': MVTecDataModule,
        'model_class': Padim,
        'engine_class': Engine
    }
    print("✅ Anomalib v1.0+ API detected")
    
except ImportError:
    # Strategy 2: Try older anomalib API
    try:
        from anomalib.data import MVTec
        from anomalib.models.padim import Padim, PadimLightningModule
        from anomalib.utils.callbacks import get_callbacks
        
        ANOMALIB_VERSION = "v0.7+"
        anomalib_components = {
            'datamodule_class': MVTec,
            'model_class': Padim,
            'callbacks_fn': get_callbacks
        }
        print("✅ Anomalib v0.7+ API detected")
        
    except ImportError:
        # Strategy 3: Try even older API
        try:
            from anomalib.data.mvtec import MVTecDataset
            from anomalib.models.padim.lightning_model import PadimLightningModule
            
            ANOMALIB_VERSION = "legacy"
            anomalib_components = {
                'dataset_class': MVTecDataset,
                'model_class': PadimLightningModule
            }
            print("✅ Anomalib legacy API detected")
            
        except ImportError:
            print("⚠️ No anomalib API detected - will use manual implementation")
            ANOMALIB_VERSION = "manual"

# Helper functions for enhanced metrics calculation
def calculate_optimal_threshold(y_true, y_scores):
    """Calculate optimal threshold using Youden's J statistic"""
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    j_scores = tpr - fpr  # Youden's J statistic
    best_threshold_idx = np.argmax(j_scores)
    best_threshold = thresholds[best_threshold_idx]
    return best_threshold

def calculate_comprehensive_metrics(y_true, y_scores, threshold=None):
    """Calculate comprehensive metrics including precision, recall, F1"""
    if threshold is None:
        threshold = calculate_optimal_threshold(y_true, y_scores)
    
    # Convert scores to binary predictions
    y_pred = (y_scores >= threshold).astype(int)
    
    # Calculate metrics
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    
    # Additional metrics from confusion matrix
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
    
    return {
        'threshold': threshold,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'specificity': specificity,
        'accuracy': accuracy,
        'true_positives': int(tp),
        'true_negatives': int(tn),
        'false_positives': int(fp),
        'false_negatives': int(fn)
    }

# Step 3: Manual PaDiM implementation as ultimate fallback
class ManualPaDiM:
    """Manual PaDiM implementation when anomalib fails"""
    
    def __init__(self, backbone='resnet18', layers=['layer1', 'layer2', 'layer3']):
        self.backbone_name = backbone
        self.layer_names = layers
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"🧠 Manual PaDiM initialized with {backbone} on {self.device}")
        
        # Load pretrained backbone
        if backbone == 'resnet18':
            import torchvision.models as models
            self.backbone = models.resnet18(pretrained=True)
            self.backbone.eval()
            self.backbone.to(self.device)
        
        self.feature_extractor = {}
        self._register_hooks()
        
    def _register_hooks(self):
        """Register hooks for specified layers"""
        def hook_fn(name):
            def hook(module, input, output):
                self.feature_extractor[name] = output
            return hook
        
        # Register hooks for specified layers
        for name, module in self.backbone.named_modules():
            if name in self.layer_names:
                module.register_forward_hook(hook_fn(name))
    
    def extract_features(self, images):
        """Extract features from images"""
        self.feature_extractor.clear()
        
        with torch.no_grad():
            _ = self.backbone(images)
        
        # Concatenate features from all layers
        feature_maps = []
        for layer_name in self.layer_names:
            if layer_name in self.feature_extractor:
                feat = self.feature_extractor[layer_name]
                # Adaptive pool to standard size
                feat = torch.nn.functional.adaptive_avg_pool2d(feat, (28, 28))
                feature_maps.append(feat)
        
        if feature_maps:
            return torch.cat(feature_maps, dim=1)
        else:
            print("⚠️ No features extracted")
            return None
    
    def fit(self, train_loader):
        """Fit the model on training data"""
        print("🔄 Training manual PaDiM...")
        
        all_features = []
        
        for batch_idx, (images, _) in enumerate(train_loader):
            if batch_idx % 10 == 0:
                print(f"Processing batch {batch_idx}...")
            
            images = images.to(self.device)
            features = self.extract_features(images)
            
            if features is not None:
                # Reshape to (batch_size * height * width, features)
                b, c, h, w = features.shape
                features = features.permute(0, 2, 3, 1).reshape(-1, c)
                all_features.append(features.cpu().numpy())
        
        if all_features:
            all_features = np.vstack(all_features)
            
            # Calculate mean and covariance for each spatial location
            self.mean = np.mean(all_features, axis=0)
            self.cov = np.cov(all_features, rowvar=False)
            
            # Add small epsilon to diagonal for numerical stability
            self.cov += np.eye(self.cov.shape[0]) * 1e-6
            
            print(f"✅ Training completed. Feature shape: {all_features.shape}")
            return True
        else:
            print("❌ No features extracted during training")
            return False
    
    def predict(self, test_loader):
        """Predict anomalies on test data"""
        print("🔍 Predicting with manual PaDiM...")
        
        predictions = []
        true_labels = []
        
        for batch_idx, (images, labels) in enumerate(test_loader):
            images = images.to(self.device)
            features = self.extract_features(images)
            
            true_labels.extend(labels.numpy())
            
            if features is not None:
                b, c, h, w = features.shape
                features = features.permute(0, 2, 3, 1).reshape(-1, c).cpu().numpy()
                
                # Calculate Mahalanobis distance
                diff = features - self.mean
                try:
                    inv_cov = np.linalg.pinv(self.cov)
                    distances = np.sum(diff @ inv_cov * diff, axis=1)
                    distances = distances.reshape(b, h, w)
                    
                    # Max pooling to get image-level score
                    image_scores = np.max(distances, axis=(1, 2))
                    predictions.extend(image_scores)
                    
                except Exception as e:
                    print(f"⚠️ Error in distance calculation: {e}")
                    predictions.extend([0.5] * b)
            else:
                predictions.extend([0.5] * len(labels))
        
        return np.array(predictions), np.array(true_labels)

# Step 4: Enhanced dataset exploration
def explore_dataset_structure():
    """Thoroughly explore the MVTec dataset structure"""
    print("\n📂 Downloading and exploring MVTec dataset...")
    
    try:
        # Download dataset
        dataset_path = kagglehub.dataset_download("shashankroy568/mvtec-anomaly-detection")
        print(f"✅ Dataset downloaded to: {dataset_path}")
        
        root_path = Path(dataset_path)
        
        # Look for MVTec categories
        mvtec_categories = [
            'bottle', 'cable', 'capsule', 'carpet', 'grid',
            'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
            'tile', 'toothbrush', 'transistor', 'wood', 'zipper'
        ]
        
        print(f"\n🔍 Searching for MVTec categories...")
        
        # Recursive search for categories
        found_categories = {}
        
        for root, dirs, files in os.walk(root_path):
            for dir_name in dirs:
                if dir_name in mvtec_categories:
                    category_path = Path(root) / dir_name
                    found_categories[dir_name] = category_path
                    print(f"  ✅ Found {dir_name} at: {category_path}")
        
        if found_categories:
            print(f"\n🎯 Found {len(found_categories)} MVTec categories!")
            return root_path, found_categories
        
        else:
            print("❌ No MVTec categories found in expected locations")
            return root_path, {}
    
    except Exception as e:
        print(f"❌ Dataset exploration error: {e}")
        return None, {}

# Step 5: FIXED manual dataset class with correct label logic
class MVTecManualDataset(Dataset):
    """Manual MVTec dataset when anomalib fails - FIXED LABEL LOGIC"""
    
    def __init__(self, root_path, category, split='train', transform=None):
        self.root_path = Path(root_path)
        self.category = category
        self.split = split
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        ])
        
        self.samples = self._load_samples()
        print(f"📊 {category}: Loaded {len(self.samples)} {split} samples")
    
    def _load_samples(self):
        """Load all samples for the dataset with CORRECT MVTec labeling"""
        samples = []
        
        # Find category path
        possible_paths = [
            self.root_path / self.category,
            self.root_path / "mvtec_anomaly_detection" / self.category,
            self.root_path / "MVTec" / self.category,
            self.root_path / "mvtec" / self.category
        ]
        
        category_path = None
        for path in possible_paths:
            if path.exists():
                category_path = path
                break
        
        if not category_path:
            print(f"❌ Category path not found for {self.category}")
            return samples
        
        if self.split == 'train':
            # Training: Only normal samples from train/good
            good_path = category_path / 'train' / 'good'
            if good_path.exists():
                for ext in ['*.png', '*.jpg', '*.jpeg', '*.bmp']:
                    for img_path in good_path.rglob(ext):
                        samples.append((str(img_path), 0))  # Label 0 = Normal
        else:  # test split
            # Test: Both normal and anomaly samples
            test_path = category_path / 'test'
            if test_path.exists():
                # Load normal test samples from test/good
                good_test_path = test_path / 'good'
                if good_test_path.exists():
                    for ext in ['*.png', '*.jpg', '*.jpeg', '*.bmp']:
                        for img_path in good_test_path.rglob(ext):
                            samples.append((str(img_path), 0))  # Label 0 = Normal
                
                # Load anomaly samples from test/defect_type folders
                for defect_dir in test_path.iterdir():
                    if defect_dir.is_dir() and defect_dir.name != 'good':
                        for ext in ['*.png', '*.jpg', '*.jpeg', '*.bmp']:
                            for img_path in defect_dir.rglob(ext):
                                samples.append((str(img_path), 1))  # Label 1 = Anomaly
        
        # Count samples by label
        normal_count = len([s for s in samples if s[1] == 0])
        anomaly_count = len([s for s in samples if s[1] == 1])
        
        print(f"   ✅ {self.category} {self.split}: {normal_count} normal, {anomaly_count} anomaly")
        
        return samples
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"⚠️ Error loading {img_path}: {e}")
            # Return dummy data
            dummy_image = torch.zeros((3, 224, 224))
            return dummy_image, label

# Step 6: Enhanced single category training function with comprehensive metrics
def train_single_category(root_path, category, results_dict):
    """Train PaDiM on a single category with comprehensive metrics"""
    
    print(f"\n" + "="*50)
    print(f"🚀 TRAINING PADIM: {category.upper()}")
    print(f"="*50)
    
    start_time = time.time()
    
    try:
        # Create datasets
        train_dataset = MVTecManualDataset(root_path, category, 'train')
        test_dataset = MVTecManualDataset(root_path, category, 'test')
        
        # Safety checks
        if len(train_dataset) == 0:
            print(f"❌ No training samples found for {category}")
            results_dict[category] = {'status': 'failed', 'reason': 'no_train_data'}
            return
            
        if len(test_dataset) == 0:
            print(f"❌ No test samples found for {category}")
            results_dict[category] = {'status': 'failed', 'reason': 'no_test_data'}
            return
        
        # Check test labels
        test_labels = [test_dataset.samples[i][1] for i in range(len(test_dataset))]
        unique_labels = set(test_labels)
        
        if len(unique_labels) < 2:
            print(f"⚠️ Warning: {category} test set only has {unique_labels} labels")
            results_dict[category] = {'status': 'failed', 'reason': 'insufficient_labels'}
            return
        
        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
        test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2)
        
        print(f"✅ {category}: Dataset ready - {len(train_dataset)} train, {len(test_dataset)} test")
        
        # Train model
        print(f"🧠 Training PaDiM for {category}...")
        model = ManualPaDiM()
        
        training_success = model.fit(train_loader)
        
        if not training_success:
            results_dict[category] = {'status': 'failed', 'reason': 'training_failed'}
            return
        
        # Test model
        predictions, true_labels = model.predict(test_loader)
        
        # Calculate comprehensive metrics
        if len(predictions) == len(true_labels) and len(predictions) > 0:
            unique_test_labels = set(true_labels)
            
            if len(unique_test_labels) >= 2:
                auc_score = roc_auc_score(true_labels, predictions)
                
                # Calculate comprehensive metrics
                comprehensive_metrics = calculate_comprehensive_metrics(true_labels, predictions)
                
                # Calculate additional statistics
                normal_indices = true_labels == 0
                anomaly_indices = true_labels == 1
                
                normal_scores = predictions[normal_indices] if normal_indices.any() else np.array([])
                anomaly_scores = predictions[anomaly_indices] if anomaly_indices.any() else np.array([])
                
                training_time = time.time() - start_time
                
                # Store comprehensive results
                results_dict[category] = {
                    'status': 'success',
                    'auc_score': auc_score,
                    'precision': comprehensive_metrics['precision'],
                    'recall': comprehensive_metrics['recall'],
                    'f1_score': comprehensive_metrics['f1_score'],
                    'specificity': comprehensive_metrics['specificity'],
                    'accuracy': comprehensive_metrics['accuracy'],
                    'threshold': comprehensive_metrics['threshold'],
                    'true_positives': comprehensive_metrics['true_positives'],
                    'true_negatives': comprehensive_metrics['true_negatives'],
                    'false_positives': comprehensive_metrics['false_positives'],
                    'false_negatives': comprehensive_metrics['false_negatives'],
                    'train_samples': len(train_dataset),
                    'test_samples': len(test_dataset),
                    'normal_test_samples': len(normal_scores),
                    'anomaly_test_samples': len(anomaly_scores),
                    'normal_mean_score': float(normal_scores.mean()) if len(normal_scores) > 0 else 0,
                    'anomaly_mean_score': float(anomaly_scores.mean()) if len(anomaly_scores) > 0 else 0,
                    'training_time': training_time,
                    'model_type': 'manual_padim'
                }
                
                print(f"✅ {category}: AUC Score = {auc_score:.4f}")
                print(f"📊 {category}: Precision = {comprehensive_metrics['precision']:.4f}, Recall = {comprehensive_metrics['recall']:.4f}, F1 = {comprehensive_metrics['f1_score']:.4f}")
                print(f"📊 {category}: Normal scores = {normal_scores.mean():.2f}, Anomaly scores = {anomaly_scores.mean():.2f}")
                print(f"⏱️  {category}: Training time = {training_time:.1f}s")
                
            else:
                results_dict[category] = {'status': 'failed', 'reason': 'insufficient_test_labels'}
        else:
            results_dict[category] = {'status': 'failed', 'reason': 'prediction_mismatch'}
            
    except Exception as e:
        print(f"❌ {category}: Training failed - {e}")
        import traceback
        traceback.print_exc()
        results_dict[category] = {'status': 'failed', 'reason': str(e)}

# Step 7: Multi-category training pipeline with 8 categories
def run_multi_category_pipeline():
    """Train on all 8 target categories"""
    
    print("🚀 Starting Multi-Category PaDiM Pipeline")
    print("=" * 70)
    
    # Explore dataset
    root_path, categories = explore_dataset_structure()
    
    if not root_path or not categories:
        print("❌ Cannot proceed without dataset")
        return
    
    # ALL 8 TARGET CATEGORIES - Updated as requested
    target_categories = ['bottle', 'metal_nut', 'capsule', 'cable', 'screw', 'pill', 'transistor', 'hazelnut']
    available_targets = [cat for cat in target_categories if cat in categories]
    
    if not available_targets:
        print("❌ No target categories found in dataset")
        return
    
    print(f"\n🎯 Will train PaDiM on {len(available_targets)} categories: {available_targets}")
    
    # Train each category
    results = {}
    total_start_time = time.time()
    
    for i, category in enumerate(available_targets, 1):
        print(f"\n🔄 Progress: {i}/{len(available_targets)} categories")
        train_single_category(root_path, category, results)
    
    total_time = time.time() - total_start_time
    
    # Generate comprehensive results summary
    print(f"\n" + "="*70)
    print("📋 MULTI-CATEGORY PADIM RESULTS WITH PRECISION & RECALL")
    print("="*70)
    
    successful_trainings = [cat for cat, res in results.items() if res.get('status') == 'success']
    failed_trainings = [cat for cat, res in results.items() if res.get('status') == 'failed']
    
    print(f"✅ Successful: {len(successful_trainings)}/{len(available_targets)} categories")
    print(f"❌ Failed: {len(failed_trainings)}/{len(available_targets)} categories")
    print(f"⏱️  Total time: {total_time:.1f}s")
    
    # Enhanced detailed results table
    if successful_trainings:
        print(f"\n📊 DETAILED PADIM RESULTS:")
        print("-" * 120)
        print(f"{'Category':<12} {'AUC':<8} {'Precision':<9} {'Recall':<8} {'F1':<8} {'Acc':<8} {'Train':<6} {'Test':<5} {'Time':<6}")
        print("-" * 120)
        
        for category in successful_trainings:
            res = results[category]
            print(f"{category:<12} {res['auc_score']:<8.4f} {res['precision']:<9.4f} {res['recall']:<8.4f} "
                  f"{res['f1_score']:<8.4f} {res['accuracy']:<8.4f} {res['train_samples']:<6} {res['test_samples']:<5} "
                  f"{res['training_time']:<6.1f}s")
        
        # Calculate averages
        avg_auc = np.mean([results[cat]['auc_score'] for cat in successful_trainings])
        avg_precision = np.mean([results[cat]['precision'] for cat in successful_trainings])
        avg_recall = np.mean([results[cat]['recall'] for cat in successful_trainings])
        avg_f1 = np.mean([results[cat]['f1_score'] for cat in successful_trainings])
        avg_accuracy = np.mean([results[cat]['accuracy'] for cat in successful_trainings])
        
        print("-" * 120)
        print(f"{'AVERAGE':<12} {avg_auc:<8.4f} {avg_precision:<9.4f} {avg_recall:<8.4f} "
              f"{avg_f1:<8.4f} {avg_accuracy:<8.4f}")
        print("-" * 120)
    
    # Failed categories details
    if failed_trainings:
        print(f"\n❌ FAILED CATEGORIES:")
        for category in failed_trainings:
            reason = results[category].get('reason', 'unknown')
            print(f"   {category}: {reason}")
    
    # Research summary
    print(f"\n🎓 PADIM RESEARCH SUMMARY:")
    print(f"   Dataset: MVTec Anomaly Detection")
    print(f"   Model: PaDiM (Manual Implementation)")
    print(f"   Categories: {len(successful_trainings)} warehouse-relevant products")
    if successful_trainings:
        print(f"   Performance:")
        print(f"     - Average AUC: {avg_auc:.4f}")
        print(f"     - Average Precision: {avg_precision:.4f}")
        print(f"     - Average Recall: {avg_recall:.4f}")
        print(f"     - Average F1-Score: {avg_f1:.4f}")
        print(f"     - Average Accuracy: {avg_accuracy:.4f}")
    else:
        print(f"   Performance: No successful trainings")
    print(f"   Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
    
    # Export comprehensive results to CSV for research paper
    if successful_trainings:
        results_df = pd.DataFrame([
            {
                'category': category,
                'model': 'PaDiM',
                'auc_score': results[category]['auc_score'],
                'precision': results[category]['precision'],
                'recall': results[category]['recall'],
                'f1_score': results[category]['f1_score'],
                'specificity': results[category]['specificity'],
                'accuracy': results[category]['accuracy'],
                'threshold': results[category]['threshold'],
                'true_positives': results[category]['true_positives'],
                'true_negatives': results[category]['true_negatives'],
                'false_positives': results[category]['false_positives'],
                'false_negatives': results[category]['false_negatives'],
                'train_samples': results[category]['train_samples'],
                'test_samples': results[category]['test_samples'],
                'normal_test_samples': results[category]['normal_test_samples'],
                'anomaly_test_samples': results[category]['anomaly_test_samples'],
                'normal_mean_score': results[category]['normal_mean_score'],
                'anomaly_mean_score': results[category]['anomaly_mean_score'],
                'training_time': results[category]['training_time'],
                'model_type': results[category]['model_type']
            }
            for category in successful_trainings
        ])
        
        results_df.to_csv('mvtec_padim_8categories_results.csv', index=False)
        print(f"\n💾 PaDiM results exported to: mvtec_padim_8categories_results.csv")
        
        # Show confusion matrix summary
        print(f"\n🔍 CONFUSION MATRIX SUMMARY:")
        print("-" * 80)
        print(f"{'Category':<12} {'TP':<5} {'TN':<5} {'FP':<5} {'FN':<5} {'Threshold':<10}")
        print("-" * 80)
        for category in successful_trainings:
            res = results[category]
            print(f"{category:<12} {res['true_positives']:<5} {res['true_negatives']:<5} "
                  f"{res['false_positives']:<5} {res['false_negatives']:<5} {res['threshold']:<10.4f}")
        print("-" * 80)
    
    print(f"\n🎉 Multi-category PaDiM training pipeline completed!")
    return results

# Execute the multi-category pipeline with 8 categories
if __name__ == "__main__":
    results = run_multi_category_pipeline()
