In [None]:

def create_interactive_raster_plotter(tif_file_path, band_names, colormap_dict):

    import os
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    import rasterio
    import numpy as np
    from matplotlib.colors import BoundaryNorm
    from PIL import Image
    import io
    import base64
    import folium
    import pyproj
    from IPython.display import display, HTML
    import ipywidgets as widgets

    """
    Creates and displays an interactive widget-based plotter for raster data.

    Args:
        tif_file_path (str): The path to the TIFF file.
        band_names (list): A list of strings, where each string is the
        human-readable name for the corresponding band
        (1-based index aligns with list index + 1).
    """

    # --- Configuration/Global Variables (scoped within this function) ---
    # Create a dictionary mapping each band to a colormap for visualization

    # --- Widget Definitions ---
    # Initialize output_widget here so it's accessible to plot_raster_logic
    output_widget = widgets.Output()

    # Verify the file exists and get band count for dropdown options
    band_count = 0
    if os.path.exists(tif_file_path):
        try:
            with rasterio.open(tif_file_path) as src:
                band_count = src.count
        except rasterio.errors.RasterioIOError as e:
            with output_widget:
                print(f"Error opening TIFF file for band count: {e}")
            return # Exit if file cannot be opened

    else:
        with output_widget:
            print(f"Error: TIFF file not found: {tif_file_path}. Cannot initialize widgets.")
        # Provide some default options even if file not found, to allow UI to render
        band_selector_options = [("Band 1 (File Not Found)", 1)]

    # Create options for the band selector dynamically
    band_selector_options = []
    if band_names and band_count > 0:
        for i in range(min(band_count, len(band_names))):
            band_selector_options.append((band_names[i], i + 1))
        for i in range(len(band_names), band_count):
            band_selector_options.append((f"Band {i+1} (Unnamed)", i + 1))
    elif band_count > 0:
        band_selector_options = [(f"Band {i+1}", i+1) for i in range(band_count)]
    else: # Fallback if no bands or file not found
        band_selector_options = [("Band 1 (No data)", 1)]


    band_selector = widgets.Dropdown(
        options=band_selector_options,
        value=band_selector_options[0][1] if band_selector_options else 1,
        description='Band:'
    )

    colorbar_mode = widgets.RadioButtons(
        options=['continuous', 'quantile'],
        value='continuous',
        description='Colorbar:'
    )

    interval_dropdown = widgets.Dropdown(
        options=[(str(i), i) for i in range(2, 24)],
        value=4,
        description='Intervals:'
    )

    # --- Core Plotting Logic (nested function to access local variables) ---
    def plot_raster_logic(band_index, colorbar_type, intervals):
        """
        Performs the actual plotting based on widget selections.
        This function is designed to be called by interactive_output.
        """
        output_widget.clear_output(wait=True) # Clear previous output for fresh display

        try:
            if not os.path.exists(tif_file_path):
                with output_widget:
                    print(f"Error: TIFF file not found: {tif_file_path}. Cannot plot.")
                return

            with rasterio.open(tif_file_path) as src:
                if not (1 <= band_index <= src.count):
                    with output_widget:
                        print(f"Error: Selected band index {band_index} is out of range. "
                            f"This TIFF has {src.count} bands.")
                    return

                img = src.read(band_index).astype(float)
                nodata = src.nodata
                nan_mask = np.isnan(img)
                if nodata is not None:
                    img[img == nodata] = np.nan
                    nan_mask = np.isnan(img)

                valid = img[~nan_mask]
                if valid.size == 0:
                    with output_widget:
                        print("No valid data in this band to plot.")
                    return

                unique_vals = np.unique(valid)
                n_unique = len(unique_vals)

                current_band_name = band_names[band_index - 1] if band_names and 0 <= (band_index - 1) < len(band_names) else f"Band {band_index}"
                cmap_name = colormap_dict.get(current_band_name, 'viridis')

                # Determine colorbar properties
                if n_unique <= intervals:
                    boundaries = np.append(unique_vals, unique_vals[-1] + 1)
                    norm = BoundaryNorm(boundaries, ncolors=n_unique, clip=True)
                    cmap = plt.get_cmap(cmap_name, n_unique).copy()
                    cmap.set_bad(color='#888888')
                    ticks = unique_vals
                    tick_labels = [str(val) for val in unique_vals]
                    colorbar_label = "Discrete values"
                else:
                    vmin, vmax = np.nanmin(valid), np.nanmax(valid)
                    cmap = plt.get_cmap(cmap_name).copy()
                    cmap.set_bad(color='#888888')

                    if colorbar_type == 'quantile':
                        quantiles = np.linspace(0, 1, intervals + 1)
                        boundaries = np.quantile(valid, quantiles)
                        #boundaries = np.unique(boundaries)
                        if len(boundaries) < 2:
                            boundaries = np.array([vmin, vmax])
                        while len(boundaries) < intervals + 1:
                            boundaries = np.linspace(vmin, vmax, intervals + 1)
                        norm = BoundaryNorm(boundaries, ncolors=cmap.N, clip=True)
                        ticks = 0.5 * (boundaries[:-1] + boundaries[1:])
                        tick_labels = [f"{boundaries[i]:.2f}–{boundaries[i+1]:.2f}" for i in range(len(boundaries)-1)]
                        colorbar_label = 'Quantile intervals'
                    else: # 'continuous'
                        boundaries = np.linspace(vmin, vmax, intervals + 1)
                        norm = BoundaryNorm(boundaries, ncolors=cmap.N, clip=True)
                        ticks = 0.5 * (boundaries[:-1] + boundaries[1:])
                        tick_labels = [f"{boundaries[i]:.2f}–{boundaries[i+1]:.2f}" for i in range(len(boundaries)-1)]
                        colorbar_label = 'Continuous intervals'

                # Prepare image for Folium
                bounds = src.bounds
                transformer = pyproj.Transformer.from_crs(src.crs, "EPSG:4326", always_xy=True)
                left, bottom = transformer.transform(bounds.left, bounds.bottom)
                right, top = transformer.transform(bounds.right, bounds.top)
                img_bounds = [[bottom, left], [top, right]]

                img_show = np.ma.masked_invalid(img)
                img_norm = norm(img_show)
                rgba_img = cmap(img_norm)
                rgba_img[..., 3][nan_mask] = 0

                rgb_img = np.uint8(rgba_img[..., :3] * 255)
                alpha = np.uint8(rgba_img[..., 3] * 255)
                img_rgba = np.dstack((rgb_img, alpha))
                img_pil = Image.fromarray(img_rgba, mode='RGBA')
                buf = io.BytesIO()
                img_pil.save(buf, format='PNG')
                data = base64.b64encode(buf.getvalue()).decode('utf-8')
                url = f"data:image/png;base64,{data}"

                # Create Folium Map
                center_lat = (bottom + top) / 2
                center_lon = (left + right) / 2
                map_height = 800

                m = folium.Map(
                    location=[center_lat, center_lon],
                    zoom_start=13,
                    tiles="https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}",
                    attr='Google',
                    name='Google Satellite',
                    height=f'{map_height}px',
                    width='95%'
                )
                folium.raster_layers.ImageOverlay(
                    image=url,
                    bounds=img_bounds,
                    opacity=0.7,
                    interactive=True,
                    cross_origin=False,
                    zindex=1,
                ).add_to(m)
                folium.LayerControl().add_to(m)

                # Create Colorbar
                fig_height = map_height / 100
                max_label_len = max(len(lbl) for lbl in tick_labels) if tick_labels else 1
                label_width_px = max_label_len * 8 + 32
                fig_width = label_width_px / 100 + 0.2

                fig, ax = plt.subplots(figsize=(fig_width, fig_height))
                ax.set_frame_on(False)
                ax.set_xticks([])
                ax.set_yticks([])

                if n_unique <= intervals:
                    cb = plt.colorbar(
                        cm.ScalarMappable(norm=norm, cmap=cmap),
                        cax=ax, orientation='vertical', boundaries=boundaries, ticks=unique_vals + 0.5
                    )
                    cb.set_ticklabels(tick_labels)
                else:
                    cb = plt.colorbar(
                        cm.ScalarMappable(norm=norm, cmap=cmap),
                        cax=ax, orientation='vertical', boundaries=boundaries, ticks=ticks
                    )
                    cb.set_ticklabels(tick_labels)

                cb.set_label(colorbar_label, fontsize=18)
                cb.ax.tick_params(labelsize=16)
                cb.ax.set_title(current_band_name, fontsize=18, pad=20)

                buf_cb = io.BytesIO()
                fig.savefig(buf_cb, format='png', bbox_inches='tight', transparent=True, dpi=100)
                plt.close(fig)
                buf_cb.seek(0)
                colorbar_data = base64.b64encode(buf_cb.read()).decode('utf-8')
                # Calculate the colorbar image height to match the map height in pixels
                # The colorbar PNG should have the same height as the map (map_height)
                # The colorbar figure is saved with dpi=100, so set fig height accordingly

                # Adjust colorbar size: keep aspect ratio, align tops, and make it smaller
                # Keep colorbar figure proportions (no stretching), align top, fit within map height
                buf_cb = io.BytesIO()
                fig.savefig(buf_cb, format='png', bbox_inches='tight', transparent=True, dpi=fig.dpi, pad_inches=0)
                plt.close(fig)
                buf_cb.seek(0)
                colorbar_data = base64.b64encode(buf_cb.read()).decode('utf-8')

                colorbar_html = f'''
                <div style="display: flex; flex-direction: row; width: 100%; align-items: flex-start; gap: 1vw;">
                    <div style="flex: 4 1 0; min-width: 0; height: {map_height}px;">
                        {m._repr_html_()}
                    </div>
                    <div style="flex: 1 1 0; min-width: 80px; display: flex; align-items: flex-start; justify-content: center; height: {map_height}px;">
                        <img src="data:image/png;base64,{colorbar_data}" style="max-width: 100%; max-height: {map_height}px; height: auto; width: auto; display: block; margin-top:0;"/>
                    </div>
                </div>
                '''
                with output_widget:
                    display(HTML(colorbar_html))

        except Exception as e:
            with output_widget:
                print(f"An error occurred during plotting: {e}")

    # --- Display UI and Link Widgets to Plotting Function ---
    ui = widgets.VBox([
        band_selector,
        colorbar_mode,
        interval_dropdown
    ])

    out = widgets.interactive_output(
        plot_raster_logic, # This function will be called by ipywidgets
        {
            'band_index': band_selector,
            'colorbar_type': colorbar_mode,
            'intervals': interval_dropdown
        }
    )

    display(widgets.VBox([ui, output_widget]))

In [None]:

def plot_band_value_distributions(bands, thresholds_list, tif_file_path, band_names, figname):

  import os
  import rasterio
  import numpy as np
  import matplotlib.pyplot as plt
  from matplotlib.patches import Patch

  n_bands = len(bands)
  fig, axes = plt.subplots(
      n_bands, 1,
      figsize=(13, 4 * n_bands),  # Increased height for legends
      constrained_layout=True
  )
  if n_bands == 1:
      axes = np.array([axes])
  with rasterio.open(tif_file_path) as src:
      for idx, (b, band_info) in enumerate(zip(bands, thresholds_list)):
          arr = src.read(b).astype(float)
          nodata = src.nodata
          if nodata is not None:
              arr[arr == nodata] = np.nan
          valid_arr = arr[~np.isnan(arr)]
          thresholds = band_info['thresholds']
          hotspotscores = band_info['hotspot_score']
          arr_class = classify_band(valid_arr, thresholds, hotspotscores)
          ax = axes[idx]
          # Prepare interval labels and bar counts
          interval_labels = [f"≤{t}" for t in thresholds]
          interval_labels.append(f">{thresholds[-1]}")
          bar_counts = []
          for v in hotspotscores:
              bar_counts.append(np.sum(arr_class == v))
          # Define colormap: low value blue, high value red
          cmap = plt.get_cmap('coolwarm', len(hotspotscores))
          # Assign colors based on hotspotscore value: lowest score = blue, highest = red
          norm_scores = [(s - min(hotspotscores)) / (max(hotspotscores) - min(hotspotscores)) if max(hotspotscores) > min(hotspotscores) else 0.5 for s in hotspotscores]
          class_colors = [cmap(norm) for norm in norm_scores]
          # Plot background color for each class interval
          xlim = (np.nanmin(valid_arr), np.nanmax(valid_arr))
          prev = xlim[0]
          for i, t in enumerate(thresholds + [xlim[1]]):
              left = prev
              right = t
              color = class_colors[i]
              ax.axvspan(left, right, color=color, alpha=0.15, zorder=0)
              prev = t
          # Plot value distribution
          ax.hist(valid_arr, bins=40, color='tab:blue', alpha=0.6, zorder=1)
          # Plot thresholds as vertical lines
          for t in thresholds:
              ax.axvline(t, color='red', linestyle='--', linewidth=1)
          # Annotate thresholds
          for t in thresholds:
              ax.text(t, ax.get_ylim()[1]*0.95, f"{t}", color='red', rotation=90, va='top', ha='right', fontsize=8)
          band_label = band_names[b - 1] if band_names is not None and b - 1 < len(band_names) else f"Band {b}"
          ax.set_title(f"{band_label} \n Value Distribution & Classified Intervals")
          ax.set_ylabel("Pixel count")
          # Prepare legend handles and labels for this band
          handles = []
          labels = []
          for i in range(len(hotspotscores)):
              patch = Patch(facecolor=class_colors[i], edgecolor='k', alpha=0.4)
              label = f"{interval_labels[i]} (Score={hotspotscores[i]}, count={bar_counts[i]})"
              handles.append(patch)
              labels.append(label)
          # Place legend below each subplot, horizontally
          ax.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.22), ncol=len(hotspotscores), fontsize=10, frameon=False)

  plt.savefig(figname, dpi=300, bbox_inches='tight', pad_inches=0.15)
  print("File saved:", os.path.exists(figname))
  plt.show()


In [None]:
def classify_band(arr, thresholds, hotspotscores):

    import numpy as np

    arr_class = np.full_like(arr, hotspotscores[-1], dtype=float)
    arr_class[arr <= thresholds[0]] = hotspotscores[0]
    for i in range(1, len(thresholds)):
        arr_class[(arr > thresholds[i-1]) & (arr <= thresholds[i])] = hotspotscores[i]
    return arr_class

In [None]:

def plot_hotspot_interactive(
    bands,
    thresholds_list,
    tif_file_path,
    combine_method='sum',
    colormap_options=None,
    save_tif_path=None):

    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    import rasterio
    import numpy as np
    from matplotlib.colors import BoundaryNorm
    from PIL import Image
    import io
    import base64
    import folium
    import pyproj
    from IPython.display import display, HTML
    import ipywidgets as widgets

    """
    Interactive hotspot map plotting for classified bands.
    Args:
        bands (list): List of band indices to use.
        thresholds_list (list): List of dicts with 'thresholds' and 'hotspot_score' for each band.
        tif_file_path (str): Path to the TIFF file.
        combine_method (str): 'sum' or 'mean' for combining classified bands.
        colormap_options (list): List of colormap names to offer in the widget.
    """

    if colormap_options is None:
        colormap_options = [
            'plasma',
            'inferno',
            'RdYlBu_r',
            'coolwarm',
            'Reds'
        ]
    arrays = []
    with rasterio.open(tif_file_path) as src:
        band2 = src.read(2).astype(float)
        nodata1 = src.nodata
        if nodata1 is not None:
            band2[band2 == nodata1] = np.nan
        band1_nan_mask = np.isnan(band2)
        total_pixels = np.sum(~band1_nan_mask)
        for b, band_info in zip(bands, thresholds_list):
            arr = src.read(b).astype(float)
            nodata = src.nodata
            if nodata is not None:
                arr[arr == nodata] = np.nan
            nan_in_band = np.isnan(arr) & (~band1_nan_mask)
            nan_count = np.sum(nan_in_band)
            if nan_count > 0:
                percent = 100 * nan_count / total_pixels if total_pixels > 0 else 0
                print(f"Warning: Band {b} contains {nan_count} NaN values ({percent:.2f}% of valid pixels) where Band 1 is valid.")
            thresholds = band_info['thresholds']
            hotspotscores = band_info['hotspot_score']
            arr_class = classify_band(arr, thresholds, hotspotscores)
            arr_class[band1_nan_mask] = np.nan
            arrays.append(arr_class)
    mask_nan = np.any([np.isnan(a) for a in arrays], axis=0)
    for i in range(len(arrays)):
        arrays[i] = np.where(mask_nan, np.nan, arrays[i])
    hotspot_cmap_selector = widgets.Dropdown(
        options=colormap_options,
        value=colormap_options[0],
        description='Colormap:'
    )
    hotspot_interval_selector = widgets.Dropdown(
        options=[(str(i), i) for i in range(2, 21)],
        value=5,
        description='Intervals:'
    )
    def plot_hotspot_map(cmap_name, intervals):
        stack = np.stack(arrays)
        if combine_method == 'sum':
          combined = np.sum(stack, axis=0, where=~np.isnan(stack)) # Sum where not NaN
          combined[np.any(np.isnan(stack), axis=0)] = np.nan # If any original was NaN, set to NaN

        else:
          # np.mean naturally propagates NaNs: if any value in the slice is NaN, the result is NaN.
          combined = np.mean(stack, axis=0)

        mask = ~np.isnan(combined) & ~band1_nan_mask
        valid = combined[mask]
        if valid.size == 0:
            print("No valid data to plot.")
            return
        vmin, vmax = np.nanmin(valid), np.nanmax(valid)
        if vmin == vmax:
            print("All combined hotspotscores are the same. Adjust your thresholds or input data.")
            return
        boundaries = np.linspace(vmin, vmax, intervals + 1)
        norm = BoundaryNorm(boundaries, ncolors=plt.get_cmap(cmap_name).N, clip=True)
        cmap = plt.get_cmap(cmap_name).copy()
        cmap.set_bad(color="#F7F6F600")
        color_indices = norm(combined)
        color_indices[~mask] = -1
        rgba_img = cmap(np.clip(color_indices, 0, cmap.N - 1))
        rgba_img[..., 3] = mask.astype(float)
        rgba_img[~mask, 3] = 0
        rgb_img = np.uint8(rgba_img[..., :3] * 255)
        alpha = np.uint8(rgba_img[..., 3] * 255)
        img_rgba = np.dstack((rgb_img, alpha))
        img_pil = Image.fromarray(img_rgba, mode='RGBA')
        buf = io.BytesIO()
        img_pil.save(buf, format='PNG')
        data = base64.b64encode(buf.getvalue()).decode('utf-8')
        url = f"data:image/png;base64,{data}"

        # Save combined hotspot map as GeoTIFF if save_path is provided
        if save_tif_path is not None:
            with rasterio.open(tif_file_path) as src:
                meta = src.meta.copy()
                meta.update({
                    "count": 1,
                    "dtype": "float32",
                    "nodata": np.nan
                })
            with rasterio.open(save_tif_path, "w", **meta) as dst:
                dst.write(combined.astype(np.float32), 1)
            print(f"Hotspot map saved to: {save_tif_path}")


        with rasterio.open(tif_file_path) as src:
            bounds = src.bounds
            transformer = pyproj.Transformer.from_crs(src.crs, "EPSG:4326", always_xy=True)
            left, bottom = transformer.transform(bounds.left, bounds.bottom)
            right, top = transformer.transform(bounds.right, bounds.top)
            img_bounds = [[bottom, left], [top, right]]
            center_lat = (bottom + top) / 2
            center_lon = (left + right) / 2
            map_height = 800
            m = folium.Map(
                location=[center_lat, center_lon],
                zoom_start=13,
                tiles="https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}",
                attr='Google',
                name='Google Satellite',
                height=f'{map_height}px',
                width='95%'
            )
        folium.raster_layers.ImageOverlay(
            image=url,
            bounds=img_bounds,
            opacity=0.7,
            interactive=True,
            cross_origin=False,
            zindex=1,
        ).add_to(m)
        folium.LayerControl().add_to(m)
        fig, ax = plt.subplots(figsize=(1.2, 6))
        cb = plt.colorbar(
            cm.ScalarMappable(norm=norm, cmap=cmap),
            cax=ax, orientation='vertical', boundaries=boundaries, ticks=boundaries
        )
        cb.set_label('Hotspot Score', fontsize=14)
        cb.ax.tick_params(labelsize=12)
        plt.close(fig)
        buf_cb = io.BytesIO()
        fig.savefig(buf_cb, format='png', bbox_inches='tight', transparent=True, dpi=100)
        buf_cb.seek(0)
        colorbar_data = base64.b64encode(buf_cb.read()).decode('utf-8')
        colorbar_html = f'''
        <div style="display:flex;flex-direction:row;width:95%;align-items:stretch;">
            <div style="width:80%;height:600px;">
                {m._repr_html_()}
            </div>
            <div style="width:20%;display:flex;align-items:center;justify-content:center;height:600px;">
                <img src="data:image/png;base64,{colorbar_data}" style="max-width:100%;max-height:100%;">
            </div>
        </div>
        '''
        display(HTML(colorbar_html))
    out = widgets.interactive_output(
        plot_hotspot_map,
        {
            'cmap_name': hotspot_cmap_selector,
            'intervals': hotspot_interval_selector
        }
    )
    display(widgets.VBox([widgets.HBox([hotspot_cmap_selector, hotspot_interval_selector]), out]))


In [None]:
def save_hotspot_figure(hotspot_tif_path_hotspotmap, figname, cityname, cmap='coolwarm'):
    import os
    import rasterio
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.colors import ListedColormap

    # --- Open and Plot the GeoTIFF data ---
    try:
        # Open the GeoTIFF file in read mode
        with rasterio.open(hotspot_tif_path_hotspotmap) as src:
            # Read the first band of the raster data
            raster_data = src.read(1)
            # Get the spatial transform for plotting (optional for extent, but good practice)
            transform = src.transform

        print(f"Successfully loaded GeoTIFF data from: {hotspot_tif_path_hotspotmap}")

        # Handle NaN values: Create a new colormap with white for NaN
        # Get the colormap
        current_cmap = plt.cm.get_cmap(cmap)
        # Create a new colormap with an added color for NaN
        current_cmap.set_bad(color='white') # Set 'bad' values (NaNs) to white

        # Create a figure and an axes for the plot
        fig, ax = plt.subplots(1, 1, figsize=(10, 8))

        # Plot the raster data
        # 'cmap' is now our modified colormap
        # 'origin' should be 'upper' for raster data to display correctly
        # We don't need 'extent' if we remove axis labels, as we're not showing geographic coordinates.
        im = ax.imshow(raster_data, cmap=current_cmap, origin='upper')

        # Add a colorbar to indicate the values - MODIFIED FOR HORIZONTAL AND LARGER LABEL
        cbar = fig.colorbar(im, ax=ax, orientation='horizontal', shrink=0.75, pad=0.05)
        cbar.set_label('Hotspot Score', fontsize=14) # Label for the colorbar, increased fontsize
        cbar.ax.tick_params(labelsize=12) # Increase tick label size for colorbar

        # Set plot title
        ax.set_title(f'Hotspot Map \n {cityname}', fontsize=20, fontweight='bold', loc='center')

        # Remove latitude and longitude from axes
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])

        # Optional: Remove grid if it's no longer useful without coordinates
        ax.grid(False)

        plt.savefig(figname, dpi=300, bbox_inches='tight', pad_inches=0.15)
        print("File saved:", os.path.exists(figname))

        # Display the plot
        plt.show()

    except rasterio.errors.RasterioIOError as e:
        print(f"Error opening or reading GeoTIFF file: {e}")
        print(f"Please ensure the file '{hotspot_tif_path_hotspotmap}' exists and is a valid GeoTIFF.")
    except FileNotFoundError:
        print(f"Error: The file '{hotspot_tif_path_hotspotmap}' was not found.")
        print("Please check the file path.")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")