In [1]:
import os
import re
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from datetime import datetime
from functools import lru_cache
import pandas as pd
from IPython.display import display, HTML
from typing import List, Dict  # Add this import at the top
from rasterio.warp import reproject, Resampling
from matplotlib.colors import LinearSegmentedColormap
from ipywidgets import interact, Dropdown, IntRangeSlider, Checkbox, fixed

In [2]:
parent_dir = r"M:\working_package_2\2024_dronecampaign\02_processing\metashape_projects\Upscale_Metashapeprojects\Pfynwald"

In [3]:
# show table with band information
# Band Information Table
def display_band_info():
    band_data = {
        "Band Number": ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B9", "B10"],
        "Band Name": ["Coastal Blue", "Blue", "Green-531", "Green", "Red-650", "Red", "Red Edge-705", "Red Edge-717", "Red Edge-740", "NIR"],
        "Wavelength (nm)": [444, 475, 531, 560, 650, 668, 705, 717, 740, 842],
        "Bandwidth (nm)": [28, 32, 14, 27, 16, 14, 10, 12, 18, 57]
    }
    band_table = pd.DataFrame(band_data)
    return band_table

display_band_info()

Unnamed: 0,Band Number,Band Name,Wavelength (nm),Bandwidth (nm)
0,B1,Coastal Blue,444,28
1,B2,Blue,475,32
2,B3,Green-531,531,14
3,B4,Green,560,27
4,B5,Red-650,650,16
5,B6,Red,668,14
6,B7,Red Edge-705,705,10
7,B8,Red Edge-717,717,12
8,B9,Red Edge-740,740,18
9,B10,NIR,842,57


# ==============================================
# 1. IMPROVED FILE FINDER (from previous answer)
# ==============================================

In [4]:
def find_raster_files(parent_dir: str) -> list[dict]:
    matches = []
    for root, _, files in os.walk(parent_dir):
        if "exports" not in root.split(os.sep):
            continue
            
        for file in files:
            if file.lower().endswith(".tif") and "multispec_ortho_100cm" in file.lower():
                path_parts = os.path.normpath(root).split(os.sep)
                try:
                    date = next(p for p in path_parts if p.isdigit() and len(p) == 8)
                    site_idx = path_parts.index("Upscale_Metashapeprojects") + 1
                    site = path_parts[site_idx] if site_idx < len(path_parts) else "unknown_site"
                except (ValueError, StopIteration):
                    date, site = "unknown_date", "unknown_site"
                
                matches.append({
                    "path": os.path.join(root, file),
                    "date": date,
                    "site": site,
                    "filename": file
                })
    return sorted(matches, key=lambda x: (x["date"], x["site"]))

# ==============================================
# 2. DATA LOADING WITH METADATA
# ==============================================

In [5]:
@lru_cache(maxsize=None)
def load_raster(filepath: str) -> dict:
    """Load raster data from filepath (modified for string input)"""
    with rasterio.open(filepath) as src:
        bands = src.read()
        profile = src.profile.copy()
    
    # Extract date from filename (e.g., "20240823_...")
    filename = os.path.basename(filepath)
    date = filename[:8] if filename[:8].isdigit() else "unknown_date"
    
    # Extract site from path
    path_parts = os.path.normpath(filepath).split(os.sep)
    try:
        site_idx = path_parts.index("Upscale_Metashapeprojects") + 1
        site = path_parts[site_idx] if site_idx < len(path_parts) else "unknown_site"
    except ValueError:
        site = "unknown_site"
    
    return {
        "path": filepath,
        "date": date,
        "site": site,
        "filename": filename,
        "bands": np.where(bands == 65535, np.nan, bands),  # NA handling
        "profile": profile,
        "scaled": False  # Flag for reflectance scaling
    }

def scale_bands(raster_dict: dict) -> dict:
    """Apply reflectance scaling if needed"""
    if not raster_dict["scaled"] and np.nanmax(raster_dict["bands"]) > 1.0:
        raster_dict["bands"] = raster_dict["bands"] / 32768
        raster_dict["scaled"] = True
    return raster_dict

# ==============================================
# 3. INDEX CALCULATION (NDVI/PRI)
# ==============================================

In [6]:
def calculate_indices(raster: dict) -> dict:
    bands = raster['bands']
    return {
        **raster,
        'ndvi': safe_division(bands[9] - bands[4], bands[9] + bands[4]),
        'pri': safe_division(bands[2] - bands[3], bands[2] + bands[3]),
        'bands_used': {
            'NDVI': {'NIR': 9, 'Red': 4},
            'PRI': {'531nm': 2, '570nm': 3}
        }
    }

def safe_division(num, denom, eps=1e-10):
    """Avoid division by zero"""
    return np.divide(num, denom + eps, where=(denom != 0))

In [7]:
def calculate_all_indices(raster: dict) -> dict:
    """Calculate multiple vegetation indices with band safety checks"""
    bands = raster['bands']
    return {
        **raster,
        'ndvi': safe_division(bands[9] - bands[4], bands[9] + bands[4]),  # (NIR-Red)/(NIR+Red)
        'pri': safe_division(bands[2] - bands[3], bands[2] + bands[3]),    # (531nm-560nm)/(531nm+560nm)
        'evi': 2.5 * safe_division(bands[9] - bands[4], bands[9] + 6*bands[4] - 7.5*bands[0] + 1),  # Enhanced Vegetation Index
        'ndwi': safe_division(bands[9] - bands[3], bands[9] + bands[3]),   # (NIR-Green)/(NIR+Green)
        'bands_used': {
            'NDVI': {'NIR': 9, 'Red': 4},
            'PRI': {'531nm': 2, '560nm': 3},
            'EVI': {'NIR': 9, 'Red': 4, 'Blue': 0},
            'NDWI': {'NIR': 9, 'Green': 3}
        }
    }

# ==============================================
# 4. CHANGE DETECTION PIPELINE
# ==============================================

In [8]:
def compute_temporal_changes(raster_results):
    """Compute both NDVI and PRI changes"""
    changes = []
    for i in range(1, len(raster_results)):
        prev, curr = raster_results[i-1], raster_results[i]
        
        changes.append({
            'from_date': prev['date'],
            'to_date': curr['date'],
            'site': curr.get('site', 'unknown_site'),
            'ndvi_change': curr['ndvi'] - prev['ndvi'],
            'pri_change': curr['pri'] - prev['pri'],
            'metadata': {
                'from_file': prev['path'],
                'to_file': curr['path'],
                'ndvi_bands': prev['bands_used']['NDVI'],
                'pri_bands': prev['bands_used']['PRI']
            }
        })
    return changes

In [9]:
def resample_to_target(src_array, src_profile, target_profile):
    """Resample source array to match target profile"""
    resampled = np.empty_like(src_array, shape=target_profile['shape'])
    reproject(
        src_array,
        resampled,
        src_transform=src_profile['transform'],
        src_crs=src_profile['crs'],
        dst_transform=target_profile['transform'],
        dst_crs=target_profile['crs'],
        resampling=Resampling.bilinear
    )
    return resampled

def ensure_consistent_shapes(raster_list):
    """Ensure all rasters have same dimensions by resampling to smallest common size"""
    if not raster_list:
        return raster_list
    
    # Find the smallest dimensions (most conservative approach)
    min_height = min(r['profile']['height'] for r in raster_list)
    min_width = min(r['profile']['width'] for r in raster_list)
    
    # Create target profile based on first raster (but with min dimensions)
    target_profile = raster_list[0]['profile'].copy()
    target_profile.update({
        'height': min_height,
        'width': min_width,
        'transform': rasterio.Affine(
            target_profile['transform'].a,
            target_profile['transform'].b,
            target_profile['transform'].c,
            target_profile['transform'].d,
            target_profile['transform'].e,
            target_profile['transform'].f
        )
    })
    
    # Resample all rasters to target
    for raster in tqdm(raster_list, desc="Resampling rasters"):
        if (raster['profile']['height'], raster['profile']['width']) != (min_height, min_width):
            resampled_bands = []
            for band in raster['bands']:
                resampled = np.empty((min_height, min_width), dtype=band.dtype)
                reproject(
                    band,
                    resampled,
                    src_transform=raster['profile']['transform'],
                    src_crs=raster['profile']['crs'],
                    dst_transform=target_profile['transform'],
                    dst_crs=target_profile['crs'],
                    resampling=Resampling.bilinear
                )
                resampled_bands.append(resampled)
            
            raster['bands'] = np.stack(resampled_bands)
            raster['profile'] = target_profile.copy()
    
    return raster_list

# ==============================================
# 5. INTERACTIVE VISUALIZATION
# ==============================================

In [10]:
def plot_dual_change_dashboard(changes):
    """Interactive dashboard for both NDVI and PRI changes"""
    # Create custom colormaps
    ndvi_cmap = LinearSegmentedColormap.from_list(
        'ndvi_cmap', ['#d7191c', '#ffffbf', '#1a9641'], 256)
    pri_cmap = LinearSegmentedColormap.from_list(
        'pri_cmap', ['#2166ac', '#f7f7f7', '#b2182b'], 256)
    
    def update_plots(change_idx=0, ndvi_stretch=(5,95), pri_stretch=(5,95)):
        change = changes[change_idx]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
        
        # NDVI Change Map
        ndvi_data = change['ndvi_change']
        ndvi_vmin, ndvi_vmax = np.nanpercentile(ndvi_data, ndvi_stretch)
        im1 = ax1.imshow(ndvi_data, cmap=ndvi_cmap, vmin=ndvi_vmin, vmax=ndvi_vmax)
        plt.colorbar(im1, ax=ax1, label='Δ NDVI', extend='both')
        ax1.set_title(f"NDVI Change: {change['from_date']} → {change['to_date']}")
        
        # PRI Change Map
        pri_data = change['pri_change']
        pri_vmin, pri_vmax = np.nanpercentile(pri_data, pri_stretch)
        im2 = ax2.imshow(pri_data, cmap=pri_cmap, vmin=pri_vmin, vmax=pri_vmax)
        plt.colorbar(im2, ax=ax2, label='Δ PRI', extend='both')
        ax2.set_title(f"PRI Change: {change['from_date']} → {change['to_date']}")
        
        plt.tight_layout()
        plt.show()
    
    # Calculate default stretches
    all_ndvi = np.concatenate([c['ndvi_change'][~np.isnan(c['ndvi_change'])] for c in changes])
    all_pri = np.concatenate([c['pri_change'][~np.isnan(c['pri_change'])] for c in changes])
    
    ndvi_range = (np.percentile(all_ndvi, 5), np.percentile(all_ndvi, 95))
    pri_range = (np.percentile(all_pri, 5), np.percentile(all_pri, 95))
    
    interact(
        update_plots,
        change_idx=Dropdown(
            options=list(range(len(changes))),
            description='Date Pair:',
            format=lambda x: f"{changes[x]['from_date']} → {changes[x]['to_date']}"
        ),
        ndvi_stretch=IntRangeSlider(
            value=ndvi_range,
            min=0, max=100,
            description='NDVI Range:'
        ),
        pri_stretch=IntRangeSlider(
            value=pri_range,
            min=0, max=100,
            description='PRI Range:'
        )
    )

def plot_change_dashboard(changes: list[dict], 
                        index_name: str = "NDVI",
                        cmap: str = "RdYlGn",
                        percentile_stretch: tuple = (1, 99),
                        figsize: tuple = (18, 8)):
    """
    Enhanced interactive change visualization dashboard with:
    - Dynamic stretching
    - Statistical summaries
    - Band information
    - Export capability
    """
    
    # Create custom diverging colormap
    if cmap == "RdYlGn":
        cmap = LinearSegmentedColormap.from_list('custom_div', 
                                               ['#d73027', '#f7f7f7', '#1a9850'], 256)
    
    def update_plot(change_idx=0, stretch=(1, 99), show_stats=True, show_bands=True):
        change = changes[change_idx]
        diff = change["change"]
        
        fig = plt.figure(figsize=figsize, constrained_layout=True)
        gs = fig.add_gridspec(2, 2, width_ratios=[3, 1], height_ratios=[4, 1])
        
        # Main change map
        ax1 = fig.add_subplot(gs[0, 0])
        vmin, vmax = np.nanpercentile(diff, stretch)
        im = ax1.imshow(diff, cmap=cmap, vmin=vmin, vmax=vmax)
        plt.colorbar(im, ax=ax1, label=f"Δ {index_name}", extend='both')
        
        title = (f"{index_name} Change: {change['from_date']} → {change['to_date']}\n"
                f"Site: {change['site']} | Resolution: {diff.shape}")
        ax1.set_title(title, pad=20)
        ax1.axis('off')
        
        # Histogram
        ax2 = fig.add_subplot(gs[1, 0])
        valid_pixels = diff[~np.isnan(diff)].flatten()
        ax2.hist(valid_pixels, bins=100, color='#2ca25f')
        ax2.set_xlabel(f"Δ {index_name} Value")
        ax2.set_ylabel("Pixel Count")
        ax2.axvline(0, color='k', linestyle='--')
        
        # Statistics panel
        if show_stats:
            ax3 = fig.add_subplot(gs[0, 1])
            ax3.axis('off')
            
            stats_text = (
                f"Statistics:\n"
                f"Mean: {np.nanmean(diff):.3f}\n"
                f"Std: {np.nanstd(diff):.3f}\n"
                f"Min: {np.nanmin(diff):.3f}\n"
                f"Max: {np.nanmax(diff):.3f}\n"
                f"Q{stretch[0]}: {np.nanpercentile(diff, stretch[0]):.3f}\n"
                f"Q{stretch[1]}: {np.nanpercentile(diff, stretch[1]):.3f}\n"
                f"Valid Pixels: {len(valid_pixels):,}"
            )
            ax3.text(0.5, 0.5, stats_text, ha='center', va='center',
                    bbox=dict(facecolor='white', alpha=0.8),
                    fontfamily='monospace')
        
        # Band information
        if show_bands and 'bands_used' in change:
            ax4 = fig.add_subplot(gs[1, 1])
            ax4.axis('off')
            band_text = "Bands Used:\n" + "\n".join(
                f"{k}: {v}" for k,v in change['bands_used'].items())
            ax4.text(0.5, 0.5, band_text, ha='center', va='center',
                    bbox=dict(facecolor='white', alpha=0.8))
        
        plt.show()
    
    # Calculate default stretch range
    all_changes = np.concatenate([c["change"][~np.isnan(c["change"])] for c in changes])
    default_stretch = (5, 95) if len(all_changes) > 0 else (1, 99)
    
    # Create interactive widgets
    widgets = {
        'change_idx': Dropdown(
            options=list(range(len(changes))),
            description='Date Pair:',
            style={'description_width': 'initial'},
            layout={'width': '300px'},
            format=lambda x: f"{changes[x]['from_date']} → {changes[x]['to_date']}"
        ),
        'stretch': IntRangeSlider(
            value=default_stretch,
            min=0,
            max=100,
            step=1,
            description='Percentile Range:',
            continuous_update=False
        ),
        'show_stats': Checkbox(value=True, description='Show Statistics'),
        'show_bands': Checkbox(value=True, description='Show Band Info')
    }
    
    return interact(update_plot, **widgets)

In [11]:
def plot_index_grid(raster_results, date_idx=0):
    """Plot all vegetation indices in a grid layout for a specific date"""
    data = raster_results[date_idx]
    
    # Define indices and their colormaps
    indices = [
        ('ndvi', 'NDVI', LinearSegmentedColormap.from_list('ndvi', ['#d73027', '#ffffbf', '#1a9850'], 256)),
        ('pri', 'PRI', LinearSegmentedColormap.from_list('pri', ['#2166ac', '#f7f7f7', '#b2182b'], 256)),
        ('evi', 'EVI', LinearSegmentedColormap.from_list('evi', ['#d7191c', '#ffffbf', '#1a9641'], 256)),
        ('ndwi', 'NDWI', LinearSegmentedColormap.from_list('ndwi', ['#e0f3f8', '#ffffbf', '#fdae61'], 256))
    ]
    
    # Create figure
    fig, axs = plt.subplots(2, 2, figsize=(16, 16))
    fig.suptitle(f"Vegetation Indices - {data['date']} - {data['site']}", y=1.02, fontsize=16)
    
    for idx, (ax, (index, name, cmap)) in enumerate(zip(axs.flat, indices)):
        # Calculate percentiles for dynamic range
        img_data = data[index]
        vmin, vmax = np.nanpercentile(img_data, (2, 98))
        
        # Plot index
        im = ax.imshow(img_data, cmap=cmap, vmin=vmin, vmax=vmax)
        plt.colorbar(im, ax=ax, label=f"{name} Value", fraction=0.046, pad=0.04)
        
        # Add title with band info
        band_info = ", ".join([f"{k} (B{v})" for k,v in data['bands_used'][name].items()])
        ax.set_title(f"{name}\nBands: {band_info}", pad=12)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

def interactive_index_explorer(raster_results):
    """Create interactive explorer for all dates"""
    dates = [f"{r['date']} ({r['site']})" for r in raster_results]
    
    interact(
        plot_index_grid,
        raster_results=fixed(raster_results),
        date_idx=Dropdown(
            options=list(range(len(raster_results))),
            value=0,
            description='Date:',
            style={'description_width': 'initial'},
            layout={'width': '300px'},
            format_func=lambda x: dates[x]
        )
    )

# ==============================================
# MAIN EXECUTION PIPELINE
# ==============================================

In [12]:
def process_change_detection(parent_dir, index):
    # 1. Find files
    raster_files = find_raster_files(parent_dir)

    # 2. Load rasters
    raster_data = [load_raster(f['path']) for f in tqdm(raster_files, desc="Loading")]
    
    # 3. Ensure consistent shapes
    raster_data = ensure_consistent_shapes(raster_data)
    
    # 4. Calculate indices
    raster_results = [calculate_indices(r) for r in tqdm(raster_data, desc="Computing indices")]
    
    # 5. Compute changes
    if index == 'ndvi':
        changes = compute_temporal_changes(raster_results, index='ndvi')
    elif index == 'pri':
        changes = compute_temporal_changes(raster_results, index='pri')

    return changes

In [13]:
# Process your data
raster_files = find_raster_files(parent_dir)
raster_data = [load_raster(f['path']) for f in raster_files]
raster_data = ensure_consistent_shapes(raster_data)
raster_results = [calculate_indices(r) for r in raster_data]
changes = compute_temporal_changes(raster_results)

# Launch interactive dashboard
plot_dual_change_dashboard(changes)

Resampling rasters:   0%|          | 0/5 [00:00<?, ?it/s]

interactive(children=(Dropdown(description='Date Pair:', options=(0, 1, 2, 3), value=0), IntRangeSlider(value=…

In [None]:

# Process your data
raster_files = find_raster_files(parent_dir)
raster_data = [load_raster(f['path']) for f in raster_files]
raster_data = ensure_consistent_shapes(raster_data)
raster_results = [calculate_all_indices(r) for r in raster_data]

# Launch interactive explorer
interactive_index_explorer(raster_results)

Resampling rasters:   0%|          | 0/5 [00:00<?, ?it/s]

interactive(children=(Dropdown(description='Date:', layout=Layout(width='300px'), options=(0, 1, 2, 3, 4), sty…

In [None]:
# First ensure your changes include band information
for change in changes:
    change['bands_used'] = {'Green-531': 2, 'Green': 3}  # Update band mapping for PRI

# Generate the interactive dashboard
plot_change_dashboard(
    changes=changes,
    index_name="PRI",
    cmap="RdYlGn",
    percentile_stretch=(5, 95),
    figsize=(16, 8)
)

interactive(children=(Dropdown(description='Date Pair:', layout=Layout(width='300px'), options=(0, 1, 2, 3), s…

<function __main__.plot_change_dashboard.<locals>.update_plot(change_idx=0, stretch=(1, 99), show_stats=True, show_bands=True)>