In [None]:
import sys
sys.path.append("../.")
sys.path.append("../models")

import torch
import matplotlib.pyplot as plt
import numpy as np

import models
from evaluate import load_model

from anomalib.data import MVTec
from metrics import F1Max

from anomalib.engine import Engine

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
weights_path = "/home/lcondados/workspace/competition-INTEL_VAND2/anomaly_detection-MVTEC/experiments/weights-Patchcore-mobilenet"
category = "capsule"
dataset_path = "../datasets/MVTec"

In [None]:
module_path = "models.simple_patchcore"
class_name  = "SimplePatchcore"
model = load_model(module_path, class_name, weights_path, category).to(device)

In [None]:
from torchvision.transforms import v2 as T

In [None]:
def show_image_gt_pred(image, pred_mask, gt_mask, pred_cls, gt_cls):
    plt.subplot(1, 3, 1)
    plt.title("Image")
    plt.imshow(image)
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.title("{}".format(gt_cls))
    plt.imshow(gt_mask)
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.title("{}".format(pred_cls))
    plt.imshow(pred_mask)
    plt.axis("off")

    plt.show()

In [None]:
# data["mask"].squeeze().cpu().numpy().min(), data["mask"].squeeze().cpu().numpy().max()

In [None]:
cls_name = ["Normal", "Anomaly"]

# Create the dataset
# T.RandomAffine()

eval_transform = None

# eval_transform = T.Compose([
#     # T.RandomResizedCrop(224),  # Randomly crop and resize the image
#     T.RandomPhotometricDistort(),
#     T.RandomHorizontalFlip(),  # Randomly flip the image horizontally
#     T.RandomApply(T.GaussianBlur(), p=0.5),  # Randomly rotate the image by up to 15 degrees
#     T.RandomRotation(degrees=30),  # Randomly rotate the image by up to 15 degrees
#     T.Resize(256)
# ])

datamodule = MVTec(root=dataset_path,
                   category=category,
                   eval_batch_size=1,
                   eval_transform=eval_transform,
                   image_size=(256, 256))
datamodule.setup()

# Create the metrics
image_metric = F1Max()
pixel_metric = F1Max()

fake_scores = torch.ones([len(datamodule.test_dataloader())], dtype=torch.int64)

# Loop over the test set and compute the metrics
for i, data in enumerate(datamodule.test_dataloader()):
    output = model(data["image"].to(device))

    pred_score = output["pred_score"].cpu()

    # Update the image metric
    # image_metric.update(fake_scores[i].cpu(), data["label"])
    image_metric.update(output["pred_score"].cpu(), data["label"])
    # Update the pixel metric
    pixel_metric.update(output["anomaly_map"].squeeze().cpu(), data["mask"].squeeze().cpu())

    # pred_score  = pred_score.squeeze()
    # # pred_score  = torch.sigmoid(pred_score).numpy()
    # pred_mask = output["anomaly_map"].squeeze().cpu()
    # pred_mask = torch.sigmoid(pred_mask).numpy()

    print(pred_score)
    # print("[DEBUG] pred_score {:.2f}".format(pred_score.numpy()))
    # print("[DEBUG] pred_mask.min {:.2f} | pred_mask.max {:.2f}".format(pred_mask.min(), pred_mask.max()))

    # show_image_gt_pred(np.transpose(data["image"].squeeze().numpy(), [1, 2, 0]),
    #                    pred_mask,
    #                    data["mask"].squeeze().cpu(),
    #                    cls_name[int(pred_score)],
    #                    cls_name[int(data["label"].numpy())]
    #                    )

# Compute the metrics
image_score = image_metric.compute()
pixel_score = pixel_metric.compute()

print("Image F1Max = {:.2f}".format(image_score.numpy()))
print("Pixel F1Max = {:.2f}".format(pixel_score.numpy()))
