# MVTEC: Anomaly Detection using PatchCore (PyTorch)

This notebook is implementation of Patchcore paper using PyTorch from scratch.

PatchCore is a state-of-the-art image anomaly detection model for the MVTec dataset, according to the Papers with Code website. It utilizes a pre-trained ResNet50 model to create a memory bank of good images. This memory bank is used to check the similarity between images of the test dataset.

Import libraries

In [None]:
import os, shutil
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

from PIL import Image
from tqdm.auto import tqdm

import torch
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.models import resnet50, ResNet50_Weights

In [None]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

Load a pretrained Resnet Model

In [None]:
class resnet_feature_extractor(torch.nn.Module):
    def __init__(self):
        """This class extracts the feature maps from a pretrained Resnet model."""
        super(resnet_feature_extractor, self).__init__()
        self.model = resnet50(weights=ResNet50_Weights.DEFAULT)

        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

        
        # Hook to extract feature maps
        def hook(module, input, output) -> None:
            """This hook saves the extracted feature map on self.featured."""
            self.features.append(output)

        self.model.layer2[-1].register_forward_hook(hook)            
        self.model.layer3[-1].register_forward_hook(hook) 

    def forward(self, input):

        self.features = []
        with torch.no_grad():
            _ = self.model(input)

        self.avg = torch.nn.AvgPool2d(3, stride=1)
        fmap_size = self.features[0].shape[-2]         # Feature map sizes h, w
        self.resize = torch.nn.AdaptiveAvgPool2d(fmap_size)

        resized_maps = [self.resize(self.avg(fmap)) for fmap in self.features]
        patch = torch.cat(resized_maps, 1)            # Merge the resized feature maps
        patch = patch.reshape(patch.shape[1], -1).T   # Craete a column tensor

        return patch

Check feature shape

In [None]:
backbone = resnet_feature_extractor().cuda()

sample_img_path = '/kaggle/input/mvtec-ad/hazelnut/test/hole/000.png'
image = Image.open(sample_img_path)
image = transform(image).unsqueeze(0).cuda()

feature = backbone(image)

print(backbone.features[0].shape)
print(backbone.features[1].shape)

print(feature.shape)

plt.imshow(image[0].cpu().permute(1,2,0))
plt.show()

Create memory bank from GOOD images

In [None]:
memory_bank = []

folder_path = Path(r'/kaggle/input/mvtec-ad/hazelnut/train/good')

for pth in tqdm(folder_path.iterdir(),leave=False):
    # print(pth)
    with torch.no_grad():
        data = transform(Image.open(pth)).cuda().unsqueeze(0)
        features = backbone(data)
        # print(features.shape)
        memory_bank.append(features.cpu().detach())

print(len(memory_bank))
print(memory_bank[0].shape)
memory_bank = torch.cat(memory_bank,dim=0).cuda()
memory_bank.shape # 784x391 = 306544

Random Sampling - Select 10% of total patches to avoid long inference and computation time

In [None]:
selected_indices = np.random.choice(len(memory_bank), size=len(memory_bank)//10, replace=False)
memory_bank = memory_bank[selected_indices]
memory_bank.shape

For GOOD images [K nearsest neighbours]

Distance scores for good images, to calculate threshold value.

In [None]:
y_score_good = []
folder_path = Path(r'/kaggle/input/mvtec-ad/hazelnut/train/good')

for pth in tqdm(folder_path.iterdir(),leave=False):
    data = transform(Image.open(pth)).cuda().unsqueeze(0)
    # print(data.shape)
    with torch.no_grad():
        features = backbone(data)
        # print(features.shape)
    distances = torch.cdist(features, memory_bank, p=2.0)
    # print(distances.shape)
    dist_score, dist_score_idxs = torch.min(distances, dim=1) 
    # print(dist_score[:10], dist_score_idxs[:10])
    s_star = torch.max(dist_score)
    # print(s_star)
    segm_map = dist_score.view(1, 1, 28, 28) 
    # print(segm_map.shape)

    y_score_good.append(s_star.cpu().numpy())
    # break

y_score_good[:5]

In [None]:
y_score_good[:5]
# image_np = segm_map.squeeze().cpu() # Remove batch & channel dimensions

# # Plot the image
# plt.imshow(image_np, cmap='gray')
# plt.title("28x28 Image")
# plt.axis("off")  # Hide axes
# plt.show()

In [None]:
print(np.mean(y_score_good))
print(np.std(y_score_good))

best_threshold = np.mean(y_score_good) + 3 * np.std(y_score_good)
print(f"Threshold: {best_threshold}")

plt.hist(y_score_good, bins=50)
plt.vlines(x=best_threshold, ymin=0, ymax=30, color='r')
plt.show()

For BAD Images

In [None]:
y_score = []
y_true = []

for classes in ['crack', 'cut', 'good', 'hole', 'print']:
    folder_path_test = Path(f"/kaggle/input/mvtec-ad/hazelnut/test/{classes}")
    
    for pth in tqdm(folder_path_test.iterdir(),leave=False):
        # print(pth)
        class_label = pth.parts[-2]
        # print(class_label)
        with torch.no_grad():
            test_image = transform(Image.open(pth)).cuda().unsqueeze(0)
            features = backbone(test_image)

        distances = torch.cdist(features, memory_bank, p=2.0)
        dist_score, dist_score_idxs = torch.min(distances, dim=1) 
        s_star = torch.max(dist_score)
        segm_map = dist_score.view(1, 1, 28, 28) 

        y_score.append(s_star.cpu().numpy())
        y_true.append(0 if class_label=='good' else 1)  # 0 -> GOOD, 1 -> BAD
        # break

In [None]:
y_score[40:45], y_true[40:45]

In [None]:
# plotting the y_score values which belong to 'BAD' class

y_score_bad = [score for score,true in zip(y_score, y_true) if true==1]
plt.hist(y_score_bad,bins=50)
plt.vlines(x=best_threshold,ymin=0,ymax=30,color='r')
plt.show()

Visualize one anomaly image

In [None]:
test_image = transform(Image.open(r'/kaggle/input/mvtec-ad/hazelnut/test/cut/000.png')).cuda().unsqueeze(0)
features = backbone(test_image)

distances = torch.cdist(features, memory_bank, p=2.0)
dist_score, dist_score_idxs = torch.min(distances, dim=1) 
s_star = torch.max(dist_score)
segm_map = dist_score.view(1, 1, 28, 28)

# Upscale by bi-linaer interpolation to match the original input resolution
segm_map = torch.nn.functional.interpolate(
                segm_map,
                size=(224, 224),
                mode='bilinear'
                )
plt.figure(figsize=(4,4))
plt.imshow(segm_map.cpu().squeeze(), cmap='jet')
plt.show()

Evaluation Matrices - best threshold calculation

In [None]:
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, ConfusionMatrixDisplay, f1_score

# Calculate AUC-ROC score
auc_roc_score = roc_auc_score(y_true, y_score)
print("AUC-ROC Score:", auc_roc_score)

# Plot ROC curve
fpr, tpr, thresholds = roc_curve(y_true, y_score)
print("fpr, tpr, thresholds: ", fpr, tpr, thresholds)

f1_scores = [f1_score(y_true, y_score >= threshold) for threshold in thresholds]
print("f1_scores:", f1_scores)

# Select the best threshold based on F1 score
best_threshold = thresholds[np.argmax(f1_scores)]

print(f'best_threshold = {best_threshold}')

plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % auc_roc_score)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.show()

# Generate confusion matrix
cm = confusion_matrix(y_true, (y_score >= best_threshold).astype(int))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['GOOD', 'BAD'])
disp.plot()
plt.show()

Results Visualization

In [None]:
import cv2, time
from IPython.display import clear_output

backbone.eval()
class_label = ['GOOD', 'BAD']
test_path = Path('/kaggle/input/mvtec-ad/hazelnut/test')

for path in test_path.glob('*/*.png'): 

    fault_type = path.parts[-2]
    if fault_type in ['hole']:  # change defect type - crack, cut, hole, print, good
        
        test_image = transform(Image.open(path)).cuda().unsqueeze(0)
        
        with torch.no_grad():
            features = backbone(test_image)
        # Forward pass
        distances = torch.cdist(features, memory_bank, p=2.0)
        dist_score, dist_score_idxs = torch.min(distances, dim=1) 
        s_star = torch.max(dist_score)
        segm_map = dist_score.view(1, 1, 28, 28) 
        # Upscale by bi-linear interpolation to match the original input resolution
        segm_map = torch.nn.functional.interpolate(
                    segm_map,
                    size=(224, 224),
                    mode='bilinear'
                ).cpu().squeeze().numpy()
        
        y_score_image = s_star.cpu().numpy()  
        y_pred_image = 1*(y_score_image>=best_threshold)
        
        plt.figure(figsize=(12,3))
        plt.subplot(1,3,1)
        plt.imshow(test_image.squeeze().permute(1,2,0).cpu().numpy())
        plt.title(f'fault type: {fault_type}')

        plt.subplot(1,3,2)
        heat_map = segm_map
        plt.imshow(heat_map, cmap='jet', vmin=best_threshold, vmax = best_threshold * 2) 
        plt.title(f'Anomaly score: {y_score_image:0.2f} | {class_label[y_pred_image]}')

        plt.subplot(1,3,3)
        plt.imshow((heat_map > best_threshold ), cmap='gray')  #
        plt.title(f'segmentation map')
        
        plt.show()
        time.sleep(1)
        clear_output(wait=True)