In [1]:
from moviad.entrypoints.cfa import train_cfa
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torchvision.models import MobileNet_V2_Weights
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
from main_scripts.main_cfa import main_train_cfa, main_test_cfa
from moviad.datasets.mvtec.mvtec_dataset import MVTecDataset
from moviad.datasets.realiad.realiad_dataset import RealIadDataset
from moviad.datasets.realiad.realiad_dataset_configurations import RealIadClassEnum
from moviad.datasets.visa.visa_dataset import VisaDataset
from moviad.datasets.visa.visa_dataset_configurations import VisaDatasetCategory
from moviad.utilities.configurations import TaskType, Split
from tests.datasets.realiaddataset_tests import IMAGE_SIZE, REAL_IAD_DATASET_PATH, AUDIO_JACK_DATASET_JSON
from tests.datasets.visadataset_tests import VISA_DATASET_PATH, VISA_DATASET_CSV_PATH
from tests.main.common import get_training_args, MVTECH_DATASET_PATH, REALIAD_DATASET_PATH

In [3]:
MVTECH_DATASET_PATH = 'E:\\VisualAnomalyDetection\\datasets\\mvtec'
REALIAD_DATASET_PATH = 'E:\\VisualAnomalyDetection\\datasets\\Real-IAD\\realiad_256'

## CFA Method

### Train CFA with MVTEC Dataset

In [4]:
from torchvision.transforms import InterpolationMode
# Seed: 3, 1024, 2048
torch.manual_seed(3)
args = get_training_args()
args.dataset_path = MVTECH_DATASET_PATH
args.category = 'pill'
args.backbone = 'wide_resnet50_2'
args.ad_layers = ["layer1", "layer2", "layer3"]
args.save_path = "./patch.pt"
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.epochs = 30
args.model_checkpoint_path = "./patch.pt"
args.visual_test_path = "./visual_test"

transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.PILToTensor(),
    transforms.Resize(
        IMAGE_SIZE,
        antialias=True,
        interpolation=InterpolationMode.NEAREST,
    ),
    transforms.ConvertImageDtype(torch.float32)
])

In [5]:
train_dataset = MVTecDataset(
    TaskType.SEGMENTATION,
    args.dataset_path,
    args.category,
    Split.TRAIN,
    img_size=(256, 256),
)

test_dataset = MVTecDataset(
    TaskType.SEGMENTATION,
    args.dataset_path,
    args.category,
    Split.TEST,
    img_size=(256, 256),
)

train_dataset.load_dataset()
test_dataset.load_dataset()

# train_dataset.contaminate(test_dataset, 0.1)

In [6]:
torch.manual_seed(3)
results, best_results = train_cfa(train_dataset,
                                  test_dataset,
                                  32,
                                  args.category, args.backbone,
                                  args.ad_layers,
                                  args.epochs,
                                  args.save_path, args.device)


Training CFA for category: pill 

Length train dataset: 267
Length test dataset: 167


100%|██████████| 8/8 [00:01<00:00,  7.26it/s]


EPOCH: 0


100%|██████████| 8/8 [00:09<00:00,  1.20s/it]
Eval: 100%|██████████| 5/5 [00:02<00:00,  1.98it/s]


End training performances:

                img_roc: 0.8062222222222222 

                pxl_roc: 0.9602539765799462 

                f1_img: 0.9178082191780821 

                f1_pxl: 0.6092782512754819 

                img_pr: 0.9613449856074278 

                pxl_pr: 0.6134243473053876 

                pxl_pro: 0.927173732535669 

            
EPOCH: 1


100%|██████████| 8/8 [00:07<00:00,  1.13it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.44it/s]


End training performances:

                img_roc: 0.8562962962962962 

                pxl_roc: 0.9737677073361425 

                f1_img: 0.9215017064846416 

                f1_pxl: 0.635610162228743 

                img_pr: 0.9691788367793465 

                pxl_pr: 0.6783865957688482 

                pxl_pro: 0.9501416732218404 

            
EPOCH: 2


100%|██████████| 8/8 [00:06<00:00,  1.17it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.45it/s]


End training performances:

                img_roc: 0.8860505166475315 

                pxl_roc: 0.9809472137233818 

                f1_img: 0.924187725631769 

                f1_pxl: 0.7095257881154479 

                img_pr: 0.9762158926947507 

                pxl_pr: 0.7730713294630963 

                pxl_pro: 0.9599679939454481 

            
EPOCH: 3


100%|██████████| 8/8 [00:06<00:00,  1.18it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.45it/s]


End training performances:

                img_roc: 0.9241481481481482 

                pxl_roc: 0.984603976176516 

                f1_img: 0.935361216730038 

                f1_pxl: 0.7538414416961523 

                img_pr: 0.9859866248128064 

                pxl_pr: 0.8202532917444509 

                pxl_pro: 0.968462926445346 

            
EPOCH: 4


100%|██████████| 8/8 [00:06<00:00,  1.17it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.37it/s]


End training performances:

                img_roc: 0.9253333333333333 

                pxl_roc: 0.9865188266709733 

                f1_img: 0.9368029739776952 

                f1_pxl: 0.7557227398454691 

                img_pr: 0.9861042896176254 

                pxl_pr: 0.8273007575888915 

                pxl_pro: 0.9685549308609677 

            
EPOCH: 5


100%|██████████| 8/8 [00:07<00:00,  1.05it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.31it/s]


End training performances:

                img_roc: 0.9365925925925925 

                pxl_roc: 0.9860383332511551 

                f1_img: 0.9368029739776952 

                f1_pxl: 0.763276518012883 

                img_pr: 0.9884864452521794 

                pxl_pr: 0.8323137548175957 

                pxl_pro: 0.9705934129466159 

            
EPOCH: 6


100%|██████████| 8/8 [00:06<00:00,  1.25it/s]
Eval: 100%|██████████| 5/5 [00:01<00:00,  2.57it/s]


End training performances:

                img_roc: 0.9270952927669347 

                pxl_roc: 0.9863140048369957 

                f1_img: 0.9402985074626865 

                f1_pxl: 0.7639703775207012 

                img_pr: 0.9853477066321792 

                pxl_pr: 0.8342957731914017 

                pxl_pro: 0.9710253605613676 

            
EPOCH: 7


100%|██████████| 8/8 [00:06<00:00,  1.27it/s]
Eval: 100%|██████████| 5/5 [00:01<00:00,  2.57it/s]


End training performances:

                img_roc: 0.9252450980392157 

                pxl_roc: 0.9875583331032042 

                f1_img: 0.9450549450549451 

                f1_pxl: 0.7683732567041307 

                img_pr: 0.9863483115980929 

                pxl_pr: 0.8394398413373247 

                pxl_pro: 0.9717512486707859 

            
EPOCH: 8


100%|██████████| 8/8 [00:06<00:00,  1.16it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.39it/s]


End training performances:

                img_roc: 0.936280137772675 

                pxl_roc: 0.9872933875555688 

                f1_img: 0.9477611940298507 

                f1_pxl: 0.7667568666489072 

                img_pr: 0.9870619748394399 

                pxl_pr: 0.8378004297729252 

                pxl_pro: 0.9737220198986365 

            
EPOCH: 9


100%|██████████| 8/8 [00:07<00:00,  1.14it/s]
Eval: 100%|██████████| 5/5 [00:01<00:00,  2.59it/s]


End training performances:

                img_roc: 0.938962962962963 

                pxl_roc: 0.9859836504862707 

                f1_img: 0.952029520295203 

                f1_pxl: 0.7455732956914143 

                img_pr: 0.9882937047349885 

                pxl_pr: 0.8121590560386425 

                pxl_pro: 0.9727151210740097 

            
EPOCH: 10


100%|██████████| 8/8 [00:07<00:00,  1.05it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.50it/s]


End training performances:

                img_roc: 0.9374814814814815 

                pxl_roc: 0.9875332951783171 

                f1_img: 0.9481481481481482 

                f1_pxl: 0.7648334786399302 

                img_pr: 0.9880484733942574 

                pxl_pr: 0.8381884941309633 

                pxl_pro: 0.9727164234748599 

            
EPOCH: 11


100%|██████████| 8/8 [00:09<00:00,  1.16s/it]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.30it/s]


End training performances:

                img_roc: 0.933037037037037 

                pxl_roc: 0.9875168693855781 

                f1_img: 0.9416058394160584 

                f1_pxl: 0.7646713615023474 

                img_pr: 0.9872067747807534 

                pxl_pr: 0.8382745157194178 

                pxl_pro: 0.9724452468873045 

            
EPOCH: 12


100%|██████████| 8/8 [00:06<00:00,  1.19it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.48it/s]


End training performances:

                img_roc: 0.9359931113662457 

                pxl_roc: 0.9878184807040895 

                f1_img: 0.9425287356321839 

                f1_pxl: 0.7654189814528828 

                img_pr: 0.9873265082412794 

                pxl_pr: 0.8394625664970801 

                pxl_pro: 0.9721410029596577 

            
EPOCH: 13


100%|██████████| 8/8 [00:06<00:00,  1.14it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.33it/s]


End training performances:

                img_roc: 0.9339259259259259 

                pxl_roc: 0.9886908635060484 

                f1_img: 0.9433962264150944 

                f1_pxl: 0.7609533777607017 

                img_pr: 0.9875102024650091 

                pxl_pr: 0.8394357312369168 

                pxl_pro: 0.9737400155000389 

            
EPOCH: 14


100%|██████████| 8/8 [00:07<00:00,  1.13it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.41it/s]


End training performances:

                img_roc: 0.9374282433983926 

                pxl_roc: 0.9880813547635118 

                f1_img: 0.9420849420849421 

                f1_pxl: 0.7646611636045978 

                img_pr: 0.9877507262273135 

                pxl_pr: 0.8392509426429695 

                pxl_pro: 0.9725938352530662 

            
EPOCH: 15


100%|██████████| 8/8 [00:07<00:00,  1.08it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.07it/s]


End training performances:

                img_roc: 0.9381127450980392 

                pxl_roc: 0.9878396229311885 

                f1_img: 0.9473684210526316 

                f1_pxl: 0.7619792236194642 

                img_pr: 0.988835367048107 

                pxl_pr: 0.8370194432407598 

                pxl_pro: 0.972630890067927 

            
EPOCH: 16


100%|██████████| 8/8 [00:06<00:00,  1.21it/s]
Eval: 100%|██████████| 5/5 [00:01<00:00,  2.53it/s]


End training performances:

                img_roc: 0.939950980392157 

                pxl_roc: 0.987998910330564 

                f1_img: 0.9477611940298508 

                f1_pxl: 0.7616526912446714 

                img_pr: 0.9891627555480653 

                pxl_pr: 0.8371349971551888 

                pxl_pro: 0.9722064282322649 

            
EPOCH: 17


100%|██████████| 8/8 [00:06<00:00,  1.20it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.47it/s]


End training performances:

                img_roc: 0.9451593137254902 

                pxl_roc: 0.9886017660361474 

                f1_img: 0.9473684210526316 

                f1_pxl: 0.7673066721780166 

                img_pr: 0.9903223338692928 

                pxl_pr: 0.8423655588627437 

                pxl_pro: 0.9734524272385271 

            
EPOCH: 18


100%|██████████| 8/8 [00:06<00:00,  1.20it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.47it/s]


End training performances:

                img_roc: 0.9357037037037036 

                pxl_roc: 0.988715004951743 

                f1_img: 0.9433962264150944 

                f1_pxl: 0.7671679671975683 

                img_pr: 0.9879431650827156 

                pxl_pr: 0.8428377125438736 

                pxl_pro: 0.9726327936951854 

            
EPOCH: 19


100%|██████████| 8/8 [00:08<00:00,  1.09s/it]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.13it/s]


End training performances:

                img_roc: 0.9377777777777778 

                pxl_roc: 0.9873347914647002 

                f1_img: 0.9469696969696969 

                f1_pxl: 0.7348308607604592 

                img_pr: 0.988336771159319 

                pxl_pr: 0.8063909148946173 

                pxl_pro: 0.9728049457353258 

            
EPOCH: 20


100%|██████████| 8/8 [00:06<00:00,  1.18it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.44it/s]


End training performances:

                img_roc: 0.9371412169919634 

                pxl_roc: 0.9887522846683245 

                f1_img: 0.946969696969697 

                f1_pxl: 0.7664564386359559 

                img_pr: 0.9877084534296561 

                pxl_pr: 0.8443166708196819 

                pxl_pro: 0.9718264666106617 

            
EPOCH: 21


100%|██████████| 8/8 [00:06<00:00,  1.17it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.43it/s]


End training performances:

                img_roc: 0.9371412169919633 

                pxl_roc: 0.9885152624236527 

                f1_img: 0.946969696969697 

                f1_pxl: 0.7562617629999645 

                img_pr: 0.9877513440798483 

                pxl_pr: 0.8328642658633651 

                pxl_pro: 0.9728124627648296 

            
EPOCH: 22


100%|██████████| 8/8 [00:07<00:00,  1.10it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.35it/s]


End training performances:

                img_roc: 0.9377777777777778 

                pxl_roc: 0.9885017592341947 

                f1_img: 0.9473684210526315 

                f1_pxl: 0.7598220202593962 

                img_pr: 0.9883018037840812 

                pxl_pr: 0.8368847716273713 

                pxl_pro: 0.9735412330715645 

            
EPOCH: 23


100%|██████████| 8/8 [00:06<00:00,  1.19it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.48it/s]


End training performances:

                img_roc: 0.949037037037037 

                pxl_roc: 0.9886557882703237 

                f1_img: 0.9509433962264151 

                f1_pxl: 0.7613605509682879 

                img_pr: 0.9908770500803922 

                pxl_pr: 0.8377919721197656 

                pxl_pro: 0.971964883934373 

            
EPOCH: 24


100%|██████████| 8/8 [00:08<00:00,  1.09s/it]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.32it/s]


End training performances:

                img_roc: 0.94345579793341 

                pxl_roc: 0.9889785981069538 

                f1_img: 0.9509433962264152 

                f1_pxl: 0.7560422528389452 

                img_pr: 0.9889152137274564 

                pxl_pr: 0.8342547945005327 

                pxl_pro: 0.9746829023950239 

            
EPOCH: 25


100%|██████████| 8/8 [00:06<00:00,  1.16it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.45it/s]


End training performances:

                img_roc: 0.9410370370370371 

                pxl_roc: 0.9871583357291689 

                f1_img: 0.9509433962264151 

                f1_pxl: 0.7254488839648563 

                img_pr: 0.9889666452204393 

                pxl_pr: 0.7867217568779036 

                pxl_pro: 0.9738488753032778 

            
EPOCH: 26


100%|██████████| 8/8 [00:06<00:00,  1.25it/s]
Eval: 100%|██████████| 5/5 [00:01<00:00,  2.58it/s]


End training performances:

                img_roc: 0.9392592592592592 

                pxl_roc: 0.9887346651673262 

                f1_img: 0.9473684210526315 

                f1_pxl: 0.7592868424531245 

                img_pr: 0.9888061311987377 

                pxl_pr: 0.8358749608494362 

                pxl_pro: 0.9718796063842916 

            
EPOCH: 27


100%|██████████| 8/8 [00:07<00:00,  1.06it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.47it/s]


End training performances:

                img_roc: 0.938074074074074 

                pxl_roc: 0.9891376184071101 

                f1_img: 0.9473684210526315 

                f1_pxl: 0.7642789443754081 

                img_pr: 0.9883670282429518 

                pxl_pr: 0.8418363153822283 

                pxl_pro: 0.9722716942308036 

            
EPOCH: 28


100%|██████████| 8/8 [00:06<00:00,  1.20it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.48it/s]


End training performances:

                img_roc: 0.936 

                pxl_roc: 0.9887841357474054 

                f1_img: 0.9473684210526315 

                f1_pxl: 0.7587311427980988 

                img_pr: 0.9880547516059036 

                pxl_pr: 0.8356003026561671 

                pxl_pro: 0.9732505145391501 

            
EPOCH: 29


100%|██████████| 8/8 [00:07<00:00,  1.05it/s]
Eval: 100%|██████████| 5/5 [00:02<00:00,  2.48it/s]


End training performances:

                img_roc: 0.9377777777777778 

                pxl_roc: 0.9885525282931087 

                f1_img: 0.9473684210526315 

                f1_pxl: 0.7565508045504936 

                img_pr: 0.9882392494352358 

                pxl_pr: 0.834102695745027 

                pxl_pro: 0.9729937768266892 

            


In [None]:
import wandb
run = wandb.init(project="my-model-training-project")
run.config = {"epochs": 1337, "learning_rate": 3e-4}
run.log({"metric": 42})
my_model_artifact = run.log_artifact("./my_model.pt", type="model")

In [12]:
epochs = range(1, args.epochs)

plt.plot(epochs, results.img_roc, label='img_roc')
plt.plot(epochs, results.pxl_roc, label='pxl_roc')
plt.plot(epochs, results.f1_img, label='f1_img')
plt.plot(epochs, results.f1_pxl, label='f1_pxl')
plt.plot(epochs, results.img_pr, label='img_pr')
plt.plot(epochs, results.pxl_pr, label='pxl_pr')
plt.plot(epochs, results.pxl_pro, label='pxl_pro')
plt.legend()
plt.show()


KeyboardInterrupt: 

In [13]:
train_dataset = MVTecDataset(
    TaskType.SEGMENTATION,
    args.dataset_path,
    args.category,
    Split.TRAIN,
    img_size=(256, 256),
)

test_dataset = MVTecDataset(
    TaskType.SEGMENTATION,
    args.dataset_path,
    args.category,
    Split.TEST,
    img_size=(256, 256),
)

train_dataset.load_dataset()
test_dataset.load_dataset()

In [23]:
args.save_path = "./patch_unsupervised.pt"
args.model_checkpoint_path = "./patch_unsupervised.pt"
results_unsupervised, best_results_unsupervised = train_cfa(train_dataset, test_dataset, 8, args.category, args.backbone,
                                                            args.ad_layers,
                                                            args.epochs,
                                                            args.save_path, args.device)

Training CFA for category: pill 

Length train dataset: 267
Length test dataset: 167


100%|██████████| 33/33 [00:00<00:00, 93.16it/s] 


EPOCH: 0


  0%|          | 0/33 [00:00<?, ?it/s]


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
image_entry = train_dataset.__getitem__(2)


