<a href="https://colab.research.google.com/github/appababba/USDA/blob/main/active_learning_HDBSCAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q segmentation-models-pytorch==0.3.3 albumentations==1.4.7 hdbscan --no-deps

import os
import random
import shutil
import math
import glob
import gc
import pickle
import json
import copy
import sys
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from contextlib import nullcontext
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from google.colab import drive
from tqdm.notebook import tqdm
import hdbscan

In [None]:
# --- Environment and Reproducibility ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
    torch.backends.cuda.matmul.allow_tf32 = True
if hasattr(torch.backends, "cudnn"):
    torch.backends.cudnn.allow_tf32 = True

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# --- Automatic Mixed Precision (AMP) Setup ---
use_amp = torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if use_amp else torch.float32

# --- Google Drive Integration ---
try:
    drive.mount('/content/drive', force_remount=False)
except Exception as e:
    print(f"Error mounting drive: {e}")

In [None]:
# --- Path Definitions ---
BASE_DRIVE = "/content/drive/Shared drives/USDA-Summer2025"
DATA_DIR   = os.path.join(BASE_DRIVE, "data")
IMG_DRIVE  = os.path.join(DATA_DIR, "Exported_Images")
MSK_DRIVE  = os.path.join(DATA_DIR, "Exported_Masks")
MODELS_DIR = os.path.join(BASE_DRIVE, "models")
LOCAL_ROOT = "/content/local_data"
IMG_LOCAL  = os.path.join(LOCAL_ROOT, "Exported_Images")
MSK_LOCAL  = os.path.join(LOCAL_ROOT, "Exported_Masks")
os.makedirs(LOCAL_ROOT, exist_ok=True)

# --- Local Data Mirroring ---
def ensure_local_data(source_dir, destination_dir):
    """Mirrors data from a source (e.g., G-Drive) to a local directory for performance."""
    if os.path.isdir(destination_dir) and any(os.scandir(destination_dir)):
        print(f"Local data found at: {destination_dir}")
        return
    print(f"Copying data from {source_dir} to {destination_dir}...")
    shutil.copytree(source_dir, destination_dir)
    print("Data copy complete.")

ensure_local_data(IMG_DRIVE, IMG_LOCAL)
ensure_local_data(MSK_DRIVE, MSK_LOCAL)

In [None]:
# --- PyTorch Dataset Definitions ---
class SegDataset(Dataset):
    def __init__(self, image_paths, mask_dir, transform=None):
        self.image_paths = image_paths
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def _get_mask_path(self, img_path):
        base_name = os.path.splitext(os.path.basename(img_path))[0]
        # Check for primary mask naming convention
        primary_mask_path = os.path.join(self.mask_dir, f"{base_name}_mask.png")
        if os.path.exists(primary_mask_path):
            return primary_mask_path
        # Check for alternative extensions
        for ext in ('.png', '.jpg', '.jpeg', '.tif', '.tiff'):
            alt_mask_path = os.path.join(self.mask_dir, base_name + ext)
            if os.path.exists(alt_mask_path):
                return alt_mask_path
        raise FileNotFoundError(f"No corresponding mask found for image: {img_path}")

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self._get_mask_path(img_path)
        image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        mask = (cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 0).astype(np.uint8)

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image, mask = transformed["image"], transformed["mask"]

        if isinstance(mask, torch.Tensor) and mask.ndim == 2:
            mask = mask.unsqueeze(0)
        return image, mask.float(), img_path

class InferenceDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(image=image)["image"]
        return image, img_path

# --- Model Architecture ---
class GaborStem(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.k, self.gamma, self.magnitude = cfg["kernel_size"], cfg["gamma"], cfg["magnitude"]
        ax = torch.arange(-(self.k // 2), self.k // 2 + 1, dtype=torch.float32)
        X, Y = torch.meshgrid(ax, ax, indexing='xy')
        self.register_buffer('X', X); self.register_buffer('Y', Y)

        params = []
        for lam in cfg["wavelengths"]:
            for sig in cfg["sigmas"]:
                for th in torch.linspace(0, math.pi, steps=cfg["orientations"]):
                    for ph in cfg["phases"]:
                        params.append((th.item(), sig, lam, ph))

        self.register_buffer('params_buf', torch.tensor(params, dtype=torch.float32))
        self.register_buffer('phases_buf', torch.tensor(cfg["phases"], dtype=torch.float32))

        num_filters = len(cfg["wavelengths"]) * len(cfg["sigmas"]) * cfg["orientations"]
        self.out_channels = num_filters if self.magnitude and len(cfg["phases"]) == 2 else len(params)
        self.norm = nn.InstanceNorm2d(self.out_channels, affine=False)

    def _get_kernels(self, device, dtype):
        P, X, Y = self.params_buf.to(device, dtype), self.X.to(device, dtype), self.Y.to(device, dtype)
        gamma = torch.as_tensor(self.gamma, dtype=dtype, device=device)
        theta, sigma, lambd, phase = [p.view(-1, 1, 1) for p in P.T]

        Xp = X * torch.cos(theta) + Y * torch.sin(theta)
        Yp = -X * torch.sin(theta) + Y * torch.cos(theta)
        gauss = torch.exp(-(Xp**2 + (gamma * Yp)**2) / (2 * sigma**2))
        carrier = torch.cos(2 * math.pi * Xp / lambd + phase)
        g = (gauss * carrier)
        g -= g.mean(dim=[1, 2], keepdim=True)
        g /= (g.norm(p=2, dim=[1, 2], keepdim=True) + 1e-8)
        return g

    def forward(self, x):
        device, dtype = x.device, x.dtype
        luminance = 0.299 * x[:, 0:1] + 0.587 * x[:, 1:2] + 0.114 * x[:, 2:3]
        K = self._get_kernels(device, dtype)

        if self.magnitude and self.phases_buf.numel() == 2:
            rc = F.conv2d(luminance, K[0::2].unsqueeze(1), padding=self.k // 2)
            rs = F.conv2d(luminance, K[1::2].unsqueeze(1), padding=self.k // 2)
            feats = torch.sqrt(rc**2 + rs**2 + 1e-8)
        else:
            feats = F.conv2d(luminance, K.unsqueeze(1), padding=self.k // 2)
        return self.norm(feats)

class UNetWithGabor_Adapted(nn.Module):
    def __init__(self, gcfg, in_img_ch=3):
        super().__init__()
        self.gabor = GaborStem(gcfg) if gcfg.get("enabled", False) else None
        self.mode = gcfg.get("mode", "concat")

        g_out_ch = self.gabor.out_channels if self.gabor else 0
        adapter_in_ch = g_out_ch
        if self.gabor and self.mode == "concat":
            adapter_in_ch += in_img_ch

        self.adapter = nn.Identity() if adapter_in_ch <= 3 else nn.Sequential(
            nn.Conv2d(adapter_in_ch, max(16, adapter_in_ch // 4), kernel_size=1, bias=False),
            nn.BatchNorm2d(max(16, adapter_in_ch // 4)),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(16, adapter_in_ch // 4), 3, kernel_size=1, bias=False)
        )
        self.net = smp.Unet('resnet50', encoder_weights='imagenet', in_channels=3, classes=1)

    def _prepare_input(self, x):
        if not self.gabor:
            return self.adapter(x)
        g_feats = self.gabor(x)
        if self.mode == "concat":
            x = torch.cat([x, g_feats], dim=1)
        else:
            x = g_feats
        return self.adapter(x)

    def forward(self, x):
        return self.net(self._prepare_input(x))

    def extract_features(self, x):
        x = self._prepare_input(x)
        # Return the output of the last encoder stage
        return self.net.encoder(x)[-1]

In [None]:
# --- Configuration and Utilities ---
GABOR_CFG_INFERENCE = {
    "enabled": True, "mode": "concat", "kernel_size": 15, "orientations": 8,
    "wavelengths": [4.0, 8.0, 12.0], "sigmas": [1.5, 3.0, 4.5],
    "phases": [0, np.pi/2], "gamma": 0.5, "magnitude": True
}

val_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

# --- Model Initialization and Weight Loading ---
print("Initializing model...")
model = UNetWithGabor_Adapted(gcfg=GABOR_CFG_INFERENCE).to(DEVICE, memory_format=torch.channels_last)

model_path = os.path.join(MODELS_DIR, "UNetRes50_GaborAdapter_imagenet.pth")
print(f"Loading weights from: {model_path}")
try:
    ckpt = torch.load(model_path, map_location=DEVICE)
    state_dict = ckpt.get("state_dict", ckpt)
    model.load_state_dict(state_dict)
    model.eval()
    print("Model weights loaded successfully.")
except Exception as e:
    print(f"Error loading weights: {e}")

# --- Model Compilation (Best-Effort) ---
try:
    print("Compiling model for performance...")
    model = torch.compile(model, mode="max-autotune")
    print("Model compiled successfully.")
except Exception as e:
    print(f"torch.compile failed or is not supported: {e}")


# --- Feature Extraction ---
all_image_paths = [os.path.join(IMG_LOCAL, f) for f in os.listdir(IMG_LOCAL) if not f.startswith('.')]
print(f"\nFound {len(all_image_paths)} total images for feature extraction.")

inference_dataset = InferenceDataset(all_image_paths, transform=val_transform)
inference_loader = DataLoader(inference_dataset, batch_size=32, shuffle=False, num_workers=0, pin_memory=(DEVICE.type == "cuda"))

all_features, all_paths = [], []
with torch.no_grad():
    for images, paths in tqdm(inference_loader, desc="Extracting Features"):
        images = images.to(DEVICE, non_blocking=True)
        with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=use_amp):
            features = model.extract_features(images)
            pooled_features = F.adaptive_avg_pool2d(features, (1, 1)).flatten(1)
            normalized_features = F.normalize(pooled_features, p=2, dim=1)
        all_features.append(normalized_features.to(torch.float32).cpu().numpy())
        all_paths.extend(paths)

feature_vectors = np.vstack(all_features)
print("Feature extraction complete.")
print(f"Feature vector shape: {feature_vectors.shape}, Total paths: {len(all_paths)}")

In [None]:
# --- Clustering with HDBSCAN ---
print("\nPerforming HDBSCAN clustering on feature vectors...")
clusterer = hdbscan.HDBSCAN(min_cluster_size=15, prediction_data=True, core_dist_n_jobs=-1)
labels = clusterer.fit_predict(feature_vectors)
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
n_noise = np.sum(labels == -1)
print(f"Clustering complete. Found {n_clusters} clusters and {n_noise} noise points (outliers).")

# --- Identify and Rank Informative Samples ---
outlier_scores = clusterer.outlier_scores_
scored_samples = sorted(zip(outlier_scores, all_paths), key=lambda x: x[0], reverse=True)

def select_next_batch(scored_samples, already_labeled_paths, batch_size=20):
    next_batch, labeled_set = [], set(already_labeled_paths)
    for _, path in scored_samples:
        if path not in labeled_set:
            next_batch.append(path)
            if len(next_batch) >= batch_size:
                break
    return next_batch

# --- Select Initial Batch for Labeling ---
labeled_image_paths = []
BATCH_SIZE = 20
first_batch_to_label = select_next_batch(scored_samples, labeled_image_paths, batch_size=BATCH_SIZE)
print(f"\nSelected the top {len(first_batch_to_label)} most informative samples to label next:")
for i, p in enumerate(first_batch_to_label[:5], 1):
    print(f"  {i}. {os.path.basename(p)}")

In [None]:
print("\n--- Initializing Active Learning Loop ---")
CKPT_DIR = '/content/drive/MyDrive/active_learning_ckpts'
os.makedirs(CKPT_DIR, exist_ok=True)
NUM_WORKERS = 0
PIN_MEMORY = (DEVICE.type == "cuda")

# --- Resume-Safe Checkpoint Loading ---
ckpt_files = sorted(glob.glob(f"{CKPT_DIR}/al_ckpt_*.pt"), key=os.path.getmtime, reverse=True)
latest_ckpt_path = ckpt_files[0] if ckpt_files else None

current_model = UNetWithGabor_Adapted(gcfg=GABOR_CFG_INFERENCE).to(DEVICE, memory_format=torch.channels_last)
optimizer = optim.Adam(current_model.parameters(), lr=1e-4)

if latest_ckpt_path:
    print(f"Resuming from checkpoint: {os.path.basename(latest_ckpt_path)}")
    ckpt = torch.load(latest_ckpt_path, map_location=DEVICE)
    current_model.load_state_dict(ckpt['model_state'])
    optimizer.load_state_dict(ckpt['optimizer_state'])
    al_labeled_pool = ckpt['al_labeled_pool']
    al_unlabeled_pool = ckpt['al_unlabeled_pool']
    active_learning_history = ckpt['active_learning_history']
    start_iter = ckpt['iter_idx'] + 1
else:
    print("No existing checkpoint found. Creating initial stratified data pools...")
    start_iter = 0
    FULL_DATASET_PATHS = [os.path.join(IMG_LOCAL, f) for f in os.listdir(IMG_LOCAL) if not f.startswith('.')]
    all_ratios = []
    helper_dataset = SegDataset([], MSK_LOCAL)
    for path in tqdm(FULL_DATASET_PATHS, desc="Stratifying Data by Mask Ratio"):
        try:
            mask_path = helper_dataset._get_mask_path(path)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            ratio = np.mean(mask > 0) if mask is not None else 0.0
            all_ratios.append(ratio)
        except FileNotFoundError:
            all_ratios.append(0.0)

    df = pd.DataFrame({'path': FULL_DATASET_PATHS, 'ratio': all_ratios})
    df['ratio_bin'] = pd.qcut(df['ratio'], q=5, labels=False, duplicates='drop')

    train_val_paths, test_paths, _, _ = train_test_split(
        df['path'].tolist(), df['ratio_bin'].tolist(),
        test_size=0.2, random_state=SEED, stratify=df['ratio_bin'].tolist()
    )
    random.shuffle(train_val_paths)
    al_labeled_pool = train_val_paths[:20]
    al_unlabeled_pool = train_val_paths[20:]
    active_learning_history = []

if 'test_paths' not in locals():
     _, test_paths = train_test_split([os.path.join(IMG_LOCAL, f) for f in os.listdir(IMG_LOCAL)], test_size=0.2, random_state=SEED)

# --- DataLoaders and Training/Validation Functions ---
train_transform = A.Compose([
    A.Resize(256, 256), A.HorizontalFlip(p=0.5), A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2(),
])
def make_loader(paths, transform, batch_size, shuffle=False):
    ds = SegDataset(paths, MSK_LOCAL, transform=transform)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

def train_one_iteration(model, loader, epochs=5):
    model.train()
    for _ in range(epochs):
        for images, masks, _ in loader:
            images, masks = images.to(DEVICE, non_blocking=True), masks.to(DEVICE, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=use_amp):
                preds = model(images)
                loss = nn.BCEWithLogitsLoss()(preds, masks)
            loss.backward()
            optimizer.step()

@torch.no_grad()
def validate_performance(model, loader, max_batches=10):
    if not loader.dataset: return 0.0
    model.eval()
    total_iou, n_batches = 0.0, 0
    for i, (images, masks, _) in enumerate(loader):
        if i >= max_batches: break
        images, masks = images.to(DEVICE, non_blocking=True), masks.to(DEVICE, non_blocking=True)
        with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=use_amp):
            preds = torch.sigmoid(model(images))
        inter = (preds * masks).sum()
        union = (preds + masks).sum() - inter
        total_iou += (inter / (union + 1e-6)).item()
        n_batches += 1
    return total_iou / max(1, n_batches)

# --- Fast Batch Selection with Caching ---
EMB_CACHE = {}
PCA_MODEL = None
@torch.no_grad()
def get_embeddings_for_paths(paths, batch_size=64):
    to_compute = [p for p in paths if p not in EMB_CACHE]
    if to_compute:
        ds = InferenceDataset(to_compute, transform=val_transform)
        loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
        current_model.eval()
        for images, computed_paths in loader:
            images = images.to(DEVICE)
            with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=use_amp):
                feats = F.normalize(F.adaptive_avg_pool2d(current_model.extract_features(images), (1, 1)).flatten(1))
            for p, f in zip(computed_paths, feats.cpu().numpy()):
                EMB_CACHE[p] = f
    return np.stack([EMB_CACHE[p] for p in paths])

def select_next_batch_fast(unlabeled_paths, batch_size=20, candidate_k=500):
    global PCA_MODEL
    if not unlabeled_paths: return []
    candidates = random.sample(unlabeled_paths, k=min(candidate_k, len(unlabeled_paths)))
    feats = get_embeddings_for_paths(candidates)

    if PCA_MODEL is None:
        PCA_MODEL = PCA(n_components=64, random_state=SEED).fit(feats)
    reduced_feats = PCA_MODEL.transform(feats)

    scores = hdbscan.HDBSCAN(min_cluster_size=15).fit(reduced_feats).outlier_scores_
    ranked = [c for _, c in sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)]
    return ranked[:batch_size]

# --- Main Loop Execution ---
NUM_ITERATIONS = 15
EPOCHS_PER_ITER = 5
print(f"Starting active learning from iteration: {start_iter + 1}")

for i in range(start_iter, NUM_ITERATIONS):
    num_labeled = len(al_labeled_pool)
    print(f"\n--- Iteration {i+1}/{NUM_ITERATIONS} | Labeled Samples: {num_labeled} ---")

    train_loader = make_loader(al_labeled_pool, train_transform, batch_size=32, shuffle=True)
    train_one_iteration(current_model, train_loader, epochs=EPOCHS_PER_ITER)

    test_loader = make_loader(test_paths, val_transform, batch_size=64)
    test_iou = validate_performance(current_model, test_loader)
    active_learning_history.append({'num_labeled': num_labeled, 'iou': test_iou})
    print(f"Validation IoU: {test_iou:.4f}")

    if not al_unlabeled_pool:
        print("Unlabeled pool is empty. Stopping.")
        break

    print("Selecting next batch of samples...")
    new_batch = select_next_batch_fast(al_unlabeled_pool, batch_size=BATCH_SIZE)

    al_labeled_pool.extend(new_batch)
    al_unlabeled_pool = [p for p in al_unlabeled_pool if p not in new_batch]

    torch.save({
        'iter_idx': i, 'model_state': current_model.state_dict(),
        'optimizer_state': optimizer.state_dict(), 'al_labeled_pool': al_labeled_pool,
        'al_unlabeled_pool': al_unlabeled_pool, 'active_learning_history': active_learning_history
    }, f"{CKPT_DIR}/al_ckpt_{i}.pt")

    del train_loader, test_loader; gc.collect(); torch.cuda.empty_cache()

print(f"\nActive learning complete. Checkpoints saved in: {CKPT_DIR}")