In [None]:
# DCAU Implementation

In [None]:




## DCAU

from matplotlib.patches import Patch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import os
import cv2
import time
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import jaccard_score
import warnings
import multiprocessing
import psutil
import platform
import sys
import gc
warnings.filterwarnings("ignore")

def check_system_capabilities():
    """Check system capabilities to determine optimal num_workers"""
    print("="*60)
    print("SYSTEM CAPABILITY CHECK")
    print("="*60)

    cpu_count = multiprocessing.cpu_count()
    logical_cores = psutil.cpu_count(logical=True)
    physical_cores = psutil.cpu_count(logical=False)

    print(f"Operating System: {platform.system()} {platform.release()}")
    print(f"Physical cores: {physical_cores}")
    print(f"Logical cores: {logical_cores}")

    memory = psutil.virtual_memory()
    total_memory_gb = memory.total / (1024**3)
    print(f"Total RAM: {total_memory_gb:.2f} GB")

    if torch.cuda.is_available():
        gpu_count = torch.cuda.device_count()
        print(f"CUDA available: Yes, {gpu_count} GPU(s)")
    else:
        print("CUDA available: No")

    # Conservative approach for Windows
    recommended_workers = 0 if platform.system() == "Windows" else min(2, physical_cores)
    print(f"Recommended num_workers: {recommended_workers}")

    return recommended_workers

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True

# Class Names
class_names = [
    "Background", "Bareland", "Rangeland", "Developed Space", "Road",
    "Tree", "Water", "Agriculture Land", "Building"
]
num_classes = len(class_names)
COLOR_MAP = {
    0: (0, 0, 0),        # Background - Black
    1: (128, 0, 0),      # Bareland - Maroon
    2: (0, 255, 36),     # Rangeland - Green
    3: (148, 148, 148),  # Developed Space - Gray
    4: (255, 255, 255),  # Road - White
    5: (34, 97, 38),     # Tree - Dark Green
    6: (0, 69, 255),     # Water - Blue
    7: (75, 181, 73),    # Agriculture Land - Light Green
    8: (222, 31, 7)      # Building - Red
}
# UNet Model Definition
class UNet(nn.Module):
    def __init__(self, num_classes=9):
        super(UNet, self).__init__()
        self.num_classes = num_classes
        self.contracting_11 = self.conv_block(3, 64)
        self.contracting_12 = nn.MaxPool2d(2, 2)
        self.contracting_21 = self.conv_block(64, 128)
        self.contracting_22 = nn.MaxPool2d(2, 2)
        self.contracting_31 = self.conv_block(128, 256)
        self.contracting_32 = nn.MaxPool2d(2, 2)
        self.contracting_41 = self.conv_block(256, 512)
        self.contracting_42 = nn.MaxPool2d(2, 2)
        self.middle = self.conv_block(512, 1024)
        self.expansive_11 = nn.ConvTranspose2d(1024, 512, 3, 2, 1, 1)
        self.expansive_12 = self.conv_block(1024, 512)
        self.expansive_21 = nn.ConvTranspose2d(512, 256, 3, 2, 1, 1)
        self.expansive_22 = self.conv_block(512, 256)
        self.expansive_31 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)
        self.expansive_32 = self.conv_block(256, 128)
        self.expansive_41 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1)
        self.expansive_42 = self.conv_block(128, 64)
        self.output = nn.Conv2d(64, num_classes, 3, 1, 1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        c1 = self.contracting_11(x)
        p1 = self.contracting_12(c1)
        c2 = self.contracting_21(p1)
        p2 = self.contracting_22(c2)
        c3 = self.contracting_31(p2)
        p3 = self.contracting_32(c3)
        c4 = self.contracting_41(p3)
        p4 = self.contracting_42(c4)
        middle = self.middle(p4)
        u1 = self.expansive_11(middle)
        u1 = self.expansive_12(torch.cat((u1, c4), dim=1))
        u2 = self.expansive_21(u1)
        u2 = self.expansive_22(torch.cat((u2, c3), dim=1))
        u3 = self.expansive_31(u2)
        u3 = self.expansive_32(torch.cat((u3, c2), dim=1))
        u4 = self.expansive_41(u3)
        u4 = self.expansive_42(torch.cat((u4, c1), dim=1))
        output = self.output(u4)
        return output


class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png', '.tif'))])
        self.mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith(('.jpg', '.png', '.tif'))])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])

        # Load image (RGB)
        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Failed to load image: {image_path}")

        # Load mask as GRAYSCALE (single channel)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise ValueError(f"Failed to load mask: {mask_path}")

        # Resize
        image = cv2.resize(image, (512, 512))
        mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)

        # Convert image to tensor [C, H, W] and normalize
        image = torch.tensor(image).permute(2, 0, 1).float() / 255.0

        # Convert mask to tensor [H, W] - NO channel dimension for segmentation masks
        mask = torch.tensor(mask, dtype=torch.long)
        mask = torch.clamp(mask, 0, 8)  # Ensure valid class indices

        return image, mask

class ValidationDataset(Dataset):
    def __init__(self, image_dir, label_dir=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png', '.tif'))])
        self.has_labels = label_dir is not None

        if self.has_labels:
            self.label_files = sorted([f for f in os.listdir(label_dir) if f.endswith(('.jpg', '.png', '.tif'))])
        else:
            self.label_files = [None] * len(self.image_files)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        image = cv2.imread(image_path)

        if image is None:
            raise ValueError(f"Failed to load image: {image_path}")

        image = cv2.resize(image, (512, 512))
        image_tensor = torch.tensor(image).permute(2, 0, 1).float() / 255.0

        if self.has_labels and self.label_files[idx]:
            label_path = os.path.join(self.label_dir, self.label_files[idx])
            if os.path.exists(label_path):
                # Load label as GRAYSCALE
                label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
                if label is not None:
                    label = cv2.resize(label, (512, 512), interpolation=cv2.INTER_NEAREST)
                    label_tensor = torch.tensor(label, dtype=torch.long)
                    label_tensor = torch.clamp(label_tensor, 0, 8)
                else:
                    label_tensor = torch.zeros((512, 512), dtype=torch.long)
            else:
                label_tensor = torch.zeros((512, 512), dtype=torch.long)
        else:
            label_tensor = torch.zeros((512, 512), dtype=torch.long)

        return image_tensor, label_tensor, self.image_files[idx]

def apply_color_map(mask, color_map):
    """Apply color map to segmentation mask"""
    h, w = mask.shape
    colored_mask = np.zeros((h, w, 3), dtype=np.uint8)

    for class_id, color in color_map.items():
        colored_mask[mask == class_id] = color

    return colored_mask

def create_safe_dataloader(dataset, batch_size, shuffle=False, num_workers=None):
    """Create a DataLoader with safe num_workers handling"""
    if num_workers is None:
        num_workers = OPTIMAL_WORKERS

    if platform.system() == "Windows":
        num_workers = 0

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available() and num_workers == 0
    )

def compute_miou(preds, targets, num_classes=9):
    """Compute mean IoU and class-wise IoU"""
    iou_per_class = np.zeros(num_classes)
    preds = preds.flatten()
    targets = targets.flatten()

    for cls in range(num_classes):
        if (targets == cls).sum() == 0:
            iou_per_class[cls] = np.nan
            continue
        iou_per_class[cls] = jaccard_score(targets == cls, preds == cls, zero_division=0)

    return np.nanmean(iou_per_class), iou_per_class
def compute_entropy(prob_map):
    """Compute pixel-wise entropy as uncertainty measure"""
    entropy = -torch.sum(prob_map * torch.log(prob_map + 1e-10), dim=0)
    return entropy

def compute_class_performance_gaps(model, val_loader, device, num_classes=9):
    """Compute IoU-based performance gaps for each class with moderated weighting"""
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in val_loader:
            if len(batch) == 3:
                images, targets, _ = batch
            else:
                images, targets = batch

            images, targets = images.to(device), targets.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy().flatten())
            all_targets.extend(targets.cpu().numpy().flatten())

    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)

    # Calculate class-wise IoU and performance gaps
    class_gaps = np.zeros(num_classes)

    for cls in range(num_classes):
        if (all_targets == cls).sum() == 0:
            class_gaps[cls] = 0.1  # Minimum gap for classes without samples
            continue

        # Calculate IoU for this class
        tp = np.sum((all_preds == cls) & (all_targets == cls))
        fp = np.sum((all_preds == cls) & (all_targets != cls))
        fn = np.sum((all_preds != cls) & (all_targets == cls))

        if tp + fp + fn == 0:
            iou = 1.0  # Perfect if no predictions or targets
        else:
            iou = tp / (tp + fp + fn)

        # **CHANGE 1: Moderate the Weighting - Floor the gap**
        class_gaps[cls] = max(0.1, 1.0 - iou)  # Floor the gap to prevent zero weights

    return class_gaps


def compute_dynamic_weights(class_gaps, alpha=1.0, use_softer_weighting=True, regularize=True):
    """Compute dynamic weights based on performance gaps with improvements"""

    if use_softer_weighting:
        # **CHANGE 2: Use softer weighting - less aggressive**
        weighted_gaps = np.sqrt(class_gaps)  # Less aggressive than power
    else:
        # Apply exponential weighting (original approach)
        weighted_gaps = np.power(class_gaps, alpha)

    # Normalize to create probability distribution
    total_weighted = np.sum(weighted_gaps)
    if total_weighted == 0:
        # If all gaps are zero, use uniform weighting
        weights = np.ones(len(class_gaps)) / len(class_gaps)
    else:
        weights = weighted_gaps / total_weighted

    if regularize:
        # **CHANGE 4: Regularized Class Weights - Prevent extreme weighting**
        weights = np.clip(weights, a_min=0.1/len(class_gaps), a_max=2.0/len(class_gaps))
        # Renormalize after clipping
        weights = weights / weights.sum()

    return weights

def compute_hybrid_dctu_score(prob_map, dynamic_weights, entropy_weight=0.7, class_weight=0.3):
    """
    **CHANGE 2: Hybrid Approach - Combine entropy and class-awareness**
    Compute Hybrid DCTU score combining standard entropy and class-weighted uncertainty

    Args:
        prob_map: [C, H, W] probability map for each class
        dynamic_weights: [C] weights for each class based on performance gaps
        entropy_weight: Weight for standard entropy component
        class_weight: Weight for class-weighted component

    Returns:
        hybrid_score: scalar uncertainty score for the image
    """
    # Ensure weights are on the same device as prob_map
    if torch.is_tensor(dynamic_weights):
        weights = dynamic_weights.to(prob_map.device)
    else:
        weights = torch.tensor(dynamic_weights, dtype=prob_map.dtype, device=prob_map.device)

    # 1. Standard entropy score
    entropy_map = -torch.sum(prob_map * torch.log(prob_map + 1e-10), dim=0)  # [H, W]
    entropy_score = torch.mean(entropy_map)

    # 2. Class-weighted uncertainty score
    weighted_entropy = torch.zeros_like(entropy_map)
    for c in range(prob_map.size(0)):
        # Weight entropy by class probability and dynamic weight
        weighted_entropy += prob_map[c] * weights[c] * entropy_map
    class_weighted_score = torch.mean(weighted_entropy)

    # 3. Combine both scores
    hybrid_score = entropy_weight * entropy_score + class_weight * class_weighted_score

    return hybrid_score
def distribute_samples_by_class(total_samples, class_weights, min_per_class=1):
    """
    **CHANGE 3: Class Frequency Constraint**
    Distribute samples ensuring minimum representation for all classes
    """
    num_classes = len(class_weights)

    # Reserve minimum samples per class
    reserved_samples = min_per_class * num_classes
    remaining_samples = max(0, total_samples - reserved_samples)

    # Initialize with minimum samples per class
    samples_per_class = np.full(num_classes, min_per_class)

    # Distribute remaining samples based on weights
    if remaining_samples > 0:
        # Normalize weights for remaining samples
        normalized_weights = class_weights / np.sum(class_weights)
        additional_samples = (normalized_weights * remaining_samples).astype(int)

        # Handle rounding errors
        samples_allocated = np.sum(additional_samples)
        if samples_allocated < remaining_samples:
            # Add remaining samples to classes with highest fractional parts
            fractional_parts = (normalized_weights * remaining_samples) - additional_samples
            top_classes = np.argsort(fractional_parts)[-int(remaining_samples - samples_allocated):]
            additional_samples[top_classes] += 1

        samples_per_class += additional_samples

    return samples_per_class

def validate_model(model, val_loader, criterion, device):
    """Validate model and return comprehensive metrics"""
    model.eval()
    total_correct = 0
    total_pixels = 0
    total_loss = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in val_loader:
            if len(batch) == 3:
                images, targets, _ = batch
            else:
                images, targets = batch

            images, targets = images.to(device), targets.to(device)
            outputs = model(images)

            loss = criterion(outputs, targets)
            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            total_correct += (preds == targets).sum().item()
            total_pixels += targets.numel()

            all_preds.extend(preds.cpu().numpy().flatten())
            all_targets.extend(targets.cpu().numpy().flatten())

    accuracy = total_correct / total_pixels if total_pixels > 0 else 0.0
    avg_loss = total_loss / len(val_loader) if len(val_loader) > 0 else float('inf')

    # Compute mIoU
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    miou, class_ious = compute_miou(all_preds, all_targets)

    # Class-wise accuracy
    class_accuracies = []
    for cls in range(num_classes):
        mask = (all_targets == cls)
        if mask.sum() > 0:
            class_acc = np.mean(all_preds[mask] == all_targets[mask])
            class_accuracies.append(class_acc)
        else:
            class_accuracies.append(0.0)

    return accuracy, avg_loss, miou, class_ious, class_accuracies

def save_model(model, iteration_dir, iteration):
    """Save model with proper handling of DataParallel"""
    model_path = os.path.join(iteration_dir, f'model_iteration_{iteration}.pth')

    # Handle DataParallel model
    if isinstance(model, nn.DataParallel):
        model_state = model.module.state_dict()
    else:
        model_state = model.state_dict()

    # Save model state
    torch.save({
        'model_state_dict': model_state,
        'iteration': iteration,
        'num_classes': num_classes,
        'model_type': 'UNet'
    }, model_path)

    print(f"✓ Model saved: {model_path}")
    return model_path

def train_model_with_validation(train_loader, val_loader, model, criterion, optimizer,
                               num_epochs, device):
    """Train model with validation and early stopping - returns epochs trained"""
    model = model.to(device)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

   # early_stopping = EarlyStopping(patience=patience, min_delta=min_delta, restore_best_weights=True)

    train_losses = []
    val_accuracies = []
    val_losses = []
    val_mious = []

  #  print(f"Starting training with early stopping (patience={patience}, min_delta={min_delta})...")
    print(f"Starting training for {num_epochs} epochs...")

    epochs_trained = 0  # Track actual epochs trained

    for epoch in range(num_epochs):
        # Training
        model.train()
        epoch_loss = 0
        batch_count = 0

        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            batch_count += 1

        avg_epoch_loss = epoch_loss / max(1, batch_count)
        train_losses.append(avg_epoch_loss)

        # Validation
        val_acc, val_loss, val_miou, _, _ = validate_model(model, val_loader, criterion, device)
        val_accuracies.append(val_acc)
        val_losses.append(val_loss)
        val_mious.append(val_miou)

        epochs_trained = epoch + 1  # Update epochs trained

        if (epoch + 1) % 5 == 0 or epoch < 10:
            print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_epoch_loss:.4f}, "
                  f"Val Acc: {val_acc:.4f}, Val mIoU: {val_miou:.4f}")

    print(f"Training completed after {epochs_trained} epochs")
    return model, train_losses, val_accuracies, val_losses, val_mious, epochs_trained


def predict_with_dctu(model, data_loader, device, dynamic_weights):
    """Predict with DCTU-based uncertainty estimation"""
    model.eval()
    predictions = []
    uncertainties = []
    image_filenames = []

    with torch.no_grad():
        for batch in data_loader:
            if len(batch) == 3:
                images, _, filenames = batch
            else:
                images, _ = batch
                filenames = [f"batch_image_{i}" for i in range(images.size(0))]

            images = images.to(device)
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)

            # Compute DCTU score for each image in the batch
            batch_dctu_scores = []
            for i in range(probs.size(0)):
                dctu_score = compute_dctu_score(probs[i], dynamic_weights)
                batch_dctu_scores.append(dctu_score)

            predictions.extend(probs.cpu())
            uncertainties.extend(batch_dctu_scores)

            if isinstance(filenames, (list, tuple)):
                image_filenames.extend(filenames)
            else:
                image_filenames.append(filenames)

    return predictions, uncertainties, image_filenames



def perform_sample_selection_dctu_improved(model, pool_image_dir, pool_label_dir, target_image_dir,
                                         target_label_dir, selected_samples_dir, samples_per_iteration,
                                         device, iteration, val_loader, batch_size=1, alpha=1.0,
                                         use_hybrid=True, use_class_constraint=True):
    """
    **IMPROVED DCTU-based Sample Selection with all enhancements**
    """
    print(f"\n--- Enhanced DCTU-based Sample Selection for Iteration {iteration + 1} ---")

    # Step 1: Compute current class performance gaps (with moderated weighting)
    print("Computing class performance gaps...")
    class_gaps = compute_class_performance_gaps(model, val_loader, device)

    # Step 2: Compute dynamic weights (with regularization and softer weighting)
    dynamic_weights = compute_dynamic_weights(class_gaps, alpha,
                                            use_softer_weighting=True,
                                            regularize=True)

    print("Enhanced Class Performance Analysis:")
    for i, (gap, weight) in enumerate(zip(class_gaps, dynamic_weights)):
        print(f"  {class_names[i]}: Gap={gap:.4f}, Weight={weight:.4f}")

    # Create directories
    os.makedirs(selected_samples_dir, exist_ok=True)
    selected_images_dir = os.path.join(selected_samples_dir, 'images')
    selected_labels_dir = os.path.join(selected_samples_dir, 'labels')
    os.makedirs(selected_images_dir, exist_ok=True)
    os.makedirs(selected_labels_dir, exist_ok=True)

    # Get list of available pool samples
    pool_files = [f for f in os.listdir(pool_image_dir) if f.endswith(('.jpg', '.png', '.tif'))]
    print(f"Available pool samples: {len(pool_files)}")

    if len(pool_files) == 0:
        print("✗ No samples available in pool!")
        return 0, []

    if len(pool_files) < samples_per_iteration:
        print(f"⚠ Only {len(pool_files)} samples available, selecting all")
        samples_per_iteration = len(pool_files)

    # Create pool dataset and loader
    pool_dataset = ValidationDataset(pool_image_dir, pool_label_dir)
    pool_loader = create_safe_dataloader(pool_dataset, batch_size=batch_size, shuffle=False)

    # Get predictions and uncertainties
    print("Computing enhanced uncertainties...")
    model.eval()
    uncertainties = []
    image_filenames = []

    with torch.no_grad():
        for batch in pool_loader:
            if len(batch) == 3:
                images, _, filenames = batch
            else:
                images, _ = batch
                filenames = [f"batch_image_{i}" for i in range(images.size(0))]

            images = images.to(device)
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)

            # Compute uncertainty score for each image in the batch
            batch_uncertainty_scores = []
            for i in range(probs.size(0)):
                if use_hybrid:
                    # Use hybrid approach
                    uncertainty_score = compute_hybrid_dctu_score(probs[i], dynamic_weights)
                else:
                    # Use standard DCTU
                    uncertainty_score = compute_dctu_score(probs[i], dynamic_weights)
                batch_uncertainty_scores.append(uncertainty_score)

            uncertainties.extend(batch_uncertainty_scores)

            if isinstance(filenames, (list, tuple)):
                image_filenames.extend(filenames)
            else:
                image_filenames.append(filenames)

    # Sample selection strategy
    if use_class_constraint:
        # **CHANGE 3: Use class frequency constraint**
        samples_per_class = distribute_samples_by_class(samples_per_iteration, dynamic_weights)
        print(f"Target samples per class: {dict(zip(class_names, samples_per_class))}")

        # For now, fall back to standard selection but this can be enhanced
        # to actually enforce per-class selection
        uncertainty_pairs = list(zip(uncertainties, image_filenames))
        uncertainty_pairs.sort(key=lambda x: x[0].item() if torch.is_tensor(x[0]) else x[0], reverse=True)
        selected_files = [filename for _, filename in uncertainty_pairs[:samples_per_iteration]]
        selected_uncertainties = [unc for unc, _ in uncertainty_pairs[:samples_per_iteration]]
    else:
        # Standard uncertainty-based selection
        uncertainty_pairs = list(zip(uncertainties, image_filenames))
        uncertainty_pairs.sort(key=lambda x: x[0].item() if torch.is_tensor(x[0]) else x[0], reverse=True)
        selected_files = [filename for _, filename in uncertainty_pairs[:samples_per_iteration]]
        selected_uncertainties = [unc for unc, _ in uncertainty_pairs[:samples_per_iteration]]

    print(f"Top 5 uncertainty scores: {[u.item() if torch.is_tensor(u) else u for u in selected_uncertainties[:5]]}")

    # Move selected samples (rest of the function remains the same as original)
    # Get already existing files to avoid duplicates
    existing_train_files = set()
    if os.path.exists(target_image_dir):
        existing_train_files = set([f for f in os.listdir(target_image_dir) if f.endswith(('.jpg', '.png', '.tif'))])

    # Move selected samples
    moved_count = 0
    successfully_moved = []
    skipped_files = []
    failed_moves = []

    selection_log = []  # For tracking selections

    for i, filename in enumerate(selected_files):
        if filename in existing_train_files:
            skipped_files.append(filename)
            continue

        # Source paths
        src_image = os.path.join(pool_image_dir, filename)
        src_label = os.path.join(pool_label_dir, filename)

        # Destination paths (training)
        dst_image = os.path.join(target_image_dir, filename)
        dst_label = os.path.join(target_label_dir, filename)

        # Selected samples paths (for tracking)
        sel_image = os.path.join(selected_images_dir, filename)
        sel_label = os.path.join(selected_labels_dir, filename)

        try:
            # Verify source files exist
            if not os.path.exists(src_image) or not os.path.exists(src_label):
                failed_moves.append(f"{filename} (missing source)")
                continue

            # Copy to selected samples directory (for tracking)
            shutil.copy2(src_image, sel_image)
            shutil.copy2(src_label, sel_label)

            # Move to training directory
            shutil.move(src_image, dst_image)
            shutil.move(src_label, dst_label)

            moved_count += 1
            successfully_moved.append(filename)

            # Log selection details
            uncertainty_score = selected_uncertainties[i]
            unc_value = uncertainty_score.item() if torch.is_tensor(uncertainty_score) else uncertainty_score

            selection_log.append({
                'filename': filename,
                'uncertainty_score': unc_value,
                'iteration': iteration + 1,
                'selection_method': 'Enhanced_DCTU',
                'alpha': alpha,
                'use_hybrid': use_hybrid,
                'use_class_constraint': use_class_constraint
            })

            print(f"✓ Moved: {filename} (Uncertainty score: {unc_value:.4f})")

        except Exception as e:
            failed_moves.append(f"{filename} (error: {str(e)})")
            print(f"✗ Failed to move {filename}: {e}")

    # Save enhanced selection log
    if selection_log:
        log_df = pd.DataFrame(selection_log)
        log_path = os.path.join(selected_samples_dir, f'enhanced_dctu_selection_log_iteration_{iteration + 1}.csv')
        log_df.to_csv(log_path, index=False)
        print(f"✓ Enhanced DCTU selection log saved: {log_path}")

        # Also save class analysis
        class_analysis = pd.DataFrame({
            'class_name': class_names,
            'performance_gap': class_gaps,
            'dynamic_weight': dynamic_weights
        })
        class_analysis_path = os.path.join(selected_samples_dir, f'enhanced_class_analysis_iteration_{iteration + 1}.csv')
        class_analysis.to_csv(class_analysis_path, index=False)
        print(f"✓ Enhanced class analysis saved: {class_analysis_path}")

    # Summary
    print(f"\n--- Enhanced DCTU Selection Summary ---")
    print(f"Alpha parameter: {alpha}")
    print(f"Hybrid approach: {use_hybrid}")
    print(f"Class constraint: {use_class_constraint}")
    print(f"Requested: {samples_per_iteration}")
    print(f"Successfully moved: {moved_count}")
    print(f"Skipped (duplicates): {len(skipped_files)}")
    print(f"Failed: {len(failed_moves)}")

    # Update pool size
    remaining_pool = len([f for f in os.listdir(pool_image_dir) if f.endswith(('.jpg', '.png', '.tif'))])
    print(f"Remaining in pool: {remaining_pool}")

    return moved_count, successfully_moved
def create_test_prediction_heatmaps(model, test_loader, device, save_dir, iteration, class_names):
    """Create comprehensive test prediction heatmaps and visualizations"""
    model.eval()

    # Create subdirectory for test predictions
    test_pred_dir = os.path.join(save_dir, 'test_predictions')
    os.makedirs(test_pred_dir, exist_ok=True)

    # 1. Class-wise confidence distribution
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    axes = axes.flatten()

    class_confidences = [[] for _ in range(len(class_names))]

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            if batch_idx >= 20:  # Limit to first 10 batches
                break

            if len(batch) == 3:
                images, targets, _ = batch
            else:
                images, targets = batch

            images = images.to(device)
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)

            # Collect confidence scores for each class
            for cls in range(len(class_names)):
                class_prob = probs[:, cls, :, :].cpu().numpy()
                class_confidences[cls].extend(class_prob.flatten())

    # Plot confidence distributions
    for i, class_name in enumerate(class_names):
        if i < len(axes) and class_confidences[i]:
            conf_array = np.array(class_confidences[i])
            axes[i].hist(conf_array, bins=50, alpha=0.7, color=plt.cm.tab10(i))
            axes[i].set_title(f'{class_name}\nMean: {conf_array.mean():.3f}', fontweight='bold')
            axes[i].set_xlabel('Confidence Score')
            axes[i].set_ylabel('Frequency')
            axes[i].grid(True, alpha=0.3)

    plt.suptitle(f'Test Set Class Confidence Distributions - Iteration {iteration}', fontsize=16)
    plt.tight_layout()
    plt.savefig(os.path.join(test_pred_dir, f'confidence_distributions_iter_{iteration}.png'),
                dpi=300, bbox_inches='tight')
    plt.close()

    # 2. Sample predictions visualization
    model.eval()
    sample_count = 0

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            if sample_count >= 20:  # Show 10 samples
                break

            if len(batch) == 3:
                images, targets, filenames = batch
            else:
                images, targets = batch
                filenames = [f"sample_{batch_idx}_{i}" for i in range(images.size(0))]

            images = images.to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)

            for i in range(min(2, images.size(0))):  # Max 2 per batch
                if sample_count >= 15:
                    break

                # Create subplot for this sample
                fig, axes = plt.subplots(1, 3, figsize=(15, 5))

                # Original image
                img = images[i].cpu().permute(1, 2, 0).numpy()
                axes[0].imshow(img)
                axes[0].set_title('Original Image')
                axes[0].axis('off')

                # Ground truth
                gt = targets[i].cpu().numpy()
                gt_colored = apply_color_map(gt, COLOR_MAP)
                axes[1].imshow(gt_colored)
                axes[1].set_title('Ground Truth')
                axes[1].axis('off')

                # Prediction
                pred = predictions[i].cpu().numpy()
                pred_colored = apply_color_map(pred, COLOR_MAP)
                im2 = axes[2].imshow(pred_colored)
                axes[2].set_title('Prediction')
                axes[2].axis('off')

                # Legend
                legend_elements = [Patch(facecolor=np.array(COLOR_MAP[i])/255.0, label=class_names[i]) for i in range(len(class_names))]
                fig.legend(handles=legend_elements, loc='lower center', ncol=3, bbox_to_anchor=(0.5, -0.05))

                filename = filenames[i] if isinstance(filenames[i], str) else f"sample_{sample_count}"
                plt.suptitle(f'Test Sample: {filename} - Iteration {iteration}')
                plt.tight_layout()
                plt.savefig(os.path.join(test_pred_dir, f'sample_{sample_count}_iter_{iteration}.png'),
                           dpi=300, bbox_inches='tight')
                plt.close()

                sample_count += 1


def verify_data_consistency(image_dir, label_dir, dataset_name="Dataset"):
    """Verify that images and labels are properly paired"""
    if not os.path.exists(image_dir) or not os.path.exists(label_dir):
        print(f" {dataset_name}: Directories not found")
        return False

    image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png', '.tif'))])
    label_files = sorted([f for f in os.listdir(label_dir) if f.endswith(('.jpg', '.png', '.tif'))])

    print(f"{dataset_name}: {len(image_files)} images, {len(label_files)} labels")

    if len(image_files) != len(label_files):
        print(f" {dataset_name}: Mismatch in number of images and labels")
        return False

    # Check if files match
    missing_pairs = []
    for img_file in image_files:
        if img_file not in label_files:
            missing_pairs.append(img_file)

    if missing_pairs:
        print(f" {dataset_name}: Missing label files: {missing_pairs[:5]}...")
        return False

    print(f"✓ {dataset_name}: Data consistency verified")
    return True

def plot_metrics(metrics_history, save_dir, iteration):
    """Plot training metrics with improved visualization"""
    if not metrics_history:
        return

    # Create metrics directory
    metrics_dir = os.path.join(save_dir, 'metrics')
    os.makedirs(metrics_dir, exist_ok=True)

    # Plot comprehensive metrics
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    iterations = list(range(1, len(metrics_history) + 1))

    # Extract metrics
    train_losses = [m.get('train_loss', 0) for m in metrics_history]
    val_accuracies = [m.get('val_accuracy', 0) for m in metrics_history]
    val_losses = [m.get('val_loss', 0) for m in metrics_history]
    val_mious = [m.get('val_miou', 0) for m in metrics_history]
    samples_added = [m.get('samples_added', 0) for m in metrics_history]
    epochs_trained = [m.get('epochs_trained', 0) for m in metrics_history]

    # Plot 1: Training Loss
    axes[0, 0].plot(iterations, train_losses, 'b-o', linewidth=2, markersize=6)
    axes[0, 0].set_title('Training Loss', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Iteration')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].grid(True, alpha=0.3)

    # Plot 2: Validation Accuracy
    axes[0, 1].plot(iterations, val_accuracies, 'g-o', linewidth=2, markersize=6)
    axes[0, 1].set_title('Validation Accuracy', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Iteration')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].grid(True, alpha=0.3)

    # Plot 3: Validation Loss
    axes[0, 2].plot(iterations, val_losses, 'r-o', linewidth=2, markersize=6)
    axes[0, 2].set_title('Validation Loss', fontsize=14, fontweight='bold')
    axes[0, 2].set_xlabel('Iteration')
    axes[0, 2].set_ylabel('Loss')
    axes[0, 2].grid(True, alpha=0.3)

    # Plot 4: Validation mIoU
    axes[1, 0].plot(iterations, val_mious, 'purple', marker='o', linewidth=2, markersize=6)
    axes[1, 0].set_title('Validation mIoU', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Iteration')
    axes[1, 0].set_ylabel('mIoU')
    axes[1, 0].grid(True, alpha=0.3)

    # Plot 5: Samples Added
    axes[1, 1].bar(iterations, samples_added, color='orange', alpha=0.7)
    axes[1, 1].set_title('Samples Added per Iteration', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Iteration')
    axes[1, 1].set_ylabel('Number of Samples')
    axes[1, 1].grid(True, alpha=0.3)

    # Plot 6: Epochs Trained
    axes[1, 2].bar(iterations, epochs_trained, color='brown', alpha=0.7)
    axes[1, 2].set_title('Epochs Trained per Iteration', fontsize=14, fontweight='bold')
    axes[1, 2].set_xlabel('Iteration')
    axes[1, 2].set_ylabel('Number of Epochs')
    axes[1, 2].grid(True, alpha=0.3)

    plt.suptitle(f'Active Learning Progress - Up to Iteration {iteration}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(metrics_dir, f'metrics_iteration_{iteration}.png'),
                dpi=300, bbox_inches='tight')
    plt.close()

    # Additional plot: Combined accuracy and mIoU
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    ax.plot(iterations, val_accuracies, 'g-o', label='Validation Accuracy', linewidth=2, markersize=6)
    ax.plot(iterations, val_mious, 'purple', marker='s', label='Validation mIoU', linewidth=2, markersize=6)
    ax.set_title('Model Performance Over Iterations', fontsize=14, fontweight='bold')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Score')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(metrics_dir, f'performance_comparison_iteration_{iteration}.png'),
                dpi=300, bbox_inches='tight')
    plt.close()

def save_detailed_results(results, save_dir, iteration):
    """Save detailed results to CSV and create summary report"""
    results_dir = os.path.join(save_dir, 'results')
    os.makedirs(results_dir, exist_ok=True)

    # Save detailed metrics
    metrics_df = pd.DataFrame(results)
    metrics_path = os.path.join(results_dir, f'detailed_metrics_iteration_{iteration}.csv')
    metrics_df.to_csv(metrics_path, index=False)

    # Create summary report
    summary_path = os.path.join(results_dir, f'summary_report_iteration_{iteration}.txt')
    with open(summary_path, 'w') as f:
        f.write(f"Active Learning Summary - Iteration {iteration}\n")
        f.write("=" * 50 + "\n\n")

        if results:
            latest = results[-1]
            f.write(f"Current Performance:\n")
            f.write(f"  - Validation Accuracy: {latest.get('val_accuracy', 0):.4f}\n")
            f.write(f"  - Validation mIoU: {latest.get('val_miou', 0):.4f}\n")
            f.write(f"  - Training Loss: {latest.get('train_loss', 0):.4f}\n")
            f.write(f"  - Validation Loss: {latest.get('val_loss', 0):.4f}\n")
            f.write(f"  - Samples Added: {latest.get('samples_added', 0)}\n")
            f.write(f"  - Epochs Trained: {latest.get('epochs_trained', 0)}\n\n")

            # Class-wise performance
            if 'class_accuracies' in latest:
                f.write("Class-wise Accuracies:\n")
                for i, acc in enumerate(latest['class_accuracies']):
                    f.write(f"  - {class_names[i]}: {acc:.4f}\n")
                f.write("\n")

            if 'class_ious' in latest:
                f.write("Class-wise IoUs:\n")
                for i, iou in enumerate(latest['class_ious']):
                    if not np.isnan(iou):
                        f.write(f"  - {class_names[i]}: {iou:.4f}\n")
                f.write("\n")

        # Progress summary
        if len(results) > 1:
            f.write("Progress Summary:\n")
            first_acc = results[0].get('val_accuracy', 0)
            latest_acc = results[-1].get('val_accuracy', 0)
            acc_improvement = latest_acc - first_acc

            first_miou = results[0].get('val_miou', 0)
            latest_miou = results[-1].get('val_miou', 0)
            miou_improvement = latest_miou - first_miou

            f.write(f"  - Accuracy Improvement: {acc_improvement:+.4f}\n")
            f.write(f"  - mIoU Improvement: {miou_improvement:+.4f}\n")

            total_samples = sum(r.get('samples_added', 0) for r in results)
            f.write(f"  - Total Samples Added: {total_samples}\n")

    print(f"✓ Detailed results saved: {results_dir}")

def calculate_class_weights(train_images_dir, train_labels_dir, num_classes=9):
    """Calculate class weights based on pixel frequency in training data"""
    print("Calculating class weights...")

    class_counts = np.zeros(num_classes)
    total_pixels = 0

    label_files = [f for f in os.listdir(train_labels_dir) if f.endswith(('.jpg', '.png', '.tif'))]

    for label_file in label_files:
        label_path = os.path.join(train_labels_dir, label_file)
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        if label is not None:
            label = cv2.resize(label, (512, 512), interpolation=cv2.INTER_NEAREST)
            label = np.clip(label, 0, num_classes-1)

            # Count pixels for each class
            for cls in range(num_classes):
                class_counts[cls] += np.sum(label == cls)
            total_pixels += label.size

    # Calculate weights (inverse frequency)
    class_weights = total_pixels / (num_classes * class_counts + 1e-10)  # Add small epsilon to avoid division by zero

    # Normalize weights
    class_weights = class_weights / class_weights.sum() * num_classes

    print("Class weights:")
    for i, weight in enumerate(class_weights):
        print(f"  {class_names[i]}: {weight:.4f}")

    return torch.FloatTensor(class_weights)
def main():
    """Main active learning pipeline"""

    # System setup
    print("ACTIVE LEARNING PIPELINE FOR SEMANTIC SEGMENTATION")
    print("=" * 60)

    global OPTIMAL_WORKERS
    OPTIMAL_WORKERS = check_system_capabilities()

    # Configure device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nUsing device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

    # Directory setup
    BASE_DIR = r"D:\DATA\data\data"
    RESULTS_DIR = os.path.join(BASE_DIR, "results_dcau2")

    # Data directories
    TRAIN_IMAGES = os.path.join(BASE_DIR, "train_data" )
    TRAIN_LABELS = os.path.join(BASE_DIR, "train_labels" )
    POOL_IMAGES = os.path.join(BASE_DIR, "Unlabeled_data")
    POOL_LABELS = os.path.join(BASE_DIR, "Validation_labels")
    VAL_IMAGES = os.path.join(BASE_DIR, "val_img")
    VAL_LABELS = os.path.join(BASE_DIR, "val_lab")
    TEST_IMAGES = os.path.join(BASE_DIR, "test_img")
    TEST_LABELS = os.path.join(BASE_DIR, "test_lab")

    # Create directories
    for dir_path in [RESULTS_DIR, TRAIN_IMAGES, TRAIN_LABELS, POOL_IMAGES, POOL_LABELS,
                     VAL_IMAGES, VAL_LABELS, TEST_IMAGES, TEST_LABELS]:
        os.makedirs(dir_path, exist_ok=True)

    # Verify data consistency
    print("\n--- Data Verification ---")
    train_ok = verify_data_consistency(TRAIN_IMAGES, TRAIN_LABELS, "Training")
    pool_ok = verify_data_consistency(POOL_IMAGES, POOL_LABELS, "Pool")
    val_ok = verify_data_consistency(VAL_IMAGES, VAL_LABELS, "Validation")
    test_ok = verify_data_consistency(TEST_IMAGES, TEST_LABELS, "Test")

    if not all([train_ok, pool_ok, val_ok, test_ok]):
        print("Data consistency issues found. Please check your data directories.")
        return

    # Active learning parameters
    MAX_ITERATIONS = 15
    SAMPLES_PER_ITERATION = 25
    EPOCHS_PER_ITERATION = 25
    epochs_trained = EPOCHS_PER_ITERATION
    BATCH_SIZE = 8
    LEARNING_RATE = 0.0001
    # PATIENCE = 20
    # MIN_DELTA = 0.001

    print(f"\n--- Active Learning Configuration ---")
    print(f"Max Iterations: {MAX_ITERATIONS}")
    print(f"Samples per Iteration: {SAMPLES_PER_ITERATION}")
    print(f"Epochs per Iteration: {EPOCHS_PER_ITERATION}")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"Learning Rate: {LEARNING_RATE}")
#    print(f"Early Stopping Patience: {PATIENCE}")

    # Initialize metrics tracking
    metrics_history = []

    # Create validation dataset (remains constant)
    val_dataset = SegmentationDataset(VAL_IMAGES, VAL_LABELS)
    val_loader = create_safe_dataloader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Create test dataset (for evaluation)
    test_dataset = ValidationDataset(TEST_IMAGES, TEST_LABELS)
    test_loader = create_safe_dataloader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print(f"\nValidation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")

    # Active learning loop
    for iteration in range(MAX_ITERATIONS):
        print(f"\n{'='*60}")
        print(f"ACTIVE LEARNING ITERATION {iteration + 1}/{MAX_ITERATIONS}")
        print(f"{'='*60}")

        # Create iteration directory
        iteration_dir = os.path.join(RESULTS_DIR, f"iteration_{iteration + 1}")
        os.makedirs(iteration_dir, exist_ok=True)

        # Check if we have training data
        train_files = [f for f in os.listdir(TRAIN_IMAGES) if f.endswith(('.jpg', '.png', '.tif'))]
        if len(train_files) == 0:
            print(" No training samples available. Skipping iteration.")
            break

        print(f"Current training samples: {len(train_files)}")

        # Create training dataset and loader
        train_dataset = SegmentationDataset(TRAIN_IMAGES, TRAIN_LABELS)
        train_loader = create_safe_dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

        # Initialize model
        model = UNet(num_classes=num_classes)
        print("\n--- Calculating Initial Class Weights ---")
        class_weights = calculate_class_weights(TRAIN_IMAGES, TRAIN_LABELS, num_classes)

        class_weights = class_weights.to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
       # criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

        # Train model
        print(f"\n--- Training Model ---")
        start_time = time.time()
        model, train_losses, val_accuracies, val_losses, val_mious, epochs_trained = train_model_with_validation(
         train_loader, val_loader, model, criterion, optimizer,
         EPOCHS_PER_ITERATION, device )
        training_time = time.time() - start_time
        print(f"Training completed in {training_time:.2f} seconds")

        # Save model
        model_path = save_model(model, iteration_dir, iteration + 1)

        # Evaluate on test set
        print(f"\n--- Test Evaluation ---")
        test_acc, test_loss, test_miou, test_class_ious, test_class_accuracies = validate_model(
            model, test_loader, criterion, device
        )

        print(f"Test Accuracy: {test_acc:.4f}")
        print(f"Test mIoU: {test_miou:.4f}")
        print(f"Test Loss: {test_loss:.4f}")

        # Create test predictions and visualizations
        create_test_prediction_heatmaps(model, test_loader, device, iteration_dir, iteration + 1, class_names)

        # Store metrics
        iteration_metrics = {
            'iteration': iteration + 1,
            'train_loss': train_losses[-1] if train_losses else 0,
            'val_accuracy': val_accuracies[-1] if val_accuracies else 0,
            'val_loss': val_losses[-1] if val_losses else 0,
            'val_miou': val_mious[-1] if val_mious else 0,
            'test_accuracy': test_acc,
            'test_loss': test_loss,
            'test_miou': test_miou,
            'epochs_trained': epochs_trained,
            'class_accuracies': test_class_accuracies,
            'class_ious': test_class_ious.tolist() if hasattr(test_class_ious, 'tolist') else test_class_ious,
            'class_weights': class_weights.cpu().tolist(),
            'samples_added': 0  # Will be updated after sample selection
        }

        # Sample selection for next iteration (if not last iteration)
        if iteration < MAX_ITERATIONS - 1:
            print(f"\n--- Sample Selection ---")
            selected_samples_dir = os.path.join(iteration_dir, "selected_samples")

            samples_added, selected_files = perform_sample_selection_dctu_improved(
                model, POOL_IMAGES, POOL_LABELS, TRAIN_IMAGES, TRAIN_LABELS,
                selected_samples_dir, SAMPLES_PER_ITERATION, device, iteration, val_loader, BATCH_SIZE
            )

            iteration_metrics['samples_added'] = samples_added

            if samples_added == 0:
                print(" No samples could be added. Stopping active learning.")
                break

        metrics_history.append(iteration_metrics)

        # Plot and save results
        plot_metrics(metrics_history, RESULTS_DIR, iteration + 1)
        save_detailed_results(metrics_history, RESULTS_DIR, iteration + 1)

        # Memory cleanup
        del model, train_loader, train_dataset
        torch.cuda.empty_cache()
        gc.collect()

        print(f"\n--- Iteration {iteration + 1} Complete ---")
        print(f"Training samples: {len(train_files)}")
        print(f"Validation Accuracy: {iteration_metrics['val_accuracy']:.4f}")
        print(f"Test Accuracy: {iteration_metrics['test_accuracy']:.4f}")
        print(f"Test mIoU: {iteration_metrics['test_miou']:.4f}")

    # Final summary
    print(f"\n{'='*60}")
    print("ACTIVE LEARNING COMPLETED")
    print(f"{'='*60}")

    if metrics_history:
        print(f"Total iterations: {len(metrics_history)}")
        print(f"Final test accuracy: {metrics_history[-1]['test_accuracy']:.4f}")
        print(f"Final test mIoU: {metrics_history[-1]['test_miou']:.4f}")

        # Calculate improvement
        if len(metrics_history) > 1:
            acc_improvement = metrics_history[-1]['test_accuracy'] - metrics_history[0]['test_accuracy']
            miou_improvement = metrics_history[-1]['test_miou'] - metrics_history[0]['test_miou']
            print(f"Accuracy improvement: {acc_improvement:+.4f}")
            print(f"mIoU improvement: {miou_improvement:+.4f}")

    print(f"\nResults saved in: {RESULTS_DIR}")
    print("Active learning pipeline completed successfully!")

if __name__ == "__main__":
    main()

ACTIVE LEARNING PIPELINE FOR SEMANTIC SEGMENTATION
SYSTEM CAPABILITY CHECK
Operating System: Windows 11
Physical cores: 24
Logical cores: 32
Total RAM: 127.72 GB
CUDA available: Yes, 1 GPU(s)
Recommended num_workers: 0

Using device: cuda
GPU: NVIDIA GeForce RTX 3060
GPU Memory: 12.0 GB

--- Data Verification ---
Training: 512 images, 512 labels
✓ Training: Data consistency verified
Pool: 2414 images, 2414 labels
✓ Pool: Data consistency verified
Validation: 250 images, 250 labels
✓ Validation: Data consistency verified
Test: 250 images, 250 labels
✓ Test: Data consistency verified

--- Active Learning Configuration ---
Max Iterations: 15
Samples per Iteration: 25
Epochs per Iteration: 25
Batch Size: 8
Learning Rate: 0.0001

Validation samples: 250
Test samples: 250

ACTIVE LEARNING ITERATION 1/15
Current training samples: 512

--- Calculating Initial Class Weights ---
Calculating class weights...
Class weights:
  Background: 7.4063
  Bareland: 0.7718
  Rangeland: 0.0549
  Developed Sp