<a href="https://colab.research.google.com/github/appababba/USDA/blob/main/ActiveLearning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# --- install stuff (colab) ---
!pip install -q segmentation-models-pytorch==0.3.3 albumentations==1.4.7 hdbscan pretrainedmodels efficientnet_pytorch --no-deps

# --- imports ---
import os, random, shutil, math, glob, gc, pickle, json, copy, sys
import numpy as np, pandas as pd, cv2, torch
import torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from contextlib import nullcontext
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from google.colab import drive
from tqdm.notebook import tqdm
import hdbscan

# --- env + reproducibility ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
    torch.backends.cuda.matmul.allow_tf32 = True
if hasattr(torch.backends, "cudnn"):
    torch.backends.cudnn.allow_tf32 = True

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

use_amp = torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if use_amp else torch.float32

# --- mount drive ---
try:
    drive.mount('/content/drive', force_remount=False)
except Exception as e:
    print(f"drive mount fail: {e}")

# --- paths ---
BASE_DRIVE = "/content/drive/Shared drives/USDA-Summer2025"
DATA_DIR   = os.path.join(BASE_DRIVE, "data")
IMG_DRIVE  = os.path.join(DATA_DIR, "Exported_Images")
MSK_DRIVE  = os.path.join(DATA_DIR, "Exported_Masks")
MODELS_DIR = os.path.join(BASE_DRIVE, "models")
LOCAL_ROOT = "/content/local_data"
IMG_LOCAL  = os.path.join(LOCAL_ROOT, "Exported_Images")
MSK_LOCAL  = os.path.join(LOCAL_ROOT, "Exported_Masks")
os.makedirs(LOCAL_ROOT, exist_ok=True)

# --- pull data local if needed ---
def ensure_local_data(src, dst):
    if os.path.isdir(dst) and any(os.scandir(dst)):
        print(f"local data ok: {dst}")
        return
    print(f"copying from {src} to {dst}...")
    shutil.copytree(src, dst)
    print("done.")

ensure_local_data(IMG_DRIVE, IMG_LOCAL)
ensure_local_data(MSK_DRIVE, MSK_LOCAL)

# --- datasets ---
class SegDataset(Dataset):
    def __init__(self, image_paths, mask_dir, transform=None):
        self.image_paths = image_paths
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def _get_mask_path(self, img_path):
        base = os.path.splitext(os.path.basename(img_path))[0]
        main = os.path.join(self.mask_dir, f"{base}_mask.png")
        if os.path.exists(main):
            return main
        for ext in ('.png', '.jpg', '.jpeg', '.tif', '.tiff'):
            alt = os.path.join(self.mask_dir, base + ext)
            if os.path.exists(alt):
                return alt
        raise FileNotFoundError(f"no mask for {img_path}")

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self._get_mask_path(img_path)
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        mask = (cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) > 0).astype(np.uint8)
        if self.transform:
            t = self.transform(image=img, mask=mask)
            img, mask = t['image'], t['mask']
        if isinstance(mask, torch.Tensor) and mask.ndim == 2:
            mask = mask.unsqueeze(0)
        return img, mask.float(), img_path

class InferenceDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        if self.transform:
            img = self.transform(image=img)['image']
        return img, img_path

# --- plain UNet ---
class UNetPlain(nn.Module):
    def __init__(self, encoder='resnet50', in_ch=3, classes=1):
        super().__init__()
        self.net = smp.Unet(encoder_name=encoder, encoder_weights='imagenet',
                            in_channels=in_ch, classes=classes)

    def forward(self, x):
        return self.net(x)

    @torch.no_grad()
    def extract_features(self, x):
        return self.net.encoder(x)[-1]

# --- augments ---
val_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

train_transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

# --- init model ---
print('init UNet (resnet50)')
model = UNetPlain('resnet50').to(DEVICE, memory_format=torch.channels_last)
model.eval()
try:
    print('compiling...')
    model = torch.compile(model, mode='max-autotune')
    print('compiled ok')
except Exception as e:
    print(f'compile skipped: {e}')

# --- feature extraction ---
all_image_paths = [os.path.join(IMG_LOCAL, f) for f in os.listdir(IMG_LOCAL) if not f.startswith('.')]
print(f'found {len(all_image_paths)} imgs')

dataset = InferenceDataset(all_image_paths, val_transform)
loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0, pin_memory=(DEVICE.type=='cuda'))

all_feats, all_paths = [], []
with torch.no_grad():
    for imgs, paths in tqdm(loader, desc='extracting'):
        imgs = imgs.to(DEVICE)
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
            f = model.extract_features(imgs)
            pooled = F.adaptive_avg_pool2d(f, (1,1)).flatten(1)
            normed = F.normalize(pooled, p=2, dim=1)
        all_feats.append(normed.cpu().numpy())
        all_paths.extend(paths)

feats = np.vstack(all_feats) if all_feats else np.empty((0,2048))
print('features ready', feats.shape)

# --- clustering ---
if feats.shape[0] > 0:
    print('running HDBSCAN...')
    cl = hdbscan.HDBSCAN(min_cluster_size=15, prediction_data=True, core_dist_n_jobs=-1)
    labels = cl.fit_predict(feats)
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    n_noise = int(np.sum(labels == -1))
    print(f'{n_clusters} clusters, {n_noise} outliers')
    outlier_scores = cl.outlier_scores_
    scored = sorted(zip(outlier_scores, all_paths), key=lambda x: x[0], reverse=True)
else:
    print('no features, skip clustering')
    scored = []

def select_next_batch(scored_samples, labeled, batch_size=20):
    picks, seen = [], set(labeled)
    for _, p in scored_samples:
        if p not in seen:
            picks.append(p)
            if len(picks) >= batch_size:
                break
    return picks

labeled_paths = []
BATCH_SIZE = 20
first_batch = select_next_batch(scored, labeled_paths, BATCH_SIZE) if scored else []
print(f'next {len(first_batch)} to label')
for i,p in enumerate(first_batch[:5]):
    print(f'  {i+1}. {os.path.basename(p)}')

# --- active learning loop setup ---
print('\nstarting AL loop')
CKPT_DIR = '/content/drive/MyDrive/active_learning_ckpts'
os.makedirs(CKPT_DIR, exist_ok=True)
NUM_WORKERS = 0
PIN_MEMORY = (DEVICE.type == 'cuda')

ckpt_files = sorted(glob.glob(f'{CKPT_DIR}/al_ckpt_*.pt'), key=os.path.getmtime, reverse=True)
latest_ckpt = ckpt_files[0] if ckpt_files else None

current_model = UNetPlain('resnet50').to(DEVICE, memory_format=torch.channels_last)
optimizer = optim.Adam(current_model.parameters(), lr=1e-4)

start_iter = 0
al_labeled, al_unlabeled, history = [], [], []

if latest_ckpt:
    try:
        print(f'resuming from {os.path.basename(latest_ckpt)}')
        ckpt = torch.load(latest_ckpt, map_location=DEVICE)
        cur = current_model.state_dict()
        compat = {k:v for k,v in ckpt['model_state'].items() if k in cur and v.shape==cur[k].shape}
        current_model.load_state_dict(compat, strict=False)
        try:
            optimizer.load_state_dict(ckpt['optimizer_state'])
        except Exception:
            pass
        al_labeled = ckpt.get('al_labeled_pool', [])
        al_unlabeled = ckpt.get('al_unlabeled_pool', [])
        history = ckpt.get('active_learning_history', [])
        start_iter = ckpt.get('iter_idx', -1) + 1
    except Exception as e:
        print(f'bad checkpoint, starting fresh: {e}')

if not al_labeled:
    print('making new pools...')
    FULL = [os.path.join(IMG_LOCAL, f) for f in os.listdir(IMG_LOCAL) if not f.startswith('.')]
    ratios = []
    helper = SegDataset([], MSK_LOCAL)
    for p in tqdm(FULL, desc='mask ratios'):
        try:
            m = cv2.imread(helper._get_mask_path(p), 0)
            ratios.append(float(np.mean(m>0)) if m is not None else 0)
        except FileNotFoundError:
            ratios.append(0)
    df = pd.DataFrame({'path':FULL, 'ratio':ratios})
    df['ratio_bin'] = pd.qcut(df['ratio'], q=5, labels=False, duplicates='drop')
    train_val, test_paths, _, _ = train_test_split(df['path'].tolist(), df['ratio_bin'].tolist(), test_size=0.2, random_state=SEED, stratify=df['ratio_bin'].tolist())
    random.shuffle(train_val)
    al_labeled, al_unlabeled, history = train_val[:20], train_val[20:], []
else:
    if 'test_paths' not in globals():
        _, test_paths = train_test_split([os.path.join(IMG_LOCAL, f) for f in os.listdir(IMG_LOCAL)], test_size=0.2, random_state=SEED)

# --- dataloaders + train/val ---
def make_loader(paths, tfm, bs, shuffle=False):
    ds = SegDataset(paths, MSK_LOCAL, transform=tfm)
    return DataLoader(ds, batch_size=bs, shuffle=shuffle, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

def train_one_iter(model, loader, epochs=5):
    model.train()
    for _ in range(epochs):
        for imgs, masks, _ in loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            optimizer.zero_grad(set_to_none=True)
            with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
                preds = model(imgs)
                loss = nn.BCEWithLogitsLoss()(preds, masks)
            loss.backward()
            optimizer.step()

@torch.no_grad()
def validate(model, loader, max_batches=10):
    if len(loader.dataset)==0: return 0.0
    model.eval()
    total, n = 0.0, 0
    for i,(imgs, masks, _) in enumerate(loader):
        if i>=max_batches: break
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
            preds = torch.sigmoid(model(imgs))
        inter = (preds*masks).sum()
        union = (preds+masks).sum() - inter
        total += (inter/(union+1e-6)).item()
        n+=1
    return total/max(1,n)

# --- embeddings + batch select ---
EMB_CACHE = {}
PCA_MODEL = None

@torch.no_grad()
def get_embeddings(paths, bs=64):
    to_do = [p for p in paths if p not in EMB_CACHE]
    if to_do:
        ds = InferenceDataset(to_do, val_transform)
        loader = DataLoader(ds, batch_size=bs, shuffle=False, num_workers=0)
        current_model.eval()
        for imgs, pths in loader:
            imgs = imgs.to(DEVICE)
            with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
                f = F.normalize(F.adaptive_avg_pool2d(current_model.extract_features(imgs), (1,1)).flatten(1))
            for p,fv in zip(pths, f.cpu().numpy()):
                EMB_CACHE[p] = fv
    return np.stack([EMB_CACHE[p] for p in paths]) if paths else np.empty((0,1))

def pick_next(unlabeled, batch_size=20, candidate_k=500):
    global PCA_MODEL
    if not unlabeled: return []
    cand = random.sample(unlabeled, k=min(candidate_k,len(unlabeled)))
    feats = get_embeddings(cand)
    if PCA_MODEL is None and feats.shape[0]>=64:
        PCA_MODEL = PCA(n_components=64, random_state=SEED).fit(feats)
    reduced = PCA_MODEL.transform(feats) if PCA_MODEL else feats
    scores = hdbscan.HDBSCAN(min_cluster_size=15).fit(reduced).outlier_scores_
    ranked = [c for _,c in sorted(zip(scores,cand), key=lambda x:x[0], reverse=True)]
    return ranked[:batch_size]

# --- main loop ---
NUM_ITERS = 15
EPOCHS_PER = 5
print(f'starting from iter {start_iter+1}')

for i in range(start_iter, NUM_ITERS):
    n_labeled = len(al_labeled)
    print(f'iter {i+1}/{NUM_ITERS} | {n_labeled} labeled')
    train_loader = make_loader(al_labeled, train_transform, 32, shuffle=True)
    train_one_iter(current_model, train_loader, epochs=EPOCHS_PER)

    test_loader = make_loader(test_paths, val_transform, 64)
    val_iou = validate(current_model, test_loader)
    history.append({'num_labeled': n_labeled, 'iou': val_iou})
    print(f'val IoU: {val_iou:.4f}')

    if not al_unlabeled:
        print('no unlabeled left')
        break

    print('selecting next batch...')
    new_batch = pick_next(al_unlabeled, BATCH_SIZE)

    al_labeled.extend(new_batch)
    al_unlabeled = [p for p in al_unlabeled if p not in new_batch]

    torch.save({
        'iter_idx': i,
        'model_state': current_model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'al_labeled_pool': al_labeled,
        'al_unlabeled_pool': al_unlabeled,
        'active_learning_history': history
    }, f'{CKPT_DIR}/al_ckpt_{i}.pt')

    del train_loader, test_loader
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print(f'AL done. checkpoints in {CKPT_DIR}')


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/58.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m155.7/155.7 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pretrainedmodels (setup.py) ... [?25l[?25hdone
  Building wheel for efficientnet_pytorch (setup.py) ... [?25l[?25hdone


  $max \{ core_k(a), core_k(b), 1/\alpha d(a,b) \}$.


Mounted at /content/drive
Copying data from /content/drive/Shared drives/USDA-Summer2025/data/Exported_Images to /content/local_data/Exported_Images...
Data copy complete.
Copying data from /content/drive/Shared drives/USDA-Summer2025/data/Exported_Masks to /content/local_data/Exported_Masks...
Data copy complete.
Initializing model (plain ResNet50 UNet)...
Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 342MB/s]


Compiling model for performance...
Model compiled successfully.

Found 1217 total images for feature extraction.


Extracting Features:   0%|          | 0/39 [00:00<?, ?it/s]

Feature extraction complete.
Feature vector shape: (1217, 2048), Total paths: 1217

Performing HDBSCAN clustering on feature vectors...




Clustering complete. Found 0 clusters and 1217 noise points (outliers).

Selected the top 20 most informative samples to label next:
  1. IMG_5471_092846_20250813_section14.jpg
  2. IMG_5150_093149_20250813_section15.jpg
  3. IMG_4606_085228_20250813_section1.jpg
  4. IMG_5163_093151_20250813_section15.jpg
  5. IMG_4601_085226_20250813_section1.jpg

--- Initializing Active Learning Loop ---
Resuming from checkpoint: al_ckpt_13.pt
Starting active learning from iteration: 15

--- Iteration 15/15 | Labeled Samples: 300 ---
Validation IoU: 0.4551
Selecting next batch of samples...





Active learning complete. Checkpoints saved in: /content/drive/MyDrive/active_learning_ckpts
