In [1]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}
</style>

In [2]:
from pathlib import Path

import numpy as np
import torch

ROOT = Path(".")
CATS_DIR = ROOT / "images_dataset" / "cats"
ANOM_DIR = ROOT / "images_dataset" / "anomalies"

IMG_SIZE = 128
BATCH_SIZE = 64
SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
from PIL import Image
from torch.utils.data import Dataset


class ImageFolderDataset(Dataset):
    def __init__(self, folder: Path, transform=None):
        self.paths = sorted(
            [
                p
                for p in folder.iterdir()
                if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}
            ]
        )
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        image = Image.open(path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)
        return image, str(path.name)

In [4]:
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

transform = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

cats_dataset = ImageFolderDataset(CATS_DIR, transform=transform)

val_size = max(1, int(0.1 * len(cats_dataset)))
train_size = len(cats_dataset) - val_size
train_dataset, val_dataset = random_split(
    cats_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(SEED)
)
anom_dataset = ImageFolderDataset(ANOM_DIR, transform=transform)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0
)
anom_loader = DataLoader(
    anom_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0
)

len(train_dataset), len(val_dataset), len(anom_dataset)

(2503, 278, 146)

In [5]:
from torchvision import models

resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT).to(device).eval()
backbone = torch.nn.Sequential(*list(resnet.children())[:-1]).to(device).eval()


@torch.no_grad()
def collect_embeddings(loader, label):
    feats, labels, names = [], [], []
    for x, name in loader:
        x = x.to(device)
        f = backbone(x).flatten(1)
        feats.append(f.cpu())
        labels.extend([label] * x.size(0))
        names.extend(list(name))
    return torch.cat(feats).numpy(), labels, names

In [6]:
x_train, _, _ = collect_embeddings(train_loader, "cat")
mean = x_train.mean(axis=0, keepdims=True)
x_c = x_train - mean
cov = (x_c.T @ x_c) / (len(x_train) - 1)
cov += 1e-3 * np.eye(cov.shape[0])
inv_cov = np.linalg.inv(cov)


def mahalanobis_distance(x):
    d = x - mean
    return np.sum((d @ inv_cov) * d, axis=1)

In [7]:
x_cat, cat_label, cat_names = collect_embeddings(val_loader, "cat")
x_anom, anom_label, anom_names = collect_embeddings(anom_loader, "anomaly")

score_cat = mahalanobis_distance(x_cat)
score_anom = mahalanobis_distance(x_anom)

In [8]:
import plotly.io as pio

pio.templates.default = "plotly_dark"

In [9]:
import pandas as pd
import plotly.express as px

threshold = np.percentile(score_cat, 99)

df = pd.DataFrame(
    {
        "score": np.concatenate([score_cat, score_anom]),
        "group": ["cat"] * len(score_cat) + ["anomaly"] * len(score_anom),
    }
)

fig = px.histogram(
    df,
    x="score",
    color="group",
    barmode="overlay",
    nbins=60,
    title="ResNet50 embedding Mahalanobis score",
)
fig.add_vline(x=threshold)
fig.show()

In [10]:
import plotly.graph_objects as go
from sklearn.metrics import auc, precision_recall_curve, roc_curve

y_true = np.array([0] * len(score_cat) + [1] * len(score_anom))
y_score = np.concatenate([score_cat, score_anom])

fpr, tpr, _ = roc_curve(y_true, y_score)
roc_auc = auc(fpr, tpr)

precision, recall, _ = precision_recall_curve(y_true, y_score)
pr_auc = auc(recall, precision)

fig = go.Figure()
fig.add_trace(go.Scatter(x=fpr, y=tpr, mode="lines", name=f"ROC AUC={roc_auc:.3f}"))
fig.add_trace(
    go.Scatter(x=[0, 1], y=[0, 1], mode="lines", name="random", line=dict(dash="dash"))
)
fig.update_layout(title="ROC curve", xaxis_title="FPR", yaxis_title="TPR")
fig.show()

fig = go.Figure()
fig.add_trace(
    go.Scatter(x=recall, y=precision, mode="lines", name=f"PR AUC={pr_auc:.3f}")
)
fig.update_layout(title="PR curve", xaxis_title="Recall", yaxis_title="Precision")
fig.show()