In [None]:
import torch
from PIL import Image
from create_model import IDOneClassClassifier
from albumentations.pytorch import ToTensorV2
import albumentations as A
import numpy as np
import os
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


In [None]:
# Define transformation for the image (same as training)
transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# Load model parameters
device = torch.device('mps' if torch.mps.is_available() else 'cpu')
model_checkpoint_path = 'best_model_ResNet50_one_class.pth'

checkpoint = torch.load(model_checkpoint_path, map_location=device)
model = IDOneClassClassifier().to(device)
model.load_state_dict(checkpoint['model_state_dict'])
center = checkpoint['center'].to(device)
radius = checkpoint['radius'].to(device)

model.eval()  # Set model to evaluation mode

In [None]:
def preprocess_image(image_path):
    """
    Preprocess an image for the model.

    Args:
        image_path (str): Path to the image.

    Returns:
        torch.Tensor: Preprocessed image tensor.
    """
    img = Image.open(image_path).convert('RGB')
    img = transform(image=np.array(img))['image']
    img = img.unsqueeze(0)  # Add batch dimension
    return img.to(device)

def predict_image(img_tensor, model, center, radius, threshold_multiplier=1.0):
    """
    Predict if an image tensor is normal or anomalous.

    Args:
        img_tensor (torch.Tensor): Preprocessed image tensor.
        model (nn.Module): Loaded model.
        center (torch.Tensor): Center of the learned hypersphere.
        radius (torch.Tensor): Radius of the hypersphere.
        threshold_multiplier (float): Multiplier for threshold sensitivity.

    Returns:
        int: Prediction result (1 for normal, 0 for anomalous).
    """
    with torch.no_grad():
        features = model(img_tensor)
        distance = torch.sum((features - center) ** 2, dim=1)
    
    is_normal = distance <= (radius * threshold_multiplier) ** 2
    return 1 if is_normal.item() else 0  # 1 for normal, 0 for anomaly

def load_images_from_directory(directory):
    """
    Load and preprocess all images in a directory.

    Args:
        directory (str): Directory path to load images from.

    Returns:
        list: List of preprocessed image tensors.
    """
    image_tensors = []
    for filename in os.listdir(directory):
        img_path = os.path.join(directory, filename)
        if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')):
            img_tensor = preprocess_image(img_path)
            image_tensors.append(img_tensor)
    return image_tensors

In [None]:
# Load datasets
normal_image_dir = 'dataset/valid'  # Replace with path to normal images
anomalous_image_dir = 'dataset/invalid'  # Replace with path to anomalous images

print("Loading datasets...")
normal_images = load_images_from_directory(normal_image_dir)
anomalous_images = load_images_from_directory(anomalous_image_dir)

# Define threshold multipliers to test
threshold_multipliers = [0.8, 1.0, 1.2, 1.5]

# Evaluate model on each threshold
for multiplier in threshold_multipliers:
    print(f"\nEvaluating with threshold multiplier: {multiplier}")
    y_true = []
    y_pred = []

    # Evaluate normal images (should be predicted as normal)
    for img_tensor in normal_images:
        y_true.append(1)  # True label for normal images
        y_pred.append(predict_image(img_tensor, model, center, radius, multiplier))

    # Evaluate anomalous images (should be predicted as anomalous)
    for img_tensor in anomalous_images:
        y_true.append(0)  # True label for anomalous images
        y_pred.append(predict_image(img_tensor, model, center, radius, multiplier))

    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")