<a href="https://colab.research.google.com/github/appababba/USDA/blob/main/rand_sampling_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# =============================================
# SETUP: DEPENDENCIES AND ENVIRONMENT
# =============================================
# --- Package Installation ---
!pip install -q segmentation-models-pytorch==0.3.3 albumentations==1.4.7 hdbscan --no-deps

# --- Core Imports ---
import os, random, shutil, math, glob, gc, pickle, json, warnings
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from contextlib import nullcontext
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from google.colab import drive
from tqdm.notebook import tqdm
import hdbscan

# --- Environment Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PIN_MEMORY = (DEVICE.type == "cuda")
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.set_float32_matmul_precision('high')

# --- Reproducibility ---
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# --- AMP Configuration ---
use_amp = torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if use_amp else torch.float32

# --- Suppress specific warnings ---
warnings.filterwarnings("ignore", ".*force_all_finite.*")

print("Setup complete. Environment configured.")

In [None]:
# =============================================
# DATA MANAGEMENT: PATHS AND LOCAL MIRRORING
# =============================================
# --- Google Drive Mount ---
try:
    drive.mount('/content/drive', force_remount=False)
except Exception as e:
    print(f"Error mounting Google Drive: {e}")

# --- Path Definitions ---
BASE_DRIVE = "/content/drive/Shared drives/USDA-Summer2025"
DATA_DIR = os.path.join(BASE_DRIVE, "data")
IMG_DRIVE = os.path.join(DATA_DIR, "Exported_Images")
MSK_DRIVE = os.path.join(DATA_DIR, "Exported_Masks")
MODELS_DIR = os.path.join(BASE_DRIVE, "models")
LOCAL_ROOT = "/content/local_data"
IMG_LOCAL = os.path.join(LOCAL_ROOT, "Exported_Images")
MSK_LOCAL = os.path.join(LOCAL_ROOT, "Exported_Masks")
CKPT_DIR = '/content/drive/MyDrive/active_learning_ckpts'
os.makedirs(LOCAL_ROOT, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

# --- Data Mirroring Function ---
def ensure_local_data(source_dir, dest_dir):
    """Copies data from G-Drive to local Colab storage if not already present."""
    if os.path.isdir(dest_dir) and any(os.scandir(dest_dir)):
        print(f"Using existing local data at: {dest_dir}")
    else:
        print(f"Copying data from {source_dir} to {dest_dir}...")
        shutil.copytree(source_dir, dest_dir)
        print("Data copy complete.")

ensure_local_data(IMG_DRIVE, IMG_LOCAL)
ensure_local_data(MSK_DRIVE, MSK_LOCAL)

In [None]:
# =============================================
# CORE COMPONENTS: DATASET AND MODEL CLASSES
# =============================================
# --- Dataset Classes ---
class SegDataset(Dataset):
    def __init__(self, image_paths, mask_dir, transform=None):
        self.image_paths, self.mask_dir, self.transform = image_paths, mask_dir, transform

    def __len__(self):
        return len(self.image_paths)

    def _get_mask_path(self, img_path):
        base = os.path.splitext(os.path.basename(img_path))[0]
        p1 = os.path.join(self.mask_dir, f"{base}_mask.png")
        if os.path.exists(p1): return p1
        for ext in ('.png', '.jpg', '.jpeg', '.tif', '.tiff'):
            p2 = os.path.join(self.mask_dir, base + ext)
            if os.path.exists(p2): return p2
        raise FileNotFoundError(f"Mask not found for image: {img_path}")

    def __getitem__(self, idx):
        ip = self.image_paths[idx]
        mp = self._get_mask_path(ip)
        img = cv2.cvtColor(cv2.imread(ip), cv2.COLOR_BGR2RGB)
        msk = (cv2.imread(mp, cv2.IMREAD_GRAYSCALE) > 0).astype(np.uint8)
        if self.transform:
            transformed = self.transform(image=img, mask=msk)
            img, msk = transformed["image"], transformed["mask"]
        if isinstance(msk, torch.Tensor) and msk.ndim == 2:
            msk = msk.unsqueeze(0)
        return img, msk.float(), ip

class InferenceDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths, self.transform = image_paths, transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(image=image)["image"]
        return image, img_path

# --- Model Architecture ---
class GaborStem(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.k, self.gamma, self.magnitude = cfg["kernel_size"], cfg["gamma"], cfg["magnitude"]
        ax = torch.arange(-(self.k // 2), self.k // 2 + 1, dtype=torch.float32)
        X, Y = torch.meshgrid(ax, ax, indexing='xy')

        phases = torch.tensor(cfg["phases"], dtype=torch.float32)
        lambdas = torch.tensor(cfg["wavelengths"], dtype=torch.float32)
        sigmas = torch.tensor(cfg["sigmas"], dtype=torch.float32)
        thetas = torch.linspace(0, math.pi, steps=cfg["orientations"], dtype=torch.float32)

        base = torch.tensor([(t.item(), s.item(), l.item(), p.item())
                             for l, s in zip(lambdas, sigmas) for t in thetas for p in phases])

        theta, sigma, lambd, phase = [p.view(-1, 1, 1) for p in base.T]

        Xp = X * torch.cos(theta) + Y * torch.sin(theta)
        Yp = -X * torch.sin(theta) + Y * torch.cos(theta)

        gauss = torch.exp(-(Xp**2 + (torch.as_tensor(self.gamma) * Yp)**2) / (2 * sigma**2))
        carrier = torch.cos(2 * math.pi * Xp / lambd + phase)
        g = gauss * carrier
        g = g - g.mean(dim=(1, 2), keepdim=True)
        g = g / (g.square().sum(dim=(1, 2), keepdim=True).sqrt() + 1e-8)

        self.register_buffer('kernels', g)
        self.register_buffer('phases_buf', phases)

        num_per_phase = lambdas.numel() * cfg["orientations"]
        self.out_channels = num_per_phase if (self.magnitude and len(phases) == 2) else base.shape[0]
        self.norm = nn.InstanceNorm2d(self.out_channels, affine=False)

    def forward(self, x):
        y = 0.299 * x[:, 0:1] + 0.587 * x[:, 1:2] + 0.114 * x[:, 2:3]
        K = self.kernels.to(x.device, x.dtype)
        if self.magnitude and self.phases_buf.numel() == 2:
            k_cos, k_sin = K[0::2].unsqueeze(1), K[1::2].unsqueeze(1)
            rc = F.conv2d(y, k_cos, padding=self.k // 2)
            rs = F.conv2d(y, k_sin, padding=self.k // 2)
            feats = torch.sqrt(rc**2 + rs**2 + 1e-8)
        else:
            feats = F.conv2d(y, K.unsqueeze(1), padding=self.k // 2)
        return self.norm(feats)

class UNetWithGabor_Adapted(nn.Module):
    def __init__(self, gcfg, in_img_ch=3):
        super().__init__()
        self.use_gabor, self.mode = gcfg["enabled"], gcfg["mode"]
        self.gabor = GaborStem(gcfg) if self.use_gabor else None
        g_out = self.gabor.out_channels if self.gabor else 0
        in_ch = (in_img_ch + g_out) if (self.use_gabor and self.mode == "concat") else (g_out if self.use_gabor else in_img_ch)

        if in_ch == 3:
            self.adapter = nn.Identity()
        else:
            mid_ch = max(16, in_ch // 4)
            self.adapter = nn.Sequential(
                nn.Conv2d(in_ch, mid_ch, 1, bias=False),
                nn.BatchNorm2d(mid_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(mid_ch, 3, 1, bias=False)
            )
        self.net = smp.Unet(encoder_name='resnet50', encoder_weights='imagenet', in_channels=3, classes=1)

    def _prep(self, x):
        if self.gabor:
            g = self.gabor(x)
            x = torch.cat([x, g], dim=1) if self.mode == "concat" else g
        return self.adapter(x)

    def forward(self, x):
        return self.net(self._prep(x))

    def extract_features(self, x):
        x = self._prep(x)
        return self.net.encoder(x)[-1]

print("Core components defined.")

In [None]:
# =============================================
# EXPERIMENT 1: RANDOM SAMPLING BASELINE
# =============================================
print("\n--- Starting Random Sampling Baseline Experiment ---")

# --- Configuration ---
GABOR_CFG = {"enabled":True, "mode":"concat", "kernel_size":15, "orientations":8, "wavelengths":[4.0,8.0,12.0], "sigmas":[1.5,3.0,4.5], "phases":[0,np.pi/2], "gamma":0.5, "magnitude":True}
train_transform = A.Compose([A.Resize(256, 256), A.HorizontalFlip(p=0.5), A.RandomBrightnessContrast(p=0.2), A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ToTensorV2()])
val_transform = A.Compose([A.Resize(256, 256), A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ToTensorV2()])
loss_fn = nn.BCEWithLogitsLoss()
CPU_COUNT = os.cpu_count() or 2
NUM_WORKERS = max(2, min(8, CPU_COUNT // 2))

# --- Dataloader Functions ---
def make_loader(paths, transform, batch_size, shuffle=False):
    ds = SegDataset(paths, MSK_LOCAL, transform=transform)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, persistent_workers=(NUM_WORKERS > 0))

# --- Checkpoint and Pool Initialization ---
ckpt_files = sorted(glob.glob(f"{CKPT_DIR}/rand_ckpt_*.pt"), key=os.path.getmtime, reverse=True)
model_rand = UNetWithGabor_Adapted(gcfg=GABOR_CFG)
optimizer_rand = optim.AdamW(model_rand.parameters(), lr=1e-4, fused=torch.cuda.is_available())

if ckpt_files:
    print(f"Resuming from checkpoint: {os.path.basename(ckpt_files[0])}")
    ckpt = torch.load(ckpt_files[0], map_location=DEVICE)
    model_rand.load_state_dict(ckpt['model_state'])
    optimizer_rand.load_state_dict(ckpt['optimizer_state'])
    rand_labeled_pool, rand_unlabeled_pool, random_history, start_iter_rand = ckpt['rand_labeled_pool'], ckpt['rand_unlabeled_pool'], ckpt['random_history'], ckpt['iter_idx'] + 1
else:
    print("No checkpoint found. Creating stratified data pools...")
    start_iter_rand = 0
    all_paths = [os.path.join(IMG_LOCAL, f) for f in os.listdir(IMG_LOCAL) if not f.startswith('.')]
    ratios = [np.mean(cv2.imread(SegDataset([], MSK_LOCAL)._get_mask_path(p), 0) > 0) for p in tqdm(all_paths, "Calculating Ratios")]
    df = pd.DataFrame({'path': all_paths, 'ratio': ratios})
    df['ratio_bin'] = pd.qcut(df['ratio'], q=5, labels=False, duplicates='drop')
    train_val_paths, test_paths, _, _ = train_test_split(df['path'].tolist(), df['ratio_bin'].tolist(), test_size=0.2, random_state=SEED, stratify=df['ratio_bin'].tolist())
    random.shuffle(train_val_paths)
    rand_labeled_pool, rand_unlabeled_pool, random_history = train_val_paths[:20], train_val_paths[20:], []

model_rand.to(DEVICE, memory_format=torch.channels_last)
try:
    print("Compiling model...")
    model_rand = torch.compile(model_rand, mode="max-autotune")
except Exception as e:
    print(f"torch.compile skipped: {e}")

# --- Training and Validation Loop ---
def train_iteration(model, loader, epochs=5):
    model.train()
    for _ in range(epochs):
        for images, masks, _ in loader:
            images, masks = images.to(DEVICE, non_blocking=True), masks.to(DEVICE, non_blocking=True)
            optimizer_rand.zero_grad(set_to_none=True)
            with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=use_amp):
                preds = model(images)
                loss = loss_fn(preds, masks)
            loss.backward()
            optimizer_rand.step()

@torch.no_grad()
def validate(model, loader, max_batches=10):
    model.eval()
    total_iou = 0.0
    for i, (images, masks, _) in enumerate(loader):
        if i >= max_batches: break
        images, masks = images.to(DEVICE, non_blocking=True), masks.to(DEVICE, non_blocking=True)
        with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=use_amp):
            preds = torch.sigmoid(model(images))
        inter = (preds * masks).sum()
        union = (preds + masks).sum() - inter
        total_iou += (inter / (union + 1e-6)).item()
    return total_iou / (i + 1)

NUM_ITERATIONS, BATCH_SIZE, EPOCHS_PER_ITER = 15, 20, 5
test_loader = make_loader(test_paths, val_transform, batch_size=128)
print(f"Starting baseline experiment from iteration: {start_iter_rand + 1}")

for i in range(start_iter_rand, NUM_ITERATIONS):
    print(f"\n[Baseline] Iteration {i+1}/{NUM_ITERATIONS} | Labeled: {len(rand_labeled_pool)}")
    train_loader = make_loader(rand_labeled_pool, train_transform, batch_size=64, shuffle=True)
    train_iteration(model_rand, train_loader, epochs=EPOCHS_PER_ITER)

    test_iou = validate(model_rand, test_loader)
    random_history.append({'num_labeled': len(rand_labeled_pool), 'iou': test_iou})
    print(f"  Test IoU: {test_iou:.4f}")

    if not rand_unlabeled_pool:
        print("Unlabeled pool exhausted.")
        break

    new_batch = random.sample(rand_unlabeled_pool, k=min(BATCH_SIZE, len(rand_unlabeled_pool)))
    rand_labeled_pool.extend(new_batch)
    rand_unlabeled_pool = [p for p in rand_unlabeled_pool if p not in new_batch]

    torch.save({
        'iter_idx': i, 'model_state': model_rand.state_dict(), 'optimizer_state': optimizer_rand.state_dict(),
        'rand_labeled_pool': rand_labeled_pool, 'rand_unlabeled_pool': rand_unlabeled_pool, 'random_history': random_history
    }, f"{CKPT_DIR}/rand_ckpt_{i}.pt")

    del train_loader; gc.collect(); torch.cuda.empty_cache()

print("\nBaseline experiment complete.")

In [None]:
# =============================================
# RESULTS VISUALIZATION
# =============================================
import matplotlib.pyplot as plt

print("Generating final comparison plot...")

# --- Load Experiment Histories ---
al_ckpt_files = sorted(glob.glob(f"{CKPT_DIR}/al_ckpt_*.pt"), key=os.path.getmtime, reverse=True)
rand_ckpt_files = sorted(glob.glob(f"{CKPT_DIR}/rand_ckpt_*.pt"), key=os.path.getmtime, reverse=True)

active_learning_history = torch.load(al_ckpt_files[0])['active_learning_history'] if al_ckpt_files else []
random_history = torch.load(rand_ckpt_files[0])['random_history'] if rand_ckpt_files else []

# --- Plotting ---
plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(12, 8))

if active_learning_history:
    al_df = pd.DataFrame(active_learning_history)
    ax.plot(al_df['num_labeled'], al_df['iou'], marker='o', linestyle='-', color='crimson', label='Active Learning (HDBSCAN Outlier)')

if random_history:
    rand_df = pd.DataFrame(random_history)
    ax.plot(rand_df['num_labeled'], rand_df['iou'], marker='s', linestyle='--', color='dodgerblue', label='Random Sampling (Baseline)')

ax.set_title('Active Learning vs. Random Sampling Performance', fontsize=16, fontweight='bold')
ax.set_xlabel('Number of Labeled Images', fontsize=12)
ax.set_ylabel('Test Set Intersection over Union (IoU)', fontsize=12)
ax.axhline(y=0.6, color='gray', linestyle=':', linewidth=2, label='Target IoU (0.60)')
ax.legend(fontsize=12)
ax.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.tight_layout()
plt.show()