# MVTec Data Loader Example

## This notebook provides an example on how to use the MVTec data loader to train a binary classifier
## This is a modified version of original https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

## Import packages

# IMAGENET-30

In [3]:
from torch.utils.data import Dataset
from PIL import ImageFilter, Image, ImageOps
from torchvision.datasets.folder import default_loader
import os

class IMAGENET30_TEST_DATASET(Dataset):
    def __init__(self, root_dir="/kaggle/input/imagenet30-dataset/one_class_test/one_class_test/", transform=None):
        """
        Args:
            root_dir (string): Directory with all the classes.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.img_path_list = []
        self.targets = []

        # Map each class to an index
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(sorted(os.listdir(root_dir)))}
        print(f"self.class_to_idx in ImageNet30_Test_Dataset:\n{self.class_to_idx}")

        # Walk through the directory and collect information about the images and their labels
        for i, class_name in enumerate(os.listdir(root_dir)):
            class_path = os.path.join(root_dir, class_name)
            for instance_folder in os.listdir(class_path):
                instance_path = os.path.join(class_path, instance_folder)
                if instance_path != "/kaggle/input/imagenet30-dataset/one_class_test/one_class_test/airliner/._1.JPEG":
                    for img_name in os.listdir(instance_path):
                        if img_name.endswith('.JPEG'):
                            img_path = os.path.join(instance_path, img_name)
                            # image = Image.open(img_path).convert('RGB')
                            self.img_path_list.append(img_path)
                            self.targets.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.img_path_list)

    def __getitem__(self, idx):
        img_path = self.img_path_list[idx]
        image = default_loader(img_path)
        label = self.targets[idx]
        if self.transform:
            image = self.transform(image)
        return image, label
    
imagenet30_testset = IMAGENET30_TEST_DATASET()

self.class_to_idx in ImageNet30_Test_Dataset:
{'acorn': 0, 'airliner': 1, 'ambulance': 2, 'american_alligator': 3, 'banjo': 4, 'barn': 5, 'bikini': 6, 'digital_clock': 7, 'dragonfly': 8, 'dumbbell': 9, 'forklift': 10, 'goblet': 11, 'grand_piano': 12, 'hotdog': 13, 'hourglass': 14, 'manhole_cover': 15, 'mosque': 16, 'nail': 17, 'parking_meter': 18, 'pillow': 19, 'revolver': 20, 'rotary_dial_telephone': 21, 'schooner': 22, 'snowmobile': 23, 'soccer_ball': 24, 'stingray': 25, 'strawberry': 26, 'tank': 27, 'toaster': 28, 'volcano': 29}


# mvtecDataset

In [96]:
# This is a modified version of original  https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py
# This file and the mvtec data directory must be in the same directory, such that:
# /.../this_directory/mvtecDataset.py
# /.../this_directory/mvtec/bottle/...
# /.../this_directory/mvtec/cable/...
# and so on

from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import torch.utils.data as data
import matplotlib.image as mpimg
from torchvision import transforms
import random


from PIL import Image

def center_paste(large_img, small_img):
    # Calculate the center position
    large_width, large_height = large_img.size
    small_width, small_height = small_img.size
    
    # Calculate the top-left position
    left = (large_width - small_width) // 2
    top = (large_height - small_height) // 2
    
    # Create a copy of the large image to keep the original unchanged
    result_img = large_img.copy()
    
    # Paste the small image onto the large one at the calculated position
    result_img.paste(small_img, (left, top))
    
    return result_img

class MVTEC(data.Dataset):
    """`MVTEC <https://www.mvtec.com/company/research/datasets/mvtec-ad/>`_ Dataset.
    Args:
        root (string): Root directory of dataset where directories
            ``bottle``, ``cable``, etc., exists.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        resize (int, optional): Desired output image size.
        interpolation (int, optional): Interpolation method for downsizing image.
        category: bottle, cable, capsule, etc.
    """


    def __init__(self, root, train=True,
                 transform=None, target_transform=None,
                 category='carpet', resize=None, interpolation=2, use_imagenet=False, select_random_image_from_imagenet=False, shrink_factor=0.9):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.resize = resize
        if use_imagenet:
            self.resize = int(resize * shrink_factor)
        self.interpolation = interpolation
        self.select_random_image_from_imagenet = select_random_image_from_imagenet
        
        # load images for training
        if self.train:
            self.train_data = []
            self.train_labels = []
            cwd = os.getcwd()
            trainFolder = self.root+'/'+category+'/train/good/'
            os.chdir(trainFolder)
            filenames = [f.name for f in os.scandir()]
            for file in filenames:
                img = mpimg.imread(file)
                img = img*255
                img = img.astype(np.uint8)
                self.train_data.append(img)
                self.train_labels.append(1)                 
            os.chdir(cwd)
                
            self.train_data = np.array(self.train_data)      
        else:
        # load images for testing
            self.test_data = []
            self.test_labels = []
            
            cwd = os.getcwd()
            testFolder = self.root+'/'+category+'/test/'
            os.chdir(testFolder)
            subfolders = [sf.name for sf in os.scandir() if sf.is_dir()]
#             print(subfolders)
            cwsd = os.getcwd()
            
            # for every subfolder in test folder
            for subfolder in subfolders:
                label = 0
                if subfolder == 'good':
                    label = 1
                testSubfolder = testFolder+subfolder+'/'
#                 print(testSubfolder)
                os.chdir(testSubfolder)
                filenames = [f.name for f in os.scandir()]
                for file in filenames:
                    img = mpimg.imread(file)
                    img = img*255
                    img = img.astype(np.uint8)
                    self.test_data.append(img)
                    self.test_labels.append(label)
                os.chdir(cwsd)
            os.chdir(cwd)
                
            self.test_data = np.array(self.test_data)
                
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)
        
        if self.select_random_image_from_imagenet:
            imagenet30_img = imagenet30_testset[int(random.random() * len(imagenet30_testset))][0].resize((224, 224))
        else:
            imagenet30_img = imagenet30_testset[100][0].resize((224, 224))
        
        
        #if resizing image
        if self.resize is not None:
            resizeTransf = transforms.Resize(self.resize, self.interpolation)
            img = resizeTransf(img)
            
#         print(f"imagenet30_img.size: {imagenet30_img.size}")
#         print(f"img.size: {img.size}")
        img = center_paste(imagenet30_img, img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)
        
        return img, target

    def __len__(self):
        """
        Args:
            None
        Returns:
            int: length of array.
        """
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

In [8]:
import torch
import torchvision
import torchvision.transforms as transforms

## Data loaders

##############################################################################################
### To use our data loader please download all the MVTec data available at:- 
### https://www.mvtec.com/company/research/datasets/mvtec-ad
### And save them in the folder ./mvtec
##############################################################################################

# KNN

In [42]:
!pip install faiss-gpu

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [44]:
import numpy as np
import torch
from tqdm import tqdm
import faiss
from sklearn.metrics import roc_auc_score

def knn_score(train_set, test_set, n_neighbours=2):
    """
    Calculates the KNN distance
    """
    index = faiss.IndexFlatL2(train_set.shape[1])
    index.add(train_set)
    D, _ = index.search(test_set, n_neighbours)
    return np.sum(D, axis=1)


def get_score_knn_auc(model, device, train_feature_space, test_loader, bd_test_loader=False):
    model.to(device)
    model.eval()

    test_feature_space = []
    test_labels = []
    with torch.no_grad():
        for idx, (imgs, labels) in enumerate(test_loader):
            imgs = imgs.to(device)
            features = model(imgs)
            test_feature_space.append(features)
            test_labels.append(labels)
        test_feature_space = torch.cat(test_feature_space, dim=0).contiguous().cpu().numpy()
        test_labels = torch.cat(test_labels, dim=0).cpu().numpy()

    distances = knn_score(train_feature_space, test_feature_space)

    auc = roc_auc_score(test_labels, -1 * distances) # I multiplied distances(scores) by -1 because here in dist label is 1

#     print(f"knn_auc: {auc}")

    return auc


def eval_step_knn_auc(
        device,
        model,
        train_loader,
        test_dataloader_ood
):
    model.to(device)
    model.eval()
    train_feature_space = []
    with torch.no_grad():
        for idx, (imgs, labels) in enumerate(train_loader, start=1):
            imgs = imgs.to(device)
            features = model(imgs)
            train_feature_space.append(features)
        train_feature_space = torch.cat(train_feature_space, dim=0).contiguous().cpu().numpy()

    knn_clean_test_auc = get_score_knn_auc(model, device, train_feature_space, test_dataloader_ood, bd_test_loader=False)

    return knn_clean_test_auc

# Model

In [45]:
import torch.nn.functional as F

class Model(torch.nn.Module):
    def __init__(self, backbone):
        super().__init__()
        if backbone == 'ViT': # ViT
            self.backbone = torchvision.models.vit_b_16(weights='DEFAULT')
        elif backbone == 152:
            self.backbone = models.resnet152(pretrained=True)
        else:
            self.backbone = models.resnet18(pretrained=True)
        self.backbone.fc = torch.nn.Identity()
        freeze_parameters(self.backbone, backbone, train_fc=False)

    def forward(self, x):
        z1 = self.backbone(x)
        z_n = F.normalize(z1, dim=-1)
        return z_n
    
def freeze_parameters(model, backbone, train_fc=False):
    if not train_fc:
        for p in model.fc.parameters():
            p.requires_grad = False
    if backbone == 152:
        for p in model.conv1.parameters():
            p.requires_grad = False
        for p in model.bn1.parameters():
            p.requires_grad = False
        for p in model.layer1.parameters():
            p.requires_grad = False
        for p in model.layer2.parameters():
            p.requires_grad = False

In [93]:
def list_directories(path):
    """
    Returns a list of directory names found in the given path.
    
    :param path: The path of the directory to list subdirectories from.
    :return: A list of directory names.
    """
    # List all entries in the given path
    entries = os.listdir(path)
    
    # Filter out entries that are directories
    dir_names = [entry for entry in entries if os.path.isdir(os.path.join(path, entry))]
    
    return dir_names

all_categories = list_directories("/kaggle/input/mvtec-ad")
print(all_categories)

['wood', 'screw', 'metal_nut', 'capsule', 'hazelnut', 'carpet', 'pill', 'grid', 'zipper', 'transistor', 'tile', 'leather', 'toothbrush', 'bottle', 'cable']


In [None]:
from torchvision import models

for shrink_factor in [0.8, 0.85, 0.9, 0.95, 1]:
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    batch_size = 70

    # Original images are high resolution, so we resize them using the transformation
    # provided by pytorch: https://pytorch.org/vision/stable/transforms.html
    im_shape = 224

    # Interpolation method for resizing the image
    interpol = 3

    # Data category to use: carpet, leather, wood, bottle, etc.
    auc_dict = {}
    auc_sum = 0.0
    
    for cat in all_categories:
        trainset = MVTEC(root='/kaggle/input/mvtec-ad/', train=True, transform=transform,
                            resize=im_shape, interpolation=interpol, category=cat, use_imagenet=True, select_random_image_from_imagenet=True, shrink_factor=shrink_factor)

        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                  shuffle=True)

        testset = MVTEC(root='/kaggle/input/mvtec-ad/', train=False, transform=transform,
                            resize=im_shape, interpolation=interpol, category=cat, use_imagenet=True, select_random_image_from_imagenet=True, shrink_factor=shrink_factor)

        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                                 shuffle=False)

        # Each category has two types of images: good (label 1) or defective (label 0)
        classes = ('defective', 'good')

        device='cuda:0'
        model = Model('ViT')
        model.to(device)
        auc = eval_step_knn_auc(device, model, trainloader, testloader)
        auc_sum += auc
        auc_dict[cat] = auc
        print(f"({shrink_factor}, {cat}) -> {auc}")
    print(f"{shrink_factor} -> {auc_dict}")
    print(f"auc mean ({shrink_factor}): {auc_sum / len(all_categories)}")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 140MB/s] 
Test set feature extracting: 2it [00:01,  1.33it/s]


knn_auc: 0.6850877192982456
(0.8, wood) -> 0.6850877192982456


Test set feature extracting: 3it [00:02,  1.31it/s]


knn_auc: 0.47509735601557695
(0.8, screw) -> 0.47509735601557695


Test set feature extracting: 2it [00:01,  1.36it/s]


knn_auc: 0.6627565982404692
(0.8, metal_nut) -> 0.6627565982404692


Test set feature extracting: 2it [00:02,  1.22s/it]


knn_auc: 0.6071001196649382
(0.8, capsule) -> 0.6071001196649382


Test set feature extracting: 2it [00:01,  1.02it/s]


knn_auc: 0.8571428571428571
(0.8, hazelnut) -> 0.8571428571428571


Test set feature extracting: 2it [00:02,  1.04s/it]


knn_auc: 0.5642054574638844
(0.8, carpet) -> 0.5642054574638844


Test set feature extracting: 3it [00:02,  1.14it/s]


knn_auc: 0.5561920349154392
(0.8, pill) -> 0.5561920349154392


Test set feature extracting: 2it [00:01,  1.78it/s]


knn_auc: 0.4912280701754386
(0.8, grid) -> 0.4912280701754386


Test set feature extracting: 3it [00:01,  1.71it/s]


knn_auc: 0.5383403361344538
(0.8, zipper) -> 0.5383403361344538


Test set feature extracting: 2it [00:01,  1.12it/s]


knn_auc: 0.71625
(0.8, transistor) -> 0.71625


Test set feature extracting: 2it [00:01,  1.19it/s]


knn_auc: 0.5386002886002886
(0.8, tile) -> 0.5386002886002886


Test set feature extracting: 2it [00:02,  1.13s/it]


knn_auc: 0.5492527173913043
(0.8, leather) -> 0.5492527173913043


Test set feature extracting: 1it [00:00,  1.29it/s]


knn_auc: 0.6833333333333333
(0.8, toothbrush) -> 0.6833333333333333


Test set feature extracting: 2it [00:01,  1.52it/s]


knn_auc: 0.7658730158730159
(0.8, bottle) -> 0.7658730158730159


Test set feature extracting: 3it [00:02,  1.06it/s]


knn_auc: 0.8135307346326837
(0.8, cable) -> 0.8135307346326837
0.8 -> {'wood': 0.6850877192982456, 'screw': 0.47509735601557695, 'metal_nut': 0.6627565982404692, 'capsule': 0.6071001196649382, 'hazelnut': 0.8571428571428571, 'carpet': 0.5642054574638844, 'pill': 0.5561920349154392, 'grid': 0.4912280701754386, 'zipper': 0.5383403361344538, 'transistor': 0.71625, 'tile': 0.5386002886002886, 'leather': 0.5492527173913043, 'toothbrush': 0.6833333333333333, 'bottle': 0.7658730158730159, 'cable': 0.8135307346326837}
auc mean (0.8): 0.633599375925462
