# PaDiM implementation

In [1]:
# https://github.com/remmarp/PaDiM-TF

In [2]:
import os

import cv2
import numpy as np
import tensorflow as tf

In [3]:
import matplotlib
import matplotlib.pyplot as plt

from skimage import morphology
from skimage.segmentation import mark_boundaries

In [4]:
import sklearn.metrics as metrics
from scipy.ndimage import gaussian_filter
from scipy.spatial.distance import mahalanobis

In [5]:
# data_loader.py 
class MVTecADLoader(object):
    #     base_path = r'D:\mvtec_ad'
    base_path = r'mvtec_ad'


    train, test = None, None
    num_train, num_test = 0, 0

    category = {'bottle': ['good', 'broken_large', 'broken_small', 'contamination'],
                'cable': ['good', 'bent_wire', 'cable_swap', 'combined', 'cut_inner_insulation', 'cut_outer_insulation',
                          'missing_cable', 'missing_wire', 'poke_insulation'],
                'capsule': ['good', 'crack', 'faulty_imprint', 'poke', 'scratch', 'squeeze'],
                'carpet': ['good', 'color', 'cut', 'hole', 'metal_contamination', 'thread'],
                'grid': ['good', 'bent', 'broken', 'glue', 'metal_contamination', 'thread'],
                'hazelnut': ['good', 'crack', 'cut', 'hole', 'print'],
                'leather': ['good', 'color', 'cut', 'fold', 'glue', 'poke'],
                'metal_nut': ['good', 'bent', 'color', 'flip', 'scratch'],
                'pill': ['good', 'color', 'combined', 'contamination', 'crack', 'faulty_imprint', 'pill_type',
                         'scratch'],
                'screw': ['good', 'manipulated_front', 'scratch_head', 'scratch_neck', 'thread_side', 'thread_top'],
                'tile': ['good', 'crack', 'glue_strip', 'gray_stroke', 'oil', 'rough'],
                'toothbrush': ['good', 'defective'],
                'transistor': ['good', 'bent_lead', 'cut_lead', 'damaged_case', 'misplaced'],
                'wood': ['good', 'color', 'combined', 'good', 'hole', 'liquid', 'scratch'],
                'zipper': ['good', 'broken_teeth', 'combined', 'fabric_border', 'fabric_interior', 'rough',
                           'split_teeth', 'squeezed_teeth']}

    def setup_base_path(self, path):
        self.base_path = path

    def load(self, category, repeat=4, max_rot=10):
        # data, mask, binary anomaly label (0 for anomaly, 1 for good)
        x, y, z = [], [], []

        # Load train set
        path = os.path.join(os.path.join(self.base_path, category), 'train\good')
        files = os.listdir(path)

        zero_mask = tf.zeros(shape=(224, 224), dtype=tf.int32)

        for rdx in range(repeat):
            for _files in files:
                full_path = os.path.join(path, _files)
                img = self._read_image(full_path=full_path)

                if not max_rot == 0:
                    img = tf.keras.preprocessing.image.random_rotation(img, max_rot)

                mask = zero_mask

                x.append(img)
                y.append(mask)
                z.append(1)

        x = np.asarray(x)
        y = np.asarray(y)
        self.num_train = len(x)

        x = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(x, dtype=tf.float32))
        y = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(y, dtype=tf.int32))
        z = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(z, dtype=tf.int32))

        self.train = tf.data.Dataset.zip((x, y, z))

        # data, anomaly label (e.g., good, cut, ..., etc.), binary anomaly label (0 for anomaly, 1 for good)
        x, y, z = [], [], []

        # Load test set
        for _label in self.category[category]:
            path = os.path.join(os.path.join(self.base_path, category), f'test\\{_label}')

            files = os.listdir(path)
            for _files in files:
                full_path = os.path.join(path, _files)
                img = self._read_image(full_path=full_path)

                if _label == 'good':
                    mask = zero_mask
                else:
                    mask_path = os.path.join(os.path.join(self.base_path, category), 'ground_truth/{}'.format(_label))
                    _mask_path = os.path.join(mask_path, '{}_mask.png'.format(_files.split('.')[0]))
                    mask = cv2.resize(cv2.imread(_mask_path, flags=cv2.IMREAD_GRAYSCALE), dsize=(256, 256)) / 255
                    mask = mask[16:-16, 16:-16]
                    mask = tf.convert_to_tensor(mask, dtype=tf.int32)

                x.append(img)
                y.append(mask)
                z.append(int(self.category[category].index(_label) == 0))

        x = np.asarray(x)
        y = np.asarray(y)
        self.num_test = len(x)

        x = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(x, dtype=tf.float32))
        y = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(y, dtype=tf.int32))
        z = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(z, dtype=tf.int32))

        self.test = tf.data.Dataset.zip((x, y, z))

    @staticmethod
    def _read_image(full_path, flags=cv2.IMREAD_COLOR):
        img = cv2.imread(full_path, flags=flags)
        b, g, r = cv2.split(img)
        img = cv2.merge([r, g, b])

        img = cv2.resize(img, dsize=(256, 256))

        img = img[16:-16, 16:-16, :]

        return img

In [6]:
# utils.py
def embedding_concat(l1, l2):
    bs, h1, w1, c1 = l1.shape
    _, h2, w2, c2 = l2.shape

    s = int(h1 / h2)
    x = tf.compat.v1.extract_image_patches(l1, ksizes=[1, s, s, 1], strides=[1, s, s, 1], rates=[1, 1, 1, 1],
                                           padding='VALID')
    x = tf.reshape(x, (bs, -1, h2, w2, c1))

    col_z = []
    for idx in range(x.shape[1]):
        col_z.append(tf.concat([x[:, idx, :, :, :], l2], axis=-1))
    z = tf.stack(col_z, axis=1)

    z = tf.reshape(z, (bs, h2, w2, -1))
    if s == 1:
        return z
    z = tf.nn.depth_to_space(z, block_size=s)

    return z


def plot_fig(test_img, scores, gts, threshold, save_dir, class_name):
    num = len(scores)
    vmax = scores.max() * 255.
    vmin = scores.min() * 255.
    for i in range(num):
        img = test_img[i][0]

        gt = gts[i].transpose(1, 2, 0).squeeze()

        heat_map = scores[i] * 255
        mask = scores[i]
        mask[mask > threshold] = 1
        mask[mask <= threshold] = 0

        kernel = morphology.disk(4)
        mask = morphology.opening(mask, kernel)
        mask *= 255
        vis_img = mark_boundaries(img, mask, color=(1, 0, 0), mode='thick')
        fig_img, ax_img = plt.subplots(1, 5, figsize=(12, 3))
        fig_img.subplots_adjust(right=0.9)
        norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
        for ax_i in ax_img:
            ax_i.axes.xaxis.set_visible(False)
            ax_i.axes.yaxis.set_visible(False)
        ax_img[0].imshow(img.astype(int))
        ax_img[0].title.set_text('Image')
        ax_img[1].imshow(gt.astype(int), cmap='gray')
        ax_img[1].title.set_text('GroundTruth')
        ax = ax_img[2].imshow(heat_map, cmap='jet', norm=norm)
        ax_img[2].imshow(img.astype(int), cmap='gray', interpolation='none')
        ax_img[2].imshow(heat_map, cmap='jet', alpha=0.5, interpolation='none')
        ax_img[2].title.set_text('Predicted heat map')
        ax_img[3].imshow(mask.astype(int), cmap='gray')
        ax_img[3].title.set_text('Predicted mask')
        ax_img[4].imshow(vis_img.astype(int))
        ax_img[4].title.set_text('Segmentation result')
        left = 0.92
        bottom = 0.15
        width = 0.015
        height = 1 - 2 * bottom
        rect = [left, bottom, width, height]
        cbar_ax = fig_img.add_axes(rect)
        cb = plt.colorbar(ax, shrink=0.6, cax=cbar_ax, fraction=0.046)
        cb.ax.tick_params(labelsize=8)
        font = {
            'family': 'serif',
            'color': 'black',
            'weight': 'normal',
            'size': 8,
        }
        cb.set_label('Anomaly Score', fontdict=font)

        fig_img.savefig(os.path.join(save_dir, class_name + '_{}'.format(i)), dpi=100)
        plt.close()


def draw_auc(fp_list, tp_list, auc, path):
    plt.figure()
    plt.plot(fp_list, tp_list, color='darkorange', lw=2, label='ROC curve (area = {:.4f})'.format(auc))

    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic example')
    plt.legend(loc="lower right")
    plt.savefig(path)

    plt.clf()
    plt.cla()
    plt.close()


def draw_precision_recall(precision, recall, base_line, path):
    f1_score = []
    for _idx in range(0, len(precision)):
        _precision = precision[_idx]
        _recall = recall[_idx]

        if _precision + _recall == 0:
            _f1 = 0
        else:
            _f1 = 2 * (_precision * _recall) / (_precision + _recall)
        f1_score.append(_f1)

    plt.figure()
    plt.plot(recall, precision, marker='.', label='precision-recall curve')
    plt.plot([0, 1], [base_line, base_line], linestyle='--', color='grey', label='No skill ({:.04f})'.format(base_line))
    plt.plot(recall, f1_score, linestyle='-', color='red', label='f1 score (Max.: {:.4f})'.format(np.max(f1_score)))
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title('Precision-Recall Curve')
    plt.legend(loc='lower left')
    plt.savefig(path)

    plt.clf()
    plt.cla()
    plt.close()

    return np.max(f1_score)

In [7]:
# input_tensor = tf.keras.layers.Input([224, 224, 3], dtype=tf.float32)
# x = tf.keras.applications.efficientnet.preprocess_input(input_tensor)
# model = tf.keras.applications.EfficientNetB7(include_top=False,
#                                              weights='imagenet',
#                                              input_tensor=x,
#                                              pooling=None)                                
# model.summary()

In [8]:
# padim.py 
def embedding_net(net_type='res'):
    input_tensor = tf.keras.layers.Input([224, 224, 3], dtype=tf.float32)

    if net_type == 'res':
        # resnet 50v2
        x = tf.keras.applications.resnet_v2.preprocess_input(input_tensor)
        model = tf.keras.applications.ResNet50V2(include_top=False, weights='imagenet', input_tensor=x, pooling=None)

        layer1 = model.get_layer(name='conv3_block1_preact_relu').output
        layer2 = model.get_layer(name='conv4_block1_preact_relu').output
        layer3 = model.get_layer(name='conv5_block1_preact_relu').output

    elif net_type == 'eff':
        # efficient net B7
        x = tf.keras.applications.efficientnet.preprocess_input(input_tensor)
        model = tf.keras.applications.EfficientNetB7(include_top=False, weights='imagenet', input_tensor=x,
                                                     pooling=None)

        layer1 = model.get_layer(name='block5a_activation').output
        layer2 = model.get_layer(name='block6a_activation').output
        layer3 = model.get_layer(name='block7a_activation').output
        
    elif net_type == 'eff_net1':
        # efficient net B7; according to github gives best location indication
        x = tf.keras.applications.efficientnet.preprocess_input(input_tensor)
        model = tf.keras.applications.EfficientNetB7(include_top=False, weights='imagenet', input_tensor=x,
                                                     pooling=None)

        layer1 = model.get_layer(name='block5a_expand_activation').output
        layer2 = model.get_layer(name='block6a_expand_activation').output
        layer3 = model.get_layer(name='block7a_expand_activation').output
        
    elif net_type == 'res_chehlarov':
        # new by Chehlarov
        x = tf.keras.applications.resnet.preprocess_input(input_tensor)
        model = tf.keras.applications.ResNet101(include_top=False, weights='imagenet', input_tensor=x, pooling=None)

        layer1 = model.get_layer(name='conv3_block1_preact_relu').output
        layer2 = model.get_layer(name='conv4_block1_preact_relu').output
        layer3 = model.get_layer(name='conv5_block1_preact_relu').output

    else:
        raise Exception("[NotAllowedNetType] network type is not allowed ")

    model.trainable = False
    # model.summary(line_length=100)
    shape = (layer1.shape[1], layer1.shape[2], layer1.shape[3] + layer2.shape[3] + layer3.shape[3])

    return tf.keras.Model(model.input, outputs=[layer1, layer2, layer3]), shape

In [9]:
def padim(category, batch_size, rd, net_type='eff', is_plot=False):
    loader = MVTecADLoader()
    loader.load(category=category, repeat=1, max_rot=0)

    train_set = loader.train.batch(batch_size=batch_size, drop_remainder=True).shuffle(buffer_size=loader.num_train,
                                                                                       reshuffle_each_iteration=True)
    test_set = loader.test.batch(batch_size=1, drop_remainder=False)

    net, _shape = embedding_net(net_type=net_type)
    h, w, c = _shape  # height and width of layer1, channel sum of layer 1, 2, and 3, and randomly sampled dimension

    out = []
    for x, _, _ in train_set:
        l1, l2, l3 = net(x)
        _out = tf.reshape(embedding_concat(embedding_concat(l1, l2), l3), (batch_size, h * w, c))  # (b, h x w, c)
        out.append(_out.numpy())

    # calculate multivariate Gaussian distribution.
    out = np.concatenate(out, axis=0)
    out = np.transpose(out, axes=[0, 2, 1])  # (b, c, h * w)

    # RD: random dimension selecting
    tmp = tf.unstack(out, axis=0)
    _tmp = []
    rd_indices = tf.random.shuffle(tf.range(c))[:rd]
    for tensor in tmp:
        _tmp.append(tf.gather(tensor, rd_indices))
    out = tf.stack(_tmp, axis=0)

    mu = np.mean(out, axis=0)
    cov = np.zeros((rd, rd, h * w))
    identity = np.identity(rd)

    for idx in range(h * w):
        cov[:, :, idx] = np.cov(out[:, :, idx], rowvar=False) + 0.01 * identity

    train_outputs = [mu, cov]

    out, gt_list, gt_mask, batch_size, test_imgs = [], [], [], 1, []
    #  x - data |   y - mask    |   z - binary label
    for x, y, z in test_set:
        test_imgs.append(x.numpy())
        gt_list.append(z.numpy())
        gt_mask.append(y.numpy())

        l1, l2, l3 = net(x)
        _out = tf.reshape(embedding_concat(embedding_concat(l1, l2), l3), (batch_size, h * w, c))  # (BS, h x w, c)
        out.append(_out.numpy())

    # calculate multivariate Gaussian distribution. skip random dimension selecting
    out = np.concatenate(out, axis=0)
    gt_list = np.concatenate(gt_list, axis=0)
    out = np.transpose(out, axes=[0, 2, 1])

    # RD
    tmp = tf.unstack(out, axis=0)
    _tmp = []
    for tensor in tmp:
        _tmp.append(tf.gather(tensor, rd_indices)) # Chehlarov: why random - comments above say the opposite
    out = tf.stack(_tmp, axis=0)

    b, _, _ = out.shape

    dist_list = []
    for idx in range(h * w):
        mu = train_outputs[0][:, idx]
        cov_inv = np.linalg.inv(train_outputs[1][:, :, idx])
        dist = [mahalanobis(sample[:, idx], mu, cov_inv) for sample in out]
        dist_list.append(dist)

    dist_list = np.reshape(np.transpose(np.asarray(dist_list), axes=[1, 0]), (b, h, w))

    ################
    #   DATA Level #
    ################
    # upsample
    score_map = tf.squeeze(tf.image.resize(np.expand_dims(dist_list, -1), size=[h, w])).numpy()

    for i in range(score_map.shape[0]):
        score_map[i] = gaussian_filter(score_map[i], sigma=4)

    # Normalization
    max_score = score_map.max()
    min_score = score_map.min()
    scores = (score_map - min_score) / (max_score - min_score)
    scores = -scores

    # calculate image-level ROC AUC score
    img_scores = scores.reshape(scores.shape[0], -1).max(axis=1)

    gt_list = np.asarray(gt_list)
    img_roc_auc = metrics.roc_auc_score(gt_list, img_scores)

    if is_plot is True:
        fpr, tpr, _ = metrics.roc_curve(gt_list, img_scores)
        precision, recall, _ = metrics.precision_recall_curve(gt_list, img_scores)

        save_dir = os.path.join(os.getcwd(), 'img')
        if os.path.isdir(save_dir) is False:
            os.mkdir(save_dir)
        draw_auc(fpr, tpr, img_roc_auc, os.path.join(save_dir, 'AUROC-{}.png'.format(category)))
        base_line = np.sum(gt_list) / len(gt_list)
        draw_precision_recall(precision, recall, base_line, os.path.join(os.path.join(save_dir,
                                                                                      'PR-{}.png'.format(category))))

    #################
    #   PATCH Level #
    #################
    # upsample
    score_map = tf.squeeze(tf.image.resize(np.expand_dims(dist_list, -1), size=[224, 224])).numpy()

    for i in range(score_map.shape[0]):
        score_map[i] = gaussian_filter(score_map[i], sigma=4)

    # Normalization
    max_score = score_map.max()
    min_score = score_map.min()
    scores = (score_map - min_score) / (max_score - min_score)
    # Note that Binary mask indicates 0 for good and 1 for anomaly. It is opposite from our setting.
    # scores = -scores

    # calculate per-pixel level ROCAUC
    gt_mask = np.asarray(gt_mask)
    fp_list, tp_list, _ = metrics.roc_curve(gt_mask.flatten(), scores.flatten())
    patch_auc = metrics.auc(fp_list, tp_list)

    precision, recall, threshold = metrics.precision_recall_curve(gt_mask.flatten(), scores.flatten(), pos_label=1)
    numerator = 2 * precision * recall
    denominator = precision + recall

    numerator[np.where(denominator == 0)] = 0
    denominator[np.where(denominator == 0)] = 1

    # get optimal threshold
    f1_list = numerator / denominator
    best_ths = threshold[np.argmax(f1_list).astype(int)]

    print('[{}] image ROCAUC: {:.04f}\t pixel ROCAUC: {:.04f}'.format(category, img_roc_auc, patch_auc))

    if is_plot is True:
        save_dir = os.path.join(os.getcwd(), 'img')
        if os.path.isdir(save_dir) is False:
            os.mkdir(save_dir)
        plot_fig(test_imgs, scores, gt_mask, best_ths, save_dir, category)

    return img_roc_auc, patch_auc

In [10]:
np.random.seed(10)
# random.seed(10)
tf.random.set_seed(10)
padim(category='carpet', batch_size=2, rd=400, net_type='eff_net1', is_plot=True)

[carpet] image ROCAUC: 0.9446	 pixel ROCAUC: 0.9720


(0.9446227929373997, 0.9720316032862548)