<a href="https://colab.research.google.com/github/appababba/USDA/blob/main/UNet_ResNet50_GrapeSegmentation_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q segmentation-models-pytorch
!pip install -q --upgrade scipy

In [None]:
# 1. Setup, Imports, and Google Drive Mount
import os
import random
import subprocess
import collections
import numpy as np
import cv2
from glob import glob
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from contextlib import nullcontext

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from google.colab import drive

print(" Starting the complete UNet-ResNet50 training pipeline...")
try:
    drive.mount('/content/drive')
    print(" Google Drive mounted.")
except Exception as e:
    print(f"ℹ Drive already mounted or mount error: {e}")

🚀 Starting the complete UNet-ResNet50 training pipeline...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Google Drive mounted.


In [None]:
# 2. Configuration & Hyperparameters

# --- Paths ---
BASE_DRIVE_DIR = '/content/drive/Shared drives/USDA-Summer2025/data'
IMAGE_DIR_DRIVE = os.path.join(BASE_DRIVE_DIR, 'Exported_Images')
MASK_DIR_DRIVE  = os.path.join(BASE_DRIVE_DIR, 'Exported_Masks')
MODELS_SAVE_DIR = '/content/drive/Shared drives/USDA-Summer2025/models'

# --- CORRECTED LOCAL PATHS ---
LOCAL_IMG_DIR  = '/content/data/images'
LOCAL_MASK_DIR = '/content/data/masks'

# --- Hyperparameters ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
IMG_SIZE = (256, 256)
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 15
RANDOM_SEED = 42

# --- Mixed Precision Setup ---
USE_AMP = (DEVICE.type == 'cuda')
amp_autocast = torch.amp.autocast(device_type=DEVICE.type, dtype=torch.float16) if USE_AMP else nullcontext()
scaler = torch.amp.GradScaler(enabled=USE_AMP)

print(f"✅ Configuration updated. Local image path set to: {LOCAL_IMG_DIR}")

✅ Configuration updated. Local image path set to: /content/data/images


In [None]:
import os, subprocess, textwrap, sys

# 1) Define sources
BASE = "/content/drive/Shareddrives/USDA-Summer2025/data"
IMG_SRC  = f"{BASE}/Exported_Images"
MASK_SRC = f"{BASE}/Exported_Masks"

# 2) Define local targets
LOCAL_IMG = "/content/data/images"
LOCAL_MSK = "/content/data/masks"
os.makedirs(LOCAL_IMG, exist_ok=True)
os.makedirs(LOCAL_MSK, exist_ok=True)

def check_dir(p):
    if not os.path.isdir(p):
        raise FileNotFoundError(f" Not found (check spelling/case): {p}")
    items = os.listdir(p)
    print(f"✅ {p} exists with {len(items)} items")
    # Show a few samples
    for name in sorted(items)[:5]:
        print("   •", name)

print(" Checking mount and base...")
print("drive root:", os.listdir("/content/drive"))
print("shared drives:", os.listdir("/content/drive/Shareddrives"))

check_dir(BASE)
check_dir(IMG_SRC)
check_dir(MASK_SRC)

def rsync_copy(src, dst):
    # trailing slashes copy *contents* into dst
    cmd = ["rsync", "-a", "--info=progress2", src + "/", dst + "/"]
    print("\n Running:", " ".join(cmd))
    subprocess.run(cmd, check=True)
    print(f" Sync complete → {dst}")

rsync_copy(IMG_SRC,  LOCAL_IMG)
rsync_copy(MASK_SRC, LOCAL_MSK)


🔎 Checking mount and base...
drive root: ['Shareddrives', 'MyDrive', '.shortcut-targets-by-id', '.Trash-0']
shared drives: ['CSci158', 'Term Project', 'USDA-Summer2025']
✅ /content/drive/Shareddrives/USDA-Summer2025/data exists with 7 items
   • Exported_Images
   • Exported_Masks
   • Julian
   • RealData
   • Section 8
✅ /content/drive/Shareddrives/USDA-Summer2025/data/Exported_Images exists with 1217 items
   • IMG_3459_085619_20250813_section1.jpg
   • IMG_3462_085619_20250813_section1.jpg
   • IMG_3468_085619_20250813_section1.jpg
   • IMG_3471_085619_20250813_section1.jpg
   • IMG_3472_085620_20250813_section1.jpg
✅ /content/drive/Shareddrives/USDA-Summer2025/data/Exported_Masks exists with 1217 items
   • IMG_3459_085619_20250813_section1_mask.png
   • IMG_3462_085619_20250813_section1_mask.png
   • IMG_3468_085619_20250813_section1_mask.png
   • IMG_3471_085619_20250813_section1_mask.png
   • IMG_3472_085620_20250813_section1_mask.png

▶ Running: rsync -a --info=progress2 /cont

In [None]:
# 4. Group Extraction
import re

def extract_group(path):
    """Extracts a group identifier (e.g., '20250813_sec1') from a filename."""
    filename = os.path.basename(path)
    m_date = re.search(r'_(\d{8})_', filename)
    m_sec  = re.search(r'section(\d+)', filename, re.IGNORECASE)

    date_str = m_date.group(1) if m_date else "nodate"
    sec_str = m_sec.group(1) if m_sec else "nosec"

    return f"{date_str}_sec{sec_str}"

In [None]:
# 5. Group-Aware Data Splitting

img_exts = ('*.jpg', '*.jpeg', '*.png', '*.JPG', '*.PNG')
all_local_image_paths = []
for ext in img_exts:
    all_local_image_paths.extend(glob(os.path.join(LOCAL_IMG_DIR, ext)))
if not all_local_image_paths:
    raise RuntimeError(f"No images found in {LOCAL_IMG_DIR}")

# 1. Map all images to their group ID
images_by_group = collections.defaultdict(list)
for path in all_local_image_paths:
    group_id = extract_group(path)
    images_by_group[group_id].append(path)

unique_groups = sorted(list(images_by_group.keys()))
print(f"Found {len(all_local_image_paths)} images across {len(unique_groups)} unique groups.")

# 2. Split the list of unique groupsi, not the images
random.seed(RANDOM_SEED); random.shuffle(unique_groups)
train_val_groups, test_groups = train_test_split(unique_groups, test_size=0.20, random_state=RANDOM_SEED)
train_groups, val_groups      = train_test_split(train_val_groups, test_size=0.15, random_state=RANDOM_SEED)

# 3. Build the final image lists from the split groups
def get_paths_from_groups(groups, group_map):
    return [path for group in groups for path in group_map[group]]

train_paths = get_paths_from_groups(train_groups, images_by_group)
val_paths   = get_paths_from_groups(val_groups, images_by_group)
test_paths  = get_paths_from_groups(test_groups, images_by_group)

# 4. Final shuffle for randomness during training
random.seed(RANDOM_SEED)
random.shuffle(train_paths)
random.shuffle(val_paths)
random.shuffle(test_paths)

# 5. Verify no overlap
train_set_groups = {extract_group(p) for p in train_paths}
val_set_groups = {extract_group(p) for p in val_paths}
test_set_groups = {extract_group(p) for p in test_paths}
print(f"Train-Val Overlap: {len(train_set_groups & val_set_groups)}")
print(f"Train-Test Overlap: {len(train_set_groups & test_set_groups)}")
print(" Group split complete with no overlap.")


🛡️ Setting up GROUP-AWARE data splits...
Found 1217 images across 12 unique groups.
Train-Val Overlap: 0
Train-Test Overlap: 0
✅ Group split complete with no overlap.


In [None]:
from glob import glob
import os, collections

img_exts = ('*.jpg', '*.jpeg', '*.png', '*.JPG', '*.PNG')
counts = {}
all_paths = []
for ext in img_exts:
    ps = glob(os.path.join(LOCAL_IMG_DIR, ext))
    counts[ext] = len(ps)
    all_paths.extend(ps)

print("Per-extension counts:", counts)
print("Total found:", len(all_paths))
print("Sample dir:", LOCAL_IMG_DIR)


Per-extension counts: {'*.jpg': 1217, '*.jpeg': 0, '*.png': 0, '*.JPG': 0, '*.PNG': 0}
Total found: 1217
Sample dir: /content/data/images


In [None]:
# 6. PyTorch Dataset and DataLoaders

class GrapeDataset(Dataset):
    def __init__(self, image_paths, mask_dir, size=(256, 256)):
        self.image_paths = image_paths
        self.mask_dir = mask_dir
        self.size = size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image_bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if image_bgr is None:
            raise FileNotFoundError(f"Failed to read image: {img_path}")
        image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, self.size, interpolation=cv2.INTER_LINEAR)

        base_name, _ = os.path.splitext(os.path.basename(img_path))
        mask_path = os.path.join(self.mask_dir, f"{base_name}_mask.png")
        if not os.path.exists(mask_path):
             mask_path = os.path.join(self.mask_dir, os.path.basename(img_path))
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Failed to read mask for {img_path} at {mask_path}")
        mask = cv2.resize(mask, self.size, interpolation=cv2.INTER_NEAREST)
        mask = (mask > 0).astype(np.float32)

        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        mask  = torch.from_numpy(mask).unsqueeze(0).float()
        return image, mask

train_dataset = GrapeDataset(train_paths, LOCAL_MASK_DIR, size=IMG_SIZE)
val_dataset   = GrapeDataset(val_paths,   LOCAL_MASK_DIR, size=IMG_SIZE)
test_dataset  = GrapeDataset(test_paths,  LOCAL_MASK_DIR, size=IMG_SIZE)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"Data ready. Train: {len(train_dataset)}, Validation: {len(val_dataset)}, Test: {len(test_dataset)}")

Data ready. Train: 639, Validation: 121, Test: 457


In [None]:
# 7. Model, Loss Function, and Optimizer

print("\n🧠 Initializing UNet+ResNet50 model...")
model = smp.Unet(
    encoder_name='resnet50',
    encoder_weights='imagenet',
    in_channels=3,
    classes=1,
    activation=None
).to(DEVICE)

# Combined Dice and BCE loss is robust for segmentation
dice_loss = smp.losses.DiceLoss(mode='binary')
bce_loss  = nn.BCEWithLogitsLoss()
def combined_loss(pred, target):
    return bce_loss(pred, target) + dice_loss(pred, target)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)


🧠 Initializing UNet+ResNet50 model...


In [None]:
# 8. Training and Validation Execution

#  fixed 0.5 threshold for IoU calculation during training
@torch.no_grad()
def iou_from_logits(logits, y):
    p = (torch.sigmoid(logits) > 0.5).float()
    inter = (p * y).sum()
    union = p.sum() + y.sum() - inter
    return ((inter + 1e-6) / (union + 1e-6)).item()

def train_epoch(loader, model):
    model.train()
    total_loss, total_iou = 0.0, 0.0
    for x, y in tqdm(loader, desc="Training", leave=False):
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with amp_autocast:
            logits = model(x)
            loss = combined_loss(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        total_iou  += iou_from_logits(logits.detach(), y)
    return total_loss / len(loader), total_iou / len(loader)

@torch.no_grad()
def validate_epoch(loader, model):
    model.eval()
    total_loss, total_iou = 0.0, 0.0
    for x, y in loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        with amp_autocast:
            logits = model(x)
            loss = combined_loss(logits, y)
        total_loss += loss.item()
        total_iou  += iou_from_logits(logits, y)
    return total_loss / len(loader), total_iou / len(loader)

print("\n Starting main training loop...")
best_val_iou = -1.0
os.makedirs(MODELS_SAVE_DIR, exist_ok=True)
BEST_MODEL_PATH = os.path.join(MODELS_SAVE_DIR, 'UNet-ResNet50_GroupSplit_Champion.pth')

for ep in range(1, NUM_EPOCHS + 1):
    tr_loss, tr_iou = train_epoch(train_loader, model)
    va_loss, va_iou = validate_epoch(val_loader, model)
    print(f"Epoch {ep}/{NUM_EPOCHS} | Train Loss: {tr_loss:.4f}, IoU: {tr_iou:.3f} || Val Loss: {va_loss:.4f}, IoU: {va_iou:.3f}")



    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(va_iou)
    new_lr = optimizer.param_groups[0]['lr']
    if new_lr < old_lr:
      print(f" Learning rate reduced to {new_lr}")


    if va_iou > best_val_iou:
        best_val_iou = va_iou
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        print(f" New best model saved! Val IoU: {best_val_iou:.4f}")

print(f"\n\n Training complete. Best model saved to {BEST_MODEL_PATH}")


🚦 Starting main training loop...


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 1/15 | Train Loss: 1.0747, IoU: 0.418 || Val Loss: 0.7728, IoU: 0.668
💾 New best model saved! Val IoU: 0.6682


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 2/15 | Train Loss: 0.7150, IoU: 0.649 || Val Loss: 0.5501, IoU: 0.734
💾 New best model saved! Val IoU: 0.7339


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 3/15 | Train Loss: 0.5694, IoU: 0.701 || Val Loss: 0.4620, IoU: 0.755
💾 New best model saved! Val IoU: 0.7551


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 4/15 | Train Loss: 0.4764, IoU: 0.727 || Val Loss: 0.4047, IoU: 0.762
💾 New best model saved! Val IoU: 0.7620


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 5/15 | Train Loss: 0.4106, IoU: 0.746 || Val Loss: 0.3570, IoU: 0.762
✅ Learning rate reduced to 5e-05


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 6/15 | Train Loss: 0.3459, IoU: 0.785 || Val Loss: 0.3386, IoU: 0.769
💾 New best model saved! Val IoU: 0.7688


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 7/15 | Train Loss: 0.3150, IoU: 0.802 || Val Loss: 0.3217, IoU: 0.768
✅ Learning rate reduced to 2.5e-05


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 8/15 | Train Loss: 0.2893, IoU: 0.817 || Val Loss: 0.3134, IoU: 0.767


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 9/15 | Train Loss: 0.2770, IoU: 0.824 || Val Loss: 0.3081, IoU: 0.768


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 10/15 | Train Loss: 0.2594, IoU: 0.836 || Val Loss: 0.3074, IoU: 0.770
💾 New best model saved! Val IoU: 0.7699


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 11/15 | Train Loss: 0.2559, IoU: 0.839 || Val Loss: 0.3017, IoU: 0.770
💾 New best model saved! Val IoU: 0.7701


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 12/15 | Train Loss: 0.2465, IoU: 0.843 || Val Loss: 0.3018, IoU: 0.768
✅ Learning rate reduced to 6.25e-06


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 13/15 | Train Loss: 0.2384, IoU: 0.849 || Val Loss: 0.2982, IoU: 0.770


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 14/15 | Train Loss: 0.2357, IoU: 0.850 || Val Loss: 0.2982, IoU: 0.771
💾 New best model saved! Val IoU: 0.7709


Training:   0%|          | 0/80 [00:00<?, ?it/s]

Epoch 15/15 | Train Loss: 0.2346, IoU: 0.850 || Val Loss: 0.2953, IoU: 0.766
✅ Learning rate reduced to 3.125e-06


🏁 Training complete. Best model saved to /content/drive/Shared drives/USDA-Summer2025/models/UNet-ResNet50_GroupSplit_Champion.pth


In [None]:
# 9. Find Optimal Threshold & Final Test Evaluation

model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))

@torch.no_grad()
def eval_iou_with_threshold(loader, model, thresh=0.5):
    model.eval()
    total_intersection, total_union = 0.0, 0.0
    for xb, yb in tqdm(loader, desc=f"Eval with Thresh={thresh:.2f}", leave=False):
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        p = (torch.sigmoid(model(xb)) > thresh).float()
        total_intersection += (p * yb).sum().item()
        total_union += (p.sum() + yb.sum() - (p * yb).sum()).item()
    return (total_intersection + 1e-6) / (total_union + 1e-6)

# 1. Sweep thresholds on the validation set to find the best one
print("\n Finding optimal threshold on validation set...")
thresholds = np.linspace(0.3, 0.7, 9)
val_ious = [eval_iou_with_threshold(val_loader, model, t) for t in thresholds]
best_t_idx = np.argmax(val_ious)
best_threshold = thresholds[best_t_idx]
print(f"Optimal threshold found: {best_threshold:.2f} (with Val IoU: {val_ious[best_t_idx]:.4f})")

# 2. Use the single best threshold to get the final score on the test set
print("\n Evaluating on test set with the chosen threshold...")
test_iou = eval_iou_with_threshold(test_loader, model, best_threshold)
print(f"\n Final Test IoU: {test_iou:.4f}")


🔍 Finding optimal threshold on validation set...


Eval with Thresh=0.30:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.35:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.40:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.45:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.50:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.55:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.60:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.65:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.70:   0%|          | 0/16 [00:00<?, ?it/s]

Optimal threshold found: 0.35 (with Val IoU: 0.7813)

🧪 Evaluating on test set with the chosen threshold...


Eval with Thresh=0.35:   0%|          | 0/58 [00:00<?, ?it/s]


✅ Final Test IoU: 0.7220


In [None]:
# 9b: Evaluating Previous Unet
import os
import torch
import numpy as np
from tqdm.notebook import tqdm

OTHER_MODEL_PATH = '/content/drive/Shared drives/USDA-Summer2025/src/rory_models/Unet-ResNet50_IoU-val83-test82.pth'

if not os.path.exists(OTHER_MODEL_PATH):
    raise FileNotFoundError(f"Model not found at: {OTHER_MODEL_PATH}")
else:
    print(f"✅ Found model file: {os.path.basename(OTHER_MODEL_PATH)}")

# --- 2. Create a new "shell" of the model architecture ---
model_to_test = smp.Unet(
    encoder_name='resnet50',
    encoder_weights=None,
    in_channels=3,
    classes=1,
).to(DEVICE)

# --- 3. Load the saved weights ---
model_to_test.load_state_dict(torch.load(OTHER_MODEL_PATH, map_location=DEVICE))
print(" Loaded model weights.")

# --- 4. Run the same evaluation logic ---

# Sweep thresholds on the validation set
print("\n Finding optimal threshold for this model on validation set...")
thresholds = np.linspace(0.3, 0.7, 9)
val_ious = [eval_iou_with_threshold(val_loader, model_to_test, t) for t in thresholds]
best_t_idx = np.argmax(val_ious)
best_threshold = thresholds[best_t_idx]
print(f"Optimal threshold found: {best_threshold:.2f} (with Val IoU: {val_ious[best_t_idx]:.4f})")

# Use the best threshold for the final score on the test set
print("\n Evaluating on test set with the chosen threshold...")
test_iou = eval_iou_with_threshold(test_loader, model_to_test, best_threshold)
print(f"\n Final Test IoU for '{os.path.basename(OTHER_MODEL_PATH)}': {test_iou:.4f}")

✅ Found model file: Unet-ResNet50_IoU-val83-test82.pth
✅ Successfully loaded model weights.

🔍 Finding optimal threshold for this model on validation set...


Eval with Thresh=0.30:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.35:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.40:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.45:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.50:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.55:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.60:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.65:   0%|          | 0/16 [00:00<?, ?it/s]

Eval with Thresh=0.70:   0%|          | 0/16 [00:00<?, ?it/s]

Optimal threshold found: 0.55 (with Val IoU: 0.8804)

🧪 Evaluating on test set with the chosen threshold...


Eval with Thresh=0.55:   0%|          | 0/58 [00:00<?, ?it/s]


✅ Final Test IoU for 'Unet-ResNet50_IoU-val83-test82.pth': 0.8636


In [None]:
# 10.  Analyze Data Split Distributions

!pip install --upgrade scipy -q
from scipy.stats import ks_2samp

def get_foreground_ratio(mask_path):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if mask is None: return 0.0
    return float((mask > 0).sum()) / float(mask.size)

def analyze_ratios(paths, mask_dir=LOCAL_MASK_DIR):
    ratios = []
    for img_path in tqdm(paths, desc="Analyzing ratios", leave=False):
        base, _ = os.path.splitext(os.path.basename(img_path))
        mask_path = os.path.join(mask_dir, f"{base}_mask.png")
        if not os.path.exists(mask_path):
             mask_path = os.path.join(mask_dir, os.path.basename(img_path))
        ratios.append(get_foreground_ratio(mask_path))

    arr = np.array(ratios, dtype=np.float32)
    return arr, float(arr.mean()), float(arr.std())

print("\n📊 Analyzing foreground ratio distributions...")
r_tr, m_tr, s_tr = analyze_ratios(train_paths)
r_va, m_va, s_va = analyze_ratios(val_paths)
r_te, m_te, s_te = analyze_ratios(test_paths)

print(f"Train coverage: mean={m_tr:.3f} ± {s_tr:.3f}")
print(f"Val   coverage: mean={m_va:.3f} ± {s_va:.3f}")
print(f"Test  coverage: mean={m_te:.3f} ± {s_te:.3f}")

print("\n--- Kolmogorov-Smirnov Test (Distribution Similarity) ---")
print(f"KS(train,val)  p-value: {ks_2samp(r_tr, r_va).pvalue:.4f}")
print(f"KS(train,test) p-value: {ks_2samp(r_tr, r_te).pvalue:.4f}")


📊 Analyzing foreground ratio distributions...


Analyzing ratios:   0%|          | 0/639 [00:00<?, ?it/s]

Analyzing ratios:   0%|          | 0/121 [00:00<?, ?it/s]

Analyzing ratios:   0%|          | 0/457 [00:00<?, ?it/s]

Train coverage: mean=0.088 ± 0.065
Val   coverage: mean=0.104 ± 0.056
Test  coverage: mean=0.100 ± 0.077

--- Kolmogorov-Smirnov Test (Distribution Similarity) ---
KS(train,val)  p-value: 0.0007
KS(train,test) p-value: 0.0177
