# Installs

In [None]:
# Mount Drive
# You dont need to run this cell unless you connect google cloud

from google.colab import drive
drive.mount("/content/drive", force_remount=True)

## DONT WORRY IF THE INSTALLS SHOW AN ERROR THE NOTEBOOK WILL RUN ANYWAYS.
### THE FOLLOWING ERROR IS EXPECTED:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
yellowbrick 1.4 requires scikit-learn>=1.0.0, but you have scikit-learn 0.24.2 which is incompatible.
torchtext 0.12.0 requires torch==1.11.0, but you have torch 1.10.2 which is incompatible.
torchaudio 0.11.0+cu113 requires torch==1.11.0, but you have torch 1.10.2 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.

In [None]:
!pip install --quiet "torch" "pytorch-lightning" "opencv-python==4.5.2.52" "scikit-learn==0.24.2" "torchmetrics" "torchvision==0.11.3" "pyod"
!pip install faiss-gpu==1.7.1.post3
!sudo apt-get install libopenblas-dev
!sudo apt-get install libomp-dev


Check if we are running on gpu and check the memory available

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

#Imports

In [None]:
import numpy as np
import cv2
import os
import psutil
import os.path as path
import glob
import shutil
import pickle
import faiss
import torch
import argparse
import gc
import torchvision.models as models
import torch.nn.functional as F
import pytorch_lightning as pl
import tensorflow as tf

from PIL import Image
from torch import nn
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from tqdm.notebook import tqdm

from sklearn.metrics import roc_auc_score, confusion_matrix
from sklearn.random_projection import SparseRandomProjection
from sklearn.neighbors import NearestNeighbors
from scipy.ndimage import gaussian_filter

from tempfile import mkdtemp


In [None]:
print('Testing gpu availability: ', end='')
print(tf.test.gpu_device_name())

# DATASET

In [None]:
class MVTecDataset(Dataset):
    def __init__(self, root, category, transform, gt_transform, phase, load_complete_data=False):
        if phase=='train':
          if load_complete_data:
            all_categories = [i for i in os.listdir(root) if os.path.isdir(os.path.join(root, i))]
            self.img_paths_root = [os.path.join(root, i_category, 'train') for i_category in all_categories]
          elif '_' in category:
            self.img_paths_root = [os.path.join(root, i_category, 'train') for i_category in category.split('_')]
          else:
            self.img_paths_root = [os.path.join(root, category, 'train')]
        
        else:
          if load_complete_data:
            all_categories = [i for i in os.listdir(root) if os.path.isdir(os.path.join(root, i))]
            self.img_paths_root = [os.path.join(root, i_category, 'test') for i_category in all_categories]
            self.gt_paths_root = [os.path.join(root, i_category, 'ground_truth') for i_category in all_categories]
          elif '_' in category:
            self.img_paths_root = [os.path.join(root, i_category, 'test') for i_category in category.split('_')]
            self.gt_paths_root = [os.path.join(root, i_category,'ground_truth') for i_category in category.split('_')]
          else:
            self.img_paths_root = [os.path.join(root, category, 'test')]
            self.gt_paths_root = [os.path.join(root, category,'ground_truth')]

        self.transform = transform
        self.gt_transform = gt_transform
        # load dataset
        self.img_paths, self.gt_paths, self.labels, self.types = self.load_dataset() # self.labels => good : 0, anomaly : 1

    def load_dataset(self):

        img_tot_paths = []
        gt_tot_paths = []
        tot_labels = []
        tot_types = []

        for ix, img_path in enumerate(self.img_paths_root):
          defect_types = os.listdir(img_path)

          for defect_type in defect_types:
              if defect_type == 'good':
                  img_paths = glob.glob(os.path.join(img_path, defect_type) + "/*.png")
                  img_tot_paths.extend(img_paths)
                  gt_tot_paths.extend([0]*len(img_paths))
                  tot_labels.extend([0]*len(img_paths))
                  tot_types.extend(['good']*len(img_paths))
              else:
                  gt_path = self.gt_paths_root[ix]
                  img_paths = glob.glob(os.path.join(img_path, defect_type) + "/*.png")
                  gt_paths = glob.glob(os.path.join(gt_path, defect_type) + "/*.png")
                  img_paths.sort()
                  gt_paths.sort()
                  img_tot_paths.extend(img_paths)
                  gt_tot_paths.extend(gt_paths)
                  tot_labels.extend([1]*len(img_paths))
                  tot_types.extend([defect_type]*len(img_paths))

        assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
        return img_tot_paths, gt_tot_paths, tot_labels, tot_types

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path, gt, label, img_type = self.img_paths[idx], self.gt_paths[idx], self.labels[idx], self.types[idx]
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)
        if gt == 0:
            gt = torch.zeros([1, img.size()[-2], img.size()[-2]])
        else:
            gt = Image.open(gt)
            gt = self.gt_transform(gt)
        
        assert img.size()[1:] == gt.size()[1:], "image.size != gt.size !!!"

        return img, gt, label, os.path.basename(img_path[:-4]), img_type

# Filemanagement

In [None]:
def copy_files(src, dst, ignores=[]):
    src_files = os.listdir(src)
    for file_name in src_files:
        ignore_check = [True for i in ignores if i in file_name]
        if ignore_check:
            continue
        full_file_name = os.path.join(src, file_name)
        if os.path.isfile(full_file_name):
            shutil.copy(full_file_name, os.path.join(dst,file_name))
        if os.path.isdir(full_file_name):
            os.makedirs(os.path.join(dst, file_name), exist_ok=True)
            copy_files(full_file_name, os.path.join(dst, file_name), ignores)

def prep_dirs(root):
    # make embeddings dir
    # embeddings_path = os.path.join(root, 'embeddings')
    embeddings_path = os.path.join('./', 'embeddings', args.category)
    os.makedirs(embeddings_path, exist_ok=True)
    # make sample dir
    sample_path = os.path.join(root, 'sample')
    os.makedirs(sample_path, exist_ok=True)
    # make source code record dir & copy
    source_code_save_path = os.path.join(root, 'src')
    os.makedirs(source_code_save_path, exist_ok=True)
    # copy_files('./', source_code_save_path, ['.git','.vscode','__pycache__','logs','README','samples','LICENSE']) # copy source code
    return embeddings_path, sample_path, source_code_save_path


# Evaluation

In [None]:
def cal_confusion_matrix(y_true, y_pred_no_thresh, thresh, img_path_list):
    pred_thresh = []
    false_n = []
    false_p = []
    for i in range(len(y_pred_no_thresh)):
        if y_pred_no_thresh[i] > thresh:
            pred_thresh.append(1)
            if y_true[i] == 0:
                false_p.append(img_path_list[i])
        else:
            pred_thresh.append(0)
            if y_true[i] == 1:
                false_n.append(img_path_list[i])

    cm = confusion_matrix(y_true, pred_thresh)
    print(cm)
    print('false positive')
    print(false_p)
    print('false negative')
    print(false_n)
    return cm, false_p, false_n
    

# Distances

In [None]:
def distance_matrix(x, y=None, power=2, distance_batch_size=0):  # pairwise distance of vectors

    y = x if type(y) == type(None) else y

    n, m, d = x.size(0), y.size(0), x.size(1)

    
    x_unsqueezed = x.unsqueeze(1).expand(n, m, d)
    y_unsqueezed = y.unsqueeze(0).expand(n, m, d)
    
    if distance_batch_size == 0:
      dist = torch.pow(x_unsqueezed - y_unsqueezed, power).sum(2)
    elif distance_batch_size > 0:
      dist_l = []
      for i in range(0, m, distance_batch_size):
        dist_l.append(torch.pow(x_unsqueezed[:, i:i+distance_batch_size, :] - y_unsqueezed[:, i:i+distance_batch_size, :], power).sum(2))
      dist = torch.cat(dist_l, 1)
    else:
      raise ValueError(f'Parameter "distance_batch_size"={distance_batch_size} is Invalid. Please define a positive Integer.')
      
    return dist

class NN():

    def __init__(self, X=None, Y=None, p=2):
        self.p = p
        self.train(X, Y)

    def train(self, X, Y):
        self.train_pts = X
        self.train_label = Y

    def __call__(self, x):
        return self.predict(x)

    def predict(self, x):
        if type(self.train_pts) == type(None) or type(self.train_label) == type(None):
            name = self.__class__.__name__
            raise RuntimeError(f"{name} wasn't trained. Need to execute {name}.train() first")

        dist = distance_matrix(x, self.train_pts, self.p) ** (1 / self.p)
        labels = torch.argmin(dist, dim=1)
        return self.train_label[labels]

class KNN(NN):

    def __init__(self, X=None, Y=None, k=3, p=2, distance_batch_size=0):
        self.k = k
        self.distance_batch_size = distance_batch_size
        super().__init__(X, Y, p)

    def train(self, X, Y):
        super().train(X, Y)
        if type(Y) != type(None):
            self.unique_labels = self.train_label.unique()

    def predict(self, x):
        dist = distance_matrix(x, self.train_pts, self.p, distance_batch_size=self.distance_batch_size) ** (1 / self.p)
        knn = dist.topk(self.k, largest=False)

        return knn


# Sampling

In [None]:
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Returns points that minimizes the maximum distance of any point to a center.
Implements the k-Center-Greedy method in
Ozan Sener and Silvio Savarese.  A Geometric Approach to Active Learning for
Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017
Distance metric defaults to l2 distance.  Features used to calculate distance
are either raw features or if a model has transform method then uses the output
of model.transform(X).
Can be extended to a robust k centers algorithm that ignores a certain number of
outlier datapoints.  Resulting centers are solution to multiple integer program.
"""

from sklearn.metrics import pairwise_distances
import abc
import numpy as np

class SamplingMethod(object):
  __metaclass__ = abc.ABCMeta

  @abc.abstractmethod
  def __init__(self, X, y, seed, **kwargs):
    self.X = X
    self.y = y
    self.seed = seed

  def flatten_X(self):
    shape = self.X.shape
    flat_X = self.X
    if len(shape) > 2:
      flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:])))
    return flat_X


  @abc.abstractmethod
  def select_batch_(self):
    return

  def select_batch(self, **kwargs):
    return self.select_batch_(**kwargs)

  def to_dict(self):
    return None

class kCenterGreedy(SamplingMethod):

  def __init__(self, X, y, seed, metric='euclidean'):
    self.X = X
    self.y = y
    self.flat_X = self.flatten_X()
    self.name = 'kcenter'
    self.features = self.flat_X
    self.metric = metric
    self.min_distances = None
    self.n_obs = self.X.shape[0]
    self.already_selected = []

  def update_distances(self, cluster_centers, only_new=True, reset_dist=False):
    """Update min distances given cluster centers.
    Args:
      cluster_centers: indices of cluster centers
      only_new: only calculate distance for newly selected points and update
        min_distances.
      rest_dist: whether to reset min_distances.
    """

    if reset_dist:
      self.min_distances = None
    if only_new:
      cluster_centers = [d for d in cluster_centers
                         if d not in self.already_selected]
    if cluster_centers:
      # Update min_distances for all examples given new cluster center.
      x = self.features[cluster_centers]
      dist = pairwise_distances(self.features, x, metric=self.metric)

      if self.min_distances is None:
        self.min_distances = np.min(dist, axis=1).reshape(-1,1)
      else:
        self.min_distances = np.minimum(self.min_distances, dist)

  def select_batch_(self, model, already_selected, N, **kwargs):
    """
    Diversity promoting active learning method that greedily forms a batch
    to minimize the maximum distance to a cluster center among all unlabeled
    datapoints.
    Args:
      model: model with scikit-like API with decision_function implemented
      already_selected: index of datapoints already selected
      N: batch size
    Returns:
      indices of points selected to minimize distance to cluster centers
    """

    try:
      # Assumes that the transform function takes in original data and not
      # flattened data.
      print('Getting transformed features...')
      if model:
        self.features = model.transform(self.X)
      else:
        self.features = self.X
      
      print('Calculating distances...')
      self.update_distances(already_selected, only_new=False, reset_dist=True)
    except:
      print('Using flat_X as features.')
      self.update_distances(already_selected, only_new=True, reset_dist=False)

    new_batch = []

    for _ in tqdm(range(N)):
      if self.already_selected is None:
        # Initialize centers with a randomly selected datapoint
        ind = np.random.choice(np.arange(self.n_obs))
      else:
        ind = np.argmax(self.min_distances)
      # New examples should not be in already selected since those points
      # should have min_distance of zero to a cluster center.
      assert ind not in already_selected

      self.update_distances([ind], only_new=True, reset_dist=False)
      new_batch.append(ind)

    self.already_selected = already_selected
    return new_batch

In [None]:
def embedding_concat(x, y):
    # from https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection-Localization-master
    B, C1, H1, W1 = x.size()
    _, C2, H2, W2 = y.size()
    s = int(H1 / H2)
    x = F.unfold(x, kernel_size=s, dilation=1, stride=s)
    x = x.view(B, C1, -1, H2, W2)
    z = torch.zeros(B, C1 + C2, x.size(2), H2, W2)
    for i in range(x.size(2)):
        z[:, :, i, :, :] = torch.cat((x[:, :, i, :, :], y), 1)
    z = z.view(B, -1, H2 * W2)
    z = F.fold(z, kernel_size=s, output_size=(H1, W1), stride=s)

    return z

def reshape_embedding(embedding):
    embedding_list = []
    for k in range(embedding.shape[0]):
        for i in range(embedding.shape[2]):
            for j in range(embedding.shape[3]):
                embedding_list.append(embedding[k, :, i, j])
    return embedding_list


# Patchcore Custom

In [None]:
class STPM(pl.LightningModule):
    def __init__(self, hparams):
        super(STPM, self).__init__()
        #self.save_hyperparameters(hparams)

        self.init_features()
        def hook_t(module, input, output):
            self.features.append(output)

        # USE THIS IF YOU ARE INTERESTED IN efficientnet_b7
        # self.model = models.efficientnet_b7(pretrained=True)
        # self.model._modules["features"][5][-1]._modules['block'][-1]._modules['0'].register_forward_hook(hook_t)
        # self.model._modules["features"][6][-1]._modules['block'][-1]._modules['0'].register_forward_hook(hook_t)

        torch.hub._validate_not_a_forked_repo=lambda a,b,c: True # bug workaround to load resnet correctly
        self.model = torch.hub.load(args.pytorch_version, args.model_name, pretrained=True)
        self.model.layer2[-1].register_forward_hook(hook_t)
        self.model.layer3[-1].register_forward_hook(hook_t)

        for param in self.model.parameters():
            param.requires_grad = False

        # just for pytorch lightning
        self.criterion = torch.nn.MSELoss(reduction='sum')

        self.init_results_list()

        # imagenet
        mean_train = [0.485, 0.456, 0.406]
        std_train = [0.229, 0.224, 0.225]

        # transform data
        self.data_transforms = transforms.Compose([
                        transforms.Resize((args.load_size, args.load_size), Image.ANTIALIAS),
                        transforms.ToTensor(),
                        transforms.CenterCrop(args.input_size),
                        transforms.Normalize(mean=mean_train,
                                             std=std_train)])
        self.gt_transforms = transforms.Compose([
                        transforms.Resize((args.load_size, args.load_size)),
                        transforms.ToTensor(),
                        transforms.CenterCrop(args.input_size)])

        self.inv_normalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255], std=[1/0.229, 1/0.224, 1/0.255])

    def init_results_list(self):
        self.gt_list_px_lvl = []
        self.pred_list_px_lvl = []
        self.gt_list_img_lvl = []
        self.pred_list_img_lvl = []
        self.img_path_list = []        

    def init_features(self):
        self.features = []

    def forward(self, x_t):
        self.init_features()
        _ = self.model(x_t)
        return self.features

    def train_dataloader(self):
        image_datasets = MVTecDataset(root=os.path.join(args.dataset_path),
                                      category=args.category,
                                      transform=self.data_transforms,
                                      gt_transform=self.gt_transforms,
                                      phase='train',
                                      load_complete_data=args.load_complete_data)
        train_loader = DataLoader(image_datasets,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4,
                                  pin_memory=True)
        return train_loader

    def test_dataloader(self):
        test_datasets = MVTecDataset(root=os.path.join(args.dataset_path),
                                     category=args.category,
                                     transform=self.data_transforms,
                                     gt_transform=self.gt_transforms,
                                     phase='test',
                                     load_complete_data=args.load_complete_data)
        test_loader = DataLoader(test_datasets,
                                 batch_size=1, # batchsize 1 is required
                                 shuffle=False,
                                 num_workers=4, 
                                 pin_memory=True)
        return test_loader

    def configure_optimizers(self):
        return None

    def on_train_start(self):
        self.model.eval()
        _, self.sample_path, self.source_code_save_path = prep_dirs(self.logger.log_dir)

        self.embedding_dir_path = f'{args.embedding_path}/{args.category}'
        if not os.path.isdir(self.embedding_dir_path):
          os.makedirs(self.embedding_dir_path)
        self.embedding_list = []

    def on_test_start(self):

        self.embedding_dir_path = f'{args.embedding_path}/{args.category}'
        embedding = pickle.load(open(os.path.join(self.embedding_dir_path, 'embedding.pickle'), 'rb'))
        self.embedding_coreset = embedding['embedding']

        #self.embedding_dir_path = f'{args.embedding_path}/{args.category}'
        #self.index = faiss.read_index(os.path.join(self.embedding_dir_path,'index.faiss'))
        #if torch.cuda.is_available():
        #    res = faiss.StandardGpuResources()
        #    self.index = faiss.index_cpu_to_gpu(res, 0 ,self.index)
        self.init_results_list()
        #self.embedding_dir_path, self.sample_path, self.source_code_save_path = prep_dirs(self.logger.log_dir)

    def training_step(self, batch, batch_idx):

        # Patches baby 
        x, _, _, file_name, _ = batch
        features = self(x)
        embeddings = []
        for feature in features:
            m = torch.nn.AvgPool2d(3, 1, 1)
            embeddings.append(m(feature))
        embedding = embedding_concat(embeddings[0], embeddings[1])
        embedding_reshaped = reshape_embedding(np.array(embedding))

        with open(os.path.join(self.embedding_dir_path, f'{args.category}_{batch_idx}_embedding.pickle'), 'wb') as f:
          pickle.dump(embedding_reshaped, f) 

        del features
        del embeddings
        del embedding
        del embedding_reshaped
        gc.collect()

        memory_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2
        print(f'Batch_idx: {batch_idx}, Memory Usage: {memory_usage}')
        #self.embedding_list.extend(reshape_embedding(np.array(embedding)))

    def training_epoch_end(self, outputs): 

        embeddings = []
        for ix in range(1000):
          if f'{args.category}_{ix}_embedding.pickle' in os.listdir(self.embedding_dir_path):
            embedding_path = f'{self.embedding_dir_path}/{args.category}_{ix}_embedding.pickle'
            with open(embedding_path, 'rb') as f:
              embedding_coreset = pickle.load(f)
              self.embedding_list.extend(embedding_coreset)

        total_embeddings = np.array(self.embedding_list)
        
        print('total_embeddings.shape', total_embeddings.shape)

        # Random projection
        self.randomprojector = SparseRandomProjection(n_components='auto', eps=0.9) # 'auto' => Johnson-Lindenstrauss lemma
        self.randomprojector.fit(total_embeddings)

        # Coreset Subsampling
        selector = kCenterGreedy(total_embeddings,0,0)
        selected_idx = selector.select_batch(model=self.randomprojector, already_selected=[], N=int(total_embeddings.shape[0]*args.coreset_sampling_ratio))
        self.embedding_coreset = total_embeddings[selected_idx]
        
        # print('initial embedding size : ', total_embeddings.shape)
        print('final embedding size : ', self.embedding_coreset.shape)

        # faiss
        self.index = faiss.IndexFlatL2(self.embedding_coreset.shape[1])
        self.index.add(self.embedding_coreset) 
        faiss.write_index(self.index,  os.path.join(self.embedding_dir_path,'index.faiss'))

        # save embedding
        with open(os.path.join(self.embedding_dir_path,'embedding.pickle'), 'wb') as f:
          pickle.dump(self.embedding_coreset, f) 

    def test_step(self, batch, batch_idx): # Nearest Neighbour Search
        x, gt, label, file_name, x_type = batch

        # extract embedding
        features = self(x)
        embeddings = []
        for feature in features:
            m = torch.nn.AvgPool2d(3, 1, 1)
            embeddings.append(m(feature))

        embedding_ = embedding_concat(embeddings[0], embeddings[1])
        embedding_test = np.array(reshape_embedding(np.array(embedding_)))
        score_patches, _ = self.index.search(embedding_test , k=args.n_neighbors)

        N_b = score_patches[np.argmax(score_patches[:,0])]
        w = (1 - (np.max(np.exp(N_b))/np.sum(np.exp(N_b))))
        score = w*max(score_patches[:,0]) # Image-level score
        gt_np = gt.cpu().numpy()[0,0].astype(int)

        self.gt_list_px_lvl.extend(gt_np.ravel())
        self.gt_list_img_lvl.append(label.cpu().numpy()[0])
        self.pred_list_img_lvl.append(score)
        self.img_path_list.extend(file_name)


        # COMMENT IN FOR ANOMALY MAPS

        # anomaly_map = score_patches[:,0].reshape((28,28))
        # anomaly_map_resized = cv2.resize(anomaly_map, (args.input_size, args.input_size))
        # anomaly_map_resized_blur = gaussian_filter(anomaly_map_resized, sigma=4)
        # self.pred_list_px_lvl.extend(anomaly_map_resized_blur.ravel())
        # save images
        # x = self.inv_normalize(x)
        # input_x = cv2.cvtColor(x.permute(0,2,3,1).cpu().numpy()[0]*255, cv2.COLOR_BGR2RGB)
        # self.save_anomaly_map(anomaly_map_resized_blur, input_x, gt_np*255, file_name[0], x_type[0])

    def test_epoch_end(self, outputs):

        img_auc = roc_auc_score(self.gt_list_img_lvl, self.pred_list_img_lvl)
        
        if not os.path.isdir(args.result_path):
          os.mkdir(args.result_path)

        save_path = f'{args.result_path}/{args.category}.pickle'

        results = {
                  'groud_truth': self.gt_list_img_lvl,
                  'scores': self.pred_list_img_lvl,
                  'image_paths': self.img_path_list,
                  'img_auc': img_auc
                  }

        with open(save_path, 'wb') as f:
            pickle.dump(results, f)
        
        print("Total image-level auc-roc score :")
        print(img_auc)
        print('test_epoch_end')

        # thresholding
        cm, false_p, false_n = cal_confusion_matrix(self.gt_list_img_lvl,
                                                    self.pred_list_img_lvl,
                                                    img_path_list=self.img_path_list,
                                                    thresh = args.threshhold_img_lvl)
        output_cm_str = args.category + ':\n\t' + f'Confusion Matrix:\n\t{cm}\n\tFalse positive:{false_p}\n\tFalse negative:{false_n}'
        print(output_cm_str)

        with open(f'{args.result_path}/{args.category}_cofusion_matrix.txt', 'w') as f:
           f.write(output_cm_str)


        # COMMENT IN FOR PIXEL LEVEL !MAY NEED DEBUGGING
        # print("Total pixel-level auc-roc score :")
        # pixel_auc = roc_auc_score(self.gt_list_px_lvl, self.pred_list_px_lvl)
        # print(pixel_auc)
        # values = {'pixel_auc': pixel_auc, 'img_auc': img_auc}
        # self.log_dict(values)


# RUN

## Example and NECESSARY Folder Structure
### PLEASE NOTE: That we create the folder "run_from_here" and we cd into it. This is because of logcreation and some other folders. I have not been able to remove it from the code...
### THIS MEANS that you probably need to append one more ../ to your data path at the beginning

./
  - some_folder/
    - this_notebook.ipynb <--- this is this notebook
    - run_from_here/ <--- this gets created and we cd into this dir because of log creation
  - datafolder/
    - category/
      - ground_truth/ <--- only needed if we have masks and want pixle level
        - anomaly_1/
          - 000.png
          - 001.png
      - train/ <--- THIS FOLDER NEEDS TO CONTAIN ONLY GOOD IMAGES, NO ANOMALIES!
        - good/
          - 000.png
          - 001.png
      - test/
        - good/
          - 000.png
          - 001.png
        - anomaly_1/
          - 000.png
          - 001.png
        - anomly_2/
          - 000.png
          - 001.png


In [None]:
import os
import shutil
if not 'run_from_here' in os.getcwd():
  !mkdir run_from_here
  %cd run_from_here
else:
  print( os.getcwd())
def _empty_cache():
  torch.cuda.empty_cache()
  gc.collect()
  gc.collect()

In [None]:
  def get_args(name='carpet', layers=None):
    parser = argparse.ArgumentParser(description='ANOMALYDETECTION')

    parser.add_argument('--phase', choices=['train','test', 'train_test'], default='train_test')
    parser.add_argument('--category', default=None)
    parser.add_argument('--load_complete_data', default=False)

    # paths
    parser.add_argument('--dataset_path', default='../drive/MyDrive/data/mvtec_anomaly_detection') # remember that we create run_from_here-folder so add ../ from where you would normaly put the path
    parser.add_argument('--project_root_path', default='./')
    parser.add_argument('--embedding_path', default='../drive/MyDrive/data/test/wideresnet_embedding')
    parser.add_argument('--result_path', default='../drive/MyDrive/data/test/results')
    
    # Training and Image level parameters
    parser.add_argument('--num_epochs', default=1)
    parser.add_argument('--batch_size', default=32)
    parser.add_argument('--load_size', default=512) # 256
    parser.add_argument('--input_size', default=224)

    # Pytorch
    parser.add_argument('--pytorch_version', default='pytorch/vision:v0.9.0')
    parser.add_argument('--model_name', default='wide_resnet50_2')

    # Patchcore parameters
    parser.add_argument('--coreset_sampling_ratio', default=0.1) # recommended, less is faster but performs worse
    parser.add_argument('--threshhold_img_lvl', type=float, default=0.00097) # fine tune this after you have your scores and your auroc
    parser.add_argument('--n_neighbors', type=int, default=9)
    parser.add_argument('--nearest_neighbors', type=int, default=9)
    parser.add_argument('--distance_batch_size', default=1000)
    
    args = parser.parse_args(['--category', name])

    return args


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

_to_test = [
            'bottle' # this is your foldername which contains all the data
        ]

for category in _to_test:

    args = get_args(category)
    print()
    print(args.category)
    print()

    trainer = Trainer.from_argparse_args(args,
                                         default_root_dir=os.path.join(args.project_root_path, args.category),
                                         max_epochs=args.num_epochs,
                                         gpus=1,
                                         log_every_n_steps=1)
                                         # not needed but if you are interested
                                         #, tpu_cores=[4])) #, check_val_every_n_epoch=args.val_freq,  num_sanity_val_steps=0) # ,fast_dev_run=True)

    model = STPM(hparams=args)
    if args.phase == 'train':
        trainer.fit(model)
        _empty_cache()
    elif args.phase == 'test':
        trainer.test(model)
        _empty_cache()
    elif args.phase == 'train_test':
        trainer.fit(model)
        _empty_cache()
        trainer.test(model)
    _empty_cache()


In [None]:

# this is how you would load your results:
args = get_args('bottle')
with open(f'{args.embedding_path}/{args.category}/embedding.pickle', 'rb') as f:
  embedding_coreset = pickle.load(f)
embedding_coreset.shape

In [None]:

# this is how you would load your embeddings:
args = get_args('bottle')

with open(f'{args.result_path}/{args.category}.pickle', 'rb') as f:
  results = pickle.load(f)
  print(results)

with open(f'{args.result_path}/{args.category}_cofusion_matrix.txt', 'rb') as f:
  cm_str = f.read()

cm_str 
