In [1]:
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import ast
import random
import matplotlib.pyplot as plt

from torch.distributions import MultivariateNormal
from IPython.display import clear_output
import scipy.ndimage as ndimage
import os
import cv2
from torch.utils.data import DataLoader, Dataset
import metrics
import joblib
import copy
from torchvision import transforms
from enum import Enum
import faiss

from dataloaders import dataloader_MVTec_setup
from utils import get_files_masks, evaluate
from evaluate import evaluate_metrics, get_scores, load_model, visualize_result
from DQN import DQN
from environment import Environment

In [4]:
def parse_tuple(value):
    try:
        return ast.literal_eval(value)
    except (ValueError, SyntaxError):
        raise argparse.ArgumentTypeError(f"Invalid tuple: {value}")
parser = argparse.ArgumentParser()
# ------------ 
# setup
# ------------
parser.add_argument("--use_gpu", action="store_true", help="Use GPU for training")
#parser.add_argument("--data_root", type=str, default= "./../../../../scratch-beauty/zzhan762/data/MVTec_AD")
#parser.add_argument("--data_root", type=str, default= "./mvtec")
#parser.add_argument("--data_root", type=str, default= "./data/mvtec")
parser.add_argument("--data_root", type=str, default= "./data/BTech_Dataset_transformed")
parser.add_argument("--verbose", type=bool, default=True)
# ------------------ 
# feature extractor
# ------------------
parser.add_argument("--target_size", type=parse_tuple, default="(256,256)")
parser.add_argument("--resize_size", type=parse_tuple, default="(256,256)")
parser.add_argument("--class_name", type=str, default="toothbrush")
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--patch_size",type=int, default=3)
parser.add_argument("--target_embed_dimension", type=int, default=1024)
parser.add_argument("--edc", action="store_true")
parser.add_argument("--backbone", type=str,default="wide_resnet50_2")
parser.add_argument("--preprocessing_dimension", type=int, default=1024)

# ------------
# DQN 
# ------------
parser.add_argument("--action_dim", type=int, default=2)
parser.add_argument("--max_steps",type=int, default=int(4e4))
parser.add_argument("--eval_interval", type=int, default =int(1e3))
parser.add_argument("--lr",type=float,default = 0.00025)
parser.add_argument("--epsilon",type=float,default=1.0)
parser.add_argument("--epsilon_min",type=float,default=0.1)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--dqn_batch_size",type=int,default=32)
parser.add_argument("--warmup_steps",type=int, default=int(2e3))
parser.add_argument("--buffer_size",type=int,default=int(5e3))
parser.add_argument("--target_update_interval",type=int,default=int(5e3))

# ---------------
# Envirnment
# ---------------
parser.add_argument("--prob", type=float, default=0.5)
parser.add_argument("--max_samples",type=float,default=40)
parser.add_argument("--iForest_update_interval",type=int,default=int(2e3))
parser.add_argument("--iForest_max_samples",type=float,default=0.3)
parser.add_argument("--iForest_total_samples",type=int,default=int(1e6))
parser.add_argument("--iForest_batch_size",type=int,default=1024)


# -----------------------
# components
# -----------------------
parser.add_argument("--use_prioritized", type=bool, default=False)
parser.add_argument("--use_intrinsic",type=bool, default=False)
parser.add_argument("--use_copypaste", type=bool, default=False)
parser.add_argument("--use_faiss", type=bool, default=False)


args = parser.parse_args(['--use_gpu'])

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")
print('device: {}'.format(device))
args.backbone = 'wide_resnet50_2'
args.layers = ['layer2','layer3']
args.hidden_sizes = [512,256,128]
args.state_dim = args.target_embed_dimension

device: cuda


In [6]:
def get_files_masks(dataroot,class_name,target_types, unknown_types,train_num, verbose=False):
    trainpath = os.path.join(dataroot,class_name,'train','good')
    testpath = os.path.join(dataroot, class_name,'test')
    maskpath = os.path.join(dataroot, class_name, "ground_truth")
    train_normal_files = sorted(os.listdir(trainpath))
    train_normal_files = [os.path.join(trainpath,x) for x in train_normal_files]

    # test normal files
    test_normal_files = sorted(os.listdir(os.path.join(testpath,'good')))
    test_normal_files = [os.path.join(testpath,'good',x) for x in test_normal_files]
    train_target_files = []
    train_target_masks = []
    test_target_files = []
    test_target_masks = []
    for anomaly in target_types:
        anomaly_path = os.path.join(testpath,anomaly)
        anomaly_files = sorted(os.listdir(anomaly_path))
        mask_files = sorted(os.listdir(os.path.join(maskpath,anomaly)))
        for i, (file,mask_file) in enumerate(zip(anomaly_files,mask_files)):
            #name, extension = os.path.splitext(file)
            #mask_file = name+extension
            if i < train_num:
                train_target_files.append(os.path.join(anomaly_path,file))
                train_target_masks.append(os.path.join(maskpath,anomaly,mask_file))
            else:
                test_target_files.append(os.path.join(anomaly_path,file))
                test_target_masks.append(os.path.join(maskpath,anomaly,mask_file))

    # train unknown files
    # test unknown files
    train_unknown_files = []
    train_unknown_masks = []
    test_unknown_files = []
    test_unknown_masks = []
    for anomaly in unknown_types:
        anomaly_path = os.path.join(testpath,anomaly)
        anomaly_files = sorted(os.listdir(anomaly_path))
        mask_files = sorted(os.listdir(os.path.join(maskpath,anomaly)))
        for i, (mask_file,file) in enumerate(zip(anomaly_files,mask_files)):
            #name, extension = os.path.splitext(file)
            #mask_file = name+extension
            if i < train_num:
                train_unknown_files.append(os.path.join(anomaly_path,file))
                train_unknown_masks.append(os.path.join(maskpath,anomaly,mask_file))
            else:
                test_unknown_files.append(os.path.join(anomaly_path,file))
                test_unknown_masks.append(os.path.join(maskpath,anomaly,mask_file))   
    
    files_dict = {'train_normal_files': train_normal_files, 
                  'train_target_files': train_target_files,
                  'train_target_masks': train_target_masks,
                  'train_unknown_files': train_unknown_files,
                  'train_unknown_masks': train_unknown_masks,
                  'test_normal_files': test_normal_files,
                  'test_target_files': test_target_files,
                  'test_target_masks': test_target_masks,
                  'test_unknown_files': test_unknown_files,
                  'test_unknown_masks': test_unknown_masks,}
    if verbose:
        for item, value in files_dict.items():
            print("{}: {}".format(item, len(value)))
    return files_dict


In [8]:
classes = [item for item in os.listdir(args.data_root) if not item.endswith('.txt')]
#classes = ["bottle","cable","capsule","carpet","grid","hazelnut","leather","metal_nut","pill","screw","tile","toothbrush","transistor","wood","zipper"]
#classes = ["carpet","grid","hazelnut","leather","metal_nut","pill","screw","tile","toothbrush","transistor","wood","zipper"]
assert (len(classes) == 3)
for class_name in classes:
    torch.cuda.empty_cache()
    args.class_name = class_name
    target_types = [item for item in os.listdir(os.path.join(args.data_root, args.class_name,'test')) if item != 'good']
    unknown_types = []
    files_dict = get_files_masks(args.data_root,args.class_name,target_types,unknown_types,1,verbose=True)
    agent = DQN(args,device=device)
    env = Environment(files_dict['train_normal_files'],
                    files_dict['train_target_files'],
                    files_dict['train_target_masks'],args,device)
    env.update_subsamples(agent.network)
    #env.initialize(agent.network)
    eval_env = Environment(files_dict['train_normal_files'],
                        files_dict['train_target_files'],
                        files_dict['train_target_masks'],args,device,eval=True)
    eval_env.update_subsamples(agent.network)
    history = {'Step':[],'AvgReturn':[],'auroc':[],'i_auroc':[]}
    s = env.reset()
    best_result = 0
    while True:
        action = agent.act(s)
        next_state, reward, terminated, truncated, info = env.step(agent.network,action)
        result = agent.process((s, action, reward, next_state, terminated))
        s = next_state
        if terminated or truncated:
            s = env.reset()
        if agent.total_steps % args.eval_interval == 0:
            #eval_env.copy_from_env(env)
            ret = evaluate(eval_env,agent)
            checkpoint = {
                'model': agent.network.state_dict(),
                'args': args,
                'files_dict':files_dict
            }
            checkpoint_path = os.path.join('checkpoints/BTAD_1_None','dqn_'+class_name +'_1.pt')
            torch.save(checkpoint,checkpoint_path)
            scores_dict = get_scores(checkpoint_path,device)
            norm_scores = (scores_dict['total_scores']-scores_dict['total_scores'].min())/(scores_dict['total_scores'].max()-(scores_dict['total_scores'].min()))
            result = evaluate_metrics(scores_dict['total_labels'],norm_scores)
            cur_res = (result['auroc'] + result['i_auroc'])/2
            if cur_res >= best_result:
                best_result=cur_res
                torch.save(checkpoint,os.path.join('checkpoints/BTAD_1_None','dqn_'+class_name +'_1_best.pt'))
            history['Step'].append(agent.total_steps)
            history['AvgReturn'].append(ret)
            history['auroc'].append(result['auroc'])
            history['i_auroc'].append(result['i_auroc'])
            clear_output()
            print(class_name)
            fig, (ax1,ax2,ax3) = plt.subplots(1,3,figsize=(12,3))
            ax1.plot(history['Step'],history['AvgReturn'],'r-')
            ax1.set_xlabel('Step',fontsize=16)
            ax1.set_ylabel('AvgReturn', fontsize=16)
            ax1.grid(axis='y')
            ax2.plot(history['Step'],history['auroc'],'r-')
            ax2.set_xlabel('Step',fontsize=16)
            ax2.set_ylabel('AUROC', fontsize=16)
            ax2.grid(axis='y')
            ax3.plot(history['Step'],history['i_auroc'],'r-')
            ax3.set_xlabel('Step',fontsize=16)
            ax3.set_ylabel('IAUROC', fontsize=16)
            ax3.grid(axis='y')
            plt.tight_layout()
            plt.savefig(os.path.join('checkpoints/BTAD_1_None','dqn_'+class_name +'_1.png'))
            plt.show()
        if agent.total_steps > args.max_steps:
            break

train_normal_files: 400
train_target_files: 1
train_target_masks: 1
train_unknown_files: 0
train_unknown_masks: 0
test_normal_files: 21
test_target_files: 48
test_target_masks: 48
test_unknown_files: 0
test_unknown_masks: 0
torch.Size([1, 512, 32, 32])
torch.Size([1, 1024, 16, 16])


In [5]:
classes = [item for item in os.listdir(args.data_root) if not item.endswith('.txt')]
#classes = ["toothbrush"]
assert (len(classes) == 3)
IAUROC = 0
PAUROC = 0
for class_name in classes:
    checkpoint_path = './checkpoints/BTAD_1/dqn_'+class_name + '_1_best.pt'
    scores_dict = get_scores(checkpoint_path,device)
    norm_scores = (scores_dict['total_scores']-scores_dict['total_scores'].min())/(scores_dict['total_scores'].max()-(scores_dict['total_scores'].min()))
    result = evaluate_metrics(scores_dict['total_labels'],norm_scores)
    PAUROC += result['auroc']
    IAUROC += result['i_auroc']
print(PAUROC/3)
print(IAUROC/3)

torch.Size([1, 512, 32, 32])
torch.Size([1, 1024, 16, 16])
auroc: 0.9847047742673309
auprc: 0.6728562875319548
aupro: 0.5738668740141769
thres: 0.9996626973152161
image_thres: 0.8055776357650757
i_auroc: 1.0
i_auprc: 1.0
f1: 1.0
accuracy: 1.0
torch.Size([1, 512, 32, 32])
torch.Size([1, 1024, 16, 16])
auroc: 0.973549302797679
auprc: 0.798400240385052
aupro: 0.10246835928752002
thres: 0.9998587965965271
image_thres: 0.0
i_auroc: 0.7914630172694689
i_auprc: 0.964090986608401
f1: 0.927400468384075
accuracy: 0.8646288209606987
torch.Size([1, 512, 32, 32])
torch.Size([1, 1024, 16, 16])
auroc: 0.975049182988505
auprc: 0.614723935957781
aupro: 0.39389380061893564
thres: 0.8374662399291992
image_thres: 0.9940745830535889
i_auroc: 0.9953252032520324
i_auprc: 0.935221666508814
f1: 0.9180327868852459
accuracy: 0.9886363636363636
0.9777677533511717
0.9289294068405004


In [None]:
checkpoint_path = './checkpoints/final_1/dqn_zipper_10_best.pt'
scores_dict = get_scores(checkpoint_path,device)
norm_scores = (scores_dict['total_scores']-scores_dict['total_scores'].min())/(scores_dict['total_scores'].max()-(scores_dict['total_scores'].min()))

evaluate_metrics(scores_dict['total_labels'],norm_scores)
visualize_result(8,scores_dict['total_imgs'],scores_dict['total_labels'],norm_scores)

FileNotFoundError: [Errno 2] No such file or directory: './checkpoints/final_1/dqn_zipper_10_best.pt'