# BARE Pipeline Implementation

This notebook implements a complete end-to-end pipeline for training, evaluating, and deploying the BARE model for tree crown delineation using the refactored codebase.

## Pipeline Overview

1. Setup & Configuration
2. Dataset Loading & Exploration
3. Model Creation
4. Training Pipeline
5. Evaluation & Metrics
6. Prediction on New Images
7. Advanced Visualization & Analysis

In [None]:
import os
import torch
import logging
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import warnings
import json
import time
import glob

# Ignore specific warnings if needed (e.g., from PIL)
warnings.filterwarnings("ignore", category=UserWarning, module='PIL')

# Import necessary modules from the codebase
from config import Config, adjust_config_for_architecture
from dataset import load_and_shuffle_dataset, create_dataloaders
from model import create_model
from pipeline import run_training_pipeline, evaluate_model, run_prediction_pipeline
from visualization import (
    plot_worst_predictions, 
    plot_confusion_matrix, 
    plot_error_analysis_map, 
    visualize_segmentation
)
from metrics import calculate_metrics # For potential re-evaluation
from utils import get_logger, set_seed
from checkpoint import load_model_for_evaluation

# Ensure plots are displayed inline in the notebook
%matplotlib inline

In [None]:
# Setup logger
logger = get_logger()

# Instantiate the configuration using defaults
config = Config()

# Adjust configuration based on the architecture specified in config.py
config = adjust_config_for_architecture(config)

logger.info(f"Configuration loaded for '{config['architecture']}' architecture.")
logger.info(f"Output directory: {config['output_dir']}")
logger.info(f"Using seed: {config['seed']}")

In [None]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Available CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 1. Setup & Configuration

### Load Configuration and Setup Logging

We'll instantiate the `Config` object using the defaults defined in `config.py` (which now match the reference model) and set up logging.

### Save config to JSON

In [None]:
# Convert config to a dictionary
if hasattr(config, 'to_dict'):
    config_dict = config.to_dict()
elif isinstance(config, dict):
    config_dict = config
else:
    config_dict = dict(config)

# Define the path
config_save_path = os.path.join(config['output_dir'], 'config.json')

# Ensure the output directory exists
os.makedirs(config['output_dir'], exist_ok=True)

# Save the configuration as JSON
with open(config_save_path, 'w') as f:
    json.dump(config_dict, f, indent=4)

logger.info(f"Configuration saved to {config_save_path}")

## 2. Dataset Loading & Exploration

Let's load the Tree Crown Delineation (TCD) dataset and explore its characteristics.

Load the dataset using `load_and_shuffle_dataset` for consistency and then create the dataloaders using `create_dataloaders`. We'll also display a sample.

In [None]:
# Load and shuffle the dataset
logger.info(f"Loading dataset: {config['dataset_name']}")
dataset_dict = load_and_shuffle_dataset(
    dataset_name=config['dataset_name'], 
    seed=config['seed']
)

# --- CORRECTED LOGIC FOR SETR ---
# For SETR, we pass target_size=None to create_dataloaders to ensure no resizing is done at the dataset level.
# The native 1024x1024 processing is handled inside the SETRWrapper.
target_size = None
if config['architecture'] == 'setr':
    logger.info(f"Architecture is SETR. No resizing at data loading stage to allow native {config['setr_input_size']}px processing.")
else:
    logger.info(f"Architecture is '{config['architecture']}'. No final resizing will be applied.")
# --- END CORRECTION ---

# Create dataloaders
logger.info("Creating dataloaders...")
train_dataloader, eval_dataloader, id2label, label2id = create_dataloaders(
    dataset_dict=dataset_dict,
    image_processor=None,  # Will be created internally with do_resize=False
    config=config, 
    train_batch_size=config['train_batch_size'],
    eval_batch_size=config['eval_batch_size'],
    num_workers=config['num_workers'],
    validation_split=config['validation_split'],
    seed=config['seed'],
    target_size=target_size  # This will be None for both architectures now
)

# Update config with actual id2label mapping (important for model creation)
config['id2label'] = id2label
config['label2id'] = label2id

logger.info(f"Dataloaders created. Train batches: {len(train_dataloader)}, Eval batches: {len(eval_dataloader)}")
logger.info(f"Class mapping: {id2label}")

# Display a sample from the training dataloader
try:
    batch = next(iter(train_dataloader))
    img_tensor = batch['pixel_values'][0]
    lbl_tensor = batch['labels'][0]
    
    # Convert tensor back to displayable image (denormalize if needed - processor handles this)
    # Assuming processor output is normalized [-1, 1] or [0, 1]
    img_display = img_tensor.permute(1, 2, 0).cpu().numpy()
    # Denormalize based on typical ImageNet stats used by SegFormer processor
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_display = std * img_display + mean
    img_display = np.clip(img_display, 0, 1)
    
    lbl_display = lbl_tensor.cpu().numpy()
    
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(img_display)
    ax[0].set_title('Sample Image (Augmented)')
    ax[0].axis('off')
    ax[1].imshow(lbl_display, cmap='gray') # Assuming binary mask
    ax[1].set_title('Sample Mask (Augmented)')
    ax[1].axis('off')
    plt.show()
except Exception as e:
    logger.warning(f"Could not display sample batch: {e}")

## 3. Model Creation

Now we'll create the SegFormer model using the configuration.

In [None]:
# Calculate number of training steps for scheduler (if needed, though ReduceLROnPlateau doesn't use it initially)
num_training_steps = len(train_dataloader) * config['num_epochs'] // config['gradient_accumulation_steps']

# Create model, optimizer, and scheduler
logger.info("Creating model...")
model, optimizer, scheduler = create_model(
    config=config, 
    num_training_steps=num_training_steps,
    logger=logger # Pass the logger
)

# Print model summary (optional)
from torchinfo import summary

# --- ARCHITECTURE-AWARE INPUT SIZE FOR SUMMARY ---
if config['architecture'] == 'setr':
    input_size_h = config['setr_input_size']
    input_size_w = config['setr_input_size']
    logger.info(f"Using SETR input size for model summary: {input_size_h}x{input_size_w}")
else:
    input_size_h = config['augmentation']['random_crop_size']
    input_size_w = config['augmentation']['random_crop_size']
    logger.info(f"Using augmentation crop size for model summary: {input_size_h}x{input_size_w}")

input_size = (config['train_batch_size'], 3, input_size_h, input_size_w) 
try:
    summary(model, input_size=input_size)
except Exception as e:
    logger.warning(f"Could not print model summary: {e}")

logger.info("Model, optimizer, and scheduler created.")

## 4. Training Pipeline

Run the training using the `run_training_pipeline` function, which handles the complete loop, evaluation during training, and checkpointing.

In [None]:
# Run the training pipeline
logger.info("Starting training pipeline...")
training_start_time = time.time()

training_results = run_training_pipeline(
    config=config,
    logger=logger,
    is_notebook=True # Let the pipeline know it's running in a notebook for tqdm compatibility
)

# Calculate training duration
training_duration = time.time() - training_start_time
print(f"\nTraining finished in {training_duration / 60:.2f} minutes.")

# The trained model is available in training_results['model']
# Final metrics are in training_results['metrics']
# Location of saved model artifacts is in training_results['model_dir']
logger.info(f"Training finished. Final model saved to: {training_results['model_dir']}")
logger.info(f"Final evaluation metrics from training: {training_results['metrics']}")

### Computional Cost Analysis

In [None]:
# Computational Cost Reporting
logger.info("\n--- Computational Cost Analysis ---")

# 1. Model Parameters
if 'model' in training_results and training_results['model'] is not None:
    model = training_results['model']
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Total Model Parameters: {total_params:,}")
    logger.info(f"Trainable Model Parameters: {trainable_params:,}")
    print(f"Total Model Parameters: {total_params:,}")
    print(f"Trainable Model Parameters: {trainable_params:,}")
else:
    logger.warning("Model not found in training_results. Skipping parameter count.")
    print("Model not found in training_results. Skipping parameter count.")

# 2. FLOPs Estimation (using thop library)
try:
    from thop import profile
    if 'model' in training_results and training_results['model'] is not None:
        model_device = next(model.parameters()).device # Get device from model
        
        # --- ARCHITECTURE-AWARE INPUT SIZE FOR FLOPS ---
        if config['architecture'] == 'setr':
            H = W = config['setr_input_size']
            logger.info(f"Using SETR input size for FLOPs calculation: {H}x{W}")
        else:
            H = W = config['augmentation']['random_crop_size']
            logger.info(f"Using augmentation crop size for FLOPs calculation: {H}x{W}")
        
        input_channels = 3
        
        # Create a dummy input tensor on the same device as the model
        dummy_input = torch.randn(1, input_channels, H, W).to(model_device)
        
        macs, params_thop = profile(model, inputs=(dummy_input,), verbose=False)
        gflops = macs * 2 / 1e9  # MACs to FLOPs (multiply by 2) and then to GFLOPs
        logger.info(f"Estimated GFLOPs: {gflops:.2f} GFLOPs")
        print(f"Estimated GFLOPs: {gflops:.2f} GFLOPs (using input size {input_channels}x{H}x{W})")
    else:
        logger.warning("Model not found for FLOPs calculation.")
        print("Model not found for FLOPs calculation.")
except ImportError:
    logger.warning("thop library not found. Skipping FLOPs calculation. Install with 'pip install thop'")
    print("thop library not found. Skipping FLOPs calculation. Install with 'pip install thop'")
except Exception as e:
    logger.error(f"Error during FLOPs calculation: {e}")
    print(f"Error during FLOPs calculation: {e}")

# 3. Inference Time
logger.info("Attempting to measure inference time...")
if 'model' in training_results and training_results['model'] is not None:
    model.eval() # Set model to evaluation mode
    val_dataloader = None # Initialize to None

    if 'eval_dataloader' in locals() or 'eval_dataloader' in globals():
        val_dataloader = eval_dataloader 
        logger.info(f"Using existing 'eval_dataloader' for inference time measurement. Batches: {len(val_dataloader)}")
    else:
        logger.warning("'eval_dataloader' not found. Skipping inference time measurement.")

    if val_dataloader:
        total_inference_time = 0
        num_samples = 0
        num_batches = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)

        with torch.no_grad():
            for batch in val_dataloader:
                if isinstance(batch, dict) and 'pixel_values' in batch:
                    inputs = batch['pixel_values'].to(device)
                elif isinstance(batch, (list, tuple)):
                    inputs = batch[0].to(device)
                else:
                    inputs = batch.to(device)

                start_time = time.time()
                _ = model(inputs)
                if device.type == 'cuda': torch.cuda.synchronize()
                end_time = time.time()
                
                total_inference_time += (end_time - start_time)
                num_samples += inputs.size(0)
                num_batches += 1
                if num_batches >= 20: # Limit for estimation
                    logger.info("Inference timing based on the first 20 batches.")
                    break 
        
        if num_batches > 0:
            avg_time_per_batch = total_inference_time / num_batches
            avg_time_per_sample = total_inference_time / num_samples
            logger.info(f"Average Inference Time per Batch: {avg_time_per_batch:.4f} seconds")
            logger.info(f"Average Inference Time per Sample: {avg_time_per_sample:.4f} seconds")
            print(f"Average Inference Time per Batch: {avg_time_per_batch:.4f} seconds")
            print(f"Average Inference Time per Sample: {avg_time_per_sample:.4f} seconds")
        else:
            logger.warning("No batches processed. Cannot calculate inference time.")
            print("WARNING: No batches processed. Cannot calculate inference time.")
else:
    logger.warning("Model not found. Skipping inference time calculation.")
    print("Model not found. Skipping inference time calculation.")

# 4. GPU Memory Usage
if torch.cuda.is_available():
    peak_memory_allocated = torch.cuda.max_memory_allocated() / (1024**2)
    peak_memory_reserved = torch.cuda.max_memory_reserved() / (1024**2)
    logger.info(f"Peak GPU Memory Allocated: {peak_memory_allocated:.2f} MB")
    logger.info(f"Peak GPU Memory Reserved: {peak_memory_reserved:.2f} MB")
    print(f"Peak GPU Memory Allocated: {peak_memory_allocated:.2f} MB")
    print(f"Peak GPU Memory Reserved: {peak_memory_reserved:.2f} MB")
else:
    logger.info("CUDA not available. Skipping GPU memory usage reporting.")

logger.info("--- End of Computational Cost Analysis ---")

## 5. Evaluation & Metrics

The `run_training_pipeline` already performs evaluation. Here, we can explicitly re-evaluate the final saved model and visualize metrics like the confusion matrix.

In [None]:
# Use this function to evaluate either a specific checkpoint or run the best model

# Either use the best model path returned by training
best_model_path = training_results.get('best_model_dir') or training_results['model_dir']
logger.info(f"Evaluating best model from training: {best_model_path}")

# Or specify an explicit path to a previous model checkpoint
# best_model_path = "past_run_repository/outputs_pspnet/best_checkpoint"

# Load model with the new function - handles both standard and BARESegformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_eval, image_processor_eval = load_model_for_evaluation(
    model_path=best_model_path,
    config=config,  # Pass your current config
    device=device,
    logger=logger
)

# Create dataloader for evaluation
_, eval_dataloader_reeval, _, _ = create_dataloaders(
    dataset_dict=dataset_dict,  # Assuming this is available from earlier
    image_processor=image_processor_eval,
    config=config,
    train_batch_size=config['train_batch_size'],
    eval_batch_size=config['eval_batch_size'],
    num_workers=config['num_workers'],
    validation_split=config['validation_split'],
    seed=config['seed']
)

# Evaluate the model
logger.info("Running evaluation on the loaded model...")
metrics = evaluate_model(
    model=model_eval,
    eval_dataloader=eval_dataloader_reeval,
    device=device,
    output_dir=config['output_dir'],
    id2label=config['id2label'],
    visualize_worst=True,
    num_worst_samples=5,
    visualize_confidence_comparison=False,
    analyze_errors=True,
    logger=logger,
    is_notebook=True
)

# Print the evaluation results
logger.info(f"Evaluation metrics: {metrics}")

# If you want to visualize the confusion matrix
import matplotlib.pyplot as plt
from PIL import Image
import os

cm_norm_path = os.path.join(config['output_dir'], "confusion_matrix_normalized.png")
if os.path.exists(cm_norm_path):
    plt.figure(figsize=(8, 8))
    plt.imshow(Image.open(cm_norm_path))
    plt.title("Normalized Confusion Matrix")
    plt.axis('off')
    plt.show()

## 6. Prediction on New Images

Use the `run_prediction_pipeline` function to predict segmentation masks on new images. We'll need some sample images for this.

In [None]:
# --- Prediction Setup ---
# Set path to the actual image file
image_path = "test_image.tif"

# Check if image exists
if not os.path.exists(image_path):
    logger.error(f"Image not found at {image_path}. Please verify the path.")
else:
    logger.info(f"Found image at {image_path}. Proceeding with prediction...")
    
    # Define prediction output directory
    prediction_output_dir = os.path.join(config['output_dir'], 'predictions')
    
    # Run prediction pipeline
    try:
        logger.info(f"Running prediction using model from {training_results['model_dir']}...")
        
        prediction_results = run_prediction_pipeline(
            config=config,
            image_paths=image_path,  # Just passing the single image path
            model_path=training_results['model_dir'],  # Use the trained model directory
            output_dir=prediction_output_dir,
            visualize=True,
            show_confidence=True,  # Show confidence maps
            logger=logger,
            is_notebook=True
        )
        
        # Display the prediction visualization
        if prediction_results['visualizations']:
            plt.figure(figsize=(12, 6))
            plt.imshow(prediction_results['visualizations'][0])
            plt.title(f"Prediction: {os.path.basename(image_path)}")
            plt.axis('off')
            plt.show()
            
        # Display the confidence map if generated
        if prediction_results.get('confidence_maps') and prediction_results['confidence_maps']:
            plt.figure(figsize=(12, 6))
            plt.imshow(prediction_results['confidence_maps'][0])
            plt.title(f"Confidence Map: {os.path.basename(image_path)}")
            plt.axis('off')
            plt.show()
            
        logger.info(f"Prediction finished. Visualizations saved to: {prediction_output_dir}")
        
    except Exception as e:
        logger.error(f"Prediction pipeline failed: {e}", exc_info=True)

## 7. Advanced Visualization & Analysis

Explore advanced visualizations like error analysis maps generated during evaluation.

In [None]:
# Display saved error analysis map for one of the worst samples (if generated)
error_analysis_dir = os.path.join(config['output_dir'], 'error_analysis')
if os.path.exists(error_analysis_dir):
    # Use glob to find any error map files rather than hardcoding a specific name
    error_maps = glob.glob(os.path.join(error_analysis_dir, 'error_map_*.png'))
    
    if error_maps:
        error_map_to_show = error_maps[0]  # Show the first one found
        logger.info(f"Displaying error analysis map: {error_map_to_show}")
        plt.figure(figsize=(10, 10))
        plt.imshow(Image.open(error_map_to_show))
        plt.title(f"Error Analysis Map: {os.path.basename(error_map_to_show)}")
        plt.axis('off')
        plt.show()
    else:
        logger.warning(f"No error analysis maps found in {error_analysis_dir}.")
else:
    logger.warning(f"Error analysis directory not found at {error_analysis_dir}. Ensure 'analyze_errors=True' during evaluation.")

# Display saved worst prediction visualization (if generated)
worst_pred_dir = os.path.join(config['output_dir'], 'worst_predictions')
if os.path.exists(worst_pred_dir):
    # Find all worst prediction files
    worst_files = glob.glob(os.path.join(worst_pred_dir, 'worst_*.png'))
    
    if worst_files:
        # Sort files to get a consistent order (optional)
        worst_files.sort()
        
        # Show the top 3 worst predictions (or fewer if less are available)
        num_to_show = min(3, len(worst_files))
        
        if num_to_show > 1:
            fig, axes = plt.subplots(1, num_to_show, figsize=(15, 5))
            for i in range(num_to_show):
                axes[i].imshow(Image.open(worst_files[i]))
                axes[i].set_title(f"Worst #{i+1}")
                axes[i].axis('off')
            plt.tight_layout()
            plt.show()
        else:
            # For a single image
            plt.figure(figsize=(10, 5))
            plt.imshow(Image.open(worst_files[0]))
            plt.title(f"Worst Prediction: {os.path.basename(worst_files[0])}")
            plt.axis('off')
            plt.show()
            
        logger.info(f"Displayed {num_to_show} worst prediction visualizations from {worst_pred_dir}")
    else:
        logger.warning(f"No worst prediction visualizations found in {worst_pred_dir}.")
else:
    logger.warning(f"Worst predictions directory not found at {worst_pred_dir}.")

## End of Pipeline

This notebook demonstrated the complete pipeline for training, evaluating, and predicting with the TCD-BARE model using the refactored codebase. All artifacts (configuration, checkpoints, metrics, visualizations) are saved in the directory specified by `config['output_dir']` (default: `./outputs`).