# Load Model


In [None]:
import torch

ckpt_path = '/content/ckpt.pth'
# ckpt_path = '/content/results/MVTecAD_Results/simplenet_mvtec/run/models/0/mvtec_floods/ckpt.pth'

state_dict = torch.load(ckpt_path, map_location="cpu")

print("Keys in checkpoint:", state_dict.keys())

in_planes = 1536
#  def __init__(self, in_planes, n_layers=1, hidden=None):
discriminator = Discriminator(in_planes=in_planes, n_layers=2, hidden=1024)
# def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):
projection = Projection(in_planes=in_planes, out_planes=1, n_layers=2)
if "discriminator" in state_dict:
    try:
        discriminator_weights = state_dict["discriminator"]
        discriminator.load_state_dict(discriminator_weights)
        print("Discriminator loaded successfully.")
    except Exception as e:
        print(f"Error loading discriminator weights: {e}")
else:
    print("Discriminator weights not found in the checkpoint.")

discriminator.eval()

# print('state_dict[projecction]', state_dict["pre_projection"])

# t-SNE

In [None]:
import torch
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
datapath = '/content/drive/MyDrive/mvtec_anomaly_detection'
test_dataset = MVTecDataset(
    source=datapath,
    classname="floods",
    resize=329,
    imagesize=288,
    split=DatasetSplit.TEST,
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

backbone = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
layers_to_extract_from = ['layer4']
input_shape = (3, 288, 288)
pretrain_embed_dimension = 1536
target_embed_dimension = 1536
patchsize = 3
patchstride = 1

simplenet = SimpleNet(device)
simplenet.load(
    backbone=backbone,
    layers_to_extract_from=layers_to_extract_from,
    device=device,
    input_shape=input_shape,
    pretrain_embed_dimension=pretrain_embed_dimension,
    target_embed_dimension=target_embed_dimension,
    patchsize=patchsize,
    patchstride=patchstride
)
simplenet.backbone.eval()

embeddings = []
labels = []

with torch.no_grad():
    for batch in test_loader:
        images = batch["image"].to(device)

        batch_embeddings = simplenet.embed(images)

        batch_embeddings = batch_embeddings[0]

        print(f"Batch embeddings shape: {batch_embeddings.shape}")

        batch_embeddings = batch_embeddings.cpu().numpy()
        embeddings.append(batch_embeddings)

        labels_batch = batch["is_anomaly"].numpy()
        print(f"Labels batch shape: {labels_batch.shape}")

        labels.extend(batch["is_anomaly"].cpu().numpy())

embeddings = np.concatenate(embeddings, axis=0)
labels = np.array(labels)

print(f"Embeddings shape: {embeddings.shape}")
print(f"Labels length: {len(labels)}")

tsne = TSNE(n_components=2, random_state=42)
reduced_embeddings = tsne.fit_transform(embeddings)

plt.figure(figsize=(10, 8))

for label in np.unique(labels):
    indices = np.where(labels == label)[0]

    plt.scatter(
        reduced_embeddings[indices, 0],
        reduced_embeddings[indices, 1],
        label=f"{'Anomalous' if label else 'Normal'}",
        alpha=0.6,
    )

plt.legend()
plt.title("t-SNE of Feature Embeddings")
plt.xlabel("t-SNE Dimension 1")
plt.ylabel("t-SNE Dimension 2")
plt.grid(True)
plt.show()


# Anomaly Maps Generation

In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import numpy as np

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

simplenet = SimpleNet(device)
simplenet.load(
    backbone=backbone,
    layers_to_extract_from=layers_to_extract_from,
    device=device,
    input_shape=input_shape,
    pretrain_embed_dimension=pretrain_embed_dimension,
    target_embed_dimension=target_embed_dimension,
    patchsize=patchsize,
    patchstride=patchstride,
)
simplenet.backbone.eval()

def normalize_mask(mask):
    """Normalize mask to range [0, 1]."""
    return (mask - np.min(mask)) / (np.max(mask) - np.min(mask) + 1e-8)

def generate_anomaly_maps(simplenet, dataloader):
    scores = []
    masks = []
    img_paths = []
    labels_gt = []
    masks_gt = []

    simplenet.backbone.eval()

    with tqdm(dataloader, desc="Generating Anomaly Maps", leave=False) as data_iterator:
        for data in data_iterator:
            images = data["image"].to(device)
            img_paths.extend(data["image_path"])

            labels_gt.extend(data["is_anomaly"].cpu().numpy().tolist())
            if data.get("mask", None) is not None:
                masks_gt.extend(data["mask"].cpu().numpy().tolist())

            _scores, _masks, _feats = simplenet._predict(images)
            scores.extend(_scores)
            masks.extend(_masks)

    return scores, masks, img_paths, labels_gt, masks_gt

scores, masks, img_paths, labels_gt, masks_gt = generate_anomaly_maps(simplenet, test_loader)

import matplotlib.pyplot as plt
import numpy as np

def plot_binary_masks(masks):
    """
    Plots the binary masks where:
    - Values < 0.5 are black
    - Values >= 0.5 are white

    Args:
    - masks (list of numpy arrays): List of normalized masks to plot.
    """
    for i, mask in enumerate(masks):
        # Threshold the mask to binary values: 0 for black, 1 for white
        binary_mask = np.where(mask >= 0.5, 1.0, 0.0)

        plt.figure(figsize=(5, 5))
        plt.title(f"Binary Mask {i + 1}")
        plt.imshow(binary_mask, cmap='gray')  # 'gray' colormap for black/white
        plt.axis('off')
        plt.show()


# def print_normalized_masks(masks, num_to_print=5):
#     """
#     Prints the normalized values of the first few anomaly masks.

#     Args:
#     - masks (list): List of masks (numpy arrays).
#     - num_to_print (int): Number of masks to print.
#     """
#     print("Normalized Anomaly Masks (Numerical Values):")
#     for i, mask in enumerate(masks[:num_to_print]):
#         # normalized_mask = normalize_mask(mask)  # Normalize the mask
#         print(f"\nMask {i + 1}:")
#         print(normalized_mask)


# pred_binary_masks2 = []
# for i, (image_path, mask) in enumerate(zip(img_paths[:5], masks[:5])):
#     normalized_mask = normalize_mask(mask)
#     plot_binary_masks([normalized_mask])
#     # Threshold the anomaly map for binary visualization
#     binary_mask_pred = np.where(normalized_mask > 0.5, 0.0, 1.0)
#     pred_binary_masks2.append(binary_mask_pred)

#     plt.figure(figsize=(15, 5))

#     plt.subplot(1, 3, 1)
#     plt.title("Original Image")
#     plt.imshow(plt.imread(image_path))
#     plt.axis("off")

#     plt.subplot(1, 3, 2)
#     plt.title("Normalized Anomaly Map")
#     plt.imshow(normalized_mask, cmap="hot")
#     plt.axis("off")

#     plt.subplot(1, 3, 3)
#     plt.title("Overlayed Anomaly Map")
#     plt.imshow(plt.imread(image_path), alpha=0.6)
#     plt.imshow(normalized_mask, cmap="jet", alpha=0.4)
#     plt.axis("off")

#     plt.show()


num_images_to_plot = 5

pred_binary_masks2 = []
for i, (image_path, mask) in enumerate(zip(img_paths[:num_images_to_plot], masks[:num_images_to_plot])):
    normalized_mask = normalize_mask(mask)
    plot_binary_masks([normalized_mask])
    binary_mask_pred = np.where(normalized_mask > 0.5, 0.0, 1.0)
    pred_binary_masks2.append(binary_mask_pred)

    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.title("Original Image")
    plt.imshow(plt.imread(image_path))
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.title("Normalized Anomaly Map")
    plt.imshow(normalized_mask, cmap="hot")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.title("Overlayed Anomaly Map")
    plt.imshow(plt.imread(image_path), alpha=0.6)
    plt.imshow(normalized_mask, cmap="jet", alpha=0.4)
    plt.axis("off")

    plt.show()

# Dice Coefficient
## Load Ground Truth Files

In [None]:
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt

def load_ground_truth_images(image_paths):

    numpy_arrays = []
    for path in image_paths:
        if not os.path.exists(path):
            print(f"Warning: File not found - {path}")
            continue

        img = Image.open(path).convert("L")

        img_array = np.array(img)
        numpy_arrays.append(img_array)

    return numpy_arrays

def plot_images(image_arrays, titles=None):
    """
    Plots a list of images using Matplotlib.

    Args:
    - image_arrays (list of np.array): List of image arrays.
    - titles (list of str, optional): Titles for each image plot.
    """
    num_images = len(image_arrays)
    plt.figure(figsize=(15, 5))

    for i, img_array in enumerate(image_arrays):
        plt.subplot(1, num_images, i + 1)
        plt.imshow(img_array, cmap='gray')
        title = titles[i] if titles else f"Image {i + 1}"
        plt.title(title)
        plt.axis('off')

    plt.tight_layout()
    plt.show()

path = "/content/drive/MyDrive/mvtec_anomaly_detection/floods/ground_truth/floods/"
image_names = [
    "EMSR260_02VIADANA.png",
    "EMSR271_02FARKADONA.png",
    "EMSR324_04LESPIGNAN.png",
    "EMSR333_02PORTOPALO.png"
]
image_paths = [os.path.join(path, img_name) for img_name in image_names]

ground_truth_arrays = load_ground_truth_images(image_paths)
plot_images(ground_truth_arrays, titles=image_names)


In [None]:
import cv2

import cv2

def dice_coefficient(mask1, mask2):
    mask1 = (mask1 > 0).astype(np.float32)
    mask2 = (mask2 > 0).astype(np.float32)

    if mask1.shape != mask2.shape:
        mask1 = cv2.resize(mask1, (mask2.shape[1], mask2.shape[0]), interpolation=cv2.INTER_NEAREST)

    intersection = np.sum(mask1 * mask2)
    dice = (2.0 * intersection) / (np.sum(mask1) + np.sum(mask2) + 1e-8)

    return dice

for idx, (binary_mask_pred, binary_mask_gt, arr) in enumerate(zip(pred_binary_masks, pred_binary_masks2, ground_truth_arrays)):
    dice_score = dice_coefficient(binary_mask_pred, binary_mask_gt)

    print(f"Dice Coefficient for image {idx + 1}: {dice_score:.4f}")