# Process large images via tiling approach

In [None]:
import os
from rasterio.windows import Window
import rasterio

# base_img_path = '/mnt/nili_e/Software/fractal_complexity/data/Jezero_HiRISE_clipped.tif'
base_img_path = '/mnt/holuhraun/spiny_pond.tif'

# Define tile size and overlap
tile_size = 500
overlap = int(tile_size * 0.1)  # 10% overlap

# Open the image to get its dimensions
with rasterio.open(base_img_path) as src:
    img_width = src.width
    img_height = src.height
    print(f"Image dimensions: {img_width}x{img_height}")
    
    # Create output directory for tiles
    # output_dir = "/mnt/nili_e/Software/fractal_complexity/data/Jezero_HiRISE_clipped_tiles/"
    output_dir = "/mnt/holuhraun/spiny_pond_tiles/"

    os.makedirs(output_dir, exist_ok=True)
    
    # Loop through the image and save tiles
    for i in range(0, img_height, tile_size - overlap):
        for j in range(0, img_width, tile_size - overlap):
            if i > 2:
                break
            # Define the window for the current tile
            window = Window(j, i, tile_size, tile_size)
            
            # Adjust the window size to avoid going out of bounds
            window = window.intersection(Window(0, 0, img_width, img_height))
            
            # Read the tile
            tile = src.read(1, window=window)
            
            # Define the output path
            tile_filename = f"Jezero_HiRISE_clipped_tile_{i}_{j}.tif"
            tile_path = os.path.join(output_dir, tile_filename)
            
            # # Save the tile
            # profile = src.profile
            # profile.update({
            #     "height": window.height,
            #     "width": window.width,
            #     "transform": rasterio.windows.transform(window, src.transform)
            # }) 
            
            # with rasterio.open(tile_path, "w", **profile) as dst:
            #     dst.write(tile, 1)

Image dimensions: 62374x47889


## Code outline

### inputs
- path to large image
- path to shape file or gdb (with layer name)
- shape file unit name

### process
- create a tile scheme with overlap
- fetch the window from the raster
- fetch the corresponding window from the shape file
- process complexity

In [None]:
import rasterio
from rasterio.windows import Window
import numpy as np
from rasterio.features import rasterize
import fiona
from shapely.geometry import mapping
import matplotlib.pyplot as plt


def get_windows(img_path, tile_size, overlap):
    # Open the image to get its dimensions
    window_list = []
    with rasterio.open(img_path) as src:
        img_width = src.width
        img_height = src.height
        print(f"Image dimensions: {img_width}x{img_height}")
        
        # Loop through the image and save tiles
        for i in range(0, img_height, tile_size - overlap):
            for j in range(0, img_width, tile_size - overlap):
                # Define the window for the current tile
                window = Window(j, i, tile_size, tile_size)
                
                # Adjust the window size to avoid going out of bounds
                window = window.intersection(Window(0, 0, img_width, img_height))
                window_list.append(window)
    
    return window_list

def get_tile_window(image_path, window):
    """
    Get a tile from the image using a specified window.
    
    Parameters:
    -----------
    image_path : str
        Path to the image file
    window : rasterio.windows.Window
        Window object specifying the tile location and size
        
    Returns:
    --------
    tile : numpy.ndarray
        2D array containing the tile data
    transform : affine.Affine
        Affine transformation for the raster
    """
    with rasterio.open(image_path) as src:
        # Adjust the window size to avoid going out of bounds
        window = window.intersection(Window(0, 0, img_width, img_height))
        # Read the tile
        tile = src.read(1, window=window)
        src_transform = src.transform
        window_transform = rasterio.windows.transform(window, src.transform)
        tile_extent = src.window_bounds(window)
        tile_crs = src.crs
    
    return tile, src_transform, window_transform, tile_extent, tile_crs


In [None]:
from tqdm import tqdm
from skimage.util import view_as_windows
from skimage.measure import shannon_entropy

def kernel_entropty(kernel_sizes, image_path, window, out_dir = None, crop = False):
    base_img, src_transform, window_transform, tile_extent, tile_crs = get_tile_window(image_path, window)
    complexity_map_cube = np.zeros((base_img.shape[0], base_img.shape[1], len(kernel_sizes)), dtype=np.float32)

    for b, kernel_size in tqdm(enumerate(kernel_sizes)):
        # Initialize an array to store the complexity values
        kernel_complexity_map = np.zeros_like(base_img)
        # Create a sliding window view of the padded map
        kernels = view_as_windows(base_img, (kernel_size, kernel_size))

        # Iterate over each window and calculate Shannon entropy
        for i in range(kernel_complexity_map.shape[0]):
            for j in range(kernel_complexity_map.shape[1]):
                kernel = kernels[i, j]
                kernel_complexity_map[i, j] = shannon_entropy(kernel)

        # mask zeros
        # kernel_complexity_map = np.ma.masked_where(kernel_complexity_map == 0, kernel_complexity_map)

        complexity_map_cube[:, :, b] = kernel_complexity_map

    if out_dir is not None:
        # Save the complexity map as a GeoTIFF
        output_path = os.path.join(out_dir, f"complexity_map_{window.col_off}_{window.row_off}.tif")
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        if crop:
            # Crop the output to the original tile size
            complexity_map_cube = complexity_map_cube[window.row_off:window.row_off + window.height,
                                                      window.col_off:window.col_off + window.width]
            window_transform = rasterio.windows.transform(window, src_transform)
            
        with rasterio.open(output_path, 'w', driver='GTiff', height=complexity_map_cube.shape[0],
                           width=complexity_map_cube.shape[1], count=len(kernel_sizes),
                           dtype=complexity_map_cube.dtype, crs=tile_crs,
                           transform=window_transform) as dst:
            for b in range(len(kernel_sizes)):
                dst.write(complexity_map_cube[:, :, b], b + 1)
    else:
        return complexity_map_cube


base_img_path = '/mnt/nili_e/Software/fractal_complexity/data/Jezero_HiRISE_clipped.tif'
kernel_sizes = [2, 3, 5, 7, 9, 11, 15, 19, 25, 31, 43, 55, 75]  

# Define tile size and overlap
tile_size = 1024
overlap = max(int(tile_size * 0.1), max(kernel_sizes))  # 10% overlap


In [7]:
from importlib import reload
import tiled_complexity
reload(tiled_complexity)
from tiled_complexity import tiledComplexity

# base_img_path = '/mnt/nili_e/Software/fractal_complexity/data/Jezero_HiRISE_clipped.tif'
# out_dir = '/mnt/nili_e/Software/fractal_complexity/data/Jezero_HiRISE_complexity_tiles/'

# base_img_path = '/mnt/holuhraun/t5_dsm_2015_50cm_clip.tif'
# out_dir = '/mnt/holuhraun/t5_dsm_50cm_complexity_tiles/'

base_img_path = '/mnt/fagradalsfjall/sites_abc_clip.tif'
out_dir = '/mnt/fagradalsfjall/sites_abc_clip_complexity/'

kernel_sizes = [3, 4, 5, 7, 9, 11, 15, 19, 25, 31, 43, 55, 75]  

# Define tile size and overlap
tile_size = 2048
overlap = max(int(tile_size * 0.1), max(kernel_sizes))  # 10% overlap

tc_processor = tiledComplexity(base_img_path, tile_size, overlap, kernel_sizes, out_dir=out_dir, crop=True)
# tc_processor.visualize_tiles() 
tc_processor.process()

Image dimensions: 1256x980
Processing 1 tiles in batches of 3
Using 24 CPU workers for processing.
Processing batch 1/1 (1 tiles)


Processing tiles: 100%|██████████| 1/1 [27:09<00:00, 1629.75s/it]


In [6]:
# output_path='/mnt/nili_e/Software/fractal_complexity/data/Jezero_HiRISE_complexity_mosaic/complexity_mosaic.tif'
tc_processor.create_mosaic(output_path='/mnt/holuhraun/t5_dsm_50cm_complexity_tiles/t5_dsm_50cm_complexity_mosaic.tif')

Found 240 tile files to merge
Processing tiles in batches of 198 to fit within 4GB memory limit
Calculating final mosaic bounds...


Sampling bounds:   0%|          | 0/100 [00:00<?, ?it/s]

Sampling bounds: 100%|██████████| 100/100 [00:01<00:00, 90.55it/s]


Processing band 1 of 13 (kernel size: 3)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:58<00:00,  1.81s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:13<00:00,  1.75s/it]


Processing band 2 of 13 (kernel size: 4)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:35<00:00,  1.69s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:14<00:00,  1.76s/it]


Processing band 3 of 13 (kernel size: 5)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:25<00:00,  1.64s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:12<00:00,  1.73s/it]


Processing band 4 of 13 (kernel size: 7)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:26<00:00,  1.65s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:11<00:00,  1.70s/it]


Processing band 5 of 13 (kernel size: 9)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:23<00:00,  1.63s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:09<00:00,  1.66s/it]


Processing band 6 of 13 (kernel size: 11)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:22<00:00,  1.63s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:11<00:00,  1.71s/it]


Processing band 7 of 13 (kernel size: 15)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:23<00:00,  1.63s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:12<00:00,  1.72s/it]


Processing band 8 of 13 (kernel size: 19)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:26<00:00,  1.65s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:10<00:00,  1.68s/it]


Processing band 9 of 13 (kernel size: 25)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:27<00:00,  1.66s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:13<00:00,  1.75s/it]


Processing band 10 of 13 (kernel size: 31)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:32<00:00,  1.68s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:12<00:00,  1.72s/it]


Processing band 11 of 13 (kernel size: 43)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:33<00:00,  1.69s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:11<00:00,  1.71s/it]


Processing band 12 of 13 (kernel size: 55)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:31<00:00,  1.68s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:11<00:00,  1.71s/it]


Processing band 13 of 13 (kernel size: 75)
  Processing batch 1/2 (198 tiles)


Opening tiles: 100%|██████████| 198/198 [05:33<00:00,  1.69s/it]


  Processing batch 2/2 (42 tiles)


Opening tiles: 100%|██████████| 42/42 [01:12<00:00,  1.74s/it]


Mosaic successfully created: /mnt/holuhraun/t5_dsm_50cm_complexity_tiles/t5_dsm_50cm_complexity_mosaic.tif
Dimensions: 22123x24074, 13 bands


'/mnt/holuhraun/t5_dsm_50cm_complexity_tiles/t5_dsm_50cm_complexity_mosaic.tif'

In [None]:

def get_shp_window(shp_path, unit_name, window, window_transform, layer_name = None):
    """
    Get a tile from the shapefile using a specified window.
    
    Parameters:
    -----------
    shp_path : str
        Path to the shapefile
    unit_name : str
        Name of the column containing facies information
    window : rasterio.windows.Window
        Window object specifying the tile location and size
        
    Returns:
    --------
    tile : numpy.ndarray
        2D array containing the tile data
    transform : affine.Affine
        Affine transformation for the raster
    """
    # Load the shapefile
    if shp_path.endswith('.gdb'):
        with fiona.open(shp_path, layer_name) as src:
            shapes = [(feature["geometry"], feature['properties'][unit_name]) for feature in src]
    else:
        with fiona.open(shp_path) as src:
            shapes = [(feature["geometry"], feature['properties'][unit_name]) for feature in src]
    
    # Rasterize the shapefile for the specified window
    facies_map = rasterize(
        shapes,
        out_shape=(window.height, window.width),
        transform=window_transform,
        all_touched=True,
        fill=0,  # Background value
        dtype=np.int32  # Use np.int32 to handle larger range of values
    )
    
    return facies_map


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import distance
from scipy.ndimage import distance_transform_edt
import rasterio
from rasterio.features import rasterize
import geopandas as gpd
from shapely.geometry import Point
import math
import tqdm
from skimage.measure import shannon_entropy
from skimage import exposure
import matplotlib.colors as colors
from matplotlib.colors import ListedColormap, BoundaryNorm

def load_base_image(image_path, band=1, img_nan_value = None):
    """
    Load a base image (elevation or grayscale data)
    
    Parameters:
    -----------
    image_path : str
        Path to the image file
    band : int
        Band number to read (default is 1 for single band images)
        
    Returns:
    --------
    image : numpy.ndarray
        2D array containing the image data
    transform : affine.Affine
        Affine transformation for the raster
    """
    with rasterio.open(image_path) as src:
        image = src.read(band)
        if img_nan_value is not None:
            image = np.where(image == img_nan_value, np.nan, image)  # Replace specified value with NaN
        transform = src.transform
    
    return image, transform

def calculate_facies_entropy(base_image, facies_map, norm = False):
    """
    Calculate Shannon entropy for each facies in the facies map
    
    Parameters:
    -----------
    base_image : numpy.ndarray
        2D array of the base image (elevation or grayscale data)
    facies_map : numpy.ndarray
        2D array where each unique integer represents a different facies
        
    Returns:
    --------
    facies_entropy : dict
        Dictionary mapping facies IDs to their Shannon entropy values
    entropy_map : numpy.ndarray
        2D array containing the entropy value for each pixel based on its facies
    """
    # Get unique facies values (excluding 0 if it's a no-data value)
    unique_facies = np.unique(facies_map)
    if 0 in unique_facies and len(unique_facies) > 1:
        unique_facies = unique_facies[unique_facies != 0]
    
    facies_entropy = {}
    entropy_map = np.zeros_like(facies_map, dtype=float)
    facies_entropy_maps = {}
    
    # Calculate Shannon entropy for each facies
    for facies_id in unique_facies:
        # Create mask for this facies
        mask = facies_map == facies_id
        
        if np.sum(mask) > 0:  # Only calculate if facies exists
            # Extract the base image values for this facies
            facies_values = base_image[mask]
            # drop NaN values if any
            facies_values = facies_values[~np.isnan(facies_values)]
            if facies_values.size == 0:
                continue
            
            # Normalize values to 0-255 for entropy calculation
            if np.nanmin(facies_values) != np.nanmax(facies_values):
                if norm:
                    facies_values_norm = exposure.rescale_intensity(facies_values, out_range=(0, 255)).astype(np.uint8)
                    
                    # Calculate Shannon entropy
                    facies_entropy[facies_id] = shannon_entropy(facies_values_norm)
                else:                    
                    # Calculate Shannon entropy without normalization
                    facies_entropy[facies_id] = shannon_entropy(facies_values)
            else:
                # If all values are the same, entropy is 0
                facies_entropy[facies_id] = 0.0
            
            # Create a local entropy map for this facies
            local_entropy_map = np.zeros_like(base_image, dtype=float)
            
            # For each pixel in this facies, calculate local entropy in a window
            window_size = 5  # Use a 5x5 window for local entropy
            half_window = window_size // 2
            
            rows, cols = np.where(mask)
            for i, j in zip(rows, cols):
                # Define window boundaries with border handling
                row_min = max(0, i - half_window)
                row_max = min(base_image.shape[0], i + half_window + 1)
                col_min = max(0, j - half_window)
                col_max = min(base_image.shape[1], j + half_window + 1)
                
                # Extract window values
                window = base_image[row_min:row_max, col_min:col_max]
                # drop NaN values if any
                window = window[~np.isnan(window)]
                if window.size == 0:
                    continue
                
                # Normalize window values
                if np.nanmin(window) != np.nanmax(window):
                    if norm:
                        window_norm = exposure.rescale_intensity(window, out_range=(0, 255)).astype(np.uint8)
                        # Calculate Shannon entropy for the window
                        local_entropy_map[i, j] = shannon_entropy(window_norm)
                    else:
                        # Calculate Shannon entropy without normalization
                        local_entropy_map[i, j] = shannon_entropy(window)
                else:
                    local_entropy_map[i, j] = 0.0
            
            # Store the local entropy map for this facies
            facies_entropy_maps[facies_id] = local_entropy_map
            
            # Assign entropy values to the entropy map
            entropy_map[mask] = local_entropy_map[mask]
    
    return facies_entropy, entropy_map, facies_entropy_maps

def calculate_complexity_map(facies_map, facies_entropy, pixel_size=1.0, kernel_radius=10):
    """
    Calculate the complexity (information content) map from a facies map,
    incorporating the entropy within each facies.
    
    Parameters:
    -----------
    facies_map : numpy.ndarray
        2D array where each unique integer represents a different facies
    facies_entropy : dict
        Dictionary mapping facies IDs to their Shannon entropy values
    pixel_size : float
        Size of each pixel in map units (e.g., meters)
    kernel_radius : float
        Radius of the kernel for calculating complexity in map units
        
    Returns:
    --------
    complexity_map : numpy.ndarray
        2D array containing the complexity (in bits) at each location
    """
    # Get unique facies values (excluding 0 if it's a no-data value)
    unique_facies = np.unique(facies_map)
    if 0 in unique_facies and len(unique_facies) > 1:
        unique_facies = unique_facies[unique_facies != 0]
    
    n_facies = len(unique_facies)
    rows, cols = facies_map.shape
    
    # Initialize complexity map
    complexity_map = np.zeros((rows, cols), dtype=float)
    
    # Create distance maps for each facies
    distance_maps = {}
    for facies_id in unique_facies:
        # Create binary map for this facies
        binary_map = (facies_map == facies_id).astype(int)
        
        # Calculate distance to nearest facies boundary
        # First get distance from non-facies areas to facies
        dist_inside = distance_transform_edt(binary_map) * pixel_size
        # Then get distance from facies to non-facies areas
        dist_outside = distance_transform_edt(1 - binary_map) * pixel_size
        
        # Combine: negative inside facies, positive outside
        distance_maps[facies_id] = dist_outside - dist_inside
    
    # Calculate complexity at each point based on kernel
    for i in range(rows):
        for j in range(cols):
            # Skip if no data
            if facies_map[i, j] == 0 and 0 not in unique_facies:
                continue
            
            # Calculate proximity weights for each facies
            weights = np.zeros(n_facies)
            entropy_weights = np.zeros(n_facies)
            
            for k, facies_id in enumerate(unique_facies):
                # Get distance to this facies
                distance = abs(distance_maps[facies_id][i, j])
                
                # Apply inverse distance weighting
                if distance <= kernel_radius:  # Using kernel_radius directly (in map units)
                    # Weight is inversely proportional to distance
                    # Add small constant to avoid division by zero
                    weights[k] = 1.0 / (distance + 0.1)
                    
                    # Incorporate facies entropy into the weights
                    entropy_factor = facies_entropy.get(facies_id, 0) + 0.1  # Add small constant
                    entropy_weights[k] = weights[k] * entropy_factor
            
            # Normalize entropy-weighted weights to get probability distribution
            if np.sum(entropy_weights) > 0:
                probabilities = entropy_weights / np.sum(entropy_weights)
                
                # Calculate Shannon entropy (in bits)
                entropy = 0
                for p in probabilities:
                    if p > 0:
                        entropy -= p * math.log2(p)
                
                complexity_map[i, j] = entropy
    
    return complexity_map

def plot_facies_and_complexity(facies_map, complexity_map, facies_entropy_map=None, base_image=None, 
                               transform=None, facies_names=None, title="Facies Complexity Analysis", save_path=None):
    """
    Plot the facies map, intra-facies entropy map, and inter-facies complexity map overlaid on a basemap
    
    Parameters:
    -----------
    facies_map : numpy.ndarray
        2D array where each unique integer represents a different facies
    complexity_map : numpy.ndarray
        2D array containing the complexity (in bits) at each location
    facies_entropy_map : numpy.ndarray, optional
        2D array containing the entropy within each facies
    base_image : numpy.ndarray, optional
        Base image (elevation or grayscale data) to use as the background
    transform : affine.Affine, optional
        Affine transformation for the raster
    facies_names : dict, optional
        Dictionary mapping facies IDs to their original names
    title : str
        Title for the plot
    """
    if facies_entropy_map is None:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 14))
    else:
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(28, 12))
    
    # Prepare basemap (normalize the base image for display)
    if base_image is not None:
        base_image = np.where(base_image < -5500, np.nan, base_image)  # Replace 0 with NaN for display
        # Normalize base_image to 0-1 for display
        base_image = exposure.rescale_intensity(base_image, in_range=(np.nanpercentile(base_image, 1), np.nanpercentile(base_image, 99)), out_range=(0, 1))
        base_norm = (base_image - np.nanmin(base_image)) / (np.nanmax(base_image) - np.nanmin(base_image))
    
    # Prepare facies map with custom colormap
    unique_facies = np.unique(facies_map)
    if 0 in unique_facies and len(unique_facies) > 1:
        unique_facies = unique_facies[unique_facies != 0]
    
    n_facies = len(unique_facies)
    colors_list = plt.cm.tab20(np.linspace(0, 1, n_facies))
    facies_cmap = ListedColormap(colors_list)
    
    # Create facies mask for overlay (to show background through transparent areas)
    facies_mask = facies_map > 0  # True where facies exist
    masked_facies = np.ma.masked_where(~facies_mask, facies_map)
    
    # Plot 1: Facies Map overlaid on basemap
    if base_image is not None:
        ax1.imshow(base_norm, cmap='gray')
        facies_img = ax1.imshow(masked_facies, cmap=facies_cmap, interpolation='nearest', alpha=0.55)
    else:
        facies_img = ax1.imshow(facies_map, cmap=facies_cmap, interpolation='nearest')
    
    ax1.set_title("Geological Unit Map", fontsize=36, pad=20)
    ax1.set_xlabel("Easting (pixels)", fontsize=30, labelpad=15)
    ax1.set_ylabel("Northing (pixels)", fontsize=30, labelpad=15)
    ax1.tick_params(axis='both', labelsize=30)
    
    bounds = np.arange(0, n_facies + 1) - 0.5
    norm = BoundaryNorm(bounds, facies_cmap.N)
    cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=facies_cmap), ax=ax1, label="Units", ticks=np.arange(n_facies), fraction=0.046, pad=0.04)
    cbar.ax.tick_params(labelsize=20)  # Increase font size for colorbar ticks
    cbar.set_label("Units", fontsize=20)  # Increase font size for the label
    
    if facies_names:
        labels = [facies_names.get(int(facies_id), f"{facies_id}") for facies_id in unique_facies]
        cbar.ax.set_yticklabels(labels, fontsize=20)  # Increase font size for colorbar labels
        if any(len(str(label)) > 10 for label in labels):
            cbar.ax.tick_params(labelsize=14)
    
    # Plot 2: Complexity Map overlaid on basemap
    if base_image is not None:
        ax2.imshow(base_norm, cmap='gray')
        masked_complexity = np.ma.masked_where(facies_map == 0, complexity_map)
        complexity_img = ax2.imshow(masked_complexity, cmap='plasma', alpha=0.55)
    else:
        complexity_img = ax2.imshow(complexity_map, cmap='plasma')
    
    ax2.set_title("Inter-Unit Complexity Map", fontsize=36, pad=20)
    ax2.set_xlabel("Easting (pixels)", fontsize=30, labelpad=15)
    ax2.tick_params(axis='both', labelsize=30)
    cbar2 = plt.colorbar(complexity_img, ax=ax2, label="Shannon Entropy (bits)", fraction=0.046, pad=0.04)
    cbar2.ax.tick_params(labelsize=20)  # Increase font size for colorbar ticks
    cbar2.set_label("Shannon Entropy (bits)", fontsize=20)  # Increase font size for the label
    
    # Plot 3: Entropy Map overlaid on basemap (if provided)
    if facies_entropy_map is not None:
        if base_image is not None:
            ax3.imshow(base_norm, cmap='gray')
            masked_entropy = np.ma.masked_where(facies_map == 0, facies_entropy_map)
            entropy_img = ax3.imshow(masked_entropy, cmap='plasma', alpha=0.8)
        else:
            entropy_img = ax3.imshow(facies_entropy_map, cmap='plasma')
        
        ax3.set_title("Raster-based Complexity Map", fontsize=36, pad=20)
        ax3.set_xlabel("Easting (pixels)", fontsize=30, labelpad=15)
        ax3.tick_params(axis='both', labelsize=30)
        cbar3 = plt.colorbar(entropy_img, ax=ax3, label="Shannon Entropy (bits)", fraction=0.046, pad=0.04)
        cbar3.ax.tick_params(labelsize=20)  # Increase font size for colorbar ticks
        cbar3.set_label("Shannon Entropy (bits)", fontsize=20)  # Increase font size for the label
    # plt.suptitle(title, fontsize=38)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, format='svg', bbox_inches='tight')
        plt.show()
    else:
        plt.show()

def plot_facies_entropy_summary(facies_map, facies_entropy, facies_names=None, save_path=None):
    """
    Plot a summary of entropy values for each facies
    
    Parameters:
    -----------
    facies_map : numpy.ndarray
        2D array where each unique integer represents a different facies
    facies_entropy : dict
        Dictionary mapping facies IDs to their Shannon entropy values
    facies_names : dict, optional
        Dictionary mapping facies IDs to their original names
    """
    # Get unique facies values (excluding 0 if it's a no-data value)
    unique_facies = np.unique(facies_map)
    if 0 in unique_facies and len(unique_facies) > 1:
        unique_facies = unique_facies[unique_facies != 0]
    
    facies_ids = []
    entropy_values = []
    facies_areas = []
    facies_labels = []
    
    for facies_id in unique_facies:
        if facies_id in facies_entropy:
            facies_ids.append(facies_id)
            entropy_values.append(facies_entropy[facies_id])
            facies_areas.append(np.sum(facies_map == facies_id))
            
            # Use original name if available
            if facies_names and facies_id in facies_names:
                facies_labels.append(facies_names[facies_id])
            else:
                facies_labels.append(f"{facies_id}")
    
    # Normalize areas for bubble size
    max_area = max(facies_areas)
    normalized_areas = [500 * (area / max_area) for area in facies_areas]
    
    plt.figure(figsize=(8, 4))
    
    # Create scatter plot with bubble size proportional to facies area
    scatter = plt.scatter(range(len(facies_ids)), entropy_values, s=normalized_areas, 
                         alpha=0.6, c=facies_ids, cmap='tab20')
    
    plt.title("Shannon Entropy by Unit")
    plt.xlabel("Unit")
    plt.ylabel("Shannon Entropy (bits)")
    
    # Set x-ticks to use facies names instead of IDs
    plt.xticks(range(len(facies_ids)), facies_labels, rotation=45, ha='right')
    
    # Adjust layout to accommodate rotated labels
    plt.subplots_adjust(bottom=0.2)
    
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Add a size legend
    sizes = [0.25, 0.5, 0.75, 1.0]
    size_labels = []
    for size in sizes:
        area = max_area * size
        # Format the area value based on its magnitude
        if area >= 1e6:
            size_labels.append(f"{area/1e6:.1f}M pixels")
        elif area >= 1e3:
            size_labels.append(f"{area/1e3:.1f}K pixels")
        else:
            size_labels.append(f"{int(area)} pixels")
            
    size_handles = [plt.scatter([], [], s=500*size, color='gray', alpha=0.4) for size in sizes]
    legend = plt.legend(size_handles, size_labels, scatterpoints=1, 
                        labelspacing=1.5, title="Unit Area", loc='center left', 
                        bbox_to_anchor=(1.05, 0.5))
    legend.get_frame().set_boxstyle("square", pad=0.5)  # Increase padding to extend the frame
    # legend.get_frame().set_edgecolor(None)  # Remove the border of the legend box
    
    # Add value annotations
    # for i, (entropy, label) in enumerate(zip(entropy_values, facies_labels)):
    #     plt.annotate(f"{entropy:.3f}",
    #                 (i, entropy),
    #                 xytext=(0, 5),
    #                 textcoords='offset points',
    #                 ha='center')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, format='svg', bbox_inches='tight')
        plt.show()
    else:
        plt.show()

def load_facies_map_from_shapefile(shapefile_path, raster_resolution, gdb_layer_name=None, unit_name=None, bbox=None):
    """
    Convert a facies shapefile to a raster map
    
    Parameters:
    -----------
    shapefile_path : str
        Path to the shapefile containing facies polygons
    raster_resolution : float
        Resolution of the output raster in map units
    gdb_layer_name : str, optional
        Name of the layer to read from a geodatabase
    unit_name : str, optional
        Name of the column containing facies information
    bbox : tuple, optional
        Bounding box for the raster (minx, miny, maxx, maxy)
        
    Returns:
    --------
    facies_map : numpy.ndarray
        2D array where each unique integer represents a different facies
    transform : affine.Affine
        Affine transformation for the raster
    facies_names : dict
        Dictionary mapping facies IDs to their original names
    """
    # Load the shapefile
    if gdb_layer_name:
        gdf = gpd.read_file(shapefile_path, layer=gdb_layer_name)
    else:
        gdf = gpd.read_file(shapefile_path)

    # Check if both "crater rim" and "crater riim" exist in the specified column
    if "Crater riim unit" in gdf[unit_name].values and "Crater rim unit" in gdf[unit_name].values:
        # Replace all occurrences of "crater riim" with "crater rim"
        gdf[unit_name] = gdf[unit_name].replace("Crater riim unit", "Crater rim unit")

    # Add a 'facies_id' column to the GeoDataFrame
    # Assigning unique IDs to each row
    facies_names = {}
    
    if unit_name is not None:
        # Create a mapping of category to ID
        categories = gdf[unit_name].astype('category')
        id_to_category = dict(enumerate(categories.cat.categories))
        
        # Assign category codes (starting from 1 to avoid 0)
        gdf['facies_id'] = categories.cat.codes + 1
        
        # Create a mapping from facies_id to original name
        for facies_id, original_name in id_to_category.items():
            facies_names[facies_id + 1] = original_name
    else:
        # If no unit_name provided, use row numbers as IDs
        gdf['facies_id'] = range(1, len(gdf) + 1)
        
        # Use row index as the name if no specific column is provided
        for i in range(len(gdf)):
            facies_names[i + 1] = f"Facies {i + 1}"
    
    # If no bbox is provided, use the bounds of the shapefile
    if bbox is None:
        bbox = gdf.total_bounds
    
    # Calculate raster dimensions
    width = int((bbox[2] - bbox[0]) / raster_resolution)
    height = int((bbox[3] - bbox[1]) / raster_resolution)
    
    # Create transform
    transform = rasterio.transform.from_bounds(
        bbox[0], bbox[1], bbox[2], bbox[3], width, height
    )
    
    # Create a list of (geometry, value) pairs for rasterization
    shapes = [(geom, value) for geom, value in zip(gdf.geometry, gdf['facies_id'])]
    
    # Rasterize the shapefile
    facies_map = rasterize(
        shapes,
        out_shape=(height, width),
        transform=transform,
        all_touched=True,
        fill=0,  # Background value
        dtype=np.int32  # Use np.int32 to handle larger range of values
    )
    
    return facies_map, transform, facies_names

def main(base_image_path, facies_shape_path, output_prefix=None, raster_resolution=10.0, 
         kernel_radius=100.0, gdb_layer_name=None, unit_name=None, base_img_nan_value=None):
    """
    Main workflow for facies complexity analysis using base image data
    
    Parameters:
    -----------
    base_image_path : str
        Path to the base image (elevation or grayscale data)
    facies_shape_path : str
        Path to the shapefile containing facies polygons
    output_prefix : str, optional
        Prefix for output files
    raster_resolution : float
        Resolution of the raster in map units
    kernel_radius : float
        Radius of the kernel for calculating complexity in map units
    gdb_layer_name : str, optional
        Name of the layer to read from a geodatabase
    unit_name : str, optional
        Name of the column containing facies information
    """
    # Load base image
    print("Loading base image...")
    base_image, base_transform = load_base_image(base_image_path, img_nan_value = base_img_nan_value)
    
    # Load facies map
    print("Loading facies map...")
    facies_map, facies_transform, facies_names = load_facies_map_from_shapefile(
        facies_shape_path,
        raster_resolution=raster_resolution,
        gdb_layer_name=gdb_layer_name,
        unit_name=unit_name,
        bbox=(base_transform[2], base_transform[5] + base_image.shape[0] * base_transform[4],
              base_transform[2] + base_image.shape[1] * base_transform[0], base_transform[5])
    )
    
    # Print facies names
    print("\nFacies Names:")
    for facies_id, name in facies_names.items():
        count = np.sum(facies_map == facies_id)
        area = count * (raster_resolution ** 2)
        print(f"  Facies {facies_id}: {name} - {count} pixels ({area:.2f} square units)")
    
    # Ensure base_image and facies_map have the same dimensions
    if base_image.shape != facies_map.shape:
        print(f"Resizing base image from {base_image.shape} to match facies map {facies_map.shape}...")
        # Resample base_image to match facies_map dimensions
        from rasterio.warp import reproject, Resampling
        base_image_resampled = np.zeros(facies_map.shape, dtype=base_image.dtype)
        reproject(
            source=base_image,
            destination=base_image_resampled,
            src_transform=base_transform,
            dst_transform=facies_transform,
            src_crs=rasterio.CRS.from_epsg(9122),  
            dst_crs=rasterio.CRS.from_epsg(9122),
            resampling=Resampling.cubic
        )
        base_image = base_image_resampled
    
    # Calculate Shannon entropy for each facies
    print("Calculating intra-facies entropy...")
    facies_entropy, entropy_map, facies_entropy_maps = calculate_facies_entropy(base_image, facies_map, norm=True)
    
    # Calculate complexity map that incorporates facies entropy
    print("Calculating inter-facies complexity...")
    complexity_map = calculate_complexity_map(
        facies_map, 
        facies_entropy, 
        pixel_size=raster_resolution, 
        kernel_radius=kernel_radius
    )
    
    # Plot results with basemap overlay
    plot_facies_and_complexity(facies_map, complexity_map, entropy_map, base_image, 
                              facies_transform, facies_names, title="Geo Unit Complexity Analysis", save_path='jezero_ctx_complexity_fig.svg')
    
    # Plot facies entropy summary with original names
    plot_facies_entropy_summary(facies_map, facies_entropy, facies_names, save_path='jezero_ctx_entropy_summary_fig.svg')
    
    # Save results if output_prefix is provided
    if output_prefix:
        # Save the entropy and complexity maps
        output_meta = {
            'driver': 'GTiff',
            'height': facies_map.shape[0],
            'width': facies_map.shape[1],
            'count': 1,
            'dtype': 'float32',
            'crs': rasterio.CRS.from_epsg(4326),  # Assuming WGS84, adjust as needed
            'transform': facies_transform
        }
        
        with rasterio.open(f"{output_prefix}_entropy.tif", 'w', **output_meta) as dst:
            dst.write(entropy_map.astype(np.float32), 1)
        
        with rasterio.open(f"{output_prefix}_complexity.tif", 'w', **output_meta) as dst:
            dst.write(complexity_map.astype(np.float32), 1)
        
        print(f"Results saved to {output_prefix}_entropy.tif and {output_prefix}_complexity.tif")
    
    # Print summary statistics with original names
    print("\nFacies Entropy Summary:")
    for facies_id, entropy in facies_entropy.items():
        if facies_id in facies_names:
            name = facies_names[facies_id]
            print(f"Facies {facies_id} ({name}): Shannon entropy = {entropy:.4f}")
        else:
            print(f"Facies {facies_id}: Shannon entropy = {entropy:.4f}")
    
    print("\nComplexity Map Summary:")
    print(f"Mean complexity: {np.mean(complexity_map):.4f}")
    print(f"Max complexity: {np.max(complexity_map):.4f}")
    print(f"Min complexity: {np.min(complexity_map):.4f}")
    
    return facies_map, entropy_map, complexity_map, facies_entropy, facies_names 


if __name__ == "__main__":
    # Example usage:
    # Replace these paths with your actual data paths
    gdb_path = '/mnt/nili_e/Software/fractal_complexity/data/sim3464_JezeroNili_FINAL/Geodatabase/Nili.gdb'
    layer_name = 'GeoUnits_v1'
    base_img_path = '/mnt/nili_e/Software/fractal_complexity/data/Jezero_HiRISE_clipped_tiles/Jezero_HiRISE_clipped_tile_30500_54500.tif'
    # base_img_path = '/Users/phillipsm/Documents/Research/Proposals/2024/SSW/COMPLEX/sim3464_JezeroNili_FINAL/Rasters/HRSC/nili_h0988_0000_da4_clip.tif'

    facies_map, entropy_map, complexity_map, facies_entropy, facies_names = main(
        base_image_path=base_img_path,
        facies_shape_path=gdb_path,
        raster_resolution=0.25,
        kernel_radius=10.0,
        gdb_layer_name=layer_name,
        unit_name="name",  # Column name containing facies information in the shapefile
        # base_img_nan_value = -32768
    )