In [25]:
import os
import argparse
import json
import math
import random
from pathlib import Path
from collections import OrderedDict
from types import SimpleNamespace
from itertools import combinations


import numpy as np
from PIL import Image
from tqdm import tqdm
import cv2
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset


from sklearn.metrics import roc_auc_score, accuracy_score, precision_recall_curve, confusion_matrix
from sklearn.model_selection import train_test_split
from scipy.ndimage import binary_fill_holes


--------------------------- Utilities ---------------------------

In [26]:
def mask_from_green_contour(annot_img, GREEN_THRESH):
    """
    Returns a binary mask where tumor area (green contour) = 1
    """
    r,g,b = cv2.split(annot_img)
    # mask = ((g < 140) & (r < 50) & ( > 190)).astype(np.uint8) #  Green
    mask = ((g < 70) & (b < 70) & (r > 150)).astype(np.uint8)  # Red

    # fill holes
    mask = cv2.dilate(mask, np.ones((3,3), np.uint8), iterations=2)
    mask = cv2.erode(mask, np.ones((3,3), np.uint8), iterations=2)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5,5), np.uint8))
    mask = binary_fill_holes(mask).astype(np.uint8)
    mask = cv2.cvtColor(mask*255, cv2.COLOR_GRAY2BGR)[:,:,0] > 0

    return mask

def sample_patches(img, mask, patch_size, n_inside, n_outside, max_tries=500):
    H,W = img.shape[:2]
    inside_patches, outside_patches = [], []

    # inside
    if mask is not None:
        tries = 0
        while len(inside_patches) < n_inside and tries < max_tries:
            x = random.randint(0, W - patch_size)
            y = random.randint(0, H - patch_size)
            patch_mask = mask[y:y+patch_size, x:x+patch_size]
            # require at least 80% inside mask
            if patch_mask.mean() > 0.8:
                patch = img[y:y+patch_size, x:x+patch_size]
                inside_patches.append(patch)
            tries += 1
    else:
        mask = np.zeros((*img.shape[:2], 3), dtype=np.uint8)

    # outside
    if img is not None:
        tries = 0
        while len(outside_patches) < n_outside and tries < max_tries:
            x = random.randint(0, W - patch_size)
            y = random.randint(0, H - patch_size)
            patch_mask = mask[y:y+patch_size, x:x+patch_size]
            # require <5% overlap with tumor
            if patch_mask.mean() < 0.05:
                patch = img[y:y+patch_size, x:x+patch_size]
                outside_patches.append(patch)
            tries += 1

    return inside_patches, outside_patches


def load_patches_from_folders(root_dir, max_per_class=None, gray=True, resize=None):
    """Expect two folders inside root_dir: 'inside' and 'outside'."""
    inside_dir = Path(root_dir) / 'inside'
    outside_dir = Path(root_dir) / 'outside'
    os.makedirs(inside_dir, exist_ok=True)
    os.makedirs(outside_dir, exist_ok=True)

    assert inside_dir.exists() and outside_dir.exists(), "Folders 'inside' and 'outside' required"

    def load_from(folder, limit):
        files = sorted([p for p in folder.iterdir() if p.suffix.lower() in ('.tif','.jpg','.jpeg','.tif','.tiff')])
        if limit:
            files = files[:limit]
        imgs = []
        for p in files:
            im = Image.open(p)
            if gray:
                im = im.convert('L')
            else:
                im = im.convert('RGB')
            if resize:
                im = im.resize(resize, Image.BILINEAR)
            arr = np.array(im, dtype=np.float32) / 255.0
            imgs.append(arr)
        return np.stack(imgs, axis=0) if imgs else np.zeros((0, *(resize if resize else im.size[::-1])), dtype=np.float32)

    Xin = load_from(inside_dir, max_per_class)
    Xout = load_from(outside_dir, max_per_class)
    # ensure shape N,H,W
    return Xin, Xout

def to_tensor_batch(X, device):
    # X: N,H,W -> return tensor N,1,H,W
    t = torch.from_numpy(X).unsqueeze(1).to(device)
    return t


--------------------------- Kernel generators ---------------------------

In [27]:
def make_gaussian_kernel(size, sigma_x, sigma_y=None, theta=0.0):
    if sigma_y is None:
        sigma_y = sigma_x
    assert size % 2 == 1, 'size should be odd'
    half = size // 2
    xs = np.arange(-half, half+1, 1)
    ys = np.arange(-half, half+1, 1)
    X, Y = np.meshgrid(xs, ys)
    # rotate
    ct = math.cos(theta); st = math.sin(theta)
    Xr = ct * X + st * Y
    Yr = -st * X + ct * Y
    G = np.exp(-0.5 * ((Xr**2) / (sigma_x**2 + 1e-12) + (Yr**2) / (sigma_y**2 + 1e-12)))
    G = G / (G.sum() + 1e-12)
    return G.astype(np.float32)


def make_dog_kernel(size, sigma1, sigma2, theta=0.0):
    g1 = make_gaussian_kernel(size, sigma1, sigma1, theta)
    g2 = make_gaussian_kernel(size, sigma2, sigma2, theta)
    k = g1 - g2
    # normalize zero mean
    k = k - k.mean()
    return k.astype(np.float32)


def make_log_kernel(size, sigma, theta=0.0):
    # Laplacian of Gaussian approximation
    assert size % 2 == 1
    half = size // 2
    xs = np.arange(-half, half+1, 1)
    ys = np.arange(-half, half+1, 1)
    X, Y = np.meshgrid(xs, ys)
    ct = math.cos(theta); st = math.sin(theta)
    Xr = ct * X + st * Y
    Yr = -st * X + ct * Y
    r2 = (Xr**2 + Yr**2)
    s2 = sigma**2
    LoG = ((r2 - 2*s2) / (s2**2)) * np.exp(-r2/(2*s2))
    LoG = LoG - LoG.mean()
    return LoG.astype(np.float32)


def make_gabor_kernel(size, sigma, freq, theta=0.0, phase=0.0, gamma=1.0):
    # gamma: aspect ratio
    assert size % 2 == 1
    half = size // 2
    xs = np.arange(-half, half+1, 1)
    ys = np.arange(-half, half+1, 1)
    X, Y = np.meshgrid(xs, ys)
    ct = math.cos(theta); st = math.sin(theta)
    Xr = ct * X + st * Y
    Yr = -st * X + ct * Y
    Yr = Yr * gamma
    gaussian = np.exp(-(Xr**2 + Yr**2) / (2 * (sigma**2)))
    sinusoid = np.cos(2 * np.pi * freq * Xr + phase)
    K = gaussian * sinusoid
    # zero mean
    K = K - K.mean()
    # normalize by L1
    if K.sum() != 0:
        K = K / (np.abs(K).sum() + 1e-12)
    return K.astype(np.float32)


--------------------------- Sampling parameters ---------------------------

In [28]:
def sample_parameters(family, n_samples, size):
    params = []
    for _ in range(n_samples):
        if family == 'gaussian':
            sigma = float(10 ** np.random.uniform(np.log10(0.5), np.log10(size/2)))
            theta = np.random.uniform(0, math.pi)
            params.append({'sigma_x': sigma, 'sigma_y': sigma, 'theta': theta, 'size': size})
        elif family == 'anisotropic_gaussian':
            sigma_x = float(10 ** np.random.uniform(np.log10(0.5), np.log10(size/2)))
            sigma_y = float(sigma_x * np.random.uniform(0.5, 3.0))
            theta = np.random.uniform(0, math.pi)
            params.append({'sigma_x': sigma_x, 'sigma_y': sigma_y, 'theta': theta, 'size': size})
        elif family == 'dog':
            s1 = float(np.random.uniform(0.5, size/2))
            s2 = float(s1 * np.random.uniform(1.2, 3.0))
            theta = np.random.uniform(0, math.pi)
            params.append({'sigma1': s1, 'sigma2': s2, 'theta': theta, 'size': size})
        elif family == 'log':
            s = float(np.random.uniform(0.5, size/2))
            theta = np.random.uniform(0, math.pi)
            params.append({'sigma': s, 'theta': theta, 'size': size})
        elif family == 'gabor':
            sigma = float(np.random.uniform(0.5, size/2))
            freq = float(np.random.uniform(0.02, 0.5))
            theta = np.random.uniform(0, math.pi)
            phase = float(np.random.uniform(0, 2*math.pi))
            gamma = float(np.random.uniform(0.5, 1.5))
            params.append({'sigma': sigma, 'freq': freq, 'theta': theta, 'phase': phase, 'gamma': gamma, 'size': size})
        else:
            raise ValueError('Unknown family')
    return params


--------------------------- Kernel bank builder ---------------------------

In [29]:

def build_kernel_bank(families, n_per_family, size):
    bank = []
    for fam in families:
        params = sample_parameters(fam, n_per_family, size)
        for p in params:
            if fam == 'gaussian' or fam == 'anisotropic_gaussian':
                k = make_gaussian_kernel(p['size'], p['sigma_x'], p.get('sigma_y', None), p['theta'])
            elif fam == 'dog':
                k = make_dog_kernel(p['size'], p['sigma1'], p['sigma2'], p['theta'])
            elif fam == 'log':
                k = make_log_kernel(p['size'], p['sigma'], p['theta'])
            elif fam == 'gabor':
                k = make_gabor_kernel(p['size'], p['sigma'], p['freq'], p['theta'], p['phase'], p['gamma'])
            else:
                continue
            bank.append({'family': fam, 'params': p, 'kernel': k})
    return bank

--------------------------- Response computation ---------------------------

In [30]:
def compute_responses(bank, X_in, X_out, device='cpu', batch_size=64, response_fn='abs_max'):
    device = torch.device(device)
    Xin_t = to_tensor_batch(X_in, device)
    Xout_t = to_tensor_batch(X_out, device)
    responses = []
    # convert bank kernels to torch tensors (filters)
    filter_tensors = []
    for entry in bank:
        k = entry['kernel']
        k_t = torch.from_numpy(k).unsqueeze(0).unsqueeze(0).to(device)  # 1,1,Hk,Wk
        filter_tensors.append(k_t)
    # compute responses per kernel
    for idx, k_t in enumerate(tqdm(filter_tensors, desc='Kernels')):
        # conv Xin
        r_in = []
        for i in range(0, Xin_t.shape[0], batch_size):
            batch = Xin_t[i:i+batch_size]
            with torch.no_grad():
                out = F.conv2d(batch, k_t, padding=k_t.shape[-1]//2)
                if response_fn == 'abs_max':
                    val = out.abs().amax(dim=[1,2,3]).cpu().numpy()
                elif response_fn == 'mean_abs':
                    val = out.abs().mean(dim=[1,2,3]).cpu().numpy()
                else:
                    val = out.abs().amax(dim=[1,2,3]).cpu().numpy()
                r_in.append(val)
        r_in = np.concatenate(r_in, axis=0) if len(r_in) else np.zeros((0,))

        r_out = []
        for i in range(0, Xout_t.shape[0], batch_size):
            batch = Xout_t[i:i+batch_size]
            with torch.no_grad():
                out = F.conv2d(batch, k_t, padding=k_t.shape[-1]//2)
                if response_fn == 'abs_max':
                    val = out.abs().amax(dim=[1,2,3]).cpu().numpy()
                elif response_fn == 'mean_abs':
                    val = out.abs().mean(dim=[1,2,3]).cpu().numpy()
                else:
                    val = out.abs().amax(dim=[1,2,3]).cpu().numpy()
                r_out.append(val)
        r_out = np.concatenate(r_out, axis=0) if len(r_out) else np.zeros((0,))

        responses.append({'r_in': r_in, 'r_out': r_out})
    return responses


In [31]:

def fisher_score(r_in, r_out):
    mu_in = r_in.mean() if r_in.size else 0.0
    mu_out = r_out.mean() if r_out.size else 0.0
    var_in = r_in.var(ddof=1) if r_in.size>1 else 0.0
    var_out = r_out.var(ddof=1) if r_out.size>1 else 0.0
    num = (mu_in - mu_out)**2
    den = var_in + var_out + 1e-12
    return float(num / den)


def auc_score(r_in, r_out):
    y_true = np.concatenate([np.ones(len(r_in)), np.zeros(len(r_out))])
    y_score = np.concatenate([r_in, r_out])
    # if constant, AUC undefined -> return 0.5
    try:
        if y_score.max() == y_score.min():
            return 0.5
        return float(roc_auc_score(y_true, y_score))
    except Exception:
        return 0.5



In [32]:
def pairwise_response_corr(resp_i, resp_j):
    # responses already 1D arrays on combined patches
    a = resp_i - resp_i.mean()
    b = resp_j - resp_j.mean()
    denom = (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12)
    return float(np.dot(a,b) / denom)


def select_diverse(bank, responses, topM=200, K=20, lambda_mm=0.75):
    # rank by combined score (we'll use AUC primarily then fisher)
    scores = []
    for i, resp in enumerate(responses):
        auc = auc_score(resp['r_in'], resp['r_out'])
        fisher = fisher_score(resp['r_in'], resp['r_out'])
        scores.append({'idx': i, 'auc': auc, 'fisher': fisher})
    scores = sorted(scores, key=lambda x: (x['auc'], x['fisher']), reverse=True)
    top_idxs = [s['idx'] for s in scores[:topM]]
    # precompute combined response vectors
    combined = []
    for i in top_idxs:
        combined.append(np.concatenate([responses[i]['r_in'], responses[i]['r_out']]))
    selected = []
    selected_idxs = []
    # greedy MMR
    for rank_i, idx in enumerate(top_idxs):
        if not selected:
            selected.append(idx)
            selected_idxs.append(idx)
            if len(selected) >= K:
                break
            continue
        # compute mmr score for remaining
        best_score = -1e9
        best_idx = None
        for j, cand in enumerate(top_idxs):
            if cand in selected_idxs:
                continue
            auc_j = next(s for s in scores if s['idx']==cand)['auc']
            # similarity to selected (max)
            sims = [abs(pairwise_response_corr(combined[top_idxs.index(cand)], combined[top_idxs.index(s)])) for s in selected]
            maxsim = max(sims) if sims else 0.0
            mmr = lambda_mm * auc_j - (1-lambda_mm) * maxsim
            if mmr > best_score:
                best_score = mmr
                best_idx = cand
        if best_idx is None:
            break
        selected.append(best_idx)
        selected_idxs.append(best_idx)
        if len(selected) >= K:
            break
    return selected_idxs


--------------------------- Training classifier ---------------------------

In [33]:
class SimpleLogistic(torch.nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.lin = torch.nn.Linear(n_features, 1)
    def forward(self, x):
        return self.lin(x).squeeze(1)


class SimpleMLP(torch.nn.Module):
    def __init__(self, n_features, hidden_dims=(64, 32), dropout=0.2):
        super().__init__()
        h1, h2 = hidden_dims
        self.net = torch.nn.Sequential(
            torch.nn.Linear(n_features, h1),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(h1, h2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(h2, 1)
        )
    def forward(self, x):
        return self.net(x).squeeze(1)


def standardize_split(X_tr, X_val):
    mean = X_tr.mean(axis=0, keepdims=True)
    std = X_tr.std(axis=0, keepdims=True)
    std[std < 1e-6] = 1.0
    return (X_tr - mean) / std, (X_val - mean) / std


def train_binary_model(
    X_feat,
    y,
    epochs=30,
    batch_size=64,
    lr=1e-3,
    device='cpu',
    return_model=True,
    model_type='logistic',
    hidden_dims=(64, 32),
    dropout=0.2,
    standardize=True
):
    if X_feat.size == 0 or y.size == 0:
        return {'model': None, 'auc': 0.5, 'acc': 0.5}
    device = torch.device(device)
    X_tr, X_val, y_tr, y_val = train_test_split(X_feat, y, test_size=0.2, stratify=y, random_state=42)
    if standardize:
        X_tr, X_val = standardize_split(X_tr, X_val)

    tr_ds = TensorDataset(torch.from_numpy(X_tr).float().to(device), torch.from_numpy(y_tr).float().to(device))
    val_ds = TensorDataset(torch.from_numpy(X_val).float().to(device), torch.from_numpy(y_val).float().to(device))
    tr_loader = DataLoader(tr_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    if model_type == 'mlp':
        model = SimpleMLP(X_feat.shape[1], hidden_dims=hidden_dims, dropout=dropout).to(device)
    else:
        model = SimpleLogistic(X_feat.shape[1]).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = torch.nn.BCEWithLogitsLoss()

    best_auc = 0.0
    best_state = None
    for ep in range(epochs):
        model.train()
        for xb, yb in tr_loader:
            logits = model(xb)
            loss = loss_fn(logits, yb)
            opt.zero_grad(); loss.backward(); opt.step()
        model.eval()
        ys = []
        ps = []
        with torch.no_grad():
            for xb, yb in val_loader:
                logits = model(xb)
                probs = torch.sigmoid(logits).cpu().numpy()
                ys.append(yb.cpu().numpy())
                ps.append(probs)
        ys = np.concatenate(ys)
        ps = np.concatenate(ps)
        try:
            auc = roc_auc_score(ys, ps)
        except Exception:
            auc = 0.5
        if auc > best_auc:
            best_auc = auc
            best_state = model.state_dict()
    if best_state is not None:
        model.load_state_dict(best_state)
    model.eval()
    X_val_t = torch.from_numpy(X_val).float().to(device)
    with torch.no_grad():
        probs = torch.sigmoid(model(X_val_t)).cpu().numpy()
    acc = accuracy_score(y_val, (probs>0.5).astype(int))
    auc = roc_auc_score(y_val, probs)
    return {
        'model': model if return_model else None,
        'auc': auc,
        'acc': acc,
        'val_probs': probs,
        'val_labels': y_val
    }


def build_feature_matrix(selected_idxs, responses):
    feats = []
    for idx in selected_idxs:
        resp = responses[idx]
        r = np.concatenate([resp['r_in'], resp['r_out']])
        feats.append(r.reshape(-1, 1))
    if not feats:
        return np.zeros((0, 0)), np.array([])
    X_feat = np.concatenate(feats, axis=1)
    n_in = len(responses[0]['r_in']) if responses else 0
    n_out = len(responses[0]['r_out']) if responses else 0
    y = np.concatenate([np.ones(n_in), np.zeros(n_out)]) if (n_in or n_out) else np.array([])
    return X_feat, y


def rank_kernels(responses):
    scores = []
    for i, resp in enumerate(responses):
        scores.append({'idx': i, 'auc': auc_score(resp['r_in'], resp['r_out']), 'fisher': fisher_score(resp['r_in'], resp['r_out'])})
    scores.sort(key=lambda x: (x['auc'], x['fisher']), reverse=True)
    return scores


def find_best_subsets(
    X_feat,
    y,
    subset_sizes,
    epochs=15,
    batch_size=64,
    lr=1e-3,
    device='cpu',
    model_type='logistic',
    hidden_dims=(64, 32),
    dropout=0.2,
    standardize=True
):
    best = {}
    n_features = X_feat.shape[1]
    for k in subset_sizes:
        if n_features < k:
            continue
        best_entry = None
        for combo in combinations(range(n_features), k):
            res = train_binary_model(
                X_feat[:, combo],
                y,
                epochs=epochs,
                batch_size=batch_size,
                lr=lr,
                device=device,
                return_model=False,
                model_type=model_type,
                hidden_dims=hidden_dims,
                dropout=dropout,
                standardize=standardize
            )
            entry = {'subset': combo, 'auc': res['auc'], 'acc': res['acc']}
            if best_entry is None or entry['auc'] > best_entry['auc']:
                best_entry = entry
        if best_entry:
            best[k] = best_entry
    return best


def fit_subset_model(
    X_feat,
    y,
    subset,
    epochs=20,
    batch_size=64,
    lr=1e-3,
    device='cpu',
    model_type='logistic',
    hidden_dims=(64, 32),
    dropout=0.2,
    standardize=True
):
    X_sub = X_feat[:, subset]
    mean = std = None
    if standardize:
        mean = X_sub.mean(axis=0, keepdims=True)
        std = X_sub.std(axis=0, keepdims=True)
        std[std < 1e-6] = 1.0
        X_sub = (X_sub - mean) / std
    device = torch.device(device)
    ds = TensorDataset(torch.from_numpy(X_sub).float().to(device), torch.from_numpy(y).float().to(device))
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True)
    if model_type == 'mlp':
        model = SimpleMLP(X_sub.shape[1], hidden_dims=hidden_dims, dropout=dropout).to(device)
    else:
        model = SimpleLogistic(X_sub.shape[1]).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = torch.nn.BCEWithLogitsLoss()
    model.train()
    for ep in range(epochs):
        for xb, yb in loader:
            logits = model(xb)
            loss = loss_fn(logits, yb)
            opt.zero_grad(); loss.backward(); opt.step()
    model.eval()
    return {'model': model, 'mean': mean, 'std': std, 'device': device, 'standardize': standardize}


def train_classifier(
    Xin,
    Xout,
    selected_idxs,
    responses,
    epochs=30,
    batch_size=64,
    lr=1e-3,
    device='cpu',
    model_type='logistic',
    hidden_dims=(64, 32),
    dropout=0.2,
    standardize=True
):
    X_feat, y = build_feature_matrix(selected_idxs, responses)
    res = train_binary_model(
        X_feat,
        y,
        epochs=epochs,
        batch_size=batch_size,
        lr=lr,
        device=device,
        model_type=model_type,
        hidden_dims=hidden_dims,
        dropout=dropout,
        standardize=standardize
    )
    return res


In [34]:
def plot_2d_scatter(X_feat, y, subset, kernel_idxs, out_path, boundary=None, title='Best 2-kernel feature space'):
    if X_feat.size == 0 or y.size == 0:
        return
    fig, ax = plt.subplots(figsize=(6, 5))
    cancer = y == 1
    healthy = y == 0
    ax.scatter(X_feat[cancer, subset[0]], X_feat[cancer, subset[1]], c='crimson', label='cancer', alpha=0.75, edgecolors='k', linewidths=0.3)
    ax.scatter(X_feat[healthy, subset[0]], X_feat[healthy, subset[1]], c='teal', label='healthy', alpha=0.65, edgecolors='k', linewidths=0.3)

    if boundary is not None and boundary.get('model') is not None:
        model = boundary['model']
        device = boundary.get('device', 'cpu')
        mean = boundary.get('mean', None)
        std = boundary.get('std', None)
        standardize = boundary.get('standardize', False)
        # build grid
        x_min, x_max = X_feat[:, subset[0]].min(), X_feat[:, subset[0]].max()
        y_min, y_max = X_feat[:, subset[1]].min(), X_feat[:, subset[1]].max()
        pad_x = 0.1 * (x_max - x_min + 1e-6)
        pad_y = 0.1 * (y_max - y_min + 1e-6)
        xs = np.linspace(x_min - pad_x, x_max + pad_x, 200)
        ys = np.linspace(y_min - pad_y, y_max + pad_y, 200)
        xx, yy = np.meshgrid(xs, ys)
        grid = np.stack([xx.ravel(), yy.ravel()], axis=1)
        if standardize and mean is not None and std is not None:
            grid = (grid - mean) / std
        with torch.no_grad():
            probs = torch.sigmoid(model(torch.from_numpy(grid).float().to(device))).cpu().numpy()
        zz = probs.reshape(xx.shape)
        ax.contourf(xx, yy, zz, levels=[0, 0.5, 1], alpha=0.18, colors=['teal', 'crimson'])
        ax.contour(xx, yy, zz, levels=[0.5], colors='k', linewidths=1.0, linestyles='--')

    ax.set_xlabel(f'Kernel {kernel_idxs[0]} response')
    ax.set_ylabel(f'Kernel {kernel_idxs[1]} response')
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(out_path, dpi=200, bbox_inches='tight')
    plt.close(fig)


def plot_3d_scatter(X_feat, y, subset, kernel_idxs, out_path):
    if X_feat.size == 0 or y.size == 0:
        return
    fig = plt.figure(figsize=(7, 5))
    ax = fig.add_subplot(111, projection='3d')
    cancer = y == 1
    healthy = y == 0
    ax.scatter(X_feat[cancer, subset[0]], X_feat[cancer, subset[1]], X_feat[cancer, subset[2]], c='crimson', label='cancer', alpha=0.75, edgecolors='k', linewidths=0.3)
    ax.scatter(X_feat[healthy, subset[0]], X_feat[healthy, subset[1]], X_feat[healthy, subset[2]], c='teal', label='healthy', alpha=0.65, edgecolors='k', linewidths=0.3)
    ax.set_xlabel(f'Kernel {kernel_idxs[0]} response')
    ax.set_ylabel(f'Kernel {kernel_idxs[1]} response')
    ax.set_zlabel(f'Kernel {kernel_idxs[2]} response')
    ax.set_title('Best 3-kernel feature space')
    ax.legend()
    fig.tight_layout()
    fig.savefig(out_path, dpi=220, bbox_inches='tight')
    plt.close(fig)


def plot_roc_pr(y_true, probs, out_prefix):
    if probs is None or y_true is None or len(y_true)==0:
        return
    fpr, tpr, _ = roc_curve(y_true, probs)
    precision, recall, _ = precision_recall_curve(y_true, probs)
    fig, ax = plt.subplots(1,2, figsize=(10,4))
    ax[0].plot(fpr, tpr, label=f"AUC={roc_auc_score(y_true, probs):.3f}")
    ax[0].plot([0,1],[0,1],'k--',alpha=0.3)
    ax[0].set_title('ROC'); ax[0].set_xlabel('FPR'); ax[0].set_ylabel('TPR'); ax[0].legend(); ax[0].grid(alpha=0.3)
    ax[1].plot(recall, precision)
    ax[1].set_title('Precision-Recall'); ax[1].set_xlabel('Recall'); ax[1].set_ylabel('Precision'); ax[1].grid(alpha=0.3)
    fig.tight_layout(); fig.savefig(f"{out_prefix}_roc_pr.png", dpi=200, bbox_inches='tight'); plt.close(fig)


def plot_confusion(y_true, probs, out_path, threshold=0.5):
    if probs is None or y_true is None or len(y_true)==0:
        return
    preds = (probs >= threshold).astype(int)
    cm = confusion_matrix(y_true, preds)
    fig, ax = plt.subplots(figsize=(4,4))
    im = ax.imshow(cm, cmap='Blues')
    ax.set_xticks([0,1]); ax.set_yticks([0,1])
    ax.set_xticklabels(['healthy','cancer']); ax.set_yticklabels(['healthy','cancer'])
    ax.set_ylabel('True'); ax.set_xlabel('Predicted')
    for i in range(2):
        for j in range(2):
            ax.text(j, i, cm[i,j], ha='center', va='center', color='black')
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    fig.tight_layout(); fig.savefig(out_path, dpi=200, bbox_inches='tight'); plt.close(fig)


---------------------- Configuration ----------------------

In [35]:
IMAGE_DIR = "data/TIFF Images/all"       # folder with original mammograms
ANNOT_DIR = "data/Pixel-level annotation"  # folder with annotated images (green contour)
OUTPUT_DIR = "data/patches"          # output folder
PATCH_SIZE = 128                # size of square patches
N_INSIDE_PER_IMAGE = 50         # how many tumor patches per image
N_OUTSIDE_PER_IMAGE = 100       # healthy patches per image
GREEN_THRESH = 150              # threshold for green channel in contour
MAX_TRIES = 500                 # max attempts to sample valid patch


--------------------------- Main ---------------------------

In [36]:
annot = None
def main(args):
    device = 'cuda' if torch.cuda.is_available() and not args.force_cpu else 'cpu'
    os.makedirs(args.out_dir, exist_ok=True)

    # Make the patches directory
    output_dir = Path(OUTPUT_DIR)
    (output_dir / "inside").mkdir(parents=True, exist_ok=True)
    (output_dir / "outside").mkdir(parents=True, exist_ok=True)

    image_paths = sorted(list(Path(IMAGE_DIR).glob("*.tif")))
    annot_paths = sorted(list(Path(IMAGE_DIR).glob("*.tif")))

    inside_idx = 0
    outside_idx = 0

    # for img_path, annot_path in tqdm(zip(image_paths, annot_paths), total=len(image_paths)):
    for img_path in tqdm(image_paths, total=len(image_paths)):
        img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED)

        annot = None
        if os.path.exists(str(str(Path(ANNOT_DIR) / Path(img_path.stem + '.tif')))):
            annot = cv2.imread(str(Path(ANNOT_DIR) / Path(img_path.stem + '.tif')), cv2.IMREAD_COLOR_RGB)

        mask = None
        if annot is not None:
            mask = mask_from_green_contour(annot, 125)

        inside_patches, outside_patches = sample_patches(
            img, mask, PATCH_SIZE, N_INSIDE_PER_IMAGE, N_OUTSIDE_PER_IMAGE, max_tries=MAX_TRIES * 20
        )

        # save patches
        for p in inside_patches:
            pil = Image.fromarray(p)
            pil.save(output_dir / "inside" / f"inside_{inside_idx:05d}.tif")
            inside_idx += 1

        for p in outside_patches:
            pil = Image.fromarray(p)
            pil.save(output_dir / "outside" / f"outside_{outside_idx:05d}.tif")
            outside_idx += 1

    print(f"Saved {inside_idx} inside patches and {outside_idx} outside patches to {OUTPUT_DIR}")


    # load data
    Xin, Xout = load_patches_from_folders(args.data_root, max_per_class=args.max_per_class, resize=(args.patch_size,args.patch_size))

    print(f'Loaded: in={len(Xin)} out={len(Xout)} patches. Device={device}')
    # build kernel bank
    families = args.families.split(',')
    bank = build_kernel_bank(families, args.n_per_family, args.kernel_size)
    print(f'Built kernel bank: {len(bank)} kernels')

    # compute responses
    responses = compute_responses(bank, Xin, Xout, device=device, batch_size=args.batch_size, response_fn=args.response_fn)

    # select diverse top kernels
    selected_idxs = select_diverse(bank, responses, topM=args.topM, K=args.K, lambda_mm=args.lambda_mm)
    print('Selected kernel indices:', selected_idxs)

    # train classifier
    clf_res = train_classifier(
        Xin,
        Xout,
        selected_idxs,
        responses,
        epochs=args.epochs,
        batch_size=args.batch_size,
        lr=args.lr,
        device=device,
        model_type=args.model_type,
        hidden_dims=(args.hidden_dim1, args.hidden_dim2),
        dropout=args.dropout,
        standardize=args.standardize_features
    )
    print(f"Classifier val AUC={clf_res['auc']:.4f} ACC={clf_res['acc']:.4f}")

    # Performance plots
    plot_roc_pr(clf_res.get('val_labels'), clf_res.get('val_probs'), str(Path(args.out_dir) / 'clf'))
    plot_confusion(clf_res.get('val_labels'), clf_res.get('val_probs'), str(Path(args.out_dir) / 'confusion.png'))

    # Visualize best 2- and 3-kernel subsets
    kernel_scores = rank_kernels(responses)
    candidate_kernel_idxs = [s['idx'] for s in kernel_scores[:args.plot_top_kernels]]
    X_candidates, y_labels = build_feature_matrix(candidate_kernel_idxs, responses)
    subset_results = find_best_subsets(
        X_candidates,
        y_labels,
        subset_sizes=[2,3],
        epochs=args.subset_search_epochs,
        batch_size=args.batch_size,
        lr=args.lr,
        device=device,
        model_type=args.model_type,
        hidden_dims=(args.hidden_dim1, args.hidden_dim2),
        dropout=args.dropout,
        standardize=args.standardize_features
    )

    pair_plot = Path(args.out_dir) / 'scatter_best_pair.png'
    triple_plot = Path(args.out_dir) / 'scatter_best_triple.png'
    if 2 in subset_results:
        pair = subset_results[2]
        pair_kernel_idxs = [candidate_kernel_idxs[i] for i in pair['subset']]
        boundary_model = fit_subset_model(
            X_candidates,
            y_labels,
            pair['subset'],
            epochs=args.boundary_epochs,
            batch_size=args.batch_size,
            lr=args.lr,
            device=device,
            model_type=args.model_type,
            hidden_dims=(args.hidden_dim1, args.hidden_dim2),
            dropout=args.dropout,
            standardize=args.standardize_features
        )
        plot_2d_scatter(
            X_candidates,
            y_labels,
            pair['subset'],
            pair_kernel_idxs,
            pair_plot,
            boundary=boundary_model,
            title='Best 2-kernel feature space with decision boundary'
        )
        print(f"Best 2-kernel subset {pair_kernel_idxs} AUC={pair['auc']:.4f} saved to {pair_plot}")
    else:
        print('Not enough kernels to make 2D scatter plot')

    if 3 in subset_results:
        triple = subset_results[3]
        triple_kernel_idxs = [candidate_kernel_idxs[i] for i in triple['subset']]
        plot_3d_scatter(X_candidates, y_labels, triple['subset'], triple_kernel_idxs, triple_plot)
        print(f"Best 3-kernel subset {triple_kernel_idxs} AUC={triple['auc']:.4f} saved to {triple_plot}")
    else:
        print('Not enough kernels to make 3D scatter plot')

    # save results
    out = {
        'selected_idxs': selected_idxs,
        'bank_meta': [{'family': b['family'], 'params': b['params']} for b in bank],
        'clf_auc': float(clf_res['auc']),
        'clf_acc': float(clf_res['acc'])
    }
    # Save for fast reuse
    np.savez(
        'results/feature_cache.npz',
        X_candidates=X_candidates,
        y_labels=y_labels,
        subset_results=subset_results,
        candidate_kernel_idxs=np.array(candidate_kernel_idxs)
    )
    np.savez_compressed(os.path.join(args.out_dir, 'results.npz'), selected_idxs=np.array(selected_idxs))
    with open(os.path.join(args.out_dir, 'results.json'), 'w') as f:
        json.dump(out, f, indent=2)
    # save kernels
    kernels = np.stack([b['kernel'] for b in bank], axis=0)
    np.save(os.path.join(args.out_dir, 'kernels.npy'), kernels)
    print('Saved results to', args.out_dir)


if __name__ == '__main__':
    args = SimpleNamespace(
        data_root=OUTPUT_DIR,
        out_dir='./results',
        families='gaussian,anisotropic_gaussian,dog,log,gabor',
        n_per_family=200,
        kernel_size=31,
        patch_size=64,
        max_per_class=2000,
        batch_size=64,
        response_fn='mean_abs',
        topM=200,
        K=20,
        lambda_mm=0.75,
        epochs=60,
        lr=5e-4,
        force_cpu=False,
        plot_top_kernels=20,
        subset_search_epochs=20,
        boundary_epochs=25,
        model_type='mlp',
        hidden_dim1=64,
        hidden_dim2=32,
        dropout=0.2,
        standardize_features=True
    )

    main(args)


  0%|          | 0/511 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [None]:
import numpy as np
from sklearn.metrics import roc_curve, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt

cache = np.load('results/feature_cache.npz', allow_pickle=True)
X_candidates = cache['X_candidates']
y_labels = cache['y_labels']
subset_results = cache['subset_results'].item()
candidate_kernel_idxs = cache['candidate_kernel_idxs']

# 2D boundary plot (re-fit quickly on the saved pair)
pair = subset_results[2]
pair_kernel_idxs = [candidate_kernel_idxs[i] for i in pair['subset']]
boundary_model = fit_subset_model(
    X_candidates, y_labels, pair['subset'],
    epochs=20, batch_size=64, lr=5e-4,
    device='cpu', model_type='mlp',
    hidden_dims=(64, 32), dropout=0.2,
    standardize=True
)
plot_2d_scatter(
    X_candidates, y_labels, pair['subset'], pair_kernel_idxs,
    'results/scatter_best_pair_cached.png', boundary=boundary_model,
    title='Best 2-kernel feature space (cached)'
)

# Per-feature histograms
fig, ax = plt.subplots(1, 2, figsize=(8, 3))
for cls, color in [(1, 'crimson'), (0, 'teal')]:
    ax[0].hist(X_candidates[y_labels==cls, pair['subset'][0]], bins=30, alpha=0.6, color=color, label=f'class {cls}')
    ax[1].hist(X_candidates[y_labels==cls, pair['subset'][1]], bins=30, alpha=0.6, color=color, label=f'class {cls}')
ax[0].set_title('Kernel ' + str(pair_kernel_idxs[0])); ax[1].set_title('Kernel ' + str(pair_kernel_idxs[1]))
for a in ax: a.legend(); a.grid(alpha=0.3)
fig.tight_layout(); plt.savefig('results/hist_pair_cached.png', dpi=200); plt.close()

# ROC on the saved pair
from sklearn.metrics import roc_curve
model = boundary_model['model']; mean = boundary_model['mean']; std = boundary_model['std']; standardize = boundary_model['standardize']
X_pair = X_candidates[:, pair['subset']]
if standardize:
    X_pair = (X_pair - mean) / std
probs = torch.sigmoid(model(torch.from_numpy(X_pair).float())).detach().cpu().numpy()
fpr, tpr, _ = roc_curve(y_labels, probs)
plt.plot(fpr, tpr); plt.xlabel('FPR'); plt.ylabel('TPR'); plt.title('ROC (cached pair)')
plt.savefig('results/roc_pair_cached.png', dpi=200); plt.close()
