# Visualization for Tree Crown Segmentation

**Objective:** Load the `best_checkpoint` from the `TrueResSegformer_mit_b5_min_lr_0` run, re-create the exact validation split used during that training, run inference on a subset of these validation samples, and visualize the results with visualizations optimized for binary tree crown segmentation.

## Key Enhancements:
- **Natural colors**: Green for trees (instead of cyan)
- **Enhanced error analysis**: Clear TP/FP/FN visualization
- **Improved contrast**: Better visibility on various backgrounds
- **Comprehensive analysis panels**: Multiple visualization modes

## 1. Setup and Enhanced Imports

In [None]:
# Standard imports
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import logging
import json
import sys

# Define project root and add to sys.path for project-specific imports
project_root_abs = "/Users/fadil/Desktop/Git_Latex_Container"
if project_root_abs not in sys.path:
    sys.path.append(project_root_abs)

# Project-specific imports (ensure these files are in project_root_abs or adjust paths)
from config import Config
from dataset import load_and_shuffle_dataset # TCDDataset might be used by this function
from checkpoint import load_model_for_evaluation

# Enhanced visualization imports
from visualization import (
    visualize_segmentation,
    tensor_to_image,
    visualize_boundary_iou_components,
    # NEW ENHANCED FUNCTIONS:
    visualize_segmentation_enhanced,
    visualize_error_decomposition,
    plot_enhanced_segmentation_analysis,
    create_confidence_visualization
)
from utils import get_logger, set_seed

# Initialize logger
logger = get_logger()

# Set seed (use the same seed as the TrueResSegformer_mit_b5_min_lr_0 run)
SEED = 42 # From config.json of the run
set_seed(SEED)
logger.info(f"Seed set to {SEED}")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Figure saving configuration
SAVE_FIGURES = True
SAVE_DIR = os.path.join(project_root_abs, "past_run_repository", "outputs_individual_figures", "validation_enhanced")
os.makedirs(SAVE_DIR, exist_ok=True)
SAVE_DPI = 300
SAVE_FORMATS = ["png"]
SAVE_TRANSPARENT = False
logger.info(f"Saving figures to: {SAVE_DIR} (formats: {SAVE_FORMATS}, dpi: {SAVE_DPI})")

## 2. Configuration Loading

In [None]:
config_path = os.path.join(project_root_abs, "/Users/fadil/Desktop/Git_Latex_Container/past_run_repository/SETR_Ablation/setr_full_res/outputs_setr_full_res/config.json")
config_dict = None
if not os.path.exists(config_path):
    logger.error(f"Config file not found at {config_path}. Please ensure the path is correct.")
else:
    config_dict = Config.load(config_path) # Config.load returns a dict
    logger.info(f"Successfully loaded configuration from {config_path}")
    logger.info(f"Run Model Name: {config_dict.get('model_name')}")
    logger.info(f"Run Dataset Name: {config_dict.get('dataset_name')}")
    logger.info(f"Run Seed: {config_dict.get('seed')}")
    logger.info(f"Run Validation Split: {config_dict.get('validation_split')}")
    # Ensure the seed used for the notebook matches the run's seed
    if config_dict.get('seed') != SEED:
        logger.warning(f"Mismatch between notebook seed ({SEED}) and run config seed ({config_dict.get('seed')}). Using notebook seed for general operations, but run's seed for dataset splitting.")

## 3. Load Model from Best Checkpoint

In [None]:
training_resolution = {"height": 1024, "width": 1024}

model_checkpoint_path = os.path.join(project_root_abs, "/Users/fadil/Desktop/Git_Latex_Container/past_run_repository/SETR_Ablation/setr_full_res/outputs_setr_full_res/best_checkpoint")

logger.info(f"Attempting to load model from: {model_checkpoint_path}")

model, image_processor = None, None
if not config_dict:
    logger.error("Config dictionary not loaded. Cannot load model.")
elif not os.path.exists(os.path.join(model_checkpoint_path, 'pytorch_model.bin')):
    logger.error(f"Checkpoint 'pytorch_model.bin' not found in {model_checkpoint_path}. Please ensure the path is correct.")
else:
    model, image_processor = load_model_for_evaluation(
        model_path=model_checkpoint_path,
        config=config_dict, # Pass the loaded run-specific config dictionary
        device=device,
        logger=logger
    )
    if model and image_processor:
        model.eval()
        logger.info("Model and image processor loaded successfully and model set to evaluation mode.")
    else:
        logger.error("Failed to load model and image processor.")

## 4. Re-create the Validation Dataset Split

In [None]:
from datasets import Dataset # For Hugging Face datasets.Dataset object

validation_hf_dataset = None
if config_dict:
    run_seed = config_dict['seed'] # IMPORTANT: Use the seed from the specific run's config for splitting
    logger.info(f"Loading dataset: {config_dict['dataset_name']} with seed for shuffling: {run_seed}")
    full_dataset_dict = load_and_shuffle_dataset(
        dataset_name=config_dict["dataset_name"],
        seed=run_seed # Use the seed from the loaded config for initial shuffle
    )

    original_train_dataset = full_dataset_dict["train"]
    logger.info(f"Original 'train' split (before val split) has {len(original_train_dataset)} samples.")

    validation_split_ratio = config_dict["validation_split"]
    num_total_original_train_samples = len(original_train_dataset)
    num_val_samples = int(validation_split_ratio * num_total_original_train_samples)

    # Ensure consistent split by using the same seed as the original run for randperm
    generator = torch.Generator().manual_seed(run_seed)
    indices = torch.randperm(num_total_original_train_samples, generator=generator).tolist()

    val_indices = indices[:num_val_samples]

    # Create a Hugging Face Dataset object for the validation split
    validation_hf_dataset = original_train_dataset.select(val_indices)

    logger.info(f"Re-created validation split with {len(validation_hf_dataset)} samples using seed {run_seed} for permutation.")
else:
    logger.error("Config dictionary not loaded. Cannot re-create dataset split.")

## 5. Enhanced Visualization for Tree Crown Segmentation

This section provides multiple visualization modes optimized for binary tree crown segmentation analysis.

In [None]:
if model and image_processor and validation_hf_dataset and config_dict:
    num_to_visualize = min(10, len(validation_hf_dataset)) # Visualize up to 5 samples
    logger.info(f"Will visualize {num_to_visualize} samples with enhanced visualizations.")

    # Get id2label from config, provide a fallback if not present
    id2label = config_dict.get('id2label', {0: 'background', 1: 'tree_crown'}) 
    # Ensure keys are integers for visualization function
    if id2label and isinstance(next(iter(id2label.keys())), str):
        id2label = {int(k): v for k, v in id2label.items()}

    for i in range(num_to_visualize):
        logger.info(f"Processing sample {i+1}/{num_to_visualize}")
        raw_sample = validation_hf_dataset[i]

        pil_image = raw_sample['image'].convert("RGB") # Ensure image is RGB
        gt_mask_pil = raw_sample.get('annotation', raw_sample.get('label')) 
        if gt_mask_pil is None:
            logger.error(f"Ground truth mask not found for sample {i}. Skipping.")
            continue
        gt_mask_np = np.array(gt_mask_pil)

        original_gt_shape = gt_mask_np.shape 
        logger.info(f"Initial gt_mask_np shape: {original_gt_shape}, unique values: {np.unique(gt_mask_np)[:20]}")

        # Process ground truth mask
        if gt_mask_np.ndim == 3 and gt_mask_np.shape[-1] == 3:
            logger.info(f"Ground truth mask is 3-channel. Converting to binary mask.")
            gt_mask_np = (gt_mask_np.sum(axis=-1) > 0).astype(np.uint8)
        elif gt_mask_np.ndim == 3 and gt_mask_np.shape[-1] == 1:
            logger.info(f"Ground truth mask is single-channel 3D. Squeezing to 2D.")
            gt_mask_np = np.squeeze(gt_mask_np, axis=-1)
        elif gt_mask_np.ndim == 2:
            logger.info(f"Ground truth mask is already 2D. Shape: {gt_mask_np.shape}")

        # Remap 255 values to 1 for binary case
        if 1 in id2label and np.any(gt_mask_np == 255):
            logger.info(f"Remapping gt_mask_np values of 255 to 1 for tree crown class.")
            gt_mask_np[gt_mask_np == 255] = 1
        
        # Perform Prediction
        training_image_size_config = config_dict.get("image_size", 1024)
        if isinstance(training_image_size_config, int):
            target_size = {"height": training_image_size_config, "width": training_image_size_config}
        else:
            target_size = {"height": 1024, "width": 1024}

        logger.info(f"Resizing input image from {pil_image.size} to: {target_size} for model inference.")
        inputs = image_processor(
            images=pil_image, 
            return_tensors="pt",
            do_resize=True,
            size=target_size,
            resample=Image.Resampling.BILINEAR,
            do_rescale=True,
            do_normalize=True
        ).to(device)

        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits

        original_h, original_w = pil_image.height, pil_image.width
        upsampled_logits = torch.nn.functional.interpolate(
            logits,
            size=(original_h, original_w),
            mode=config_dict.get("interpolation_mode", "bilinear"),
            align_corners=config_dict.get("interpolation_align_corners", False)
        )
        
        predicted_mask_np = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()
        
        logger.info(f"Prediction complete. Unique values: {np.unique(predicted_mask_np)}")

        # Convert PIL image to numpy for visualization
        image_np = np.array(pil_image)

        print(f"\n{'='*60}")
        print(f"SAMPLE {i+1} - TREE CROWN SEGMENTATION ANALYSIS")
        print(f"{'='*60}")

        # === 1. COMPREHENSIVE ENHANCED ANALYSIS PANEL ===
        logger.info(f"Creating comprehensive analysis panel for sample {i+1}")
        plot_enhanced_segmentation_analysis(
            image=image_np,
            prediction=predicted_mask_np,
            ground_truth=gt_mask_np,
            sample_id=f"Sample {i+1}",
            figsize=(25, 12),
            id2label=id2label,
            include_boundary_analysis=True,
            include_error_decomposition=True
        )
        # Skip saving comprehensive (multi-panel) figure; focus on individual PNGs only
        fig_comp = plt.gcf()
        plt.show()

        # === 2. VISUALIZATION (INDIVIDUAL FIGURES) ===
        print("\nVisualization (Individual Figures):")
        sample_prefix = f"sample_{i+1:03d}"

        # Original image (individual)
        fig_o, ax_o = plt.subplots(1, 1, figsize=(8, 8))
        ax_o.imshow(image_np)
        ax_o.set_title("Original Image", fontsize=14)
        ax_o.axis('off')
        if SAVE_FIGURES:
            for fmt in SAVE_FORMATS:
                out_path = os.path.join(SAVE_DIR, f"{sample_prefix}_original.{fmt}")
                fig_o.savefig(out_path, dpi=SAVE_DPI, bbox_inches='tight', pad_inches=0.05)
        plt.tight_layout()
        plt.show()

        # Enhanced GT (individual)
        gt_enhanced = visualize_segmentation_enhanced(
            image_np.copy(), gt_mask_np, id2label=id2label, 
            alpha=0.6, use_natural_colors=True, high_contrast=True
        )
        fig_gt, ax_gt = plt.subplots(1, 1, figsize=(8, 8))
        ax_gt.imshow(gt_enhanced)
        ax_gt.set_title("Ground Truth (Overlay)", fontsize=14)
        ax_gt.axis('off')
        if SAVE_FIGURES:
            for fmt in SAVE_FORMATS:
                out_path = os.path.join(SAVE_DIR, f"{sample_prefix}_gt_overlay.{fmt}")
                fig_gt.savefig(out_path, dpi=SAVE_DPI, bbox_inches='tight', pad_inches=0.05)
        plt.tight_layout()
        plt.show()

        # Enhanced Prediction (individual)
        pred_enhanced = visualize_segmentation_enhanced(
            image_np.copy(), predicted_mask_np, id2label=id2label,
            alpha=0.6, use_natural_colors=True, high_contrast=True
        )
        fig_pr, ax_pr = plt.subplots(1, 1, figsize=(8, 8))
        ax_pr.imshow(pred_enhanced)
        ax_pr.set_title("Prediction (Overlay)", fontsize=14)
        ax_pr.axis('off')
        if SAVE_FIGURES:
            for fmt in SAVE_FORMATS:
                out_path = os.path.join(SAVE_DIR, f"{sample_prefix}_prediction_overlay.{fmt}")
                fig_pr.savefig(out_path, dpi=SAVE_DPI, bbox_inches='tight', pad_inches=0.05)
        plt.tight_layout()
        plt.show()

        # === 3. DETAILED ERROR ANALYSIS (INDIVIDUAL FIGURES) ===
        print("\nDetailed Error Analysis (Individual):")
        # Error decomposition (individual)
        error_vis = visualize_error_decomposition(
            image_np.copy(), predicted_mask_np, gt_mask_np, 
            alpha=0.8, use_natural_colors=True
        )
        fig_err, ax_err = plt.subplots(1, 1, figsize=(8, 8))
        ax_err.imshow(error_vis)
        ax_err.set_title("Error Analysis\nGreen: Correct Trees, Yellow: False Alarms, Red: Missed Trees", fontsize=14)
        ax_err.axis('off')
        if SAVE_FIGURES:
            for fmt in SAVE_FORMATS:
                out_path = os.path.join(SAVE_DIR, f"{sample_prefix}_error_decomposition.{fmt}")
                fig_err.savefig(out_path, dpi=SAVE_DPI, bbox_inches='tight', pad_inches=0.05)
        plt.tight_layout()
        plt.show()

        # Boundary analysis (individual)
        boundary_vis = visualize_boundary_iou_components(
            image_np.copy(), gt_mask_np, predicted_mask_np,
            dilation_pixels=5, alpha=0.8
        )
        fig_ba, ax_ba = plt.subplots(1, 1, figsize=(8, 8))
        ax_ba.imshow(boundary_vis)
        ax_ba.set_title("Boundary Analysis\nGreen: Correct Boundaries, Yellow: False Boundaries, Red: Missed Boundaries", fontsize=14)
        ax_ba.axis('off')
        if SAVE_FIGURES:
            for fmt in SAVE_FORMATS:
                out_path = os.path.join(SAVE_DIR, f"{sample_prefix}_boundary_analysis.{fmt}")
                fig_ba.savefig(out_path, dpi=SAVE_DPI, bbox_inches='tight', pad_inches=0.05)
        plt.tight_layout()
        plt.show()

        # === 4. GROUND TRUTH BOUNDARY (INDIVIDUAL FIGURE) ===
        logger.info(f"Rendering ground truth boundary overlay for sample {i+1}")
        # Ensure binary mask (0/1)
        gt_binary = (gt_mask_np == 1).astype(np.uint8)
        fig_gb, ax_gb = plt.subplots(1, 1, figsize=(8, 8))
        ax_gb.imshow(image_np)
        ax_gb.contour(gt_binary, levels=[0.5], colors='lime', linewidths=2)
        ax_gb.set_title(f"Sample {i+1} - Ground Truth Boundary (Overlay)", fontsize=14)
        ax_gb.axis('off')
        if SAVE_FIGURES:
            for fmt in SAVE_FORMATS:
                out_path = os.path.join(SAVE_DIR, f"{sample_prefix}_gt_boundary_contour.{fmt}")
                fig_gb.savefig(out_path, dpi=SAVE_DPI, bbox_inches='tight', pad_inches=0.05)
        plt.tight_layout()
        plt.show()

        print(f"\n{'='*60}\n")

else:
    logger.error("Cannot proceed with visualization. Ensure Model, Image Processor, Validation Dataset, and Config are available.")

## Summary of Visualizations

This notebook provides several key improvements for binary tree crown segmentation analysis:

### 🎨 **Visual Enhancements:**
- **Natural Colors**: Green for trees instead of generic cyan/blue
- **High Contrast**: Better visibility against various backgrounds
- **Optimized Alpha Blending**: Different transparency levels for different information types

### 📊 **Analysis Capabilities:**
- **Comprehensive Analysis Panel**: 5-panel view with original, GT, prediction, error decomposition, and boundary analysis
- **Error Decomposition**: Clear visualization of TP (green), FP (red), FN (yellow)
- **Boundary Analysis**: Focus on edge/boundary errors critical for tree crown delineation
- **Enhanced Visualization**: Clean single-row visualization comparison

### 🎯 **Tree Crown Specific Features:**
- Optimized for binary segmentation (background vs tree crown)
- Natural forest/vegetation color scheme
- Enhanced boundary focus (critical for crown delineation)
- Intuitive error color coding for forestry applications