In [None]:
# Enhancements to Active Learning Pipeline


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import os
import cv2
import time
import shutil
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score
import warnings
warnings.filterwarnings("ignore")

# Class Names
class_names = [
    "Background", "Bareland", "Rangeland", "Developed Space", "Road",
    "Tree", "Water", "Agriculture Land", "Building"
]
num_classes = len(class_names)

# UNet Model Definition
class UNet(nn.Module):
    def __init__(self, num_classes=9):
        super(UNet, self).__init__()
        self.num_classes = num_classes
        self.contracting_11 = self.conv_block(3, 64)
        self.contracting_12 = nn.MaxPool2d(2, 2)
        self.contracting_21 = self.conv_block(64, 128)
        self.contracting_22 = nn.MaxPool2d(2, 2)
        self.contracting_31 = self.conv_block(128, 256)
        self.contracting_32 = nn.MaxPool2d(2, 2)
        self.contracting_41 = self.conv_block(256, 512)
        self.contracting_42 = nn.MaxPool2d(2, 2)
        self.middle = self.conv_block(512, 1024)
        self.expansive_11 = nn.ConvTranspose2d(1024, 512, 3, 2, 1, 1)
        self.expansive_12 = self.conv_block(1024, 512)
        self.expansive_21 = nn.ConvTranspose2d(512, 256, 3, 2, 1, 1)
        self.expansive_22 = self.conv_block(512, 256)
        self.expansive_31 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)
        self.expansive_32 = self.conv_block(256, 128)
        self.expansive_41 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1)
        self.expansive_42 = self.conv_block(128, 64)
        self.output = nn.Conv2d(64, num_classes, 3, 1, 1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )
    
    def forward(self, x):
        c1 = self.contracting_11(x)
        p1 = self.contracting_12(c1)
        c2 = self.contracting_21(p1)
        p2 = self.contracting_22(c2)
        c3 = self.contracting_31(p2)
        p3 = self.contracting_32(c3)
        c4 = self.contracting_41(p3)
        p4 = self.contracting_42(c4)
        middle = self.middle(p4)
        u1 = self.expansive_11(middle)
        u1 = self.expansive_12(torch.cat((u1, c4), dim=1))
        u2 = self.expansive_21(u1)
        u2 = self.expansive_22(torch.cat((u2, c3), dim=1))
        u3 = self.expansive_31(u2)
        u3 = self.expansive_32(torch.cat((u3, c2), dim=1))
        u4 = self.expansive_41(u3)
        u4 = self.expansive_42(torch.cat((u4, c1), dim=1))
        output = self.output(u4)
        return output

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_paths = sorted([os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.endswith(('.jpg', '.png', '.tif'))])
        self.mask_paths = sorted([os.path.join(mask_dir, img) for img in os.listdir(mask_dir) if img.endswith(('.jpg', '.png', '.tif'))])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image with error checking
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]
        
        # Check if files exist
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image file not found: {image_path}")
        
        if not os.path.exists(mask_path):
            raise FileNotFoundError(f"Mask file not found: {mask_path}")
        
        # Load images
        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Failed to load image: {image_path}")
            
        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        if mask is None:
            raise ValueError(f"Failed to load mask: {mask_path}")
        
        # Check image dimensions
        if image.size == 0:
            raise ValueError(f"Empty image: {image_path}, shape: {image.shape}")
        
        if mask.size == 0:
            raise ValueError(f"Empty mask: {mask_path}, shape: {mask.shape}")
        
        # Resize with proper error handling
        try:
            image = cv2.resize(image, (512, 512))
            mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)
        except cv2.error as e:
            print(f"Resize error on image {image_path}, shape: {image.shape}")
            print(f"Resize error on mask {mask_path}, shape: {mask.shape}")
            raise e
        
        # Convert to tensor
        image = torch.tensor(image).permute(2, 0, 1).float() / 255.0
        mask = torch.tensor(mask, dtype=torch.long)
        mask = torch.clamp(mask, 0, 8)  # Ensure mask values are within valid range
        
        return image, mask

# Similarly update ValidationDataset.__getitem__ with error handling
class ValidationDataset(Dataset):
    def __init__(self, image_dir, label_dir=None):
        self.image_paths = sorted([os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.endswith(('.jpg', '.png', '.tif'))])
        self.has_labels = label_dir is not None
        
        if self.has_labels:
            self.label_paths = sorted([os.path.join(label_dir, img) for img in os.listdir(label_dir) if img.endswith(('.jpg', '.png', '.tif'))])
        else:
            self.label_paths = [None] * len(self.image_paths)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        
        # Check if file exists
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image file not found: {image_path}")
        
        # Load image
        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Failed to load image: {image_path}")
            
        # Check image dimensions
        if image.size == 0:
            raise ValueError(f"Empty image: {image_path}, shape: {image.shape}")
        
        # Resize with proper error handling
        try:
            image = cv2.resize(image, (512, 512))
        except cv2.error as e:
            print(f"Resize error on image {image_path}, shape: {image.shape}")
            raise e
            
        image_tensor = torch.tensor(image).permute(2, 0, 1).float() / 255.0
        
        if self.has_labels and self.label_paths[idx]:
            label_path = self.label_paths[idx]
            
            if not os.path.exists(label_path):
                print(f"Warning: Label file not found: {label_path}, using zeros")
                label_tensor = torch.zeros((512, 512), dtype=torch.long)
            else:
                label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
                
                if label is None:
                    print(f"Warning: Failed to load label: {label_path}, using zeros")
                    label_tensor = torch.zeros((512, 512), dtype=torch.long)
                else:
                    try:
                        label = cv2.resize(label, (512, 512), interpolation=cv2.INTER_NEAREST)
                        label_tensor = torch.tensor(label, dtype=torch.long)
                    except cv2.error as e:
                        print(f"Resize error on label {label_path}, shape: {label.shape}")
                        label_tensor = torch.zeros((512, 512), dtype=torch.long)
        else:
            label_tensor = torch.zeros((512, 512), dtype=torch.long)
        
        return image_tensor, label_tensor, self.image_paths[idx]


def validate_dataset_files(dataset_dir, type_str="images"):
    print(f"Validating {type_str} in {dataset_dir}...")
    invalid_files = []

    for file in os.listdir(dataset_dir):
        if file.endswith(('.jpg', '.png', '.tif')):
            file_path = os.path.join(dataset_dir, file)

            try:
                if type_str in ["masks", "labels"]:
                    img = cv2.imread(file_path, cv2.IMREAD_UNCHANGED)
                else:
                    img = cv2.imread(file_path)

                if img is None or img.size == 0:
                    print(f"Issue with {file_path}, shape: {getattr(img, 'shape', 'N/A')}")
                    invalid_files.append(file_path)
                    continue

                # Try resizing
                resize_mode = cv2.INTER_NEAREST if type_str in ["masks", "labels"] else cv2.INTER_LINEAR
                cv2.resize(img, (512, 512), interpolation=resize_mode)

            except Exception as e:
                print(f"Error with {file_path}: {str(e)}")
                invalid_files.append(file_path)

    if invalid_files:
        print(f"Found {len(invalid_files)} invalid {type_str} files")
    else:
        print(f"All {type_str} files validated successfully")

    return invalid_files

    for _, mask in dataset:
        unique, counts = np.unique(mask.numpy(), return_counts=True)
        for u, c in zip(unique, counts):
            if u < len(class_counts):  # Ensure valid class index
                class_counts[u] += c
        total_pixels += mask.numel()
    
    class_percentage = (class_counts / total_pixels) * 100
    return class_counts, class_percentage

def compute_class_distribution(dataset):
    class_counts = np.zeros(len(class_names))
    total_pixels = 0

    for _, mask in dataset:
        if isinstance(mask, torch.Tensor):
            mask = mask.numpy()
        unique, counts = np.unique(mask, return_counts=True)
        for u, c in zip(unique, counts):
            if u < len(class_counts):  # Make sure it's a valid class index
                class_counts[u] += c
        total_pixels += mask.size

    class_percentage = (class_counts / total_pixels) * 100
    return class_counts, class_percentage
def compute_entropy(prob_map):
    """Compute pixel-wise entropy as uncertainty measure"""
    entropy = -torch.sum(prob_map * torch.log(prob_map + 1e-10), dim=0)
    return entropy

def compute_miou(preds, targets):
    """Compute mean IoU and class-wise IoU"""
    iou_per_class = np.zeros(num_classes)
    preds = preds.flatten()
    targets = targets.flatten()
    
    for cls in range(num_classes):
        if (targets == cls).sum() == 0:
            iou_per_class[cls] = np.nan
            continue
        iou_per_class[cls] = jaccard_score(targets == cls, preds == cls)
    
    return np.nanmean(iou_per_class), iou_per_class

def plot_training_loss(loss_history, iteration, save_dir):
    """Plot and save training loss graph"""
    plt.figure(figsize=(10, 5))
    plt.plot(loss_history)
    plt.title(f'Training Loss - Iteration {iteration}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, f'iteration{iteration}_training_loss.png'))
    plt.close()

def create_directories(base_dir, iteration):
    """Create all necessary directories for the current iteration"""
    dirs = {
        'iteration_dir': os.path.join(base_dir, f'iteration{iteration}'),
        'train_data': os.path.join(base_dir, f'iteration{iteration}', 'train_data'),
        'train_labels': os.path.join(base_dir, f'iteration{iteration}', 'train_labels'),
        'results': os.path.join(base_dir, f'iteration{iteration}', 'results'),
        'predicted_masks': os.path.join(base_dir, f'iteration{iteration}', 'results', 'predicted_masks'),
        'high_uncertainty': os.path.join(base_dir, f'iteration{iteration}', 'results', 'uncertainty_high'),
        'low_uncertainty': os.path.join(base_dir, f'iteration{iteration}', 'results', 'uncertainty_low'),
        'heatmaps': os.path.join(base_dir, f'iteration{iteration}', 'results', 'heatmaps'),
    }
    
    for dir_path in dirs.values():
        os.makedirs(dir_path, exist_ok=True)
    
    return dirs

# Training Function
def train_model(train_loader, model, criterion, optimizer, num_epochs, device, iteration, save_dir):
    """Train the model and save results"""
    model = nn.DataParallel(model).to(device)
    scaler = torch.cuda.amp.GradScaler()
    best_loss = float('inf')
    loss_history = []
    
    print(f"\n Starting training for Iteration {iteration}...")
    start_time = time.time()
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, masks)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            epoch_loss += loss.item()
        
        avg_epoch_loss = epoch_loss / len(train_loader)
        loss_history.append(avg_epoch_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_epoch_loss:.4f}")
        
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            model_path = os.path.join(save_dir, f'iteration{iteration}_model.pth')
            torch.save(model.state_dict(), model_path)
    
    total_time = time.time() - start_time
    print(f"Training completed in {total_time:.2f} seconds")
    
    # Plot training loss
    plot_training_loss(loss_history, iteration, save_dir)
    
    # Save training details to Excel
    train_details = pd.DataFrame({
        "Metric": ["Best Training Loss", "Total Training Time (s)", "Train Image Count"],
        "Value": [best_loss, total_time, len(train_loader.dataset)]
    })
    
    # Compute class distribution
    class_counts, class_percentage = compute_class_distribution(train_loader.dataset)
    class_distribution_df = pd.DataFrame({
        "Class": class_names,
        "Pixel Count": class_counts,
        "Percentage": class_percentage
    })
    
    # Save to Excel
    excel_path = os.path.join(save_dir, f'iteration{iteration}_training_details.xlsx')
    with pd.ExcelWriter(excel_path) as writer:
        train_details.to_excel(writer, sheet_name="Training Details", index=False)
        class_distribution_df.to_excel(writer, sheet_name="Class Distribution", index=False)
    
    print(f" Training details saved to {excel_path}")
    
    return model_path

# Prediction and Uncertainty Estimation
def predict_and_analyze(model_path, val_loader, iteration, dirs, device, uncertainty_threshold=0.5):
    """Make predictions, analyze uncertainty, and select most uncertain samples"""
    print(f"\n Running predictions and uncertainty analysis for Iteration {iteration}...")
    
    # Load model
    model = UNet(num_classes=num_classes).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    
    # Handle both DataParallel and non-DataParallel models
    if all(k.startswith('module.') for k in checkpoint.keys()):
        new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
    else:
        new_state_dict = checkpoint
    
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()
    
    # Initialize metrics
    overall_iou_per_class = np.zeros(num_classes)
    total_images = 0
    total_miou = 0
    classwise_pred_counts = np.zeros(num_classes)
    classwise_actual_counts = np.zeros(num_classes)
    
    # Store uncertainty for all images
    uncertainty_scores = []
    
    for img, lbl, img_path in val_loader:
        img = img.to(device)
        filename = os.path.basename(img_path[0])
        
        with torch.no_grad():
            output = model(img)
            probs = F.softmax(output, dim=1).cpu()
            pred_mask = torch.argmax(probs, dim=1)[0].numpy()
            entropy_map = compute_entropy(probs[0])
        
        lbl = lbl.cpu().numpy()[0]
        
        # Update class-wise pixel counts
        for cls in range(num_classes):
            classwise_pred_counts[cls] += (pred_mask == cls).sum()
            classwise_actual_counts[cls] += (lbl == cls).sum()
        
        # Compute IoU if we have ground truth
        if np.sum(lbl) > 0:
            miou, classwise_iou = compute_miou(pred_mask, lbl)
            overall_iou_per_class += np.nan_to_num(classwise_iou)
            total_miou += miou
            total_images += 1
        
        # Save predicted mask
        cv2.imwrite(os.path.join(dirs['predicted_masks'], filename), (pred_mask * 30).astype(np.uint8))
        
        # Calculate mean uncertainty
        mean_uncertainty = entropy_map.mean().item()
        uncertainty_scores.append((filename, img_path[0], mean_uncertainty))
        
        # Create and save uncertainty heatmap
        entropy_min = entropy_map.min()
        entropy_max = entropy_map.max()
        if entropy_max == entropy_min:
            entropy_norm = torch.zeros_like(entropy_map)
        else:
            entropy_norm = ((entropy_map - entropy_min) / (entropy_max - entropy_min)) * 255
        
        entropy_vis = cv2.applyColorMap(entropy_norm.to(torch.uint8).cpu().numpy(), cv2.COLORMAP_JET)
        
        # Add uncertainty text to heatmap
        text = f"Mean Uncertainty: {mean_uncertainty:.4f}"
        cv2.putText(entropy_vis, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
        
        # Save heatmap
        cv2.imwrite(os.path.join(dirs['heatmaps'], filename), entropy_vis)
        
        # Save in appropriate uncertainty folder
        uncertainty_folder = dirs['high_uncertainty'] if mean_uncertainty > uncertainty_threshold else dirs['low_uncertainty']
        cv2.imwrite(os.path.join(uncertainty_folder, filename), (pred_mask * 30).astype(np.uint8))
    
    # Compute final metrics
    avg_miou = total_miou / max(1, total_images)
    if total_images > 0:
        overall_iou_per_class /= total_images
    
    # Save validation report
    data = [[class_names[cls], overall_iou_per_class[cls], classwise_actual_counts[cls], classwise_pred_counts[cls]] 
            for cls in range(num_classes)]
    data.append(["Overall mIoU", avg_miou, "", ""])
    
    df = pd.DataFrame(data, columns=["Class", "IoU Score", "Actual Pixels", "Predicted Pixels"])
    report_path = os.path.join(dirs['results'], f'iteration{iteration}_validation_report.xlsx')
    df.to_excel(report_path, index=False)
    
    print(f" Validation report saved to {report_path}")
    
    # Sort by uncertainty and return top uncertain images
    uncertainty_scores.sort(key=lambda x: x[2], reverse=True)
    
    return uncertainty_scores

# Select and copy most uncertain samples
def select_uncertain_samples(uncertainty_scores, top_n, unlabeled_dir, validation_labels_dir, next_iteration_dirs):
    """Select top N uncertain samples and move them to the next iteration's training folders"""
    print(f"\n Selecting top {top_n} most uncertain samples for next iteration...")
    
    # Get top N uncertain images
    top_uncertain = uncertainty_scores[:top_n]
    
    # Create CSV with uncertainty data
    uncertainty_df = pd.DataFrame(uncertainty_scores, columns=["Filename", "Path", "Uncertainty"])
    uncertainty_df.to_csv(os.path.join(os.path.dirname(next_iteration_dirs['train_data']), "uncertainty_scores.csv"), index=False)
    
    # Copy images and labels to next iteration's training folders
    for _, img_path, uncertainty in top_uncertain:
        filename = os.path.basename(img_path)
        
        # Copy image to next iteration's training data
        shutil.copy2(img_path, os.path.join(next_iteration_dirs['train_data'], filename))
        
        # Copy corresponding label if available
        label_path = os.path.join(validation_labels_dir, filename)
        if os.path.exists(label_path):
            shutil.copy2(label_path, os.path.join(next_iteration_dirs['train_labels'], filename))
    
    print(f" Selected {top_n} samples added to next iteration's training set")
    return top_uncertain

# Mai

def active_learning_pipeline(base_dir, initial_train_data, initial_train_labels, 
                            unlabeled_data, validation_labels, 
                            total_iterations=10, initial_samples=512, samples_per_iteration=50,
                            uncertainty_threshold=0.5, num_epochs=200, batch_size=16):
    """Run the complete active learning pipeline for semantic segmentation"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(f"Starting Active Learning Pipeline with {total_iterations} iterations")
    print(f"Initial samples: {initial_samples}, adding {samples_per_iteration} per iteration")
    
    # Prepare CSV file to track metrics across iterations
    metrics_tracking = []
    metrics_file = os.path.join(base_dir, "active_learning_metrics.xlsx")
    
    # Setup initial directories
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    
    # Prepare a list of all images in the initial training folder
    all_train_images = sorted([f for f in os.listdir(initial_train_data) if f.endswith(('.jpg', '.png', '.tif'))])
    
    # Select initial_samples
    if len(all_train_images) > initial_samples:
        # Randomly select initial_samples images
        import random
        random.seed(42)  # For reproducibility
        selected_images = random.sample(all_train_images, initial_samples)
    else:
        selected_images = all_train_images
        print(f"Warning: Only {len(selected_images)} images available in initial training set (requested {initial_samples})")
    
    # For the first iteration, create folders and copy selected images
    iteration1_dirs = create_directories(base_dir, 1)
    
    # Copy selected initial samples to iteration 1 folder
    for img_file in selected_images:
        # Copy image
        src_img = os.path.join(initial_train_data, img_file)
        dst_img = os.path.join(iteration1_dirs['train_data'], img_file)
        shutil.copy2(src_img, dst_img)
        
        # Copy corresponding label
        src_label = os.path.join(initial_train_labels, img_file)
        dst_label = os.path.join(iteration1_dirs['train_labels'], img_file)
        if os.path.exists(src_label):
            shutil.copy2(src_label, dst_label)
    
    print(f"Copied {len(selected_images)} initial images to iteration 1 folder")
    
    # Now run all iterations
    for iteration in range(1, total_iterations + 1):
        print(f"\n{'='*80}\n ITERATION {iteration}/{total_iterations}\n{'='*80}")
        
        # Create directories for this iteration (or use existing ones for iteration 1)
        if iteration == 1:
            dirs = iteration1_dirs
        else:
            dirs = create_directories(base_dir, iteration)
        
        # Count how many training samples we have for this iteration
        train_sample_count = len(os.listdir(dirs['train_data']))
        print(f"Training on {train_sample_count} samples for iteration {iteration}")
        
        # 1. Train model
        train_dataset = SegmentationDataset(dirs['train_data'], dirs['train_labels'])
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
        
        # Compute class distribution BEFORE training
        train_class_counts, train_class_percentage = compute_class_distribution(train_dataset)
        
        # Create and train model
        model = UNet(num_classes=num_classes)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=0.0001)
        
        training_start_time = time.time()
        model_path, loss_history = train_model_enhanced(
            train_loader=train_loader,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            num_epochs=num_epochs,
            device=device,
            iteration=iteration,
            save_dir=dirs['iteration_dir']
        )
        training_time = time.time() - training_start_time
        
        # 2. Make predictions, analyze uncertainty and validate
        val_dataset = ValidationDataset(unlabeled_data, validation_labels)
        val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
        
        # Run validation with enhanced metrics
        uncertainty_scores, val_metrics = predict_and_analyze_enhanced(
            model_path=model_path,
            val_loader=val_loader,
            iteration=iteration,
            dirs=dirs,
            device=device,
            uncertainty_threshold=uncertainty_threshold
        )
        
        # Track metrics for this iteration
        iteration_metrics = {
            "Iteration": iteration,
            "Training Samples": train_sample_count,
            "Training Time (s)": training_time,
            "Final Training Loss": loss_history[-1] if loss_history else float('nan'),
            "Best Training Loss": min(loss_history) if loss_history else float('nan'),
            "Validation mIoU": val_metrics['miou'],
            "High Uncertainty Count": val_metrics['high_uncertainty_count'],
            "Low Uncertainty Count": val_metrics['low_uncertainty_count'],
            "Average Uncertainty": val_metrics['avg_uncertainty']
        }
        
        # Add class-specific IoUs
        for i, cls_name in enumerate(class_names):
            iteration_metrics[f"IoU_{cls_name}"] = val_metrics['class_iou'][i]
        
        metrics_tracking.append(iteration_metrics)
        
        # Save accumulated metrics to excel
        metrics_df = pd.DataFrame(metrics_tracking)
        metrics_df.to_excel(metrics_file, index=False)
        
        # 3. If not the last iteration, select samples for next iteration
        if iteration < total_iterations:
            next_iteration_dirs = create_directories(base_dir, iteration + 1)
            
            # First copy ALL current iteration data to next iteration folders
            for file in os.listdir(dirs['train_data']):
                shutil.copy2(os.path.join(dirs['train_data'], file), 
                           os.path.join(next_iteration_dirs['train_data'], file))
                
            for file in os.listdir(dirs['train_labels']):
                shutil.copy2(os.path.join(dirs['train_labels'], file), 
                           os.path.join(next_iteration_dirs['train_labels'], file))
            
            # Then select and add the new samples
            new_samples = select_uncertain_samples(
                uncertainty_scores=uncertainty_scores,
                top_n=samples_per_iteration,
                unlabeled_dir=unlabeled_data,
                validation_labels_dir=validation_labels,
                next_iteration_dirs=next_iteration_dirs
            )
            
            # Track the newly added samples
            new_samples_df = pd.DataFrame(new_samples, columns=["Filename", "Path", "Uncertainty"])
            new_samples_df.to_csv(os.path.join(next_iteration_dirs['iteration_dir'], f"new_samples_iter{iteration+1}.csv"), index=False)
            
            print(f"Next iteration will use {len(os.listdir(next_iteration_dirs['train_data']))} total training samples")
    
    # Create summary plots after all iterations
    create_summary_plots(base_dir, metrics_tracking)
    
    print(f"\n Active Learning Pipeline Completed Successfully!")
    print(f"Results saved in: {base_dir}")
    print(f"Summary metrics saved to: {metrics_file}")


def train_model_enhanced(train_loader, model, criterion, optimizer, num_epochs, device, iteration, save_dir):
    """Train the model and save results with enhanced metrics"""
    model = nn.DataParallel(model).to(device)
    scaler = torch.cuda.amp.GradScaler()
    best_loss = float('inf')
    loss_history = []
    validation_loss_history = []  # For validation loss if available
    
    print(f"\n Starting training for Iteration {iteration}...")
    start_time = time.time()
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, masks)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            epoch_loss += loss.item()
        
        avg_epoch_loss = epoch_loss / len(train_loader)
        loss_history.append(avg_epoch_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_epoch_loss:.4f}")
        
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            model_path = os.path.join(save_dir, f'iteration{iteration}_model.pth')
            torch.save(model.state_dict(), model_path)
    
    total_time = time.time() - start_time
    print(f"Training completed in {total_time:.2f} seconds")
    
    # Plot training loss
    plot_training_loss(loss_history, iteration, save_dir)
    
    # Save training details to Excel
    train_details = pd.DataFrame({
        "Metric": ["Best Training Loss", "Total Training Time (s)", "Train Image Count", "Final Loss"],
        "Value": [best_loss, total_time, len(train_loader.dataset), loss_history[-1]]
    })
    
    # Compute class distribution
    class_counts, class_percentage = compute_class_distribution(train_loader.dataset)
    class_distribution_df = pd.DataFrame({
        "Class": class_names,
        "Pixel Count": class_counts,
        "Percentage": class_percentage
    })
    
    # Plot class distribution
    plot_class_distribution(class_names, class_percentage, iteration, save_dir)
    
    # Save epoch-wise loss to CSV
    epoch_loss_df = pd.DataFrame({
        "Epoch": list(range(1, num_epochs + 1)),
        "Loss": loss_history
    })
    epoch_loss_df.to_csv(os.path.join(save_dir, f'iteration{iteration}_epoch_loss.csv'), index=False)
    
    # Save to Excel
    excel_path = os.path.join(save_dir, f'iteration{iteration}_training_details.xlsx')
    with pd.ExcelWriter(excel_path) as writer:
        train_details.to_excel(writer, sheet_name="Training Details", index=False)
        class_distribution_df.to_excel(writer, sheet_name="Class Distribution", index=False)
        epoch_loss_df.to_excel(writer, sheet_name="Epoch Loss", index=False)
    
    print(f" Training details saved to {excel_path}")
    
    return model_path, loss_history


def predict_and_analyze_enhanced(model_path, val_loader, iteration, dirs, device, uncertainty_threshold=0.5):
    """Make predictions, analyze uncertainty, and select most uncertain samples with enhanced metrics"""
    print(f"\n Running predictions and uncertainty analysis for Iteration {iteration}...")
    
    # Load model
    model = UNet(num_classes=num_classes).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    
    # Handle both DataParallel and non-DataParallel models
    if all(k.startswith('module.') for k in checkpoint.keys()):
        new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
    else:
        new_state_dict = checkpoint
    
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()
    
    # Initialize metrics
    overall_iou_per_class = np.zeros(num_classes)
    total_images = 0
    total_miou = 0
    classwise_pred_counts = np.zeros(num_classes)
    classwise_actual_counts = np.zeros(num_classes)
    
    # Store uncertainty for all images
    uncertainty_scores = []
    all_uncertainties = []
    high_uncertainty_count = 0
    low_uncertainty_count = 0
    
    # Confusion matrix for class prediction
    confusion_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)
    
    for img, lbl, img_path in val_loader:
        img = img.to(device)
        filename = os.path.basename(img_path[0])
        
        with torch.no_grad():
            output = model(img)
            probs = F.softmax(output, dim=1).cpu()
            pred_mask = torch.argmax(probs, dim=1)[0].numpy()
            entropy_map = compute_entropy(probs[0])
        
        lbl = lbl.cpu().numpy()[0]
        
        # Update class-wise pixel counts
        for cls in range(num_classes):
            classwise_pred_counts[cls] += (pred_mask == cls).sum()
            classwise_actual_counts[cls] += (lbl == cls).sum()
        
        # Update confusion matrix
        for true_cls in range(num_classes):
            for pred_cls in range(num_classes):
                confusion_matrix[true_cls, pred_cls] += np.sum((lbl == true_cls) & (pred_mask == pred_cls))
        
        # Compute IoU if we have ground truth
        if np.sum(lbl) > 0:
            miou, classwise_iou = compute_miou(pred_mask, lbl)
            overall_iou_per_class += np.nan_to_num(classwise_iou)
            total_miou += miou
            total_images += 1
        
        # Save predicted mask
        cv2.imwrite(os.path.join(dirs['predicted_masks'], filename), (pred_mask * 30).astype(np.uint8))
        
        # Calculate mean uncertainty
        mean_uncertainty = entropy_map.mean().item()
        all_uncertainties.append(mean_uncertainty)
        uncertainty_scores.append((filename, img_path[0], mean_uncertainty))
        
        # Create and save uncertainty heatmap
        entropy_min = entropy_map.min()
        entropy_max = entropy_map.max()
        if entropy_max == entropy_min:
            entropy_norm = torch.zeros_like(entropy_map)
        else:
            entropy_norm = ((entropy_map - entropy_min) / (entropy_max - entropy_min)) * 255
        
        entropy_vis = cv2.applyColorMap(entropy_norm.to(torch.uint8).cpu().numpy(), cv2.COLORMAP_JET)
        
        # Add uncertainty text to heatmap
        text = f"Mean Uncertainty: {mean_uncertainty:.4f}"
        cv2.putText(entropy_vis, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
        
        # Save heatmap
        cv2.imwrite(os.path.join(dirs['heatmaps'], filename), entropy_vis)
        
        # Save in appropriate uncertainty folder
        if mean_uncertainty > uncertainty_threshold:
            uncertainty_folder = dirs['high_uncertainty']
            high_uncertainty_count += 1
        else:
            uncertainty_folder = dirs['low_uncertainty']
            low_uncertainty_count += 1
            
        cv2.imwrite(os.path.join(uncertainty_folder, filename), (pred_mask * 30).astype(np.uint8))
    
    # Compute final metrics
    avg_miou = total_miou / max(1, total_images)
    if total_images > 0:
        overall_iou_per_class /= total_images
    
    # Save validation report
    data = [[class_names[cls], overall_iou_per_class[cls], classwise_actual_counts[cls], classwise_pred_counts[cls]] 
            for cls in range(num_classes)]
    data.append(["Overall mIoU", avg_miou, "", ""])
    
    df = pd.DataFrame(data, columns=["Class", "IoU Score", "Actual Pixels", "Predicted Pixels"])
    report_path = os.path.join(dirs['results'], f'iteration{iteration}_validation_report.xlsx')
    df.to_excel(report_path, index=False)
    
    # Save confusion matrix
    confusion_df = pd.DataFrame(confusion_matrix, 
                               index=[f"True_{cls}" for cls in class_names],
                               columns=[f"Pred_{cls}" for cls in class_names])
    confusion_df.to_excel(os.path.join(dirs['results'], f'iteration{iteration}_confusion_matrix.xlsx'))
    
    # Plot confusion matrix
    plot_confusion_matrix(confusion_matrix, class_names, iteration, dirs['results'])
    
    # Plot uncertainty distribution
    plot_uncertainty_distribution(all_uncertainties, uncertainty_threshold, iteration, dirs['results'])
    
    # Plot IoU per class
    plot_iou_per_class(overall_iou_per_class, class_names, iteration, dirs['results'])
    
    # Plot class distribution comparison (actual vs. predicted)
    plot_class_comparison(classwise_actual_counts, classwise_pred_counts, class_names, iteration, dirs['results'])
    
    print(f" Validation report saved to {report_path}")
    
    # Sort by uncertainty and return metrics
    uncertainty_scores.sort(key=lambda x: x[2], reverse=True)
    
    # Create metrics dictionary
    val_metrics = {
        'miou': avg_miou,
        'class_iou': overall_iou_per_class,
        'high_uncertainty_count': high_uncertainty_count,
        'low_uncertainty_count': low_uncertainty_count,
        'avg_uncertainty': np.mean(all_uncertainties) if all_uncertainties else 0,
        'confusion_matrix': confusion_matrix,
        'classwise_actual_counts': classwise_actual_counts,
        'classwise_pred_counts': classwise_pred_counts
    }
    
    return uncertainty_scores, val_metrics


# New visualization functions
def plot_class_distribution(class_names, class_percentage, iteration, save_dir):
    """Plot class distribution as a bar chart"""
    plt.figure(figsize=(12, 6))
    bars = plt.bar(class_names, class_percentage)
    
    # Add percentage values on top of bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{height:.1f}%', ha='center', va='bottom', rotation=0)
    
    plt.title(f'Class Distribution - Iteration {iteration}')
    plt.xlabel('Class')
    plt.ylabel('Percentage (%)')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'iteration{iteration}_class_distribution.png'))
    plt.close()


def plot_confusion_matrix(confusion_matrix, class_names, iteration, save_dir):
    """Plot confusion matrix as a heatmap"""
    plt.figure(figsize=(14, 12))
    
    # Normalize by rows (true classes)
    row_sums = confusion_matrix.sum(axis=1, keepdims=True)
    norm_conf_matrix = np.zeros_like(confusion_matrix, dtype=float)
    np.divide(confusion_matrix, row_sums, out=norm_conf_matrix, where=row_sums!=0)
    
    plt.imshow(norm_conf_matrix, cmap='Blues')
    plt.colorbar(label='Normalized Count')
    
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45, ha='right')
    plt.yticks(tick_marks, class_names)
    
    # Add text annotations
    thresh = norm_conf_matrix.max() / 2.
    for i in range(norm_conf_matrix.shape[0]):
        for j in range(norm_conf_matrix.shape[1]):
            plt.text(j, i, f'{norm_conf_matrix[i, j]:.2f}',
                    ha="center", va="center",
                    color="white" if norm_conf_matrix[i, j] > thresh else "black")
    
    plt.xlabel('Predicted label')
    plt.ylabel('True label')
    plt.title(f'Normalized Confusion Matrix - Iteration {iteration}')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'iteration{iteration}_confusion_matrix.png'))
    plt.close()


def plot_uncertainty_distribution(uncertainties, threshold, iteration, save_dir):
    """Plot histogram of uncertainty values"""
    plt.figure(figsize=(10, 6))
    plt.hist(uncertainties, bins=30, alpha=0.7, color='skyblue')
    plt.axvline(x=threshold, color='r', linestyle='--', label=f'Threshold ({threshold})')
    
    # Add counts of high/low uncertainty
    high_count = sum(1 for u in uncertainties if u > threshold)
    low_count = len(uncertainties) - high_count
    
    plt.text(threshold*1.05, plt.ylim()[1]*0.9, f'High: {high_count}', color='r')
    plt.text(threshold*0.8, plt.ylim()[1]*0.9, f'Low: {low_count}', color='g')
    
    plt.xlabel('Uncertainty Value')
    plt.ylabel('Frequency')
    plt.title(f'Uncertainty Distribution - Iteration {iteration}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(save_dir, f'iteration{iteration}_uncertainty_distribution.png'))
    plt.close()


def plot_iou_per_class(iou_values, class_names, iteration, save_dir):
    """Plot IoU values per class"""
    plt.figure(figsize=(12, 6))
    bars = plt.bar(class_names, iou_values)
    
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{height:.3f}', ha='center', va='bottom', rotation=0)
    
    plt.title(f'IoU per Class - Iteration {iteration}')
    plt.xlabel('Class')
    plt.ylabel('IoU Score')
    plt.xticks(rotation=45, ha='right')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'iteration{iteration}_iou_per_class.png'))
    plt.close()


def plot_class_comparison(actual_counts, pred_counts, class_names, iteration, save_dir):
    """Plot comparison between actual and predicted class distributions"""
    fig, ax = plt.subplots(figsize=(12, 6))
    
    x = np.arange(len(class_names))
    width = 0.35
    
    # Convert to percentages
    total_actual = actual_counts.sum()
    total_pred = pred_counts.sum()
    actual_pct = actual_counts / total_actual * 100 if total_actual > 0 else np.zeros_like(actual_counts)
    pred_pct = pred_counts / total_pred * 100 if total_pred > 0 else np.zeros_like(pred_counts)
    
    ax.bar(x - width/2, actual_pct, width, label='Ground Truth')
    ax.bar(x + width/2, pred_pct, width, label='Predicted')
    
    ax.set_title(f'Class Distribution Comparison - Iteration {iteration}')
    ax.set_xlabel('Class')
    ax.set_ylabel('Percentage (%)')
    ax.set_xticks(x)
    ax.set_xticklabels(class_names, rotation=45, ha='right')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'iteration{iteration}_class_comparison.png'))
    plt.close()


def create_summary_plots(base_dir, metrics_tracking):
    """Create summary plots for all iterations"""
    metrics_df = pd.DataFrame(metrics_tracking)
    
    # Directory for summary plots
    summary_dir = os.path.join(base_dir, "summary_plots")
    os.makedirs(summary_dir, exist_ok=True)
    
    # 1. Plot training samples progression
    plt.figure(figsize=(10, 6))
    plt.plot(metrics_df["Iteration"], metrics_df["Training Samples"], marker='o', linewidth=2)
    plt.title('Training Sample Count Progression')
    plt.xlabel('Iteration')
    plt.ylabel('Number of Training Samples')
    plt.grid(True)
    plt.savefig(os.path.join(summary_dir, 'training_samples_progression.png'))
    plt.close()
    
    # 2. Plot training loss progression
    plt.figure(figsize=(10, 6))
    plt.plot(metrics_df["Iteration"], metrics_df["Final Training Loss"], marker='o', label='Final Loss', linewidth=2)
    plt.plot(metrics_df["Iteration"], metrics_df["Best Training Loss"], marker='s', label='Best Loss', linewidth=2)
    plt.title('Training Loss Progression')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(summary_dir, 'training_loss_progression.png'))
    plt.close()
    
    # 3. Plot mIoU progression
    plt.figure(figsize=(10, 6))
    plt.plot(metrics_df["Iteration"], metrics_df["Validation mIoU"], marker='o', linewidth=2)
    plt.title('Validation mIoU Progression')
    plt.xlabel('Iteration')
    plt.ylabel('mIoU')
    plt.grid(True)
    plt.savefig(os.path.join(summary_dir, 'miou_progression.png'))
    plt.close()
    
    # 4. Plot class-wise IoU progression
    plt.figure(figsize=(12, 8))
    for cls in class_names:
        if f"IoU_{cls}" in metrics_df.columns:
            plt.plot(metrics_df["Iteration"], metrics_df[f"IoU_{cls}"], marker='o', label=cls)
    plt.title('Class-wise IoU Progression')
    plt.xlabel('Iteration')
    plt.ylabel('IoU')
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(summary_dir, 'class_iou_progression.png'))
    plt.close()
    
    # 5. Plot uncertainty metrics
    plt.figure(figsize=(10, 6))
    plt.plot(metrics_df["Iteration"], metrics_df["Average Uncertainty"], marker='o', color='purple', linewidth=2)
    plt.title('Average Uncertainty Progression')
    plt.xlabel('Iteration')
    plt.ylabel('Average Uncertainty')
    plt.grid(True)
    plt.savefig(os.path.join(summary_dir, 'uncertainty_progression.png'))
    plt.close()
    
    # 6. Plot high vs low uncertainty counts
    plt.figure(figsize=(10, 6))
    plt.plot(metrics_df["Iteration"], metrics_df["High Uncertainty Count"], marker='o', label='High Uncertainty', linewidth=2)
    plt.plot(metrics_df["Iteration"], metrics_df["Low Uncertainty Count"], marker='s', label='Low Uncertainty', linewidth=2)
    plt.title('Uncertainty Distribution Progression')
    plt.xlabel('Iteration')
    plt.ylabel('Number of Images')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(summary_dir, 'uncertainty_counts_progression.png'))
    plt.close()
    
    # 7. Plot training time progression
    plt.figure(figsize=(10, 6))
    plt.plot(metrics_df["Iteration"], metrics_df["Training Time (s)"], marker='o', linewidth=2)
    plt.title('Training Time Progression')
    plt.xlabel('Iteration')
    plt.ylabel('Training Time (seconds)')
    plt.grid(True)
    plt.savefig(os.path.join(summary_dir, 'training_time_progression.png'))
    plt.close()
    
    print(f"Summary plots saved to {summary_dir}")

if __name__ == "__main__":
    # Base directory where all iterations will be stored
    base_dir = "./active_learning_results2"
    
    # Paths to your data
    initial_train_data = r"C:\Users\SRM\Desktop\Active Learning Project\Open Earth Data Set\dfc25_track1_trainval\Data\SegXAL\data\train_data"
    initial_train_labels = r"C:\Users\SRM\Desktop\Active Learning Project\Open Earth Data Set\dfc25_track1_trainval\Data\SegXAL\data\train_labels"
    unlabeled_data = r"C:\Users\SRM\Desktop\Active Learning Project\Open Earth Data Set\dfc25_track1_trainval\Data\SegXAL\data\Unlabeled_data"  
    validation_labels = r"C:\Users\SRM\Desktop\Active Learning Project\Open Earth Data Set\dfc25_track1_trainval\Data\SegXAL\data\Validation_labels"  # Can be None if no ground truth available
    
    # Run the pipeline
    active_learning_pipeline(
        base_dir=base_dir,
        initial_train_data=initial_train_data,
        initial_train_labels=initial_train_labels,
        unlabeled_data=unlabeled_data,
        validation_labels=validation_labels,
        total_iterations=10,           # Number of active learning iterations
        samples_per_iteration=50,     # Number of uncertain samples to select per iteration
        uncertainty_threshold=0.5,    # Threshold for high/low uncertainty classification
        num_epochs=200,               # Number of training epochs per iteration
        batch_size=16                 # Batch size for training
    )


Using device: cuda
Starting Active Learning Pipeline with 10 iterations
Initial samples: 512, adding 50 per iteration
Copied 512 initial images to iteration 1 folder

 ITERATION 1/10
Training on 512 samples for iteration 1

 Starting training for Iteration 1...
Epoch 10/200, Loss: 1.0998
Epoch 20/200, Loss: 0.9435
Epoch 30/200, Loss: 0.8600
Epoch 40/200, Loss: 0.7579
Epoch 50/200, Loss: 0.7048
Epoch 60/200, Loss: 0.6063
Epoch 70/200, Loss: 0.5279
Epoch 80/200, Loss: 0.4534
Epoch 90/200, Loss: 0.3753
Epoch 100/200, Loss: 0.3238
Epoch 110/200, Loss: 0.2776
Epoch 120/200, Loss: 0.2165
Epoch 130/200, Loss: 0.2048
Epoch 140/200, Loss: 0.1713
Epoch 150/200, Loss: 0.1480
Epoch 160/200, Loss: 0.1343
Epoch 170/200, Loss: 0.1597
Epoch 180/200, Loss: 0.1013
Epoch 190/200, Loss: 0.1031
Epoch 200/200, Loss: 0.1922
Training completed in 3010.29 seconds
 Training details saved to ./active_learning_results2\iteration1\iteration1_training_details.xlsx

 Running predictions and uncertainty analysis for 

In [None]:
# import torch
# import torch.nn.functional as F
# import torch.nn as nn
# from torch.utils.data import DataLoader, Dataset, Subset
# import numpy as np
# import os
# import cv2
# import pandas as pd
# import matplotlib.pyplot as plt
# from sklearn.metrics import jaccard_score
# from collections import defaultdict
# from scipy.stats import entropy

# # Set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Class Names
# class_names = [
#     "Background", "Bareland", "Rangeland", "Developed Space", "Road",
#     "Tree", "Water", "Agriculture Land", "Building"
# ]
# num_classes = len(class_names)

# # Dataset Class
# class SegmentationDataset(Dataset):
#     def __init__(self, image_dir, label_dir=None, transform=None):
#         self.image_paths = sorted([os.path.join(image_dir, img) for img in os.listdir(image_dir)])
#         self.label_paths = None if label_dir is None else sorted([os.path.join(label_dir, lbl) for lbl in os.listdir(label_dir)])
#         self.transform = transform
#         self.labeled = label_dir is not None
    
#     def __len__(self):
#         return len(self.image_paths)
    
#     def __getitem__(self, idx):
#         image = cv2.imread(self.image_paths[idx])
#         image = cv2.resize(image, (512, 512))
#         image_tensor = torch.tensor(image).permute(2, 0, 1).float() / 255.0
        
#         if self.labeled:
#             label = cv2.imread(self.label_paths[idx], cv2.IMREAD_GRAYSCALE)
#             label = cv2.resize(label, (512, 512), interpolation=cv2.INTER_NEAREST)
#             label_tensor = torch.tensor(label, dtype=torch.long)
#             return image_tensor, label_tensor, self.image_paths[idx]
#         else:
#             return image_tensor, self.image_paths[idx]

# # UNet Model
# class UNet(nn.Module):
#     def __init__(self, num_classes=9):
#         super(UNet, self).__init__()
#         self.num_classes = num_classes
#         self.contracting_11 = self.conv_block(3, 64)
#         self.contracting_12 = nn.MaxPool2d(2, 2)
#         self.contracting_21 = self.conv_block(64, 128)
#         self.contracting_22 = nn.MaxPool2d(2, 2)
#         self.contracting_31 = self.conv_block(128, 256)
#         self.contracting_32 = nn.MaxPool2d(2, 2)
#         self.contracting_41 = self.conv_block(256, 512)
#         self.contracting_42 = nn.MaxPool2d(2, 2)
#         self.middle = self.conv_block(512, 1024)
#         self.expansive_11 = nn.ConvTranspose2d(1024, 512, 3, 2, 1, 1)
#         self.expansive_12 = self.conv_block(1024, 512)
#         self.expansive_21 = nn.ConvTranspose2d(512, 256, 3, 2, 1, 1)
#         self.expansive_22 = self.conv_block(512, 256)
#         self.expansive_31 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)
#         self.expansive_32 = self.conv_block(256, 128)
#         self.expansive_41 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1)
#         self.expansive_42 = self.conv_block(128, 64)
#         self.output = nn.Conv2d(64, num_classes, 3, 1, 1)

#     def conv_block(self, in_channels, out_channels):
#         return nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, 3, 1, 1),
#             nn.ReLU(),
#             nn.BatchNorm2d(out_channels),
#             nn.Conv2d(out_channels, out_channels, 3, 1, 1),
#             nn.ReLU(),
#             nn.BatchNorm2d(out_channels)
#         )
    
#     def forward(self, x):
#         c1 = self.contracting_11(x)
#         p1 = self.contracting_12(c1)
#         c2 = self.contracting_21(p1)
#         p2 = self.contracting_22(c2)
#         c3 = self.contracting_31(p2)
#         p3 = self.contracting_32(c3)
#         c4 = self.contracting_41(p3)
#         p4 = self.contracting_42(c4)
#         middle = self.middle(p4)
#         u1 = self.expansive_11(middle)
#         u1 = self.expansive_12(torch.cat((u1, c4), dim=1))
#         u2 = self.expansive_21(u1)
#         u2 = self.expansive_22(torch.cat((u2, c3), dim=1))
#         u3 = self.expansive_31(u2)
#         u3 = self.expansive_32(torch.cat((u3, c2), dim=1))
#         u4 = self.expansive_41(u3)
#         u4 = self.expansive_42(torch.cat((u4, c1), dim=1))
#         output = self.output(u4)
#         return output

# # Compute Entropy for Uncertainty
# def compute_entropy(probs):
#     """Compute pixel-wise entropy from probability map"""
#     return -np.sum(probs * np.log(probs + 1e-10), axis=0)

# # Class for Active Learning Sample Selection
# class ActiveLearningSelector:
#     def __init__(self, model, unlabeled_dataset, alpha=0.7, batch_size=10):
#         """
#         Args:
#             model: The trained model
#             unlabeled_dataset: Dataset of unlabeled samples
#             alpha: Weight balance between uncertainty and diversity (higher alpha = more weight on uncertainty)
#             batch_size: Number of samples to select in each active learning round
#         """
#         self.model = model
#         self.unlabeled_dataset = unlabeled_dataset
#         self.alpha = alpha
#         self.batch_size = batch_size
#         self.device = next(model.parameters()).device
        
#     def get_class_distribution(self, labeled_indices):
#         """Calculate class distribution in current labeled dataset"""
#         class_counts = np.zeros(num_classes)
        
#         # Create dataloader for the labeled subset
#         labeled_subset = Subset(self.unlabeled_dataset, labeled_indices)
#         loader = DataLoader(labeled_subset, batch_size=1)
        
#         for img, path in loader:
#             img = img.to(self.device)
#             with torch.no_grad():
#                 output = self.model(img)
#                 pred_mask = torch.argmax(output, dim=1).cpu().numpy()[0]
                
#                 for cls in range(num_classes):
#                     class_counts[cls] += (pred_mask == cls).sum()
                    
#         return class_counts / class_counts.sum()
        
#     def select_samples(self, labeled_indices=None):
#         """
#         Select the next batch of samples for labeling
        
#         Args:
#             labeled_indices: Indices of samples that are already labeled
            
#         Returns:
#             List of indices to be labeled next
#         """
#         unlabeled_loader = DataLoader(self.unlabeled_dataset, batch_size=1)
#         all_scores = []
#         all_indices = []
        
#         # Get class distribution in current labeled set if available
#         if labeled_indices and len(labeled_indices) > 0:
#             class_distribution = self.get_class_distribution(labeled_indices)
#             # Convert to inverse weights (less represented classes get higher weights)
#             class_weights = 1.0 / (class_distribution + 1e-5)
#             # Normalize weights
#             class_weights = class_weights / class_weights.sum()
#         else:
#             # If no labeled data yet, use uniform weights
#             class_weights = np.ones(num_classes) / num_classes
            
#         print("Class weights for diversity:", class_weights)
        
#         # Calculate scores for all unlabeled samples
#         for i, (img, path) in enumerate(unlabeled_loader):
#             if labeled_indices and i in labeled_indices:
#                 continue
                
#             img = img.to(self.device)
            
#             with torch.no_grad():
#                 output = self.model(img)
#                 probs = F.softmax(output, dim=1).cpu().numpy()[0]
                
#                 # Calculate uncertainty score (mean entropy)
#                 entropy_map = compute_entropy(probs)
#                 uncertainty_score = entropy_map.mean()
                
#                 # Calculate class distribution in this sample
#                 pred_mask = np.argmax(probs, axis=0)
#                 sample_class_counts = np.zeros(num_classes)
#                 for cls in range(num_classes):
#                     sample_class_counts[cls] = (pred_mask == cls).sum()
                    
#                 # Normalize to get class distribution
#                 if sample_class_counts.sum() > 0:
#                     sample_class_distribution = sample_class_counts / sample_class_counts.sum()
#                 else:
#                     sample_class_distribution = np.ones(num_classes) / num_classes
                
#                 # Calculate diversity score (weighted by class representation)
#                 diversity_score = np.sum(sample_class_distribution * class_weights)
                
#                 # Combined score (alpha controls the balance)
#                 combined_score = self.alpha * uncertainty_score + (1 - self.alpha) * diversity_score
                
#                 all_scores.append({
#                     'index': i,
#                     'path': path[0],
#                     'uncertainty': uncertainty_score,
#                     'diversity': diversity_score,
#                     'combined_score': combined_score,
#                     'class_distribution': sample_class_distribution
#                 })
#                 all_indices.append(i)
        
#         # Sort by combined score (higher is better)
#         sorted_samples = sorted(all_scores, key=lambda x: x['combined_score'], reverse=True)
        
#         # Class-balanced sampling
#         # First, select a larger candidate pool based on combined score
#         candidate_pool = sorted_samples[:min(len(sorted_samples), self.batch_size * 3)]
        
#         # Implement class-balanced selection from candidate pool
#         selected_indices = []
#         target_class_counts = np.zeros(num_classes)
        
#         # Fill the batch with balanced class representation
#         while len(selected_indices) < self.batch_size and candidate_pool:
#             best_score = -float('inf')
#             best_idx = -1
#             best_candidate_idx = -1
            
#             for i, candidate in enumerate(candidate_pool):
#                 if candidate['index'] in selected_indices:
#                     continue
                    
#                 # Calculate how well this sample balances the class distribution
#                 temp_counts = target_class_counts.copy()
#                 for cls in range(num_classes):
#                     temp_counts[cls] += candidate['class_distribution'][cls]
                
#                 # Entropy of class distribution (higher is more balanced)
#                 balance_score = entropy(temp_counts + 1e-10)
                
#                 # Final score combines original score with balance
#                 final_score = 0.7 * candidate['combined_score'] + 0.3 * balance_score
                
#                 if final_score > best_score:
#                     best_score = final_score
#                     best_idx = candidate['index']
#                     best_candidate_idx = i
            
#             if best_idx != -1:
#                 selected_indices.append(best_idx)
#                 # Update target class counts
#                 for cls in range(num_classes):
#                     target_class_counts[cls] += candidate_pool[best_candidate_idx]['class_distribution'][cls]
#                 # Remove from candidate pool
#                 candidate_pool.pop(best_candidate_idx)
#             else:
#                 break
                
#         # Print summary of selected samples
#         print(f"Selected {len(selected_indices)} samples for labeling")
#         return selected_indices

# # Function to train with Active Learning
# def train_active_learning(
#     initial_model_path,
#     unlabeled_img_dir,
#     labeled_img_dir,
#     labeled_label_dir,
#     output_dir="output",
#     initial_labeled_samples=50,
#     active_learning_rounds=5,
#     samples_per_round=10,
#     epochs_per_round=5,
#     alpha=0.7
# ):
#     # Create output directories
#     os.makedirs(output_dir, exist_ok=True)
#     os.makedirs(os.path.join(output_dir, "models"), exist_ok=True)
#     os.makedirs(os.path.join(output_dir, "selected_samples"), exist_ok=True)
    
#     # Load initial model
#     model = UNet(num_classes=num_classes).to(device)
#     if os.path.exists(initial_model_path):
#         checkpoint = torch.load(initial_model_path, map_location=device)
#         # Handle both DataParallel and non-DataParallel models
#         if "module." in list(checkpoint.keys())[0]:
#             model = nn.DataParallel(model)
#         model.load_state_dict(checkpoint)
    
#     # Create datasets
#     unlabeled_dataset = SegmentationDataset(unlabeled_img_dir)
#     labeled_dataset = SegmentationDataset(labeled_img_dir, labeled_label_dir)
    
#     # Initial training if there's labeled data
#     if len(labeled_dataset) > 0:
#         print(f"Initial training with {len(labeled_dataset)} labeled samples")
#         train_model(model, labeled_dataset, epochs=epochs_per_round)
#         torch.save(model.state_dict(), os.path.join(output_dir, "models", "initial_model.pth"))
    
#     # Initialize active learning selector
#     selector = ActiveLearningSelector(model, unlabeled_dataset, alpha=alpha, batch_size=samples_per_round)
    
#     # Track labeled indices
#     labeled_indices = []
#     if initial_labeled_samples > 0:
#         # Random selection for initial labeled set
#         labeled_indices = np.random.choice(
#             len(unlabeled_dataset), 
#             size=min(initial_labeled_samples, len(unlabeled_dataset)),
#             replace=False
#         ).tolist()
        
#     # Metrics tracking
#     round_metrics = []
    
#     # Active learning loop
#     for round_idx in range(active_learning_rounds):
#         print(f"\n--- Active Learning Round {round_idx+1}/{active_learning_rounds} ---")
        
#         # Select samples
#         new_indices = selector.select_samples(labeled_indices)
#         labeled_indices.extend(new_indices)
        
#         # Save selected samples info
#         round_info = {
#             'round': round_idx + 1,
#             'selected_indices': new_indices,
#             'selected_paths': [unlabeled_dataset.image_paths[i] for i in new_indices]
#         }
        
#         # Here you would typically:
#         # 1. Get these images labeled (manual annotation process)
#         # 2. Add them to your labeled dataset
#         # For simulation, we'll just print the paths
        
#         print(f"Selected {len(new_indices)} new samples for labeling:")
#         for idx in new_indices:
#             print(f"  - {unlabeled_dataset.image_paths[idx]}")
        
#         # In a real scenario, after getting labels:
#         # 1. Move these samples from unlabeled to labeled dataset
#         # 2. Retrain model with updated labeled dataset
        
#         print(f"Retraining model with {len(labeled_indices)} total labeled samples")
        
#         # Simulate: In practice, you would retrain with actual labeled data
#         # For now, we'll just update metrics and save the round info
        
#         # Evaluate on validation set (if available)
#         # val_miou = evaluate_model(model, val_dataset)
#         val_miou = 0.0  # Placeholder
        
#         round_metrics.append({
#             'round': round_idx + 1,
#             'num_labeled_samples': len(labeled_indices),
#             'val_miou': val_miou
#         })
        
#         # Save round information
#         pd.DataFrame([round_info]).to_csv(
#             os.path.join(output_dir, "selected_samples", f"round_{round_idx+1}.csv"),
#             index=False
#         )
        
#         # Save model for this round
#         torch.save(model.state_dict(), 
#                   os.path.join(output_dir, "models", f"model_round_{round_idx+1}.pth"))
    
#     # Save final metrics
#     pd.DataFrame(round_metrics).to_csv(
#         os.path.join(output_dir, "active_learning_metrics.csv"),
#         index=False
#     )
    
#     # Plot learning curve
#     if round_metrics:
#         plt.figure(figsize=(10, 6))
#         rounds = [m['round'] for m in round_metrics]
#         mious = [m['val_miou'] for m in round_metrics]
#         plt.plot(rounds, mious, 'o-', linewidth=2)
#         plt.title('Active Learning Progress', fontsize=16)
#         plt.xlabel('Round', fontsize=14)
#         plt.ylabel('Validation mIoU', fontsize=14)
#         plt.grid(True)
#         plt.savefig(os.path.join(output_dir, "learning_curve.png"), dpi=300)
#         plt.close()
    
#     print("Active Learning completed successfully!")
#     return model, labeled_indices

# # Function to train model for a few epochs
# def train_model(model, dataset, epochs=5, batch_size=4):
#     """Train model for a few epochs on the given dataset"""
#     train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
#     criterion = nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
#     model.train()
#     for epoch in range(epochs):
#         running_loss = 0.0
#         for i, (imgs, labels, _) in enumerate(train_loader):
#             imgs = imgs.to(device)
#             labels = labels.to(device)
            
#             optimizer.zero_grad()
#             outputs = model(imgs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
            
#             running_loss += loss.item()
            
#             if (i+1) % 10 == 0:
#                 print(f"Epoch {epoch+1}/{epochs}, Batch {i+1}, Loss: {running_loss/10:.4f}")
#                 running_loss = 0.0
                
#     return model

# # Function to evaluate model performance
# def evaluate_model(model, dataset):
#     """Evaluate model on dataset and return mIoU"""
#     val_loader = DataLoader(dataset, batch_size=1, shuffle=False)
#     model.eval()
    
#     overall_iou_per_class = np.zeros(num_classes)
#     total_images = 0
    
#     with torch.no_grad():
#         for img, lbl, _ in val_loader:
#             img = img.to(device)
#             output = model(img)
#             probs = F.softmax(output, dim=1).cpu().numpy()
#             pred_mask = np.argmax(probs, axis=1)[0]
#             lbl_np = lbl.cpu().numpy()[0]
            
#             # Compute IoU
#             miou, classwise_iou = compute_miou(pred_mask, lbl_np)
#             overall_iou_per_class += np.nan_to_num(classwise_iou)
#             total_images += 1
    
#     avg_miou = np.nanmean(overall_iou_per_class) / total_images
#     return avg_miou

# # Compute mIoU
# def compute_miou(preds, targets):
#     iou_per_class = np.zeros(num_classes)
#     preds = preds.flatten()
#     targets = targets.flatten()
#     for cls in range(num_classes):
#         if (targets == cls).sum() == 0:
#             iou_per_class[cls] = np.nan
#             continue
#         iou_per_class[cls] = jaccard_score(targets == cls, preds == cls)
#     return np.nanmean(iou_per_class), iou_per_class

# # Example usage
# if __name__ == "__main__":
#     initial_model_path = "base_unet_model.pth"
#     unlabeled_img_dir = "unlabeled/images"
#     labeled_img_dir = "labeled/images"
#     labeled_label_dir = "labeled/labels"
    
#     # Start active learning process
#     train_active_learning(
#         initial_model_path=initial_model_path,
#         unlabeled_img_dir=unlabeled_img_dir,
#         labeled_img_dir=labeled_img_dir,
#         labeled_label_dir=labeled_label_dir,
#         initial_labeled_samples=50,  # Start with 50 random samples
#         active_learning_rounds=5,    # Do 5 rounds of active learning
#         samples_per_round=10,        # Select 10 samples per round
#         epochs_per_round=5,          # Train for 5 epochs in each round
#         alpha=0.7                    # Balance between uncertainty (0.7) and diversity (0.3)
#     )