In [None]:
!pip install -q diffusers==0.30.0 transformers accelerate safetensors opensimplex

In [None]:
ds_root = "/kaggle/input/mvtec-ad/carpet"
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Creating transformer and loading dataset

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),                         
])

train_dataset = datasets.ImageFolder(
    root=os.path.join(ds_root, "train"),
    transform=transform
)

test_dataset = datasets.ImageFolder(
    root=os.path.join(ds_root, "test"),
    transform=transform
)


In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

#
# Creating a basic UNet with sinusoidal time embeddings that help predict noise
# credits: https://apxml.com/courses/advanced-diffusion-architectures/chapter-2-advanced-unet-architectures/unet-time-embeddings
# pretty interesting and 


class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """
        t: [B] (int or float timesteps)
        returns: [B, dim]
        """
        device = t.device
        half_dim = self.dim // 2
        # [half_dim]
        freqs = torch.exp(
            torch.arange(half_dim, device=device) * (-math.log(10000.0) / (half_dim - 1))
        )
        # [B, 1] * [half_dim] -> [B, half_dim]
        args = t.float().unsqueeze(1) * freqs.unsqueeze(0)
        # [B, dim]
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)
        return emb


# --------- Building blocks ---------

class DoubleConv(nn.Module):
    """
    (Conv -> Norm -> ReLU) x2 with optional time embedding.
    in_ch -> out_ch
    """
    def __init__(self, in_ch: int, out_ch: int, time_dim: int = None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.norm1 = nn.BatchNorm2d(out_ch)
        self.norm2 = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)

        self.time_mlp = None
        if time_dim is not None:
            self.time_mlp = nn.Linear(time_dim, out_ch)

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor = None) -> torch.Tensor:
        """
        x: [B, C, H, W]
        t_emb: [B, time_dim] or None
        """
        h = self.conv1(x)
        h = self.norm1(h)

        if self.time_mlp is not None and t_emb is not None:
            # [B, out_ch]
            time_bias = self.time_mlp(t_emb)
            # [B, out_ch, 1, 1]
            time_bias = time_bias.unsqueeze(-1).unsqueeze(-1)
            h = h + time_bias

        h = self.act(h)
        h = self.conv2(h)
        h = self.norm2(h)
        h = self.act(h)
        return h


class DownBlock(nn.Module):
    """
    DoubleConv + MaxPool
    """
    def __init__(self, in_ch: int, out_ch: int, time_dim: int):
        super().__init__()
        self.conv = DoubleConv(in_ch, out_ch, time_dim)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
        x = self.conv(x, t_emb)   # [B, out_ch, H, W]
        p = self.pool(x)         # [B, out_ch, H/2, W/2]
        return x, p              # return features + pooled


class UpBlock(nn.Module):
    """
    Up-conv (transpose) + concat skip + DoubleConv
    """
    def __init__(self, in_ch: int, out_ch: int, time_dim: int):
        """
        in_ch: channels coming into this block (from lower level),
               this goes into ConvTranspose2d
        out_ch: channels after up-conv; skip connection will also be out_ch,
                so DoubleConv sees in_ch = out_ch * 2
        """
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        # after up: [B, out_ch, H*2, W*2]
        # skip:     [B, out_ch, H*2, W*2]
        # concat:   [B, 2*out_ch, H*2, W*2]
        self.conv = DoubleConv(out_ch * 2, out_ch, time_dim)

    def forward(self, x: torch.Tensor, skip: torch.Tensor, t_emb: torch.Tensor):
        x = self.up(x)
        # handle possible mismatch due to odd sizes
        if x.shape[-2:] != skip.shape[-2:]:
            diff_y = skip.shape[-2] - x.shape[-2]
            diff_x = skip.shape[-1] - x.shape[-1]
            x = F.pad(x, [diff_x // 2, diff_x - diff_x // 2,
                          diff_y // 2, diff_y - diff_y // 2])
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x, t_emb)
        return x


# --------- Full UNet ---------

class UNet(nn.Module):
    """
    UNet with time conditioning, suitable for DDPM-style noise prediction.

    - in_channels: 4 for SD latents, or 3 for RGB pixels
    - base_channels: base width of the network
    - time_dim: dimension of time embedding
    """
    def __init__(self, in_channels: int = 4, base_channels: int = 64, time_dim: int = 256):
        super().__init__()

        # time embedding (sinusoidal + MLP)
        self.time_embed = SinusoidalTimeEmbedding(time_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, time_dim),
            nn.ReLU(inplace=True),
            nn.Linear(time_dim, time_dim),
        )

        # Channel sizes
        c1 = base_channels          # 64
        c2 = base_channels * 2      # 128
        c3 = base_channels * 4      # 256
        c4 = base_channels * 8      # 512
        c5 = base_channels * 16     # 1024 (bottleneck)

        # Down path: in -> 64 -> 128 -> 256 -> 512
        self.down1 = DownBlock(in_channels, c1, time_dim)  # 4 -> 64
        self.down2 = DownBlock(c1,          c2, time_dim)  # 64 -> 128
        self.down3 = DownBlock(c2,          c3, time_dim)  # 128 -> 256
        self.down4 = DownBlock(c3,          c4, time_dim)  # 256 -> 512

        # Bottleneck: 512 -> 1024
        self.bot = DoubleConv(c4, c5, time_dim)

        # Up path: 1024 -> 512 -> 256 -> 128 -> 64
        self.up1 = UpBlock(c5, c4, time_dim)   # in: 1024, out: 512
        self.up2 = UpBlock(c4, c3, time_dim)   # in: 512,  out: 256
        self.up3 = UpBlock(c3, c2, time_dim)   # in: 256,  out: 128
        self.up4 = UpBlock(c2, c1, time_dim)   # in: 128,  out: 64

        # Final 1x1 conv: 64 -> in_channels (4 for latents)
        self.out_conv = nn.Conv2d(c1, in_channels, kernel_size=1)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        x: [B, in_channels, H, W]   (noisy image or latent)
        t: [B]                      (timestep indices)
        returns: [B, in_channels, H, W]  (predicted noise)
        """
        # embed t
        t_emb = self.time_embed(t)          # [B, time_dim]
        t_emb = self.time_mlp(t_emb)       # [B, time_dim]

        # Down path
        d1, p1 = self.down1(x,  t_emb)     # d1: [B, 64,  H,   W]
        d2, p2 = self.down2(p1, t_emb)     # d2: [B, 128, H/2, W/2]
        d3, p3 = self.down3(p2, t_emb)     # d3: [B, 256, H/4, W/4]
        d4, p4 = self.down4(p3, t_emb)     # d4: [B, 512, H/8, W/8]

        # Bottleneck
        b = self.bot(p4, t_emb)            # [B, 1024, H/8, W/8]

        # Up path with skip connections
        u1 = self.up1(b,  d4, t_emb)       # [B, 512, H/8, W/8]
        u2 = self.up2(u1, d3, t_emb)       # [B, 256, H/4, W/4]
        u3 = self.up3(u2, d2, t_emb)       # [B, 128, H/2, W/2]
        u4 = self.up4(u3, d1, t_emb)       # [B, 64,  H,   W]

        out = self.out_conv(u4)            # [B, in_channels, H, W]
        return out


In [None]:
# from diffusers import AutoencoderKL
# import torch

# device = "cuda" if torch.cuda.is_available() else "cpu"

# vae = AutoencoderKL.from_pretrained(
#     "runwayml/stable-diffusion-v1-5",
#     subfolder="vae"
# ).to(device)
# vae.eval()


In [None]:
import numpy as np
import torch
from opensimplex import OpenSimplex

# Took help of GPT and the the code of AnoDDPM paper for this implementation

def simplex_noise(
    batch,
    ch,
    H,
    W,
    seed=0,
    persistence=0.5,
    colorless=True,
    normalize_per_sample=False,
    device="cpu",
    dtype=torch.float32,
):
    """
    Simple, fast AnoDDPM-style simplex noise generator.
    
    Frequencies: ν = 2^-1 ... 2^-6  (fixed, as in AnoDDPM)
    persistence: amplitude decay per octave
    colorless: if True → same noise map for all channels
    """

    # AnoDDPM fixed multi-scale frequencies
    # freqs = [2**(-i) for i in range(1, 7)]   # 0.5, 0.25, ..., 1/64
    freqs = [2**(-6)]

    # Coordinates
    xs = np.linspace(0.0, 1.0, W, endpoint=False).astype(np.float32)
    ys = np.linspace(0.0, 1.0, H, endpoint=False).astype(np.float32)

    def build_field(rng):
        """Produces a single H×W simplex noise map with AnoDDPM frequencies."""
        field = np.zeros((H, W), dtype=np.float32)
        for o, freq in enumerate(freqs):
            amp = persistence ** o
            for i, yi in enumerate(ys):
                for j, xj in enumerate(xs):
                    field[i, j] += amp * float(rng.noise2(xj * freq, yi * freq))

        # normalize
        if normalize_per_sample:
            field = (field - field.mean()) / (field.std() + 1e-9)
        return field

    rng = OpenSimplex(seed)

    # -------- colorless mode (AnoDDPM default) --------
    if colorless:
        base = build_field(rng)
        base = (base - base.mean()) / (base.std() + 1e-9)

        arr = np.tile(base[None, None], (batch, ch, 1, 1)).astype(np.float32)

    # -------- multi-channel independent noise --------
    else:
        arr = np.zeros((batch, ch, H, W), dtype=np.float32)
        loc_seed = seed
        for b in range(batch):
            for c in range(ch):
                rng_local = OpenSimplex(loc_seed)
                f = build_field(rng_local)
                f = (f - f.mean()) / (f.std() + 1e-9)
                arr[b, c] = f
                loc_seed += 1

    return torch.tensor(arr, device=device, dtype=dtype)


In [None]:
# Loading VAE for working in the Latent Space of SD 1.5

import torch
from torch.utils.data import DataLoader
from diffusers import AutoencoderKL

device = "cuda" if torch.cuda.is_available() else "cpu"

vae = AutoencoderKL.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="vae"
).to(device)
vae.eval()
vae.requires_grad_(False)

batch_size = 8 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


In [None]:
# precomputing latents for training because ran into kaggle oom issues

all_latents = []

with torch.no_grad():
    for x0,_ in train_loader:
        x0 = x0.repeat(1,3,1,1) # grayscale input and vae expects 3 channels
        x0 = x0.to(device) 
        x0_norm = x0 * 2 - 1   # [0,1] -> [-1,1]
        # encode with VAE
        posterior = vae.encode(x0_norm)
        # use mean/sample for deterministic encoding
        z0 = posterior.latent_dist.mean  # [B, 4, H_lat, W_lat]
        # scale like SD
        z0 = z0 * 0.18215
        all_latents.append(z0.cpu())

all_latents = torch.cat(all_latents, dim=0)  # [N, 4, H_lat, W_lat]
print("Latents shape:", all_latents.shape)
torch.save(all_latents, "train_latents.pt")
print("Saved latents to train_latents.pt")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

class LatentDataset(Dataset):
    def __init__(self, latents_path: str):
        super().__init__()
        self.latents = torch.load(latents_path)  # [N, 4, H_lat, W_lat], float32

    def __len__(self):
        return self.latents.shape[0]

    def __getitem__(self, idx):
        return self.latents[idx]  # [4, H_lat, W_lat]


latent_dataset = LatentDataset("train_latents.pt")
latent_loader = DataLoader(latent_dataset, batch_size=16, shuffle=True, num_workers=2)


In [None]:
# initialize model and optimizer
in_channels = 4
base_channels = 64
time_dim = 256

model = UNet(in_channels=in_channels, base_channels=base_channels, time_dim=time_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
# training block
T = 1000  #max number of noising steps

# Scheduler start and end values
beta_start = 1e-4
beta_end   = 0.02

# precomputing schedueler constants

betas = torch.linspace(beta_start, beta_end, T, dtype=torch.float32)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)  # [T]

sqrt_alpha_bars      = torch.sqrt(alpha_bars)
sqrt_one_minus_abars = torch.sqrt(1.0 - alpha_bars)

# move to gpu
betas               = betas.to(device)
alphas              = alphas.to(device)
alpha_bars          = alpha_bars.to(device)
sqrt_alpha_bars     = sqrt_alpha_bars.to(device)
sqrt_one_minus_abars= sqrt_one_minus_abars.to(device)

# get zt from z0,t and eps
def q_sample(z0, t, eps):
    B = z0.shape[0]
    sqrt_ab    = sqrt_alpha_bars[t].view(B, 1, 1, 1)        # [B,1,1,1]
    sqrt_ombab = sqrt_one_minus_abars[t].view(B, 1, 1, 1)   # [B,1,1,1]
    return sqrt_ab * z0 + sqrt_ombab * eps

# training loop
num_epochs = 20  
global_step = 0

model.train()

for epoch in range(num_epochs):
    for z0 in latent_loader:
        z0 = z0.to(device)  # [B, 4, H_lat, W_lat]
        B  = z0.shape[0]

        # sample random timesteps t for each sample in batch
        t = torch.randint(low=T//4, high=T, size=(B,), device=device)

        eps = simplex_noise(
        batch=B,
        ch=4,                 
        H=64, W=64,         
        seed=123,
        persistence=0.5,
        colorless=True,
        device="cuda",
    )

        eps = eps.to(device)

        z_t = q_sample(z0, t, eps)

        eps_hat = model(z_t, t)

        loss = F.mse_loss(eps_hat, eps)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        global_step += 1

    print(f"Epoch {epoch+1}/{num_epochs} - loss: {loss.item():.6f}")


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

model.eval()
vae.eval()

vae.requires_grad_(False)
model.requires_grad_(False)

def q_sample(z0, t, eps, sqrt_alpha_bars, sqrt_one_minus_abars):
    B = z0.shape[0]
    sqrt_ab    = sqrt_alpha_bars[t].view(B, 1, 1, 1)
    sqrt_ombab = sqrt_one_minus_abars[t].view(B, 1, 1, 1)
    return sqrt_ab * z0 + sqrt_ombab * eps, sqrt_ab, sqrt_ombab

@torch.no_grad()
def reconstruct_and_heatmap(x_batch, t_step):

    B = x_batch.shape[0]
    x0 = x_batch.to(device)

    x0 = x0.repeat(1,3,1,1)
    x0_norm = x0 * 2 - 1                     
    posterior = vae.encode(x0_norm)
    z0 = posterior.latent_dist.sample() * 0.18215 # SD standardize smtn

    t = torch.full((B,), t_step, device=device, dtype=torch.long)

    eps = simplex_noise(
    batch=B,
    ch=4,                 
    H=64, W=64,         
    seed=123,
    persistence=0.5,
    colorless=True,
    device="cuda",
)
    eps = eps.to(device)

    z_t, sqrt_ab, sqrt_ombab = q_sample(z0, t, eps, sqrt_alpha_bars, sqrt_one_minus_abars)

    eps_hat = model(z_t, t)  # [B, 4, H_lat, W_lat]

    z_recon = (z_t - sqrt_ombab * eps_hat) / sqrt_ab

    z_recon_dec = z_recon / 0.18215
    x_recon = vae.decode(z_recon_dec).sample  
    x_recon = (x_recon + 1) / 2               
    residual = (x_recon - x0) ** 2            
    heat = residual.mean(dim=1)               

    #heatmap for anomalies
    
    B, H, W = heat.shape
    heat_flat = heat.view(B, -1)
    h_min = heat_flat.min(dim=1, keepdim=True)[0]
    h_max = heat_flat.max(dim=1, keepdim=True)[0]
    heat_norm = (heat_flat - h_min) / (h_max - h_min + 1e-8)
    heat_norm = heat_norm.view(B, H, W)

    return x0.cpu(), x_recon.cpu(), heat_norm.cpu()


In [None]:
from torch.utils.data import DataLoader

test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)
t_step = 500  

x_test_batch, _ = next(iter(test_loader))  
x_orig, x_recon, heatmap = reconstruct_and_heatmap(x_test_batch, t_step)


B = x_orig.shape[0]
n_show = min(B, 4) 

plt.figure(figsize=(9, 3 * n_show))

for i in range(n_show):
    plt.subplot(n_show, 3, 3*i + 1)
    plt.imshow(x_orig[i].permute(1, 2, 0).numpy())
    plt.title("Original")
    plt.axis("off")

    plt.subplot(n_show, 3, 3*i + 2)
    plt.imshow(x_recon[i].permute(1, 2, 0).numpy())
    plt.title("Reconstruction")
    plt.axis("off")

    plt.subplot(n_show, 3, 3*i + 3)
    img = x_orig[i].permute(1, 2, 0).numpy()
    hm  = heatmap[i].numpy()
    plt.imshow(img)
    plt.imshow(hm, cmap="jet", alpha=0.5)  # overlay
    plt.title("Anomaly heatmap")
    plt.axis("off")

plt.tight_layout()
plt.show()
