Skip to content

NickF93/MH-PatchCore

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

28 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MH-PatchCore

Mahalanobis-flavored PatchCore with online/batch processing modes and PyTorch implementation.

This repository implements an enhanced version of PatchCore with support for various feature reduction, clustering, and distance metric strategies, including Mahalanobis distance in a whitened Z-space. The implementation offers both VANILLA (batch) and ONLINE (streaming) processing modes.

For a detailed comparison of implementation approaches and architectural flow diagrams, see ARCHITECTURE.md.

Installation

  1. Clone the repository:

    git clone https://github.com/yourusername/MH-PatchCore.git
    cd MH-PatchCore
  2. Install dependencies:

    pip install -r requirements.txt

    Note: Ensure you have a CUDA-capable GPU and PyTorch installed with CUDA support for optimal performance.

Features

  • Core Logic: Replicates the original PatchCore logic for full backward compatibility.
  • Processing Modes:
    • VANILLA: Batch processing (collect all features, then process) - original PatchCore compatible.
    • ONLINE: Streaming/batch online processing - memory-efficient incremental learning.
  • Feature Reduction:
    • NONE: No dimensionality reduction (original PatchCore behavior).
    • PCA: Principal Component Analysis.
      • VANILLA mode: Standard sklearn PCA.
      • ONLINE mode: IncrementalPCA for streaming data.
  • Clustering / Aggregation:
    • GREEDY: Original PatchCore Greedy Coreset Sampling.
    • KMEANS: MiniBatch K-Means clustering.
      • VANILLA mode: Fit on all data at once.
      • ONLINE mode: partial_fit for incremental clustering.
    • NONE: Use all features without sampling.
  • Distance Metric:
    • EUCLIDEAN: Nearest Neighbor Euclidean distance (Original PatchCore).
    • MAHALANOBIS: Mahalanobis distance to cluster centroids (performed in whitened Z-space).
      • Euclidean distance in Z-space = Mahalanobis distance in original space.

Usage

MHPatchCore

The main class MHPatchCore allows flexible configuration.

import torch
from mhpc.core.mh_patch_core import (
    MHPatchCore,
    PatchCoreAlgorithm,
    PatchCoreAggregation,
    PatchCoreFeatureReduction,
    PatchCoreDistance
)
from mhpc.core.backbones import BackboneName

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model
model = MHPatchCore(
    backbone=BackboneName.wideresnet50,
    embedding_layers=["layer2", "layer3"],
    device=device,
    input_shape=(224, 224),
    pretrain_embed_dimension=1024,
    target_embed_dimension=1024,
    patchsize=3,
    algorithm=PatchCoreAlgorithm.VANILLA,
    aggregation=PatchCoreAggregation.GREEDY,
    feature_reduction=PatchCoreFeatureReduction.NONE,
    distance=PatchCoreDistance.EUCLIDEAN,
    coreset_percentage=0.1
)

# Fit
# train_dataloader = ...
# model.fit(train_dataloader)

# Predict
# test_dataloader = ...
# scores, masks, labels_gt, masks_gt = model.predict(test_dataloader)

Ablation Study Guide

To perform an ablation study, you can iterate through combinations of the following parameters:

  1. Vanilla PatchCore (Baseline):

    • algorithm=VANILLA
    • aggregation=GREEDY
    • feature_reduction=NONE
    • distance=EUCLIDEAN
  2. Mahalanobis PatchCore (Proposed):

    • algorithm=ONLINE (or VANILLA)
    • aggregation=KMEANS
    • feature_reduction=PCA
    • distance=MAHALANOBIS
  3. Variations:

    • Effect of PCA: Compare feature_reduction=NONE vs PCA (with pca_variance_ratio=0.99).
    • Effect of Clustering: Compare aggregation=GREEDY vs KMEANS vs NONE.
    • Effect of Distance: Compare distance=EUCLIDEAN vs MAHALANOBIS.
    • Effect of Online Processing: Compare algorithm=VANILLA vs ONLINE with same settings (e.g., KMEANS + PCA).

Example Iteration:

configs = [
    {"alg": "VANILLA", "agg": "GREEDY", "red": "NONE", "dist": "EUCLIDEAN"}, # Baseline
    {"alg": "ONLINE", "agg": "KMEANS", "red": "PCA", "dist": "MAHALANOBIS"}, # Proposed
    {"alg": "VANILLA", "agg": "KMEANS", "red": "PCA", "dist": "MAHALANOBIS"}, # Batch Proposed
    # ... add more combinations
]

for cfg in configs:
    model = MHPatchCore(..., 
        algorithm=PatchCoreAlgorithm[cfg["alg"]],
        aggregation=PatchCoreAggregation[cfg["agg"]],
        feature_reduction=PatchCoreFeatureReduction[cfg["red"]],
        distance=PatchCoreDistance[cfg["dist"]]
    )
    model.fit(train_loader)
    # evaluate...

The MH-PatchCore Idea (ONLINE Mode)

The ONLINE mode implements a memory-efficient streaming approach that processes training data incrementally:

Complete Flow:

  1. Pass 1 - PCA Training: Fit IncrementalPCA online by streaming through data batch-by-batch.
  2. Pass 2 - Covariance Computation: Compute global covariance matrix Σ and Cholesky decomposition L (where Σ = L·L^T) using OnlineCovariance.
  3. Pass 3 - Z-space Clustering:
    • Transform each batch to Z-space using z = L^-1(x - μ).
    • Incrementally fit MiniBatchKMeans using partial_fit.
    • Collect cluster centroids.
  4. Memory Bank Creation: Build memory bank using the centroids from step 3.
  5. Inference:
    • Input sample → PCA transform → L transform to Z-space → Find nearest centroids → KNN in Z-space.
    • Key insight: KNN in Z-space is equivalent to Mahalanobis distance in original space.

This approach never loads all training data into memory simultaneously, making it suitable for large datasets.

Low-Level Utilities

The repository also contains low-level PyTorch components for Mahalanobis distance:

  • OnlineCovariance: a batched/streaming covariance estimator.
  • CovarianceCholesky: a small wrapper around torch.linalg.cholesky with optional diagonal jitter for numerical stability.

Online covariance (batched/streaming)

import torch
from mhpc.mahalanobis import OnlineCovariance

cov = OnlineCovariance(num_features=128, dtype=torch.float64)
for batch in data_loader:
    cov.update(batch)  # batch shape: [batch, features...]

cov_matrix = cov.covariance(unbiased=True)

Cholesky from covariance

import torch
from mhpc.mahalanobis import CovarianceCholesky

cov = torch.randn(64, 64, dtype=torch.float64)
cov = cov @ cov.t() + 1e-3 * torch.eye(64, dtype=torch.float64)

chol = CovarianceCholesky(jitter=0.0)
L = chol.compute(cov)

Tests

Pytest is configured to only collect tests from the tests/ directory.

pytest

About

Mahalanobis Version of PatchCore

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages