In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
from scipy.ndimage import gaussian_filter
from functools import partial

# Import your model classes
from models.MMR.MMR import MMR_base
from models.MMR.utils import ForwardHook, cal_anomaly_map
from config import get_cfg

class AnomalyDetector:
    def __init__(self, checkpoint_path, config_path):
        # Load configuration
        self.cfg = get_cfg()
        self.cfg.merge_from_file(config_path)
        
        # Set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Initialize models
        self.cur_model = self._init_current_model()
        self.mmr_model = self._init_mmr_model(checkpoint_path)
        
        # Register hooks
        self.teacher_outputs_dict = {}
        for extract_layer in self.cfg.TRAIN.MMR.layers_to_extract_from:
            forward_hook = ForwardHook(self.teacher_outputs_dict, extract_layer)
            network_layer = self.cur_model.__dict__["_modules"][extract_layer]
            network_layer[-1].register_forward_hook(forward_hook)
        
        # Define image transforms
        self.transform = transforms.Compose([
            transforms.Resize(self.cfg.DATASET.resize),
            transforms.CenterCrop(self.cfg.DATASET.imagesize),
            transforms.ToTensor(),
        ])

    def _init_current_model(self):
        model = torch.hub.load('pytorch/vision:v0.10.0', 
                             'wide_resnet50_2', 
                             pretrained=True)
        model.to(self.device)
        model.eval()
        return model

    def _init_mmr_model(self, checkpoint_path):
        model = MMR_base(
            patch_size=16, 
            embed_dim=768, 
            depth=12, 
            num_heads=12,
            mlp_ratio=4, 
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            cfg=self.cfg
        )
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(self.device)
        model.eval()
        return model

    def detect(self, image_path, threshold=0.5):
        """
        Detect anomalies in a single image
        Args:
            image_path: Path to the image file
            threshold: Anomaly threshold (0-1)
        Returns:
            anomaly_score: Overall anomaly score
            anomaly_map: 2D numpy array of anomaly scores
            is_anomaly: Boolean indicating if anomaly was detected
        """
        # Load and preprocess image
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            # Get current model features
            self.teacher_outputs_dict.clear()
            _ = self.cur_model(image_tensor)
            multi_scale_features = [self.teacher_outputs_dict[key] 
                                  for key in self.cfg.TRAIN.MMR.layers_to_extract_from]
            
            # Get MMR features
            reverse_features = self.mmr_model(image_tensor, 
                                           mask_ratio=self.cfg.TRAIN.MMR.test_mask_ratio)
            multi_scale_reverse_features = [reverse_features[key] 
                                          for key in self.cfg.TRAIN.MMR.layers_to_extract_from]
            
            # Calculate anomaly map
            anomaly_map, _ = cal_anomaly_map(
                multi_scale_features, 
                multi_scale_reverse_features, 
                image_tensor.shape[-1],
                amap_mode='a'
            )
            
            # Apply Gaussian smoothing
            anomaly_map = gaussian_filter(anomaly_map[0], sigma=4)
            
            # Calculate anomaly score
            anomaly_score = np.max(anomaly_map)
            is_anomaly = anomaly_score > threshold
            
            return {
                'anomaly_score': float(anomaly_score),
                'anomaly_map': anomaly_map,
                'is_anomaly': bool(is_anomaly)
            }

# Example usage
if __name__ == "__main__":
    detector = AnomalyDetector(
        checkpoint_path="checkpoints/aebad_S_AeBAD_S_MMR_model.pth",
        config_path="method_config/AeBAD_S/MMR.yaml"
    )
    
    # Test on a single image
    result = detector.detect("path_to_your_test_image.jpg")
    print(f"Anomaly Score: {result['anomaly_score']:.3f}")
    print(f"Is Anomaly: {result['is_anomaly']}")

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /home/rvce/.cache/torch/hub/v0.10.0.zip
