In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

from utils import (
    get_train_test_loaders,
    CustomVGG,
    train,
    evaluate,
    predict_localize,
    get_cv_train_test_loaders,
)

## Parameters

In [3]:
data_folder = "data/"
input_size = (224, 224)
neg_class = 0

batch_size = 10
lr = 0.0001
epochs = 7
class_weight = [5, 1] if neg_class == 0 else [1, 5]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

localization_thres = 10
n_cv_folds = 5

In [3]:
import torch
from torch.utils.data import Dataset
import os

In [None]:
class MVTEC_AD_DATASET(Dataset):
    def __init__(self, root):
        

# Data

In [4]:
train_loader, test_loader = get_train_test_loaders(
    root=data_folder,
    batch_size=batch_size,
    img_size=input_size,
    test_size=0.2,
    random_state=42,
)

# Model Training

In [None]:
model = CustomVGG(input_size)

class_weight = torch.tensor(class_weight).type(torch.FloatTensor).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weight)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
model = train(train_loader, model, optimizer, criterion, epochs, device)

In [None]:
model_path = "model.h5"
torch.save(model, model_path)
# model = torch.load(model_path, map_location=device)

# Evaluation

In [None]:
evaluate(model, test_loader, device)

# Cross Validation

In [None]:
cv_folds = get_cv_train_test_loaders(
    root=data_folder,
    batch_size=batch_size,
    img_size=input_size,
    n_folds=n_cv_folds,
)

class_weight = torch.tensor(class_weight).type(torch.FloatTensor).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weight)

for i, (train_loader, test_loader) in enumerate(cv_folds):
    print(f"Fold {i+1}/{n_cv_folds}")
    model = CustomVGG(input_size)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model = train(train_loader, model, optimizer, criterion, epochs, device)
    evaluate(model, test_loader, device)

# Visualization

In [None]:
predict_localize(
    model,
    test_loader,
    device,
    thres=localization_thres,
    neg_class=neg_class,
    n_samples=6,
)