# Convolutional Autoencoder Anomaly Detection (Fashion-MNIST)

**Goal:** Unsupervised anomaly detection using **reconstruction error**.  
We train a convolutional autoencoder (CAE) on one **normal** class and treat other classes as anomalies.

**Pipeline**
1) Train CAE on normal-class images only  
2) Score test images by reconstruction error  
3) Evaluate with ROC-AUC / PR-AUC and a thresholded confusion matrix  
4) Visualize reconstructions and top anomalies


## 0) Install & imports


In [None]:
# (Optional for Colab)
# !pip -q install torch torchvision numpy matplotlib tqdm scikit-learn


In [None]:
import os, random, time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, auc, confusion_matrix

print("torch:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)


## 1) Config


In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Fashion-MNIST labels:
# 0: T-shirt/top, 1: Trouser, 2: Pullover, 3: Dress, 4: Coat,
# 5: Sandal, 6: Shirt, 7: Sneaker, 8: Bag, 9: Ankle boot
NORMAL_CLASS = 0

EPOCHS = 10
BATCH_SIZE = 256
LR = 1e-3
WEIGHT_DECAY = 1e-5

OUT_DIR = "outputs_cae_anomaly"
os.makedirs(OUT_DIR, exist_ok=True)

print("NORMAL_CLASS =", NORMAL_CLASS)


## 2) Data


In [None]:
tf = transforms.Compose([transforms.ToTensor()])

train_ds = datasets.FashionMNIST(root="data", train=True, download=True, transform=tf)
test_ds  = datasets.FashionMNIST(root="data", train=False, download=True, transform=tf)

train_targets = np.array(train_ds.targets)
normal_train_idx = np.where(train_targets == NORMAL_CLASS)[0]
train_normal = Subset(train_ds, normal_train_idx)

test_targets = np.array(test_ds.targets)
y_true = (test_targets != NORMAL_CLASS).astype(np.int64)  # 1=anomaly, 0=normal

train_dl = DataLoader(train_normal, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=(device=="cuda"))
test_dl  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=(device=="cuda"))

print("Train (normal only):", len(train_normal))
print("Test:", len(test_ds), "| anomaly ratio:", y_true.mean())


## 3) Model (Convolutional Autoencoder)


In [None]:
class ConvAutoencoder(nn.Module):
    def __init__(self, latent_dim=64):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),  # 16x14x14
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), # 32x7x7
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=1, padding=1), # 64x7x7
            nn.ReLU(inplace=True),
        )
        self.enc_fc = nn.Linear(64*7*7, latent_dim)

        self.dec_fc = nn.Linear(latent_dim, 64*7*7)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 32x14x14
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), # 16x28x28
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 1, 3, padding=1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.enc(x).view(x.size(0), -1)
        return self.enc_fc(h)

    def decode(self, z):
        h = self.dec_fc(z).view(z.size(0), 64, 7, 7)
        return self.dec(h)

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)

model = ConvAutoencoder(latent_dim=64).to(device)
print(model)


## 4) Train on normal data only


In [None]:
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

train_losses = []
model.train()

start = time.time()
for epoch in range(1, EPOCHS+1):
    running = 0.0
    pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{EPOCHS}")
    for x, _ in pbar:
        x = x.to(device, non_blocking=True)
        xhat = model(x)
        loss = criterion(xhat, x)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        running += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    avg = running / len(train_dl)
    train_losses.append(avg)
    print(f"Epoch {epoch}: mse={avg:.6f}")

print(f"Training time: {(time.time()-start)/60:.2f} min")
torch.save({"model": model.state_dict(), "config": {"NORMAL_CLASS": NORMAL_CLASS}}, os.path.join(OUT_DIR, "cae.pt"))
print("Saved:", os.path.join(OUT_DIR, "cae.pt"))


### Plot training loss


In [None]:
plt.figure(figsize=(6,4))
plt.plot(range(1, len(train_losses)+1), train_losses)
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.title("CAE Training Loss (normal-only)")
plt.grid(True)
plt.show()


## 5) Score test images by reconstruction error


In [None]:
@torch.no_grad()
def reconstruction_errors(model, dataloader):
    model.eval()
    errs = []
    for x, _ in tqdm(dataloader, desc="Scoring"):
        x = x.to(device, non_blocking=True)
        xhat = model(x)
        e = ((xhat - x) ** 2).view(x.size(0), -1).mean(dim=1).cpu().numpy()
        errs.append(e)
    return np.concatenate(errs, axis=0)

scores = reconstruction_errors(model, test_dl)  # higher => more anomalous
print("scores:", scores.shape)


### Score distribution (normal vs anomaly)


In [None]:
plt.figure(figsize=(6,4))
plt.hist(scores[y_true==0], bins=50, alpha=0.7, label="Normal")
plt.hist(scores[y_true==1], bins=50, alpha=0.7, label="Anomaly")
plt.xlabel("Reconstruction error (MSE)")
plt.ylabel("Count")
plt.title("Reconstruction Error Distribution")
plt.grid(True)
plt.legend()
plt.show()


## 6) Metrics: ROC-AUC, PR-AUC, and thresholding


In [None]:
roc_auc = roc_auc_score(y_true, scores)
fpr, tpr, thr = roc_curve(y_true, scores)

prec, rec, _ = precision_recall_curve(y_true, scores)
pr_auc = auc(rec, prec)

print(f"ROC-AUC: {roc_auc:.4f}")
print(f"PR-AUC : {pr_auc:.4f}")

plt.figure(figsize=(5,5))
plt.plot(fpr, tpr)
plt.plot([0,1],[0,1], linestyle="--")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.title(f"ROC (AUC={roc_auc:.3f})")
plt.grid(True)
plt.show()

plt.figure(figsize=(5,5))
plt.plot(rec, prec)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title(f"Precision-Recall (AUC={pr_auc:.3f})")
plt.grid(True)
plt.show()


### Choose threshold via Youdenâ€™s J and compute confusion matrix


In [None]:
j = tpr - fpr
best_i = j.argmax()
best_thr = thr[best_i]
print("Best threshold:", best_thr)

y_pred = (scores >= best_thr).astype(np.int64)
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

print("Confusion matrix:\n", cm)
print(f"TP={tp}, FP={fp}, TN={tn}, FN={fn}")
print("Precision:", tp / max(tp+fp, 1))
print("Recall   :", tp / max(tp+fn, 1))


## 7) Visualize reconstructions and top anomalies


In [None]:
@torch.no_grad()
def recon_pair(model, dataset, idx):
    model.eval()
    x, y = dataset[idx]
    x_in = x.unsqueeze(0).to(device)
    xhat = model(x_in).squeeze(0).cpu()
    return x, xhat, int(y)

def show_recon_grid(indices, title):
    plt.figure(figsize=(12,4))
    for i, idx in enumerate(indices):
        x, xhat, y = recon_pair(model, test_ds, idx)
        plt.subplot(2, len(indices), i+1)
        plt.imshow(x.squeeze(0), cmap="gray")
        plt.axis("off")
        plt.title(f"y={y}")
        plt.subplot(2, len(indices), len(indices)+i+1)
        plt.imshow(xhat.squeeze(0), cmap="gray")
        plt.axis("off")
        plt.title("recon")
    plt.suptitle(title)
    plt.show()

# random normals and anomalies
test_targets = np.array(test_ds.targets)
normal_idx = np.where(test_targets == NORMAL_CLASS)[0]
anom_idx = np.where(test_targets != NORMAL_CLASS)[0]
np.random.shuffle(normal_idx)
np.random.shuffle(anom_idx)

show_recon_grid(normal_idx[:6], "Normal samples (top) vs reconstructions (bottom)")
show_recon_grid(anom_idx[:6], "Anomaly samples (top) vs reconstructions (bottom)")


### Top-k most anomalous samples (highest reconstruction error)


In [None]:
top_k = 12
top_idx = np.argsort(-scores)[:top_k]
plt.figure(figsize=(14,4))
for i, idx in enumerate(top_idx):
    x, xhat, y = recon_pair(model, test_ds, int(idx))
    plt.subplot(2, top_k, i+1)
    plt.imshow(x.squeeze(0), cmap="gray")
    plt.axis("off")
    plt.title(f"y={y}")
    plt.subplot(2, top_k, top_k+i+1)
    plt.imshow(xhat.squeeze(0), cmap="gray")
    plt.axis("off")
    plt.title(f"e={scores[idx]:.4f}")
plt.suptitle("Top anomalies (highest reconstruction error)")
plt.show()


## 8) Next steps (to make it even stronger)
- Try different NORMAL_CLASS values and compare ROC-AUC  
- Use a **denoising autoencoder** (add noise, reconstruct clean)  
- Replace CAE with a **VAE** and use negative log-likelihood  
- Add latent-space visualization (PCA/t-SNE on `z`)
