# 📒 Scanner ID Pipeline with Flatfield Fingerprints
This notebook extracts **scanner fingerprints** from flat-field images, then compares document noise patterns against them using correlation, FFT features, and trains ML models (SVM/CNN).

In [None]:
import os, glob, random
import numpy as np
import matplotlib.pyplot as plt

from typing import Dict, List, Tuple
from collections import Counter

from skimage import io, color
from skimage.util import img_as_float32
from skimage.restoration import denoise_wavelet

from numpy.fft import fft2, fftshift

from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, ConfusionMatrixDisplay

import joblib

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split


In [None]:
# ==== EDIT THESE PATHS FOR YOUR MACHINE ====
FLATROOT = r"D:\scanner_id_pipeline\data\flatfields"
OFFICIALROOT = r"D:\scanner_id_pipeline\data\Official"
WIKIROOT = r"D:\scanner_id_pipeline\data\Wikipedia"

def check_path(p):
    print(p, "=>", "OK" if os.path.exists(p) else "MISSING")

print("Checking your paths...")
check_path(FLATROOT)
check_path(OFFICIALROOT)
check_path(WIKIROOT)


In [None]:
def load_image_gray(path: str) -> np.ndarray:
    img = io.imread(path)
    if img.ndim == 3:
        if img.shape[-1] == 4:
            img = img[..., :3]
        img = color.rgb2gray(img)
    else:
        img = img.astype(np.float32) / (np.iinfo(img.dtype).max if np.issubdtype(img.dtype, np.integer) else 1.0)
    return img_as_float32(img)

def normalize_image(img: np.ndarray) -> np.ndarray:
    m, s = np.mean(img), np.std(img) + 1e-8
    return (img - m) / s

def residual_wavelet(img: np.ndarray) -> np.ndarray:
    den = denoise_wavelet(img, method='BayesShrink', mode='soft', rescale_sigma=True)
    res = img - den
    return res.astype(np.float32)


In [None]:
def list_images(root: str, exts=('.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp')) -> List[str]:
    files = []
    for ext in exts:
        files.extend(glob.glob(os.path.join(root, '**', f'*{ext}'), recursive=True))
    return sorted(files)

def group_by_scanner(root: str) -> Dict[str, List[str]]:
    paths = list_images(root)
    mapping = {}
    for p in paths:
        rel = os.path.relpath(p, root)
        scanner = rel.split(os.sep)[0]
        mapping.setdefault(scanner, []).append(p)
    return mapping


## 🔑 Step 1: Build Flatfield Fingerprints

In [None]:
def build_scanner_fingerprints(flatroot: str) -> Dict[str, np.ndarray]:
    by_scanner = group_by_scanner(flatroot)
    fingerprints = {}
    for lab, paths in by_scanner.items():
        resids = []
        for p in paths:
            img = load_image_gray(p)
            img = normalize_image(img)
            res = residual_wavelet(img)
            resids.append(res)
        if resids:
            fingerprints[lab] = np.mean(resids, axis=0).astype(np.float32)
            print(f"Built fingerprint for {lab}, using {len(resids)} flatfield images.")
    return fingerprints

fingerprints = build_scanner_fingerprints(FLATROOT)


## 🔑 Step 2: Extract Document Residuals and Correlate with Fingerprints

In [None]:
def corr2(a, b):
    a, b = a - np.mean(a), b - np.mean(b)
    return np.sum(a*b) / (np.sqrt(np.sum(a*a)) * np.sqrt(np.sum(b*b)) + 1e-8)

def correlation_features(residual: np.ndarray, fingerprints: Dict[str, np.ndarray], label_names: List[str]) -> np.ndarray:
    return np.array([corr2(residual, fingerprints[lab]) for lab in label_names], dtype=np.float32)

def fft_radial_stats(patch: np.ndarray, n_bins: int = 16) -> np.ndarray:
    F = fftshift(fft2(patch))
    P = np.abs(F) ** 2
    H, W = P.shape
    cy, cx = H//2, W//2
    y, x = np.indices(P.shape)
    r = np.sqrt((y - cy)**2 + (x - cx)**2)
    r_norm = r / (r.max() + 1e-8)
    bins = np.linspace(0, 1.0, n_bins+1)
    feats = []
    for i in range(n_bins):
        mask = (r_norm >= bins[i]) & (r_norm < bins[i+1])
        feats.append(P[mask].mean() if np.any(mask) else 0.0)
    feats = np.log1p(np.array(feats, dtype=np.float32))
    return (feats - feats.mean()) / (feats.std() + 1e-8)


In [None]:
def build_document_dataset(docroot: str, fingerprints: Dict[str, np.ndarray], max_docs: int = None):
    by_scanner = group_by_scanner(docroot)
    label_names = sorted(fingerprints.keys())
    X_list, y_list = [], []
    for lab in by_scanner.keys():
        for p in by_scanner[lab][:max_docs or None]:
            img = load_image_gray(p)
            img = normalize_image(img)
            res = residual_wavelet(img)
            # feature = correlation with fingerprints + FFT features
            feats = np.concatenate([correlation_features(res, fingerprints, label_names),
                                    fft_radial_stats(res)], axis=0)
            X_list.append(feats)
            y_list.append(label_names.index(lab))
    return np.array(X_list, dtype=np.float32), np.array(y_list, dtype=np.int64), label_names

X_off, y_off, labels = build_document_dataset(OFFICIALROOT, fingerprints)
X_wiki, y_wiki, _ = build_document_dataset(WIKIROOT, fingerprints)

X = np.concatenate([X_off, X_wiki], axis=0)
y = np.concatenate([y_off, y_wiki], axis=0)
print("Dataset:", X.shape, "labels:", labels)


## 🔑 Step 3: Train ML Classifier on Features

In [None]:
def make_svm_pipeline(C: float = 10.0, gamma: str | float = 'scale') -> Pipeline:
    return Pipeline([
        ('scaler', StandardScaler()),
        ('clf', SVC(C=C, kernel='rbf', gamma=gamma, probability=True, class_weight='balanced'))
    ])

def train_and_eval(X, y, labels, test_size=0.25):
    Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=test_size, stratify=y, random_state=42)
    model = make_svm_pipeline()
    model.fit(Xtr, ytr)
    yhat_te = model.predict(Xte)
    print("Test Accuracy:", accuracy_score(yte, yhat_te))
    print(classification_report(yte, yhat_te, target_names=labels))
    cm = confusion_matrix(yte, yhat_te)
    ConfusionMatrixDisplay(cm, display_labels=labels).plot(xticks_rotation=45)
    plt.show()
    return model

svm_model = train_and_eval(X, y, labels)


## 🔑 Step 4: Train with Random Forest and XGBoost

In [None]:
from sklearn.ensemble import RandomForestClassifier
import xgboost as xgb

def train_random_forest(X, y, labels, test_size=0.25):
    Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=test_size, stratify=y, random_state=42)
    model = RandomForestClassifier(n_estimators=200, random_state=42, class_weight='balanced')
    model.fit(Xtr, ytr)
    yhat = model.predict(Xte)
    print("Random Forest Test Accuracy:", accuracy_score(yte, yhat))
    print(classification_report(yte, yhat, target_names=labels))
    return model

def train_xgboost(X, y, labels, test_size=0.25):
    Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=test_size, stratify=y, random_state=42)
    model = xgb.XGBClassifier(
        n_estimators=300, max_depth=6, learning_rate=0.1, subsample=0.8, colsample_bytree=0.8,
        objective='multi:softmax', num_class=len(labels), random_state=42, n_jobs=-1
    )
    model.fit(Xtr, ytr)
    yhat = model.predict(Xte)
    print("XGBoost Test Accuracy:", accuracy_score(yte, yhat))
    print(classification_report(yte, yhat, target_names=labels))
    return model

rf_model = train_random_forest(X, y, labels)
xgb_model = train_xgboost(X, y, labels)

os.makedirs("artifacts", exist_ok=True)
joblib.dump({'model': rf_model, 'labels': labels}, "artifacts/random_forest.joblib")
joblib.dump({'model': xgb_model, 'labels': labels}, "artifacts/xgboost.joblib")
print("Saved Random Forest and XGBoost models.")


🔑Step 5: CNN Classifier (ResNet18 on Residual Patches)

In [None]:
from torch.utils.data import Dataset, DataLoader, random_split
import torch
from torch import nn
import torchvision.models as models
import numpy as np

# ---- Dataset wrapper for patches ----
class PatchDataset(Dataset):
    def __init__(self, patches: np.ndarray, labels: np.ndarray):
        self.X = patches.astype(np.float32)
        self.y = labels.astype(np.int64)
    def __len__(self): 
        return len(self.X)
    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx][None, ...]), torch.tensor(self.y[idx])

# ---- ResNet18 modified for grayscale ----
class ResNetScanner(nn.Module):
    def __init__(self, n_classes: int):
        super().__init__()
        self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        in_feats = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_feats, n_classes)
    def forward(self, x):
        return self.resnet(x)

# ---- Patch extraction ----
def extract_patches(img: np.ndarray, patch: int = 128, stride: int = 128, min_margin: int = 16):
    H, W = img.shape
    patches = []
    for y in range(min_margin, H - patch - min_margin + 1, stride):
        for x in range(min_margin, W - patch - min_margin + 1, stride):
            patches.append(img[y:y+patch, x:x+patch])
    return np.stack(patches, axis=0) if patches else np.empty((0, patch, patch), dtype=img.dtype)

def build_patch_dataset(docroot: str, max_docs: int = None):
    by_scanner = group_by_scanner(docroot)
    X_list, y_list, labels = [], [], sorted(by_scanner.keys())
    for lab in labels:
        for p in by_scanner[lab][:max_docs or None]:
            img = load_image_gray(p)
            img = normalize_image(img)
            res = residual_wavelet(img)
            patches = extract_patches(res, patch=128, stride=256)
            for ph in patches:
                X_list.append(ph)
                y_list.append(labels.index(lab))
    return np.stack(X_list), np.array(y_list), labels

# Build patch dataset from both sources
X_patches_off, y_patches_off, _ = build_patch_dataset(OFFICIALROOT)
X_patches_wiki, y_patches_wiki, _ = build_patch_dataset(WIKIROOT)
X_patches = np.concatenate([X_patches_off, X_patches_wiki], axis=0)
y_patches = np.concatenate([y_patches_off, y_patches_wiki], axis=0)
print("Patch dataset:", X_patches.shape, "labels:", len(set(y_patches)))


🔑 Step 6: Train ResNet18

In [None]:
def train_resnet(patches, labels, epochs=10, batch=64, lr=1e-4):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    ds = PatchDataset(patches, labels)
    n_val = max(1, int(0.2 * len(ds)))
    n_train = len(ds) - n_val
    train_ds, val_ds = random_split(ds, [n_train, n_val])
    train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=batch)

    model = ResNetScanner(int(labels.max()+1)).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    crit = nn.CrossEntropyLoss()

    for ep in range(epochs):
        model.train()
        for xb, yb in train_dl:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for xb, yb in val_dl:
                xb, yb = xb.to(device), yb.to(device)
                pred = model(xb).argmax(1)
                correct += (pred == yb).sum().item()
                total += len(yb)
        print(f"Epoch {ep+1}/{epochs}, Val Acc: {correct/max(1,total):.3f}")

    return model

cnn_model = train_resnet(X_patches, y_patches, epochs=10, batch=64, lr=1e-4)

torch.save({'state_dict': cnn_model.state_dict()}, "artifacts/cnn_resnet18.pt")
print("Saved CNN model to artifacts/cnn_resnet18.pt")


🔑 Step 7: Final CNN Evaluation (Overall Accuracy)

In [None]:
from sklearn.metrics import classification_report, accuracy_score

def evaluate_resnet(model, patches, labels):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    ds = PatchDataset(patches, labels)
    dl = DataLoader(ds, batch_size=64)
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for xb, yb in dl:
            xb = xb.to(device)
            out = model(xb).argmax(1).cpu().numpy()
            preds.extend(out)
            trues.extend(yb.numpy())
    acc = accuracy_score(trues, preds)
    print("Final CNN Overall Accuracy:", acc)
    print(classification_report(trues, preds))
    return acc

# Evaluate on all patches used
evaluate_resnet(cnn_model, X_patches, y_patches)


In [None]:
# 🔍 Scanner ID Pipeline with Flatfield Fingerprints - IMPROVED VERSION
# This notebook extracts scanner fingerprints from flat-field images, then compares 
# document noise patterns against them using correlation, FFT features, and trains ML models (SVM/CNN).

import os
import glob
import random
import warnings
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional, Union
from collections import Counter
import logging

# Image processing
from skimage import io, color
from skimage.util import img_as_float32
from skimage.restoration import denoise_wavelet

# FFT
from numpy.fft import fft2, fftshift

# Machine Learning
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, ConfusionMatrixDisplay
import joblib

# Deep Learning
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models

# XGBoost
try:
    import xgboost as xgb
    XGBOOST_AVAILABLE = True
except ImportError:
    XGBOOST_AVAILABLE = False
    print("Warning: XGBoost not available. Install with: pip install xgboost")

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning)

# ==== CONFIGURATION ====
class Config:
    """Configuration class for the pipeline"""
    # EDIT THESE PATHS FOR YOUR MACHINE
    FLATROOT = r"D:\scanner_id_pipeline\data\flatfields"
    OFFICIALROOT = r"D:\scanner_id_pipeline\data\Official"
    WIKIROOT = r"D:\scanner_id_pipeline\data\Wikipedia"
    ARTIFACTS_DIR = "artifacts"
    
    # Processing parameters
    MAX_DOCS_PER_SCANNER = None  # None for all documents
    PATCH_SIZE = 128
    PATCH_STRIDE = 256
    MIN_MARGIN = 16
    FFT_BINS = 16
    
    # ML parameters
    TEST_SIZE = 0.25
    RANDOM_STATE = 42
    CNN_EPOCHS = 15  # Increased from 10
    CNN_BATCH_SIZE = 64
    CNN_LEARNING_RATE = 1e-4
    
    # Memory management
    MAX_PATCHES_IN_MEMORY = 10000  # Limit patches to prevent memory issues

config = Config()

def check_paths():
    """Check if all required paths exist"""
    paths = {
        "Flatfields": config.FLATROOT,
        "Official Documents": config.OFFICIALROOT, 
        "Wikipedia Documents": config.WIKIROOT
    }
    
    logger.info("Checking paths...")
    all_exist = True
    for name, path in paths.items():
        exists = os.path.exists(path)
        status = "✓ OK" if exists else "✗ MISSING"
        logger.info(f"{name}: {path} => {status}")
        if not exists:
            all_exist = False
    
    if not all_exist:
        raise FileNotFoundError("Some required paths are missing. Please update Config class.")
    
    # Create artifacts directory
    Path(config.ARTIFACTS_DIR).mkdir(exist_ok=True)
    logger.info(f"Artifacts directory: {config.ARTIFACTS_DIR}")

# ==== IMAGE PROCESSING FUNCTIONS ====
def load_image_gray(path: str) -> Optional[np.ndarray]:
    """Load and convert image to grayscale with error handling"""
    try:
        if not os.path.exists(path):
            logger.warning(f"File not found: {path}")
            return None
            
        img = io.imread(path)
        
        # Handle different image formats
        if img.ndim == 3:
            if img.shape[-1] == 4:  # RGBA
                img = img[..., :3]  # Remove alpha channel
            img = color.rgb2gray(img)
        elif img.ndim == 2:
            # Already grayscale
            if np.issubdtype(img.dtype, np.integer):
                img = img.astype(np.float32) / np.iinfo(img.dtype).max
            else:
                img = img.astype(np.float32)
        else:
            logger.warning(f"Unexpected image dimensions: {img.shape} for {path}")
            return None
            
        return img_as_float32(img)
        
    except Exception as e:
        logger.error(f"Failed to load image {path}: {str(e)}")
        return None

def normalize_image(img: np.ndarray) -> np.ndarray:
    """Normalize image to zero mean and unit variance"""
    if img is None:
        return None
    mean_val = np.mean(img)
    std_val = np.std(img)
    if std_val < 1e-8:
        logger.warning("Image has very low variance, normalization might be unstable")
        std_val = 1e-8
    return (img - mean_val) / std_val

def residual_wavelet(img: np.ndarray) -> Optional[np.ndarray]:
    """Extract wavelet residual with error handling"""
    try:
        if img is None:
            return None
        denoised = denoise_wavelet(img, method='BayesShrink', mode='soft', rescale_sigma=True)
        residual = img - denoised
        return residual.astype(np.float32)
    except Exception as e:
        logger.error(f"Wavelet denoising failed: {str(e)}")
        return None

# ==== FILE DISCOVERY FUNCTIONS ====
def list_images(root: str, extensions=('.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp')) -> List[str]:
    """List all image files recursively with better error handling"""
    if not os.path.exists(root):
        logger.warning(f"Directory does not exist: {root}")
        return []
        
    files = []
    try:
        for ext in extensions:
            # Case insensitive search
            pattern = os.path.join(root, '**', f'*{ext}')
            files.extend(glob.glob(pattern, recursive=True))
            # Also search uppercase
            pattern_upper = os.path.join(root, '**', f'*{ext.upper()}')
            files.extend(glob.glob(pattern_upper, recursive=True))
        
        # Remove duplicates and sort
        files = sorted(list(set(files)))
        logger.info(f"Found {len(files)} image files in {root}")
        return files
        
    except Exception as e:
        logger.error(f"Error listing images in {root}: {str(e)}")
        return []

def group_by_scanner(root: str) -> Dict[str, List[str]]:
    """Group image paths by scanner (first directory level)"""
    paths = list_images(root)
    mapping = {}
    
    for path in paths:
        try:
            rel_path = os.path.relpath(path, root)
            scanner_name = rel_path.split(os.sep)[0]
            mapping.setdefault(scanner_name, []).append(path)
        except Exception as e:
            logger.warning(f"Could not process path {path}: {str(e)}")
            continue
    
    # Log scanner distribution
    for scanner, files in mapping.items():
        logger.info(f"Scanner '{scanner}': {len(files)} files")
    
    return mapping

# ==== FINGERPRINT BUILDING ====
def build_scanner_fingerprints(flatroot: str) -> Dict[str, np.ndarray]:
    """Build scanner fingerprints from flatfield images with improved error handling"""
    by_scanner = group_by_scanner(flatroot)
    fingerprints = {}
    
    for scanner_name, paths in by_scanner.items():
        logger.info(f"Processing scanner: {scanner_name}")
        residuals = []
        failed_count = 0
        
        for path in paths:
            img = load_image_gray(path)
            if img is None:
                failed_count += 1
                continue
                
            img_norm = normalize_image(img)
            if img_norm is None:
                failed_count += 1
                continue
                
            residual = residual_wavelet(img_norm)
            if residual is None:
                failed_count += 1
                continue
                
            residuals.append(residual)
        
        if residuals:
            # Ensure all residuals have the same shape
            shapes = [r.shape for r in residuals]
            if len(set(shapes)) > 1:
                logger.warning(f"Scanner {scanner_name} has images with different shapes: {set(shapes)}")
                # Find most common shape
                common_shape = Counter(shapes).most_common(1)[0][0]
                residuals = [r for r in residuals if r.shape == common_shape]
                logger.info(f"Using {len(residuals)} images with shape {common_shape}")
            
            if residuals:
                fingerprint = np.mean(residuals, axis=0).astype(np.float32)
                fingerprints[scanner_name] = fingerprint
                logger.info(f"Built fingerprint for {scanner_name} using {len(residuals)} images")
                if failed_count > 0:
                    logger.warning(f"Failed to process {failed_count} images for {scanner_name}")
            else:
                logger.error(f"No valid residuals for scanner {scanner_name}")
        else:
            logger.error(f"No flatfield images found for scanner {scanner_name}")
    
    return fingerprints

# ==== FEATURE EXTRACTION ====
def correlation_coefficient(a: np.ndarray, b: np.ndarray) -> float:
    """Compute normalized correlation coefficient with numerical stability"""
    try:
        a_centered = a - np.mean(a)
        b_centered = b - np.mean(b)
        
        numerator = np.sum(a_centered * b_centered)
        denominator = np.sqrt(np.sum(a_centered**2)) * np.sqrt(np.sum(b_centered**2))
        
        if denominator < 1e-10:
            return 0.0
        
        return float(numerator / denominator)
    except:
        return 0.0

def correlation_features(residual: np.ndarray, fingerprints: Dict[str, np.ndarray], 
                        label_names: List[str]) -> np.ndarray:
    """Compute correlation features with all scanner fingerprints"""
    features = []
    for label in label_names:
        if label in fingerprints:
            # Ensure shapes match
            if residual.shape != fingerprints[label].shape:
                # Resize to minimum common shape
                min_h = min(residual.shape[0], fingerprints[label].shape[0])
                min_w = min(residual.shape[1], fingerprints[label].shape[1])
                res_crop = residual[:min_h, :min_w]
                fp_crop = fingerprints[label][:min_h, :min_w]
                corr = correlation_coefficient(res_crop, fp_crop)
            else:
                corr = correlation_coefficient(residual, fingerprints[label])
            features.append(corr)
        else:
            features.append(0.0)
    
    return np.array(features, dtype=np.float32)

def fft_radial_stats(patch: np.ndarray, n_bins: int = None) -> np.ndarray:
    """Compute radial FFT statistics with improved numerical stability"""
    if n_bins is None:
        n_bins = config.FFT_BINS
        
    try:
        # Apply window to reduce spectral leakage
        window = np.outer(np.hanning(patch.shape[0]), np.hanning(patch.shape[1]))
        windowed_patch = patch * window
        
        F = fftshift(fft2(windowed_patch))
        power_spectrum = np.abs(F) ** 2
        
        H, W = power_spectrum.shape
        center_y, center_x = H // 2, W // 2
        
        y, x = np.indices(power_spectrum.shape)
        radius = np.sqrt((y - center_y)**2 + (x - center_x)**2)
        max_radius = np.sqrt(center_y**2 + center_x**2)
        
        if max_radius < 1e-6:
            return np.zeros(n_bins, dtype=np.float32)
        
        radius_normalized = radius / max_radius
        bins = np.linspace(0, 1.0, n_bins + 1)
        
        features = []
        for i in range(n_bins):
            mask = (radius_normalized >= bins[i]) & (radius_normalized < bins[i + 1])
            if np.any(mask):
                mean_power = np.mean(power_spectrum[mask])
                features.append(np.log1p(mean_power))  # log(1 + x) for numerical stability
            else:
                features.append(0.0)
        
        features = np.array(features, dtype=np.float32)
        
        # Normalize features
        mean_feat = np.mean(features)
        std_feat = np.std(features)
        if std_feat > 1e-8:
            features = (features - mean_feat) / std_feat
        
        return features
        
    except Exception as e:
        logger.warning(f"FFT feature extraction failed: {str(e)}")
        return np.zeros(n_bins, dtype=np.float32)

# ==== DATASET BUILDING ====
def build_document_dataset(docroot: str, fingerprints: Dict[str, np.ndarray], 
                          max_docs: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """Build dataset from document images with memory management"""
    by_scanner = group_by_scanner(docroot)
    label_names = sorted(fingerprints.keys())
    
    # Filter to only scanners that have fingerprints
    by_scanner = {k: v for k, v in by_scanner.items() if k in fingerprints}
    
    if not by_scanner:
        logger.error("No scanners found with matching fingerprints!")
        return np.array([]), np.array([]), []
    
    X_list, y_list = [], []
    total_processed = 0
    
    for scanner_name in by_scanner.keys():
        files_to_process = by_scanner[scanner_name][:max_docs] if max_docs else by_scanner[scanner_name]
        
        logger.info(f"Processing {len(files_to_process)} documents for scanner {scanner_name}")
        
        for i, path in enumerate(files_to_process):
            if i % 50 == 0:  # Progress logging
                logger.info(f"  Progress: {i}/{len(files_to_process)}")
            
            img = load_image_gray(path)
            if img is None:
                continue
                
            img_norm = normalize_image(img)
            if img_norm is None:
                continue
                
            residual = residual_wavelet(img_norm)
            if residual is None:
                continue
            
            # Extract features
            corr_feats = correlation_features(residual, fingerprints, label_names)
            fft_feats = fft_radial_stats(residual)
            
            # Combine features
            combined_features = np.concatenate([corr_feats, fft_feats])
            
            X_list.append(combined_features)
            y_list.append(label_names.index(scanner_name))
            total_processed += 1
    
    if not X_list:
        logger.error("No valid features extracted!")
        return np.array([]), np.array([]), []
    
    X = np.array(X_list, dtype=np.float32)
    y = np.array(y_list, dtype=np.int64)
    
    logger.info(f"Dataset built: {X.shape[0]} samples, {X.shape[1]} features, {len(label_names)} classes")
    return X, y, label_names

# ==== MACHINE LEARNING MODELS ====
def create_svm_pipeline(C: float = 10.0, gamma: Union[str, float] = 'scale') -> Pipeline:
    """Create SVM pipeline with preprocessing"""
    return Pipeline([
        ('scaler', StandardScaler()),
        ('svm', SVC(C=C, kernel='rbf', gamma=gamma, probability=True, 
                   class_weight='balanced', random_state=config.RANDOM_STATE))
    ])

def train_and_evaluate_model(X: np.ndarray, y: np.ndarray, labels: List[str], 
                           model_name: str, model, test_size: float = None) -> object:
    """Generic function to train and evaluate models with cross-validation"""
    if test_size is None:
        test_size = config.TEST_SIZE
    
    logger.info(f"Training {model_name}...")
    
    # Stratified split to ensure balanced classes
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, stratify=y, random_state=config.RANDOM_STATE
    )
    
    # Train model
    model.fit(X_train, y_train)
    
    # Evaluate
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    
    logger.info(f"{model_name} Test Accuracy: {accuracy:.4f}")
    print(f"\n=== {model_name} Results ===")
    print(f"Test Accuracy: {accuracy:.4f}")
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=labels))
    
    # Confusion Matrix
    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(10, 8))
    ConfusionMatrixDisplay(cm, display_labels=labels).plot(xticks_rotation=45)
    plt.title(f"{model_name} Confusion Matrix")
    plt.tight_layout()
    plt.show()
    
    return model

def train_random_forest(X: np.ndarray, y: np.ndarray, labels: List[str]) -> RandomForestClassifier:
    """Train Random Forest classifier"""
    model = RandomForestClassifier(
        n_estimators=200, 
        max_depth=15,
        min_samples_split=5,
        min_samples_leaf=2,
        random_state=config.RANDOM_STATE, 
        class_weight='balanced',
        n_jobs=-1
    )
    return train_and_evaluate_model(X, y, labels, "Random Forest", model)

def train_xgboost(X: np.ndarray, y: np.ndarray, labels: List[str]) -> object:
    """Train XGBoost classifier if available"""
    if not XGBOOST_AVAILABLE:
        logger.warning("XGBoost not available, skipping...")
        return None
    
    model = xgb.XGBClassifier(
        n_estimators=300,
        max_depth=6,
        learning_rate=0.1,
        subsample=0.8,
        colsample_bytree=0.8,
        objective='multi:softmax',
        num_class=len(labels),
        random_state=config.RANDOM_STATE,
        n_jobs=-1
    )
    return train_and_evaluate_model(X, y, labels, "XGBoost", model)

# ==== DEEP LEARNING COMPONENTS ====
class PatchDataset(Dataset):
    """Dataset class for CNN training with patches"""
    def __init__(self, patches: np.ndarray, labels: np.ndarray):
        self.patches = patches.astype(np.float32)
        self.labels = labels.astype(np.int64)
        
    def __len__(self) -> int:
        return len(self.patches)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # Add channel dimension for grayscale
        patch = torch.from_numpy(self.patches[idx][None, ...])  # Shape: (1, H, W)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return patch, label

class ResNetScanner(nn.Module):
    """ResNet18 modified for grayscale scanner identification"""
    def __init__(self, n_classes: int):
        super().__init__()
        self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # Modify first conv layer for grayscale input
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        # Modify final layer
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, n_classes)
        
        # Add dropout for regularization
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.resnet.avgpool(self.resnet.layer4(
            self.resnet.layer3(
                self.resnet.layer2(
                    self.resnet.layer1(
                        self.resnet.relu(
                            self.resnet.bn1(
                                self.resnet.conv1(x)
                            )
                        )
                    )
                )
            )
        ))
        features = torch.flatten(features, 1)
        features = self.dropout(features)
        return self.resnet.fc(features)

def extract_patches_from_image(img: np.ndarray, patch_size: int = None, 
                              stride: int = None, min_margin: int = None) -> np.ndarray:
    """Extract patches from image with memory-efficient approach"""
    if patch_size is None:
        patch_size = config.PATCH_SIZE
    if stride is None:
        stride = config.PATCH_STRIDE
    if min_margin is None:
        min_margin = config.MIN_MARGIN
    
    H, W = img.shape
    patches = []
    
    # Calculate number of patches to avoid memory issues
    max_patches_per_image = config.MAX_PATCHES_IN_MEMORY // 100  # Conservative estimate
    
    y_positions = list(range(min_margin, H - patch_size - min_margin + 1, stride))
    x_positions = list(range(min_margin, W - patch_size - min_margin + 1, stride))
    
    # Limit number of patches if too many
    total_possible = len(y_positions) * len(x_positions)
    if total_possible > max_patches_per_image:
        # Randomly sample positions
        random.seed(config.RANDOM_STATE)
        n_y = min(len(y_positions), int(np.sqrt(max_patches_per_image)))
        n_x = min(len(x_positions), max_patches_per_image // n_y)
        
        y_positions = sorted(random.sample(y_positions, n_y))
        x_positions = sorted(random.sample(x_positions, n_x))
    
    for y in y_positions:
        for x in x_positions:
            patch = img[y:y+patch_size, x:x+patch_size]
            if patch.shape == (patch_size, patch_size):  # Ensure full patch size
                patches.append(patch)
    
    return np.stack(patches) if patches else np.empty((0, patch_size, patch_size), dtype=img.dtype)

def build_patch_dataset_memory_efficient(docroot: str, max_docs: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """Build patch dataset with memory management"""
    by_scanner = group_by_scanner(docroot)
    labels = sorted(by_scanner.keys())
    
    all_patches = []
    all_labels = []
    patch_count = 0
    
    for scanner_idx, scanner_name in enumerate(labels):
        files = by_scanner[scanner_name][:max_docs] if max_docs else by_scanner[scanner_name]
        logger.info(f"Extracting patches from {len(files)} documents for scanner {scanner_name}")
        
        for i, path in enumerate(files):
            if patch_count >= config.MAX_PATCHES_IN_MEMORY:
                logger.warning(f"Reached maximum patch limit ({config.MAX_PATCHES_IN_MEMORY}), stopping...")
                break
                
            img = load_image_gray(path)
            if img is None:
                continue
                
            img_norm = normalize_image(img)
            if img_norm is None:
                continue
                
            residual = residual_wavelet(img_norm)
            if residual is None:
                continue
            
            patches = extract_patches_from_image(residual)
            
            if len(patches) > 0:
                # Limit patches per image to prevent memory issues
                max_patches_per_img = min(len(patches), 50)
                if len(patches) > max_patches_per_img:
                    indices = np.random.choice(len(patches), max_patches_per_img, replace=False)
                    patches = patches[indices]
                
                all_patches.append(patches)
                all_labels.extend([scanner_idx] * len(patches))
                patch_count += len(patches)
            
            if i % 10 == 0:
                logger.info(f"  Processed {i+1}/{len(files)} files, {patch_count} patches so far")
        
        if patch_count >= config.MAX_PATCHES_IN_MEMORY:
            break
    
    if not all_patches:
        logger.error("No patches extracted!")
        return np.array([]), np.array([]), []
    
    # Concatenate all patches
    X_patches = np.concatenate(all_patches, axis=0)
    y_patches = np.array(all_labels, dtype=np.int64)
    
    logger.info(f"Patch dataset: {X_patches.shape} patches from {len(labels)} scanners")
    return X_patches, y_patches, labels

def train_cnn_model(patches: np.ndarray, labels: np.ndarray, n_classes: int, 
                   epochs: int = None, batch_size: int = None, lr: float = None) -> ResNetScanner:
    """Train CNN model with improved training loop"""
    if epochs is None:
        epochs = config.CNN_EPOCHS
    if batch_size is None:
        batch_size = config.CNN_BATCH_SIZE
    if lr is None:
        lr = config.CNN_LEARNING_RATE
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Training CNN on device: {device}")
    
    # Create dataset and split
    dataset = PatchDataset(patches, labels)
    
    # Stratified split for better validation
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size], 
        generator=torch.Generator().manual_seed(config.RANDOM_STATE)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Initialize model
    model = ResNetScanner(n_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    
    # Training loop
    best_val_acc = 0.0
    train_losses = []
    val_accuracies = []
    
    for epoch in range(epochs):
        # Training
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        avg_train_loss = epoch_loss / num_batches
        train_losses.append(avg_train_loss)
        
        # Validation
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)
        
        val_acc = correct / total
        val_accuracies.append(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, os.path.join(config.ARTIFACTS_DIR, 'best_cnn_model.pt'))
        
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        logger.info(f"Epoch {epoch+1}/{epochs}: "
                   f"Train Loss: {avg_train_loss:.4f}, "
                   f"Val Acc: {val_acc:.4f}, "
                   f"LR: {current_lr:.6f}")
    
    # Plot training curves
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(val_accuracies)
    plt.title('Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    
    plt.tight_layout()
    plt.savefig(os.path.join(config.ARTIFACTS_DIR, 'training_curves.png'))
    plt.show()
    
    logger.info(f"Best validation accuracy: {best_val_acc:.4f}")
    return model

def evaluate_cnn_final(model: ResNetScanner, patches: np.ndarray, labels: np.ndarray, 
                      label_names: List[str]) -> float:
    """Final evaluation of CNN model"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    dataset = PatchDataset(patches, labels)
    dataloader = DataLoader(dataset, batch_size=config.CNN_BATCH_SIZE, shuffle=False)
    
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in dataloader:
            data = data.to(device)
            output = model(data)
            pred = output.argmax(dim=1).cpu().numpy()
            all_preds.extend(pred)
            all_targets.extend(target.numpy())
    
    accuracy = accuracy_score(all_targets, all_preds)
    
    print(f"\n=== Final CNN Evaluation ===")
    print(f"Overall Patch Accuracy: {accuracy:.4f}")
    print("\nDetailed Classification Report:")
    print(classification_report(all_targets, all_preds, target_names=label_names))
    
    # Confusion Matrix
    cm = confusion_matrix(all_targets, all_preds)
    plt.figure(figsize=(10, 8))
    ConfusionMatrixDisplay(cm, display_labels=label_names).plot(xticks_rotation=45)
    plt.title('CNN Final Confusion Matrix')
    plt.tight_layout()
    plt.savefig(os.path.join(config.ARTIFACTS_DIR, 'cnn_confusion_matrix.png'))
    plt.show()
    
    return accuracy

# ==== SAVE/LOAD FUNCTIONS ====
def save_models_and_data(models_dict: Dict, fingerprints: Dict, labels: List[str]):
    """Save all trained models and metadata"""
    # Save traditional ML models
    for model_name, model in models_dict.items():
        if model is not None:
            filename = os.path.join(config.ARTIFACTS_DIR, f"{model_name.lower().replace(' ', '_')}.joblib")
            joblib.dump({'model': model, 'labels': labels}, filename)
            logger.info(f"Saved {model_name} to {filename}")
    
    # Save fingerprints and metadata
    metadata = {
        'fingerprints': fingerprints,
        'labels': labels,
        'config': {
            'patch_size': config.PATCH_SIZE,
            'patch_stride': config.PATCH_STRIDE,
            'fft_bins': config.FFT_BINS,
            'random_state': config.RANDOM_STATE
        }
    }
    
    metadata_file = os.path.join(config.ARTIFACTS_DIR, 'metadata.joblib')
    joblib.dump(metadata, metadata_file)
    logger.info(f"Saved metadata to {metadata_file}")

def load_model(model_path: str):
    """Load a saved model"""
    try:
        data = joblib.load(model_path)
        return data['model'], data['labels']
    except Exception as e:
        logger.error(f"Failed to load model from {model_path}: {str(e)}")
        return None, None

# ==== MAIN PIPELINE EXECUTION ====
def run_complete_pipeline():
    """Run the complete scanner identification pipeline"""
    print("🔍 Starting Scanner Identification Pipeline")
    print("=" * 50)
    
    try:
        # Step 1: Check paths
        check_paths()
        
        # Step 2: Build scanner fingerprints
        print("\n📋 Step 1: Building Scanner Fingerprints")
        fingerprints = build_scanner_fingerprints(config.FLATROOT)
        
        if not fingerprints:
            logger.error("No fingerprints could be built. Check your flatfield data.")
            return
        
        print(f"✓ Built fingerprints for {len(fingerprints)} scanners")
        
        # Step 3: Build feature-based dataset
        print("\n📊 Step 2: Building Feature Dataset")
        print("Processing Official documents...")
        X_official, y_official, labels = build_document_dataset(
            config.OFFICIALROOT, fingerprints, config.MAX_DOCS_PER_SCANNER
        )
        
        print("Processing Wikipedia documents...")
        X_wiki, y_wiki, _ = build_document_dataset(
            config.WIKIROOT, fingerprints, config.MAX_DOCS_PER_SCANNER
        )
        
        if len(X_official) == 0 and len(X_wiki) == 0:
            logger.error("No feature data could be built. Check your document data.")
            return
        
        # Combine datasets
        X_combined = np.concatenate([X_official, X_wiki], axis=0) if len(X_official) > 0 and len(X_wiki) > 0 else (X_official if len(X_official) > 0 else X_wiki)
        y_combined = np.concatenate([y_official, y_wiki], axis=0) if len(y_official) > 0 and len(y_wiki) > 0 else (y_official if len(y_official) > 0 else y_wiki)
        
        print(f"✓ Combined dataset: {X_combined.shape[0]} samples, {len(labels)} classes")
        
        # Step 4: Train traditional ML models
        print("\n🤖 Step 3: Training Traditional ML Models")
        models = {}
        
        # SVM
        svm_model = create_svm_pipeline()
        models['SVM'] = train_and_evaluate_model(X_combined, y_combined, labels, "SVM", svm_model)
        
        # Random Forest
        models['Random Forest'] = train_random_forest(X_combined, y_combined, labels)
        
        # XGBoost (if available)
        if XGBOOST_AVAILABLE:
            models['XGBoost'] = train_xgboost(X_combined, y_combined, labels)
        
        # Step 5: Build patch dataset for CNN
        print("\n🧠 Step 4: Building Patch Dataset for CNN")
        print("Processing Official documents for patches...")
        X_patches_off, y_patches_off, patch_labels = build_patch_dataset_memory_efficient(
            config.OFFICIALROOT, config.MAX_DOCS_PER_SCANNER
        )
        
        print("Processing Wikipedia documents for patches...")
        X_patches_wiki, y_patches_wiki, _ = build_patch_dataset_memory_efficient(
            config.WIKIROOT, config.MAX_DOCS_PER_SCANNER
        )
        
        if len(X_patches_off) == 0 and len(X_patches_wiki) == 0:
            logger.warning("No patch data could be built. Skipping CNN training.")
            cnn_model = None
        else:
            # Combine patch datasets
            X_patches = np.concatenate([X_patches_off, X_patches_wiki], axis=0) if len(X_patches_off) > 0 and len(X_patches_wiki) > 0 else (X_patches_off if len(X_patches_off) > 0 else X_patches_wiki)
            y_patches = np.concatenate([y_patches_off, y_patches_wiki], axis=0) if len(y_patches_off) > 0 and len(y_patches_wiki) > 0 else (y_patches_off if len(y_patches_off) > 0 else y_patches_wiki)
            
            print(f"✓ Patch dataset: {X_patches.shape[0]} patches")
            
            # Step 6: Train CNN
            print("\n🔥 Step 5: Training CNN (ResNet18)")
            cnn_model = train_cnn_model(X_patches, y_patches, len(labels))
            
            # Final CNN evaluation
            evaluate_cnn_final(cnn_model, X_patches, y_patches, labels)
        
        # Step 7: Save everything
        print("\n💾 Step 6: Saving Models and Results")
        save_models_and_data(models, fingerprints, labels)
        
        if cnn_model is not None:
            # CNN model is already saved during training (best model)
            logger.info("CNN model saved as best_cnn_model.pt")
        
        print("\n🎉 Pipeline completed successfully!")
        print(f"📁 All artifacts saved in: {config.ARTIFACTS_DIR}/")
        print("\nSummary:")
        print(f"  - Scanner fingerprints: {len(fingerprints)} scanners")
        print(f"  - Feature dataset: {X_combined.shape[0] if 'X_combined' in locals() else 0} samples")
        print(f"  - Patch dataset: {X_patches.shape[0] if 'X_patches' in locals() else 0} patches")
        print(f"  - Trained models: {len([m for m in models.values() if m is not None]) + (1 if cnn_model else 0)}")
        
    except Exception as e:
        logger.error(f"Pipeline failed: {str(e)}")
        raise

# ==== INFERENCE FUNCTIONS ====
def predict_document_scanner(image_path: str, model_path: str, fingerprints_path: str) -> Dict:
    """Predict scanner for a single document"""
    try:
        # Load model and metadata
        model, labels = load_model(model_path)
        metadata = joblib.load(fingerprints_path)
        fingerprints = metadata['fingerprints']
        
        if model is None:
            return {"error": "Failed to load model"}
        
        # Process image
        img = load_image_gray(image_path)
        if img is None:
            return {"error": "Failed to load image"}
        
        img_norm = normalize_image(img)
        residual = residual_wavelet(img_norm)
        
        # Extract features
        corr_feats = correlation_features(residual, fingerprints, labels)
        fft_feats = fft_radial_stats(residual)
        features = np.concatenate([corr_feats, fft_feats]).reshape(1, -1)
        
        # Predict
        prediction = model.predict(features)[0]
        probabilities = model.predict_proba(features)[0] if hasattr(model, 'predict_proba') else None
        
        result = {
            "predicted_scanner": labels[prediction],
            "confidence": float(probabilities[prediction]) if probabilities is not None else None,
            "all_probabilities": {labels[i]: float(prob) for i, prob in enumerate(probabilities)} if probabilities is not None else None
        }
        
        return result
        
    except Exception as e:
        return {"error": str(e)}

def batch_predict_documents(image_folder: str, model_path: str, fingerprints_path: str) -> List[Dict]:
    """Predict scanner for multiple documents"""
    image_paths = list_images(image_folder)
    results = []
    
    for img_path in image_paths:
        result = predict_document_scanner(img_path, model_path, fingerprints_path)
        result['image_path'] = img_path
        results.append(result)
        
        if len(results) % 10 == 0:
            logger.info(f"Processed {len(results)}/{len(image_paths)} images")
    
    return results

# ==== JUPYTER NOTEBOOK EXECUTION ====
if __name__ == "__main__" or "__file__" not in globals():
    # This section runs when executed in Jupyter notebook
    print("🔍 Scanner Identification Pipeline - Improved Version")
    print("=" * 60)
    print()
    print("To run the complete pipeline, execute:")
    print(">>> run_complete_pipeline()")
    print()
    print("To predict a single document:")
    print(">>> result = predict_document_scanner('path/to/image.jpg', 'artifacts/svm.joblib', 'artifacts/metadata.joblib')")
    print()
    print("To run batch predictions:")
    print(">>> results = batch_predict_documents('path/to/images/', 'artifacts/svm.joblib', 'artifacts/metadata.joblib')")
    print()
    print("Configuration can be modified in the Config class above.")
    print("Current settings:")
    print(f"  - Max documents per scanner: {config.MAX_DOCS_PER_SCANNER}")
    print(f"  - Patch size: {config.PATCH_SIZE}")
    print(f"  - CNN epochs: {config.CNN_EPOCHS}")
    print(f"  - Max patches in memory: {config.MAX_PATCHES_IN_MEMORY}")

# Uncomment the line below to run the pipeline automatically
# run_complete_pipeline()