## Import

In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import importlib as il
from hypso import Hypso1, Hypso2
import src.deh as deh
import copy
import gc
il.reload(deh)

HYPSO_HEIGHT    =1092
HYPSO_WIDTH     =598
HYPSO_BANDS     =112 

In [None]:
HYPSO_IMAGE_DIR = r"D:\Downloads"
DATA_DIR        = r"D:\Hierarchical Unmixing Label\hUH\data"
IMAGES_DIR      = DATA_DIR + "\images"
DEH_DIR         = DATA_DIR + "\deh_models"
LABELS_DIR      = DATA_DIR + "\labels"

## Functions

In [None]:
#Functions
def plot_rgb_composite(cm, title="", red_band_index=69, green_band_index=46, blue_band_index=26, aspect=0.1, figsize=(10, 10), height=1092, contrast_enhancement=False):
    """
    Create and plot an RGB composite image from a hyperspectral cube.
    
    Parameters:
    - cm: Input hyperspectral data cube or array (can be flattened (h*w, n_bands) or 3D (h, w, n_bands))
    - title: Optional title for the plot
    - red_band_index: Index of the band to use for red channel (default: 69)
    - green_band_index: Index of the band to use for green channel (default: 46)
    - blue_band_index: Index of the band to use for blue channel (default: 26)
    - aspect: Aspect ratio for the plot (default: 0.1)
    - figsize: Figure size as tuple (width, height) in inches (default: (10, 10))
    - height: Height of the image when reshaping from flattened data (default: 1092)
    
    Returns:
    - rgb_image: The processed RGB image
    """
    # Check if input is already a 3D cube or needs reshaping
    if len(cm.shape) == 2:  # Flattened data (h*w, n_bands)
        width = cm.shape[0] // height
        data_cube = cm.reshape(width, height, cm.shape[1])
        data_cube = np.transpose(data_cube, (0, 1, 2))  # Ensure correct orientation
    elif len(cm.shape) == 3:  # Already a cube (h, w, n_bands)
        data_cube = cm
    else:
        raise ValueError(f"Unexpected input shape: {cm.shape}. Expected 2D or 3D array.")
    
    # Extract the specified bands for RGB channels
    red_band = data_cube[:, :, red_band_index]
    green_band = data_cube[:, :, green_band_index]
    blue_band = data_cube[:, :, blue_band_index]

    # Stack the bands to create an RGB image
    rgb_image = np.stack((red_band, green_band, blue_band), axis=-1)
    
    # Process data for better visualization
    # Replace NaN values with 0
    rgb_image[np.isnan(rgb_image)] = 0

    # Apply normalization to each channel
    for i in range(3):
        channel = rgb_image[:,:,i]
        
        # Always use min-max normalization regardless of value range
        min_val = np.nanmin(channel)
        max_val = np.nanmax(channel)
        
        if max_val > min_val:  # Avoid division by zero
            # Normalize to [0,1] range
            channel = (channel - min_val) / (max_val - min_val)
            
            # Apply contrast enhancement if requested
            if contrast_enhancement and np.any(channel > 0):
                # Only enhance contrast if we have enough non-zero values
                non_zero_values = channel[channel > 0]
                if len(non_zero_values) > 10:  # Arbitrary threshold
                    percentiles = np.nanpercentile(channel, [2, 98])
                    p_low, p_high = percentiles[0], percentiles[1]
                    if p_high > p_low:
                        channel = np.clip(channel, p_low, p_high)
                        channel = (channel - p_low) / (p_high - p_low)
            
            rgb_image[:,:,i] = channel
    
    # Final normalization and cleanup
    rgb_image = np.clip(rgb_image, 0, 1)
    
    # Rotate for proper orientation
    # rgb_image = np.rot90(rgb_image)
    
    # Create and display the plot
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    if title:
        plt.title(title)
    plt.imshow(np.rot90(rgb_image), aspect=aspect)
    plt.axis('off')  # Hide axes for cleaner visualization
    
    return rgb_image

def plot_DEH_overlay(DEH_input, rgb_image, node, opacity=0.5):
    """
    Plot and return the overlay image of DEH figure on top of an RGB image with controllable opacity.
    Non-zero DEH values will be shown with the specified opacity, while zero values will be fully transparent.
    Both normal and inverse overlays will be shown side by side.

    Parameters:
    - DEH_input: The DEH input object containing the node data
    - rgb_image: The RGB image to overlay the DEH figure on, shape (w, h, 3)
    - node: The node to visualize
    - opacity: The opacity of the DEH figure where values are non-zero (0.0 to 1.0)

    Returns:
    - overlay_image: The RGB image with the DEH figure overlaid
    """

    deh_figure = DEH_input.nodes[node].map.reshape(DEH_input.plot_size)
    
    # Normalize both the DEH figure and the RGB image to the range [0, 1]
    deh_figure_normalized = (deh_figure - np.min(deh_figure)) / (np.max(deh_figure) - np.min(deh_figure))
    rgb_image_normalized = (rgb_image - np.min(rgb_image)) / (np.max(rgb_image) - np.min(rgb_image))

    # Create a red overlay
    red_overlay = np.ones_like(deh_figure_normalized)[:, :, np.newaxis] * [1, 0, 0]  # Red color

    # Create normal overlay
    overlay_image = rgb_image_normalized.copy()
    non_zero_mask = deh_figure != 0
    overlay_image[non_zero_mask] = (rgb_image_normalized[non_zero_mask] * (1 - opacity) + 
                                  red_overlay[non_zero_mask] * opacity * deh_figure_normalized[non_zero_mask, np.newaxis])

    # Create inverse overlay
    inverse_overlay = rgb_image_normalized.copy()
    inverse_mask = deh_figure == 0
    inverse_overlay[inverse_mask] = (rgb_image_normalized[inverse_mask] * (1 - opacity) + 
                                   red_overlay[inverse_mask] * opacity)

    # Plot both overlays side by side
    fig, axs = plt.subplots(1, 2, figsize=(50, 200), gridspec_kw={'wspace': 0.1, 'hspace': 0.1})
    
    axs[0].imshow(np.rot90(overlay_image), aspect=DEH_input.plot_aspect, vmin=0, vmax=1, interpolation='bicubic')
    axs[0].set_title(f"Node: {node} - Original")
    axs[0].axis('off')

    axs[1].imshow(np.rot90(inverse_overlay), aspect=DEH_input.plot_aspect, vmin=0, vmax=1, interpolation='bicubic')
    axs[1].set_title(f"Node: {node} - Inverse")
    axs[1].axis('off')

    plt.show()

def plot_level_and_overlay(save_file_path, cm_input, plot_only_last_level=False, binarize=False, opacity=0.5):
    DEH_temp = deh.DEH(no_negative_residuals=True)
    DEH_temp.load(save_file_path)
    cube_image=cm_input.reshape(-1,HYPSO_HEIGHT,HYPSO_BANDS)

    DEH_temp.plot_size = (cube_image.shape[0],cube_image.shape[1])
    DEH_temp.verbose = False
    deh_predicted = DEH_temp.simple_predict(cm_input)
    rgb_image=plot_rgb_composite(cm_input)
    nodes=DEH_temp.nodes
    num_levels = max(len(node) for node in DEH_temp.nodes)
    levels_to_plot = [num_levels] if plot_only_last_level else range(1, num_levels + 1)
    if binarize:
        DEH_temp.binarize_lmdas()
        DEH_temp.lmda_2_map()
    
    for level in levels_to_plot:
        DEH_temp.display_level(level)
        for node in nodes:
            if len(node) == level:
                overlay_image = plot_DEH_overlay(DEH_temp, rgb_image, node=node, opacity=opacity)

def plot_spectra_for_level(save_file_path, cm_input, level_input=""):
    DEH_temp = deh.DEH(no_negative_residuals=True)
    DEH_temp.load(save_file_path)
    cube_image=cm_input.reshape(-1,HYPSO_HEIGHT,HYPSO_BANDS)

    DEH_temp.plot_size = (cube_image.shape[0],cube_image.shape[1])
    DEH_temp.simple_predict(cm_input)
    nodes = DEH_temp.nodes
    num_levels = max(len(node) for node in DEH_temp.nodes)
    
    # If level_input is empty or invalid, use the last level
    if level_input == "" or not isinstance(level_input, int) or level_input > num_levels or level_input < 1:
        level = num_levels
    else:
        level = level_input
        
    # Get nodes at specified level
    level_nodes = [node for node in nodes if len(node) == level]
    print("using level", level, "with", len(level_nodes), "nodes")
    DEH_temp.display_spectra(level_nodes)

def save_binarized_labels_to_npy(deh_path, data, save_path=None):
    """
    Save binarized labels from DEH model to a .npy file.
    Labels are stored as a dictionary with node names as keys and binary arrays as values.
    
    Args:
        deh_model: The DEH model with binarized labels
        data: Input data used for prediction 
        output_path: Path to save the output .npy file
    """
    import numpy as np
    deh_binarized=deh.DEH(no_negative_residuals=True)
    deh_binarized.load(deh_path)
    # Get predictions and binarize
    print(data.shape)
    cube=data.reshape(-1,HYPSO_HEIGHT,HYPSO_BANDS)
    deh_binarized.plot_size=(cube.shape[0],cube.shape[1])
    print(deh_binarized.plot_size)
    deh_binarized.simple_predict(data)
    deh_binarized.binarize_lmdas()
    deh_binarized.lmda_2_map()
    
    # Create dictionary to store labels
    labels_dict = {}
    
    # Store binary labels for each node
    for node in deh_binarized.nodes:
        # Get binary labels from node.map which already contains the binarized values
        labels = deh_binarized.nodes[node].map.flatten()
        labels_dict[node] = labels
        
    # Save dictionary to .npy file
    # Use provided savepath if available, otherwise create from deh_path
    if save_path is not None:
        output_path = save_path
    else:
        # Create output path based on deh_path
        # Replace file extension and add '_labels' to the filename
        base_path = os.path.splitext(deh_path)[0]
        output_path = base_path + '_labels.npy'
    
    print(f"Saving binarized labels to: {output_path}")
    np.save(output_path, labels_dict)

def plot_level_labels(labels_filename, level=None):
    """
    Plot binary labels for nodes at specified level, highlighting overlapping regions.
    
    Args:
        npy_file: Path to .npy file containing binary labels
        level: Level of nodes to plot (based on key length). If None, uses max level.
    """

    labels_path=os.path.join(LABELS_DIR, labels_filename)
    # Load labels
    labels = np.load(labels_path, allow_pickle=True).item()
    
    # Get all keys and their lengths
    key_lengths = [len(k) for k in labels.keys()]
    max_level = max(key_lengths)
    
    # If level not specified or invalid, use max level
    if level is None or level > max_level:
        level = max_level
        
    # Get keys for specified level
    level_keys = [k for k in labels.keys() if len(k) == level]
    
    if not level_keys:
        print(f"No nodes found at level {level}")
        return
        
    # Setup plot
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Create combined array to track overlaps
    first_key = level_keys[0]
    combined_labels = np.zeros(labels[first_key].reshape(-1, HYPSO_HEIGHT).shape)
    
    # First pass - count overlaps
    for key in level_keys:
        plot_label = labels[key].reshape(-1, HYPSO_HEIGHT)
        combined_labels += plot_label
    
    # Plot each node's labels
    for i, key in enumerate(level_keys):
        plot_label = labels[key].reshape(-1, HYPSO_HEIGHT)
        # Create custom colormap for this node
        colors = [(1,1,1,0), plt.cm.rainbow(i/len(level_keys))]  # Transparent white to color
        node_cmap = plt.matplotlib.colors.LinearSegmentedColormap.from_list(f'custom_{i}', colors)
        alpha = 0.5 if np.any(combined_labels > 1) else 1.0
        plt.imshow(np.rot90(plot_label), aspect=0.1, alpha=alpha, cmap=node_cmap)
    
    # Plot overlaps with a different color if they exist
    if np.any(combined_labels > 1):
        overlap_mask = combined_labels > 1
        overlap_display = np.rot90(overlap_mask.astype(float))
        plt.imshow(overlap_display, aspect=0.1, alpha=0.7, cmap='Reds', 
                   label='Overlapping Regions')
    
    plt.title(f'Binary Labels for Level {level} Nodes')
    
    # Add legend
    legend_elements = [plt.Rectangle((0,0),1,1, facecolor=plt.cm.rainbow(i/len(level_keys))) 
                      for i in range(len(level_keys))]
    if np.any(combined_labels > 1):
        legend_elements.append(plt.Rectangle((0,0),1,1, facecolor='red', alpha=0.7))
        legend_labels = list(level_keys) + ['Overlaps']
    else:
        legend_labels = list(level_keys)
    ax.legend(legend_elements, legend_labels, loc='center left', bbox_to_anchor=(1, 0.5))
    
    plt.tight_layout()
    plt.show()

def print_nodes_by_depth(deh_obj):
    """
    Print all nodes in the DEH object organized by depth.
    
    Args:
        deh_obj: The DEH object containing nodes
    """
    # Get all nodes
    all_nodes = list(deh_obj.nodes.keys())
    
    # Group nodes by depth
    nodes_by_depth = {}
    for node_id in all_nodes:
        depth = len(node_id)
        if depth not in nodes_by_depth:
            nodes_by_depth[depth] = []
        nodes_by_depth[depth].append(node_id)
    
    # Print nodes by depth
    print("Nodes by depth:")
    for depth in sorted(nodes_by_depth.keys()):
        node_count = len(nodes_by_depth[depth])
        print(f"Depth {depth}: {node_count} nodes")
        # Print the first 10 nodes at this depth as examples
        example_nodes = nodes_by_depth[depth][:10]
        if example_nodes:
            print(f"  Examples: {', '.join(example_nodes)}")
        if node_count > 10:
            print(f"  ... and {node_count - 10} more")
        print()

def tree_walk(deh_obj, n=10, verbose=False):
    """
    Performs a tree walk starting from the root of a DEH object.
    The frontier expands by adding children of visited nodes.
    The next node to visit is selected based on the highest sum value.
    
    Parameters:
    -----------
    deh_obj : DEH object
        The hierarchical unmixing object to traverse
    n : int, default=10
        The number of nodes to have in the frontier before stopping
        
    Returns:
    --------
    list
        The nodes in the frontier when the threshold is reached
    """
    if verbose:
        print("Tree walking...")
    # Initialize with root node
    visited = []
    frontier = ['']  # Start with root node
    
    # Continue until frontier reaches desired size
    while len(frontier) < n:
        if not frontier:
            break
            
        # Find node with highest sum in frontier
        max_sum = -float('inf')
        max_node = None
        max_idx = -1
        
        for i, node_id in enumerate(frontier):
            node_sum = deh_obj.nodes[node_id].map.sum()
            if node_sum > max_sum:
                max_sum = node_sum
                max_node = node_id
                max_idx = i
        
        # Remove the selected node from frontier and add to visited
        frontier.pop(max_idx)
        visited.append(max_node)
        
        # Add children to frontier
        # Children of node 'x' are 'x0' and 'x1'
        left_child = max_node + '0'
        right_child = max_node + '1'
        
        # Check if children exist in the model
        if left_child in deh_obj.nodes:
            frontier.append(left_child)
        if right_child in deh_obj.nodes:
            frontier.append(right_child)
    
    # Calculate sum of all nodes in frontier
    frontier_sum = 0
    for node_id in frontier:
        frontier_sum += deh_obj.nodes[node_id].map.sum()
    
    # Calculate root node sum for comparison
    root_sum = deh_obj.nodes[''].map.sum()
    
    # Print results
    if verbose:
        print(f"Frontier nodes when threshold of {n} nodes reached:")
        for node_id in frontier:
            level = len(node_id)  # Level is determined by the length of the node_id
            print(f"Node: {node_id}, Level: {level}, Sum: {deh_obj.nodes[node_id].map.sum():.4f}")
        print(f"Total sum of frontier nodes: {frontier_sum:.4f}")
        print(f"Root node sum: {root_sum:.4f}")
        print(f"Difference: {root_sum - frontier_sum:.4f}")
    
    return frontier

def trim_tree_to_frontier_nodes(deh_obj, node_list):
    """
    Prune a DEH object by keeping only the nodes in the given list and their ancestors,
    after extending all nodes to the same depth.
    
    Parameters:
    -----------
    deh_obj : DEH object
        The DEH object to prune
    node_list : list
        List of node IDs to keep (frontier nodes)
    
    Returns:
    --------
    pruned_deh_obj : DEH object
        The pruned DEH object
    """
    
    # Find the maximum depth of nodes in the node_list
    max_depth = max(len(node_id) for node_id in node_list)
    
    # Extend all nodes to the same depth by adding '0's
    extended_node_list = []
    for node_id in node_list:
        if len(node_id) < max_depth:
            # Extend the node_id by adding '0's until it reaches max_depth
            extended_node_id = node_id + '0' * (max_depth - len(node_id))
            extended_node_list.append(extended_node_id)
        else:
            extended_node_list.append(node_id)
    
    # Create a list of all nodes in the DEH object
    all_nodes = list(deh_obj.nodes.keys())
    
    # Create a set of nodes to keep (extended frontier nodes and their ancestors)
    nodes_to_keep = set(extended_node_list)
    
    # Add ancestors of frontier nodes to the keep list
    for node_id in extended_node_list:
        # Add all prefixes of the node_id (ancestors)
        for i in range(len(node_id)):
            nodes_to_keep.add(node_id[:i])
    
    # Add the root node
    nodes_to_keep.add('')
    
    # Create a list of nodes to delete
    nodes_to_delete = []
    for node_id in all_nodes:
        if node_id not in nodes_to_keep:
            # Check if any ancestor is already in the delete list
            # If so, we don't need to add this node as it will be deleted with its ancestor
            ancestor_in_delete_list = False
            for i in range(1, len(node_id)):
                if node_id[:i] in nodes_to_delete:
                    ancestor_in_delete_list = True
                    break
            
            if not ancestor_in_delete_list:
                nodes_to_delete.append(node_id)
    
    # Sort nodes to delete by depth (delete deepest nodes first)
    nodes_to_delete.sort(key=len, reverse=True)
    
    # Delete the nodes one by one
    deleted_count = 0
    for node_id in nodes_to_delete:
        try:
            if node_id in deh_obj.nodes:  # Check if node still exists
                # Delete the node (this will delete the node and its descendants)
                deh_obj.delete_node(node_id)
                deleted_count += 1
        except KeyError as e:
            # Node might have been already deleted
            print(f"Note: Node {node_id} was already deleted or caused error: {e}")
            continue
    
    print(f"Pruned {deleted_count} nodes and its decendants from the DEH object")
    print(f"Extended node list to depth {max_depth}: {extended_node_list}")

def train_random_DEH(DEH_path, image_data, endmembers,  verbose=False):
    print("Training Random DEH...")
    DEH_random= deh.DEH(no_negative_residuals=True)
    #Init DEH
    DEH_random.splitting_size=1000
    DEH_random.max_depth=2
    DEH_random.max_iter=5
    DEH_random.max_nodes=2


    cube=image_data.reshape(-1,HYPSO_HEIGHT,HYPSO_BANDS)
    #these set the size of the image to be plotted
    DEH_random.plot_size = (cube.shape[0],cube.shape[1])
    DEH_random.plot_aspect = 0.1

    # these set the normaliztion on the data, but right now it is turned off and replaced with a nearest-neighbor noise estimate
    DEH_random.weight_power=0
    DEH_random.eps = 0.00
    #dehx.wf = lambda x: (np.sum(x**2, axis=-1))**(-1/2+1/(2+dehx.eps))

    #this switches between Archetype Analysis and Pure Pixel Analysis
    DEH_random.aa = False

    #this is currently replaced by "mpp_tol" in the training code
    DEH_random.mixed_pix = 0

    # this is the prefactor on the L2 regularization
    DEH_random.reg=0

    # if this is anything other than zero, it introduces L2 regularization on the trained weights 
    DEH_random.set_mu(0)

    # this turns on a sort of protection to keep endmembers from vanishing. Generally unneeded with new training
    DEH_random.use_bonus_boost = False

    # turns on normalization of the data. Generally always used
    DEH_random.use_norm(True)

    # if greater than 1, turns on a sort of spectral sampling with PAA. Occasionally gives good performance
    # but I generally keep it at 1 to turn it off
    DEH_random.PAA_backcount = 1

    # changes the gradient descent learning rate as a proportion of the optimal rate. I recommend keeping it at 1
    DEH_random.a_speed= 1.0

    #Verbose print
    if verbose:
        print("cm_input.shape: ", image_data.shape)
        print("cube.shape: ", cube.shape)
        print("DEH_random.plot_size", DEH_random.plot_size)
    #random Init
    if verbose:
        print("Quick nn")
    DEH_random.neighbors = deh.quick_nn(image_data.reshape(DEH_random.plot_size + (-1,)), k_size=1).flatten()
    if verbose: 
        print("Setting neighbor weights")
    DEH_random.set_neighbor_weights_memory_efficient(image_data)
    if verbose:
        print("Random init")
    DEH_random.random_init(image_data, endmembers)
    if verbose:
        print("Simple predict")
    DEH_random_pred=DEH_random.simple_predict(image_data)
    
    # Update the save path to add '_random' before the file extension
    if DEH_path.endswith('.h5'):
        DEH_path = DEH_path[:-3] + '_random.h5'
    else:
        DEH_path = DEH_path + '_random'
    
    DEH_random.save(DEH_path)
    
    if verbose:
        print("Training Random DEH complete, saved to ", DEH_path)
    
    return DEH_path

def trim_DEH(DEH_path, image_data, endmembers, verbose=False):
    print("Pruning DEH...")
    DEH_prune=deh.DEH(no_negative_residuals=True)
    DEH_prune.load(DEH_path)
    DEH_prune.verbose = verbose
    DEH_prune_pred=DEH_prune.simple_predict(image_data)
    frontier_nodes = tree_walk(DEH_prune, n=endmembers, verbose=verbose)
    trim_tree_to_frontier_nodes(DEH_prune, frontier_nodes)
    if verbose:
        print("\nNodes by depth:")
        print_nodes_by_depth(DEH_prune)
    
    #Update the save path to reflect pruned model with correct number of endmembers
    save_path_parts = DEH_path.split('_')
    for i, part in enumerate(save_path_parts):
        if 'end' in part:
            # Find parts containing 'end' (like '256end')
            prefix = part.split('end')[0]  # Get the part before 'end'
            suffix = part.split('end')[1] if len(part.split('end')) > 1 else ''  # Get the part after 'end'
            # Replace with new endmembers value
            save_path_parts[i] = str(endmembers) + 'end' + suffix
        if 'random' in part.lower():
            # Replace 'random' with 'pruned'
            save_path_parts[i] = part.lower().replace('random', 'trimmed')
    
    # Reconstruct the save path
    DEH_path = '_'.join(save_path_parts)
    DEH_prune.save(DEH_path)
    
    if verbose:
        print("Pruning DEH complete, saved to ", DEH_path)
    
    return DEH_path

def stabelize_DEH(DEH_path, image_data, saturated_image_data, n_runs=10, step_delta=0.05, verbose=False):
    print("Stabelizing DEH....")
    DEH_stabelize=deh.DEH(no_negative_residuals=True)
    DEH_stabelize.load(DEH_path)
    DEH_stabelize.verbose = verbose
    if verbose:
        print("Simple predict")
    DEH_stabelize.simple_predict(image_data)
    
    if verbose:
        print("Quick nn")
    DEH_stabelize.neighbors = deh.quick_nn(image_data.reshape(DEH_stabelize.plot_size + (-1,)), k_size=1).flatten()
    
    if verbose:
        print("Setting neighbor weights")

    DEH_stabelize.set_neighbor_weights_memory_efficient(image_data)
    if verbose:
        print("DEH_stabelized.full_weights.shape",DEH_stabelize.full_weights.shape)
    
    #saturated_input = np.asarray(saturated_input, dtype=bool)
    DEH_stabelize.full_weights[saturated_image_data]=0
    DEH_stabelize.PAA_backcount=1
    
    # Create a new save path by replacing the last part with 'stabelized'
    save_path_parts = DEH_path.split('_')
    
    save_path_parts = DEH_path.split('_')
    for i, part in enumerate(save_path_parts):
        if 'random' in part.lower():
            # Replace 'random' with 'pruned'
            save_path_parts[i] = part.lower().replace('random', 'stabelized')
        if 'trimmed' in part.lower():
            # Replace 'trimmed' with 'stabelized'
            save_path_parts[i] = part.lower().replace('trimmed', 'stabelized')
    
    # Reconstruct the save path
    DEH_path = '_'.join(save_path_parts)
    
    # Reconstruct the save path
    save_name = '_'.join(save_path_parts)
    if verbose:
        print(f"Updated save path for stabilization: {save_name}")

    if verbose:
        print("Accepted network stablization")
    DEH_stabelize.accepted_network_stablization(image_data, n_runs=n_runs, n_pts=(1000,10000), obj_record=(), sampling_points=(), mpp_tol=0.2, step_delta=step_delta, reg_max=0.2, name=save_name)
    

## Load data

### L1D + MACHI DATA

In [None]:
combined_10_L1D_112_MACHI = np.load(f'{IMAGES_DIR}//combined_10_L1D_112_MACHI.npy')
combined_10_L1D_112_MACHI_saturated = np.load(f'{IMAGES_DIR}//combined_10_saturated.npy')

In [None]:
yucatan2_2025_02_06_L1D         = np.load(f'{IMAGES_DIR}//yucatan2_2025-02-06T16-01-18Z-l1a_flat_L1D_112_MACHI.npy')
kemigawa_2024_12_17_L1D         = np.load(f'{IMAGES_DIR}//kemigawa_2024-12-17T01-01-32Z-l1a_flat_L1D_112_MACHI.npy')
chapala_2025_02_24_L1D          = np.load(f'{IMAGES_DIR}//chapala_2025-02-24T16-52-47Z-l1a_flat_L1D_112_MACHI.npy')
grizzlybay_2025_01_27_L1D       = np.load(f'{IMAGES_DIR}//grizzlybay_2025-01-27T18-19-56Z-l1a_flat_L1D_112_MACHI.npy')
victoriaLand_2025_02_07_L1D     = np.load(f'{IMAGES_DIR}//victoriaLand_2025-02-07T20-35-33Z-l1a_flat_L1D_112_MACHI.npy')
catala_2025_01_28_L1D           = np.load(f'{IMAGES_DIR}//catala_2025-01-28T19-17-32Z-l1a_flat_L1D_112_MACHI.npy')
khnifiss_2025_02_12_L1D         = np.load(f'{IMAGES_DIR}//khnifiss_2025-02-12T11-05-35Z-l1a_flat_L1D_112_MACHI.npy')
menindee_2025_02_18_L1D         = np.load(f'{IMAGES_DIR}//menindee_2025-02-18T00-10-42Z-l1a_flat_L1D_112_MACHI.npy')
tampa_2024_11_12_L1D            = np.load(f'{IMAGES_DIR}//tampa_2024-11-12T15-31-55Z-l1a_flat_L1D_112_MACHI.npy')
falklandsatlantic_2024_12_18_L1D= np.load(f'{IMAGES_DIR}//falklandsatlantic_2024-12-18T13-25-18Z-l1a_flat_L1D_112_MACHI.npy')

In [None]:
flat=combined_10_L1D_112_MACHI
cube=flat.reshape(-1,HYPSO_HEIGHT,HYPSO_BANDS)
saturated=combined_10_L1D_112_MACHI_saturated

## TRAINING

In [None]:
DEH_name = "1_10img_256end_L1D_112_MACHI.h5"
DEH_path=os.path.join(DEH_DIR, DEH_name)

In [None]:
DEH_random_path=train_random_DEH(DEH_path, flat, 256, verbose=True)

In [None]:
DEH_trimmed_path=trim_DEH(DEH_random_path, flat, 8, verbose=True)

In [None]:
stabelize_DEH(DEH_trimmed_path, flat, saturated, verbose=True)

## GENERATE LABELS

In [None]:
stabelized='10img_8end_L1D_112_MACHI_stabelized_aa.h5'
stabelized_path=(os.path.join(DEH_DIR, stabelized))
print(f"Loading stabelized DEH from: {stabelized_path}")

In [None]:
image_name='caspiansea1_2025-04-08T07-11-56Z-l1a_flat_L1D_112_MACHI.npy'
image_path=os.path.join(IMAGES_DIR, image_name)
corrected_image=np.load(image_path)
save_path = os.path.join(LABELS_DIR, os.path.basename(image_name).split('.')[0] + '_labels.npy')

In [None]:
save_binarized_labels_to_npy(stabelized_path, corrected_image, save_path)

## Plotting

In [None]:
plot_level_and_overlay(stabelized_path, tampa_2024_11_12_L1D, plot_only_last_level=True, binarize=True, opacity=.5)

In [None]:
plot_spectra_for_level(stabelized_path, tampa_2024_11_12_L1D, level_input=1)

In [None]:
plot_level_labels("caspiansea1_2025-04-08T07-11-56Z-l1a_flat_L1D_112_MACHI_labels.npy", level=3)

## DEBUG

In [None]:
save_name = "10img_256end_L1D_112_MACHI_random.h5"
verbose=True
#Init DEH
DEH= deh.DEH(no_negative_residuals=True)
DEH.splitting_size=1000
DEH.max_depth=2
DEH.max_iter=5
DEH.max_nodes=2

#these set the size of the image to be plotted
DEH.plot_size = (cube.shape[0],cube.shape[1])
DEH.plot_aspect = 0.1

# these set the normaliztion on the data, but right now it is turned off and replaced with a nearest-neighbor noise estimate
DEH.weight_power=0
DEH.eps = 0.00
#dehx.wf = lambda x: (np.sum(x**2, axis=-1))**(-1/2+1/(2+dehx.eps))

#this switches between Archetype Analysis and Pure Pixel Analysis
DEH.aa = False

#this is currently replaced by "mpp_tol" in the training code
DEH.mixed_pix = 0

# this is the prefactor on the L2 regularization
DEH.reg=0

# if this is anything other than zero, it introduces L2 regularization on the trained weights 
DEH.set_mu(0)

# this turns on a sort of protection to keep endmembers from vanishing. Generally unneeded with new training
DEH.use_bonus_boost = False

# turns on normalization of the data. Generally always used
DEH.use_norm(True)

# if greater than 1, turns on a sort of spectral sampling with PAA. Occasionally gives good performance
# but I generally keep it at 1 to turn it off
DEH.PAA_backcount = 1

# changes the gradient descent learning rate as a proportion of the optimal rate. I recommend keeping it at 1
DEH.a_speed= 1.0

#Verbose print
if verbose:
    print("cm_input.shape: ", flat.shape)
    print("cube.shape: ", cube.shape)
    print("DEH.plot_size", DEH.plot_size)
    print("saturated_input.shape", saturated.shape)

#random Init
if verbose:
    print("Quick nn")
DEH.neighbors = deh.quick_nn(flat.reshape(DEH.plot_size + (-1,)), k_size=1).flatten()

if verbose: 
    print("Setting neighbor weights")
DEH.set_neighbor_weights_memory_efficient(flat)
if verbose:
    print("Random init")
DEH.random_init(flat, 256, seed=np.random.randint(0, 1000))
if verbose:
    print("Simple predict")
DEH_pred=DEH.simple_predict(flat)

frontier_nodes = tree_walk(DEH, n=8, verbose=True)
print(f"Frontier nodes: {frontier_nodes}")

In [None]:
DEH.save(os.path.join(DEH_DIR, save_name))

In [None]:
# Test the pruning function on DEH_LOAD with the frontier nodes
trim_tree_to_frontier_nodes(DEH, frontier_nodes)

In [None]:
DEH.save(os.path.join(DEH_DIR, save_name))

In [None]:
#Stabelized
print("Quick nn")
DEH.neighbors = deh.quick_nn(flat.reshape(DEH.plot_size + (-1,)), k_size=1).flatten()

print("Setting neighbor weights")
DEH.set_neighbor_weights_memory_efficient(flat)

#set saturated pixels to 0
DEH.full_weights[saturated]=0
DEH.PAA_backcount=1

print("Accepted network stablization")
#accepted_network_stablization saves the model to the save_name
DEH.accepted_network_stablization(flat, n_runs=10, n_pts=(1000,10000), obj_record=(), sampling_points=(), mpp_tol=0.2, step_delta=0.05, reg_max=0.2, name=save_name)

In [None]:
# Print nodes by depth for the original and pruned DEH objects
print_nodes_by_depth(DEH)