## CONFIG

In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import importlib as il
from hypso import Hypso1, Hypso2
from sklearn.svm import LinearSVC
import src.deh as deh
import json

In [None]:
HYPSO_HEIGHT=1092
HYPSO_WIDTH=598
HYPSO_BANDS=120

In [None]:
HYPSO_1_NC_DIR  = r"D:\HYPSO_1_NC"
HYPSO_2_NC_DIR  = r"D:\HYPSO_2_NC"
DATA_DIR        = r"D:\Hierarchical Unmixing Label\hUH\data"
IMAGES_DIR      = DATA_DIR + r"\images"
DEH_DIR         = DATA_DIR + r"\deh_models"
LABELS_DIR      = DATA_DIR + r"\labels"
TREE_DIR        = DATA_DIR + r"\tree_models"
PRED_DIR        = DATA_DIR + r"\predictions"
SVM_DIR    =r'\\wsl.localhost\Ubuntu-24.04\home\lofty\CODE\onboard-pipeline-modules\ground_training'

## CLASS AND FUCNTIONS

In [None]:
class BinaryDecisionTree:
    def __init__(self, verbose=False):
        self.verbose = verbose
        self.all_labels = {}
        self.endmembers = []
        self.splitting_nodes = []
        self.models = {}  # Dictionary to store SVM models for each splitting node
        
    def initialize_tree_structure(self, image_gt):
        """Initialize the tree structure by identifying endmembers and splitting nodes"""
        print("Initializing Binary Decision Tree structure...")
        all_keys=image_gt.keys()
        self.identify_set_endmembers(all_keys)
        self.identify_set_splitting_nodes(all_keys)
        
    def identify_set_endmembers(self, all_keys):
        max_length = max(len(key) for key in all_keys)
        self.endmembers = []
        
        if self.verbose:
            print("\n" + "="*50)
            print("IDENTIFYING ENDMEMBERS:")
            print("="*50)
            print(f"Maximum key length found: {max_length}")
            print("-"*50)
        
        for key in all_keys:
            if len(key) == max_length:
                self.endmembers.append(key)
                if self.verbose:
                    print(f"Found endmember: '{key}'")
        
        if self.verbose:
            print(f"Total endmembers identified: {len(self.endmembers)}")
        
    def identify_set_splitting_nodes(self, all_keys):
        self.splitting_nodes = []
        
        if self.verbose:
            print("\n" + "="*50)
            print("IDENTIFYING SPLITTING NODES:")
            print("="*50)
        
        for key in all_keys:
            has_zero = key + '0' in all_keys
            has_one = key + '1' in all_keys
            if has_zero and has_one:
                self.splitting_nodes.append(key)
                if self.verbose:
                    print(f"Found splitting node: '{key}' → branches to '{key}0' and '{key}1'")
        
        if self.verbose:
            print("-"*50)
            print(f"Total splitting nodes identified: {len(self.splitting_nodes)}")

    def _preprocess_input(self, X):
        """
        Preprocess input data to ensure it's in the right format (n_samples, n_bands)
        
        Parameters:
        -----------
        X : numpy.ndarray
            Hyperspectral image data. Can be:
            - Flattened array of shape (n_samples, n_bands)
            - Cube of shape (height, width, n_bands) or (n_bands, height, width)
            
        Returns:
        --------
        X_processed : numpy.ndarray
            Processed data of shape (n_samples, n_bands)
        """
        if self.verbose:
            print(f"\nPreprocessing input data with shape: {X.shape}")
            
        # If already 2D with samples as first dimension, return as is
        if len(X.shape) == 2:
            if self.verbose:
                # print(f"Input is already in correct format: {X.shape}")
                print("-"*50)
            return X
        
        # Handle 3D data (cube)
        elif len(X.shape) == 3:
            # Determine which dimension is the spectral dimension
            # Typically, spectral dimension is the smallest
            dims = np.array(X.shape)
            spectral_dim = np.argmin(dims)
            
            if self.verbose:
                print(f"Detected spectral dimension: {spectral_dim}")
            
            if spectral_dim == 0:  # (n_bands, height, width)
                n_bands, height, width = X.shape
                if self.verbose:
                    print(f"Reshaping from (n_bands={n_bands}, height={height}, width={width}) to ({height*width}, {n_bands})")
                    print("-"*50)
                return X.reshape(n_bands, -1).T  # Reshape to (height*width, n_bands)
            
            elif spectral_dim == 2:  # (height, width, n_bands)
                height, width, n_bands = X.shape
                if self.verbose:
                    print(f"Reshaping from (height={height}, width={width}, n_bands={n_bands}) to ({height*width}, {n_bands})")
                    print("-"*50)
                return X.reshape(-1, n_bands)  # Reshape to (height*width, n_bands)
            
            else:  # (height, n_bands, width) - unusual but handle it
                height, n_bands, width = X.shape
                if self.verbose:
                    print(f"Unusual format detected: (height={height}, n_bands={n_bands}, width={width})")
                    print(f"Reshaping to ({height*width}, {n_bands})")
                    print("-"*50)
                return X.transpose(0, 2, 1).reshape(-1, n_bands)  # Reshape to (height*width, n_bands)
        
        
        else:
            if self.verbose:
                print(f"Error: Unsupported input shape: {X.shape}")
            raise ValueError(f"Unsupported input shape: {X.shape}. Expected 2D or 3D array.")
    
    def train(self, X, labels=None):
        """
        Train SVMs for each splitting node
        
        Parameters:
        -----------
        X : numpy.ndarray
            Hyperspectral image data. Can be:
            - Flattened array of shape (n_samples, n_bands)
            - Cube of shape (height, width, n_bands) or (n_bands, height, width)
        labels : dict, optional
            Labels to use for training. If None, uses the labels provided during initialization.
        """
        from sklearn.svm import LinearSVC
        
        # Process input data to ensure it's in the right format (n_samples, n_bands)
        X = self._preprocess_input(X)
        
        # Use provided labels or fall back to the ones from initialization
        training_labels = labels if labels is not None else self.all_labels
        
        # if self.verbose:
        if True:
            print("Training SVMs for each splitting node...")
        
        for node in self.splitting_nodes:
            print("Splitting node:",node)
            # Get labels for this node
            y_parent = training_labels[node]
            
            # Only consider pixels that belong to this node
            mask = y_parent == 1
            X_node = X[mask]
            
            # Get labels for children nodes
            left_child = node + '0'
            right_child = node + '1'
            
            # Create binary labels for SVM (1 for right child, 0 for left child)
            y_train = np.zeros(X_node.shape[0], dtype=int)
            
            # Find indices where right child is 1
            if right_child in training_labels:
                right_mask = training_labels[right_child][mask] == 1
                y_train[right_mask] = 1
            
            # Train LinearSVC
            model = LinearSVC(dual='auto', random_state=42)
            model.fit(X_node, y_train)
            
            # Store the model
            self.models[node] = model
            
            # if self.verbose:
            if True:
                print(f"Trained model for node '{node}'")
    
    def predict(self, X):
        """
        Predict endmember classes for input data
        
        Parameters:
        -----------
        X : numpy.ndarray
            Hyperspectral image data. Can be:
            - Flattened array of shape (n_samples, n_bands)
            - Cube of shape (height, width, n_bands) or (n_bands, height, width)
            
        Returns:
        --------
        predictions : dict
            Dictionary with keys for each endmember and values as binary masks
        """
        # Process input data to ensure it's in the right format (n_samples, n_bands)
        X = self._preprocess_input(X)
        
        n_samples = X.shape[0]
        
        # Initialize predictions with all True for root node
        current_predictions = {
            '': np.ones(n_samples, dtype=bool)
        }
        
        # Process each level of the tree
        for level in range(max(len(node) for node in self.splitting_nodes) + 1):
            # Get nodes at this level
            level_nodes = [node for node in self.splitting_nodes if len(node) == level]
            
            for node in level_nodes:
                # Skip if node doesn't exist in current_predictions
                if node not in current_predictions:
                    continue
                # Skip if no samples belong to this node
                if not np.any(current_predictions[node]):
                    continue
                
                # Get samples that belong to this node
                node_mask = current_predictions[node]
                X_node = X[node_mask]
                
                # Predict using the SVM model
                if len(X_node) > 0:
                    y_pred = self.models[node].predict(X_node)
                    
                    # Create masks for children
                    left_child = node + '0'
                    right_child = node + '1'
                    
                    # Initialize child predictions
                    if left_child not in current_predictions:
                        current_predictions[left_child] = np.zeros(n_samples, dtype=bool)
                    if right_child not in current_predictions:
                        current_predictions[right_child] = np.zeros(n_samples, dtype=bool)
                    
                    # Update predictions for children
                    left_indices = np.where(node_mask)[0][y_pred == 0]
                    right_indices = np.where(node_mask)[0][y_pred == 1]
                    
                    current_predictions[left_child][left_indices] = True
                    current_predictions[right_child][right_indices] = True
        
        # For non-splitting nodes that just have a single child with '0' appended
        for key in self.all_labels.keys():
            if key not in self.splitting_nodes and key != '':
                # Find the parent node
                parent = key[:-1]
                if parent in current_predictions and key not in current_predictions:
                    # If this is a non-splitting child, it inherits parent's prediction
                    current_predictions[key] = current_predictions[parent].copy()
        
        # Extract predictions for endmembers
        endmember_predictions = {}
        for endmember in self.endmembers:
            if endmember in current_predictions:
                endmember_predictions[endmember] = current_predictions[endmember]
            else:
                # If endmember not in predictions, try to find its parent
                parent = endmember[:-1]
                while parent and parent not in current_predictions:
                    parent = parent[:-1]
                if parent:
                    endmember_predictions[endmember] = current_predictions[parent].copy()
                else:
                    endmember_predictions[endmember] = np.zeros(n_samples, dtype=bool)
        
        return endmember_predictions

    def evaluate(self, X, gt_labels=None):
        """
        Evaluate the model on test data
        
        Parameters:
        -----------
        X : numpy.ndarray
            Hyperspectral image data. Can be:
            - Flattened array of shape (n_samples, n_bands)
            - Cube of shape (height, width, n_bands) or (n_bands, height, width)
        gt_labels : dict, optional
            Ground truth labels for evaluation. If None, uses self.all_labels
            
        Returns:
        --------
        accuracy : float
            Overall accuracy of endmember predictions
        """
        if gt_labels is None:
            gt_labels = self.all_labels
            
        # if self.verbose:
        if True:
            print("Starting evaluation...")
            
        predictions = self.predict(X)
        
        # Calculate accuracy for endmembers
        correct = 0
        total = 0
        
        if self.verbose:
            print(f"Evaluating accuracy for {len(self.endmembers)} endmembers...")
            
        endmember_accuracies = {}
        for endmember in self.endmembers:
            if endmember in predictions and endmember in gt_labels:
                pred = predictions[endmember]
                true = gt_labels[endmember]
                
                endmember_correct = np.sum(pred == true)
                endmember_total = len(true)
                
                correct += endmember_correct
                total += endmember_total
                
                if self.verbose:
                    endmember_acc = endmember_correct / endmember_total if endmember_total > 0 else 0
                    endmember_accuracies[endmember] = endmember_acc
                    print(f"  Endmember '{endmember}': {endmember_acc:.4f} ({endmember_correct}/{endmember_total})")
        
        accuracy = correct / total if total > 0 else 0
        
        # if self.verbose:
        if True:
            print(f"Overall accuracy: {accuracy:.4f} ({correct}/{total} pixels)")
            
        return accuracy

    def plot_input_image(self, image_data, slice_idx=None, figsize=(15, 5), cmap='viridis', rgb_bands=(69, 46, 26)):
        """Plot input image data as single band, RGB composite, or flat image."""
        plt.figure(figsize=figsize)
        
        # Preprocess input data for more efficient handling
        image_data = np.asarray(image_data)  # Ensure numpy array
        
        # Handle flat images (h*w, n_bands)
        if len(image_data.shape) == 2 and image_data.shape[1] > 3:
            slice_idx = slice_idx if slice_idx is not None else image_data.shape[1] // 2
            band_data = image_data[:, slice_idx]
            
            # Try to reshape to a square-ish image if possible
            side = int(np.sqrt(image_data.shape[0]))
            reshaped_data = band_data.reshape(side, side) if side * side == image_data.shape[0] else band_data.reshape(-1, 1)
            plt.imshow(np.flipud(np.rot90(reshaped_data)), cmap=cmap, aspect=0.1)
            plt.colorbar(label='Intensity')
            #plt.title(f'Input Image - Band {slice_idx} (Flattened)')
        
        # Handle 3D data (h, w, n_bands)
        elif len(image_data.shape) == 3:
            h, w, n_bands = image_data.shape
            slice_idx = slice_idx if slice_idx is not None else n_bands // 2
            
            # Create RGB composite if bands are specified
            if rgb_bands is not None and len(rgb_bands) == 3:
                r_idx, g_idx, b_idx = rgb_bands
                if max(rgb_bands) > n_bands - 1:
                    raise ValueError(f"RGB band indices {rgb_bands} exceed available bands (0-{n_bands-1})")
                
                # More efficient RGB creation - preallocate and process in one go
                rgb_img = np.zeros((h, w, 3), dtype=np.float32)
                for i, band_idx in enumerate([r_idx, g_idx, b_idx]):
                    band = image_data[:, :, band_idx].astype(np.float32)
                    # Normalize only if needed
                    band_min, band_max = band.min(), band.max()
                    if band_min != band_max:
                        band = (band - band_min) / (band_max - band_min)
                    rgb_img[:, :, i] = band
                
                plt.imshow(np.flipud(np.rot90(rgb_img)), aspect=0.1)
                #plt.title(f'RGB Composite (R:{r_idx}, G:{g_idx}, B:{b_idx})')
            else:
                plt.imshow(np.flipud(np.rot90(image_data[:, :, slice_idx])), cmap=cmap, aspect=0.1)
                plt.colorbar(label='Intensity')
                #plt.title(f'Input Image - Band {slice_idx}')
        
        # Handle other cases
        else:
            plt.imshow(np.flipud(np.rot90(image_data)), cmap=cmap, aspect=0.1)
            plt.colorbar(label='Intensity')
            #plt.title('Input Image')
        
        plt.axis('off')
        plt.tight_layout()
        plt.show()

    def plot_ground_truth(self, labels=None, key='', figsize=(15, 5), cmap='viridis'):
        """Plot ground truth labels."""
        # Preprocess input
        labels = self.all_labels if labels is None else labels
        shape = labels[''].shape
        reshape = labels[''].reshape(-1, HYPSO_HEIGHT)
        height, width = reshape.shape[0], reshape.shape[1]
        
        # Handle dictionary-type labels
        if isinstance(labels, dict):
            if key == '':
                # Create a combined view of all endmembers
                plt.figure(figsize=figsize)
                first_val = list(labels.values())[0]
                
                if len(first_val.shape) == 1:
                    # Use the dimensions we already determined
                    combined_labels = np.zeros((height, width), dtype=int)
                    endmember_keys = [k for k in labels.keys() if k in self.endmembers]
                    
                    # Process data before visualization
                    for i, k in enumerate(endmember_keys):
                        if k == '':
                            continue
                        try:
                            label_reshaped = labels[k].reshape(height, width)
                            combined_labels[label_reshaped == 1] = i + 1
                        except Exception as e:
                            print(f"Error reshaping key '{k}': {e}")
                            continue
                    
                    # Rotate the image 90 degrees and set aspect ratio
                    plt.imshow(np.flipud(np.rot90(combined_labels)), cmap=cmap, aspect=0.1)
                    #plt.title('Ground Truth Labels - Combined View')
                else:
                    print("Cannot display labels: unexpected format")
            else:
                # Plot just the specified key
                if key not in labels:
                    print(f"Key '{key}' not found in labels")
                    return
                
                label_data = labels[key]
                if len(label_data.shape) == 1:
                    try:
                        # Use the dimensions we already determined
                        label_data = label_data.reshape(height, width)
                    except Exception as e:
                        print(f"Error reshaping label data: {e}")
                        return
                
                plt.figure(figsize=figsize)
                # Use viridis colormap for the binary data and rotate the image 90 degrees
                plt.imshow(np.flipud(np.rot90(label_data)), cmap=cmap, vmin=0, vmax=1, aspect=0.1)
                #plt.title(f'Ground Truth Label for "{key}"')
        
        # Handle array-type labels
        elif hasattr(labels, 'shape'):
            # Preprocess array data
            labels = np.asarray(labels)  # Ensure numpy array
            if len(labels.shape) == 1:
                try:
                    # Use the dimensions we already determined
                    labels = labels.reshape(height, width)
                except Exception as e:
                    print(f"Error reshaping label data: {e}")
                    return
            
            plt.figure(figsize=figsize)
            # Rotate the image 90 degrees and set aspect ratio
            plt.imshow(np.flipud(np.rot90(labels)), cmap=cmap, aspect=0.1)
            #plt.title('Ground Truth Labels')
        else:
            print("Cannot plot labels: unsupported format")
            return
        
        plt.axis('off')
        plt.tight_layout()
        plt.show()

    def plot_prediction(self, prediction=None, key='', figsize=(15, 5), cmap='viridis'):
        """Plot prediction results."""
        # Preprocess input
        if prediction is None:
            if not hasattr(self, 'last_prediction'):
                print("No prediction available. Run predict() first.")
                return
            prediction = self.last_prediction
        
        # Determine dimensions from the data
        first_key = next(iter(prediction))
        shape = prediction[first_key].shape
        reshape = prediction[first_key].reshape(-1, HYPSO_HEIGHT)
        height, width = reshape.shape[0], reshape.shape[1]
        
        # Handle dictionary-type predictions
        if isinstance(prediction, dict):
            if key == '':
                plt.figure(figsize=figsize)
                first_val = list(prediction.values())[0]
                
                if len(first_val.shape) == 1:
                    combined_pred = np.zeros((height, width), dtype=int)
                    endmember_keys = [k for k in prediction.keys() if k in self.endmembers]
                    
                    # Process data before visualization
                    for i, k in enumerate(endmember_keys):
                        if k == '':
                            continue
                        try:
                            pred_reshaped = prediction[k].reshape(height, width)
                            combined_pred[pred_reshaped == True] = i + 1
                        except Exception as e:
                            print(f"Error reshaping key '{k}': {e}")
                            continue
                    
                    # Rotate the image 90 degrees and set aspect ratio
                    plt.imshow(np.flipud(np.rot90(combined_pred)), cmap=cmap, vmin=0, aspect=0.1)
                    #plt.title('Predictions - Combined View')
                else:
                    print("Cannot display predictions: unexpected format")
            else:
                if key not in prediction:
                    print(f"Key '{key}' not found in predictions")
                    return
                
                pred_data = prediction[key]
                if len(pred_data.shape) == 1:
                    try:
                        # Use the dimensions we already determined
                        pred_data = pred_data.reshape(height, width)
                    except Exception as e:
                        print(f"Error reshaping prediction data: {e}")
                        return
                
                plt.figure(figsize=figsize)
                # Use viridis colormap for the binary data, rotate 90 degrees and set aspect ratio
                plt.imshow(np.flipud(np.rot90(pred_data)), cmap=cmap, vmin=0, vmax=1, aspect=0.1)
                # plt.title(f'Prediction for key "{key}"')
        
        # Handle array-type predictions
        elif hasattr(prediction, 'shape'):
            # Preprocess array data
            prediction = np.asarray(prediction)  # Ensure numpy array
            if len(prediction.shape) == 1:
                try:
                    # Use the dimensions we already determined
                    prediction = prediction.reshape(height, width)
                except Exception as e:
                    print(f"Error reshaping prediction data: {e}")
                    return
            
            plt.figure(figsize=figsize)
            # Rotate the image 90 degrees and set aspect ratio
            plt.imshow(np.flipud(np.rot90(prediction)), cmap=cmap, aspect=0.1, )
            # plt.title('Prediction')
        else:
            print("Cannot plot prediction: unsupported format")
            return
        
        plt.axis('off')
        plt.tight_layout()
        plt.show()

    def save_model_parameters(self, filename='decision_tree_model.pkl'):
        """
        Save the trained SVM parameters of the decision tree to a file.
        
        Parameters:
        -----------
        filename : str
            The name of the file to save the model parameters to.
        """
        import pickle
        import os
        
        if self.verbose:
            print("Preparing to save model parameters...")
            # for node, model in self.models.items():
            #     if hasattr(model, 'coef_'):
            #         print(f"Node {node} SVM weights shape: {model.coef_.shape}")
            #         print(f"Node {node} SVM weights: {model.coef_}")
            #     if hasattr(model, 'intercept_'):
            #         print(f"Node {node} SVM intercept: {model.intercept_}")
        
        # Create a dictionary to store all model data
        model_data = {
            'models': self.models,  # This is your dictionary of SVM models
            'endmembers': self.endmembers,
            'splitting_nodes': self.splitting_nodes,
            'all_labels': self.all_labels  # Save the labels too
        }
        
        # Ensure weights folder exists
        weights_folder = 'weights'
        os.makedirs(weights_folder, exist_ok=True)
        
        # Create full path to save file in weights folder
        filepath = os.path.join(weights_folder, filename)
        
        try:
            if self.verbose:
                print(f"Saving model to {filepath}...")
            with open(filepath, 'wb') as f:
                pickle.dump(model_data, f)
            print(f"Model parameters successfully saved to {filepath}")
        except Exception as e:
            print(f"Error saving model parameters: {e}")

    def load_model_parameters(self, filename='decision_tree_model.pkl'):
        """
        Load the trained SVM parameters for the decision tree from a file.
        
        Parameters:
        -----------
        filename : str
            The name of the file to load the model parameters from.
        
        Returns:
        --------
        bool
            True if loading was successful, False otherwise.
        """
        import pickle
        import os
        
        if self.verbose:
            print(f"Attempting to load model parameters from {filename}...")
        
        # Create full path to load file from weights folder
        # filepath = os.path.join('weights', filename)
        filepath=filename
        
        if not os.path.exists(filepath):
            print(f"Error: Model file {filepath} not found")
            return False
        
        try:
            if self.verbose:
                print("Reading model file...")
            with open(filepath, 'rb') as f:
                model_data = pickle.load(f)
            
            if self.verbose:
                print("Restoring model components...")
            
            # Restore the model components
            self.models = model_data['models']
            self.endmembers = model_data['endmembers']
            self.splitting_nodes = model_data['splitting_nodes']
            
            # Optionally restore labels if they were saved
            if 'all_labels' in model_data:
                self.all_labels = model_data['all_labels']
                if self.verbose:
                    print("Labels restored from saved model")
            
            print(f"Model parameters successfully loaded from {filepath}")
            return True
        except Exception as e:
            print(f"Error loading model parameters: {e}")
            return False

In [None]:
def predict_pipeline(image_path,tree_path, deh_path, hypso=1, verbose=False):
    # Check if the image path is a .npy or .nc file
    
    if image_path.endswith('.npy'):
        # Load numpy array directly
        image_data = np.load(image_path)
        channels=image_data.shape[-1]    
        if image_data.ndim == 3:
            image_data_flat = image_data.reshape(-1,channels)
    elif image_path.endswith('.nc'):
        # Use hypso to load .nc file
        try:
            image_data= nc_to_image(image_path, hypso, verbose, machi=True)
        except ImportError:
            print("Error: hypso package is required to process .nc files")
            return None
        except Exception as e:
            print(f"Error processing .nc file: {e}")
            return None
    else:
        print(f"Unsupported file format: {image_path}")
        print("Supported formats: .npy, .nc")
        return None
    
    # Print information about the image data values
    print(f"Image data shape: {image_data.shape}")
    
    #create label form deh
    image_gt = generate_huH_labels(image_data, deh_path, verbose=verbose, channels=channels)

    # Load the decision tree model
    tree = BinaryDecisionTree(image_gt,verbose=verbose)
    if not tree.load_model_parameters(tree_path):
        print("Failed to load tree model")
        return None
    
    
    
    # Make predictions using the tree
    predictions = tree.predict(image_data)
    evaluation = tree.evaluate(image_data,image_gt)
    print(evaluation)
    
    tree.plot_input_image(image_data.reshape(-1,HYPSO_HEIGHT,channels))
    tree.plot_ground_truth(image_gt)
    tree.plot_prediction(predictions)
    # Return the predictions
    return predictions

def nc_to_image(image_path, hypso=1, verbose=False, machi=False):
    if hypso==2:
        satobj = Hypso2(path=image_path, verbose=verbose)
    else:
        satobj = Hypso1(path=image_path, verbose=verbose)
    # Reshape if needed (depends on hypso output format)
    satobj.generate_l1b_cube()
    data = satobj.l1b_cube
    #satobj.generate_l1c_cube()
    #satobj.generate_l1d_cube()
    image_data = np.array(data)[:,:,6:118]
    image_data_flat=image_data.reshape(-1,112)

    return image_data_flat

def generate_huH_labels(image_file, deh_path, verbose=False, channels=120):
    DEH_model=deh.DEH(no_negative_residuals=True)
    DEH_model.load(deh_path)
    image_file_cube=image_file.reshape(-1,HYPSO_HEIGHT,channels)
    DEH_model.plot_size=(image_file_cube[0], image_file_cube[1])
    DEH_model.simple_predict(image_file)
    DEH_model.binarize_lmdas()
    DEH_model.lmda_2_map()
    labels_dict={}
    for node in DEH_model.nodes:
        labels = DEH_model.nodes[node].map.flatten()
        labels_dict[node] = labels
    return labels_dict

def save_tree_svms_to_binary_remapped(tree, output_dir):
    """
    Save all SVM models in the decision tree as binary files, with class labels remapped
    to sequential integers starting from 0.
    
    Parameters:
    -----------
    tree : BinaryDecisionTree
        The decision tree containing trained SVM models
    output_dir : str
        Directory where the binary files will be saved
        
    Returns:
    --------
    dict
        A dictionary containing the mapping information
    """
    import os
    import struct
    import numpy as np
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all splitting nodes with trained models
    splitting_nodes = tree.splitting_nodes
    print(f"Found {len(splitting_nodes)} splitting nodes in the tree")
    
    # Map each binary path to the corresponding endmember decimal value
    path_to_value = {}
    for i, endmember_label in enumerate(tree.endmembers):
        if endmember_label:  # Skip empty endmembers if any
            # Convert binary string to decimal
            em_value = int(endmember_label, 2)
            path_to_value[endmember_label] = em_value
    
    # Collect all unique class values
    all_classes = set()
    
    # Identify all class values
    for node in splitting_nodes:
        # Skip nodes without models
        if node not in tree.models:
            continue
        
        # Find all endmembers that would go to the left branch (node+'0')
        left_values = []
        # Find all endmembers that would go to the right branch (node+'1')
        right_values = []
        
        for path, value in path_to_value.items():
            # Only process endmembers whose paths are long enough to be affected by this node
            if len(path) > len(node):
                # Check if this path goes through the current node
                if path.startswith(node):
                    # This is a descendent of the current node, check which branch
                    next_bit = path[len(node)]
                    if next_bit == '0':
                        left_values.append(value)
                    else:
                        right_values.append(value)
        
        # Find lowest class in each branch (by decimal value)
        if left_values:
            all_classes.add(min(left_values))
        if right_values:
            all_classes.add(min(right_values))
    
    # Sort classes for consistent mapping
    original_classes = sorted(list(all_classes))
    
    # Create mapping from original classes to sequential integers
    class_mapping = {orig: i for i, orig in enumerate(original_classes)}
    reverse_mapping = {i: orig for i, orig in enumerate(original_classes)}
    
    print(f"Found {len(original_classes)} unique class values in the tree")
    print(f"Original classes: {original_classes}")
    print(f"Remapped to: {list(range(len(original_classes)))}")
    print(f"Mapping: {class_mapping}")
    
    # Count of successfully saved models
    saved_count = 0
    
    for node in splitting_nodes:
        # Check if this node has an SVM model
        if node not in tree.models:
            print(f"Warning: Node '{node}' is marked as a splitting node but has no SVM model")
            continue
        
        # Get the SVM model for this node
        svm_model = tree.models[node]
        
        # Find all endmembers that would go to the left branch (node+'0')
        left_values = []
        # Find all endmembers that would go to the right branch (node+'1')
        right_values = []
        
        for path, value in path_to_value.items():
            # Only process endmembers whose paths are long enough to be affected by this node
            if len(path) > len(node):
                # Check if this path goes through the current node
                if path.startswith(node):
                    # This is a descendent of the current node, check which branch
                    next_bit = path[len(node)]
                    if next_bit == '0':
                        left_values.append(value)
                    else:
                        right_values.append(value)
        
        # Find lowest class in each branch (by decimal value)
        left_min = min(left_values) if left_values else 0
        right_min = min(right_values) if right_values else 0
        
        # Remap class values to sequential indices
        left_mapped = class_mapping[left_min] if left_min in class_mapping else 0 
        right_mapped = class_mapping[right_min] if right_min in class_mapping else 0
        
        # Format the filename using remapped values
        svm_model_name = f"lsm{left_mapped:02d}{right_mapped:02d}"
        filepath = os.path.join(output_dir, svm_model_name)
        
        print(f"Node {node}: Splitting between left branch (orig={left_min}, remapped={left_mapped}) and right branch (orig={right_min}, remapped={right_mapped})")
        
        # Debug: Check if weights are non-zero before saving
        if hasattr(svm_model, 'coef_'):
            weights = svm_model.coef_[0]
            non_zero_count = np.sum(np.abs(weights) > 1e-10)
            if non_zero_count == 0:
                print(f"  WARNING: All weights are zero for node {node}!")
                
                # Try to access the SVM model's raw data directly
                if hasattr(svm_model, '_impl') and hasattr(svm_model._impl, 'raw_coef_'):
                    print("  Attempting to use _impl.raw_coef_ instead...")
                    weights = svm_model._impl.raw_coef_
                else:
                    print("  Could not find alternative weight source")
        
        # Open the file for binary writing
        with open(filepath, "wb") as file:
            # Write REMAPPED class labels as uint8
            file.write(struct.pack("BB", left_mapped, right_mapped))
            
            # Write intercept as float32
            intercept = svm_model.intercept_[0]
            file.write(struct.pack('f', intercept))
            
            # Write weights as float32
            if hasattr(svm_model, 'coef_'):
                weights = svm_model.coef_[0]
                file.write(struct.pack(f'{len(weights)}f', *weights))
                print(f"  Saved {len(weights)} weights (non-zero: {non_zero_count})")
            else:
                print("  WARNING: Model does not have coef_ attribute!")
        
        print(f"  Saved model to {svm_model_name}")
        saved_count += 1
    
    print(f"Successfully saved {saved_count} SVM models to {output_dir}")
    
    # Return the mapping for reference
    return {
        "original_to_remapped": class_mapping,
        "remapped_to_original": reverse_mapping,
        "original_classes": original_classes,
        "remapped_classes": list(range(len(original_classes)))
    }

def save_prediction(predictions, save_name):
        """
        Save the prediction results to a file
        
        Parameters:
        -----------
        predictions : dict
            Dictionary containing prediction masks for each endmember
        save_path : str
            Path where to save the predictions
        """
        save_name = save_name + '_pred.npy'
        save_path=os.path.join(PRED_DIR,save_name)
        np.save(save_path, predictions)

def load_prediction(prediction_name):
    """
    Load the prediction results from a file
    
    Parameters:
    -----------
    prediction_name : str
        filename of prediction file
        
    Returns:
    --------
    dict
        Dictionary containing prediction masks for each endmember
    """
    prediction_path=os.path.join(PRED_DIR,prediction_name)
    return np.load(prediction_path, allow_pickle=True).item()

## LOAD DATA

### HYPSO 1

In [None]:
tree_raw_path=os.path.join(TREE_DIR,'combined_10_L1A_120_16end_TREE.pkl')
tree_corrected_path=os.path.join(TREE_DIR,'combined_10_L1D_112_MACHI_16end_TREE.pkl')
print(tree_raw_path)

In [None]:
combined_10_L1A_120                 = np.load(os.path.join(IMAGES_DIR,'combined_10_L1A_120.npy'))
print(combined_10_L1A_120.shape)
combined_10_L1D_112_MACHI           = np.load(os.path.join(IMAGES_DIR,'combined_10_L1D_112_MACHI.npy'))
print(combined_10_L1D_112_MACHI.shape)
combined_10_L1D_112_MACHI_labels    = np.load(os.path.join(LABELS_DIR,'combined_10_L1D_112_MACHI_16end_labels.npy'), allow_pickle=True).item()
print(combined_10_L1D_112_MACHI_labels.keys())

In [None]:
tampa_2024_11_12_L1A_120        =np.load(os.path.join(IMAGES_DIR, 'tampa_2024-11-12T15-31-55Z-l1a_flat_L1A_120.npy'))
tampa_2024_11_12_L1A_120_labels =np.load(os.path.join(LABELS_DIR, 'tampa_2024-11-12T15-31-55Z-l1a_flat_L1D_112_MACHI_16end_labels.npy'),allow_pickle=True).item()

In [None]:
caspiansea1_2025_04_08_l1A_120=np.load(os.path.join(IMAGES_DIR, 'caspiansea1_2025-04-08T07-11-56Z-l1a_flat_L1A_120.npy'))
caspiansea1_2025_04_08_l1A_120_labels=np.load(os.path.join(LABELS_DIR, 'caspiansea1_2025-04-08T07-11-56Z-l1a_flat_L1D_112_MACHI_16end_labels.npy'),allow_pickle=True).item()

In [None]:
vancouver_2025_05_04_L1A_120=np.load(os.path.join(IMAGES_DIR, 'vancouver_2025-05-04T19-12-24Z-l1a_flat_L1A_120.npy'))
vancouver_2025_05_04_L1A_120_labels=np.load(os.path.join(LABELS_DIR, 'vancouver_2025-05-04T19-12-24Z-l1a_flat_L1D_112_MACHI_16end_labels.npy'),allow_pickle=True).item()

In [None]:
yucatan2_2025_02_06_L1A         = np.load(f'{IMAGES_DIR}//yucatan2_2025-02-06T16-01-18Z-l1a_flat_L1A_120.npy')
kemigawa_2024_12_17_L1A         = np.load(f'{IMAGES_DIR}//kemigawa_2024-12-17T01-01-32Z-l1a_flat_L1A_120.npy')
chapala_2025_02_24_L1A          = np.load(f'{IMAGES_DIR}//chapala_2025-02-24T16-52-47Z-l1a_flat_L1A_120.npy')
grizzlybay_2025_01_27_L1A       = np.load(f'{IMAGES_DIR}//grizzlybay_2025-01-27T18-19-56Z-l1a_flat_L1A_120.npy')
victoriaLand_2025_02_07_L1A     = np.load(f'{IMAGES_DIR}//victoriaLand_2025-02-07T20-35-33Z-l1a_flat_L1A_120.npy')
catala_2025_01_28_L1A           = np.load(f'{IMAGES_DIR}//catala_2025-01-28T19-17-32Z-l1a_flat_L1A_120.npy')
khnifiss_2025_02_12_L1A         = np.load(f'{IMAGES_DIR}//khnifiss_2025-02-12T11-05-35Z-l1a_flat_L1A_120.npy')
menindee_2025_02_18_L1A         = np.load(f'{IMAGES_DIR}//menindee_2025-02-18T00-10-42Z-l1a_flat_L1A_120.npy')
tampa_2024_11_12_L1A            = np.load(f'{IMAGES_DIR}//tampa_2024-11-12T15-31-55Z-l1a_flat_L1A_120.npy')
falklandsatlantic_2024_12_18_L1A= np.load(f'{IMAGES_DIR}//falklandsatlantic_2024-12-18T13-25-18Z-l1a_flat_L1A_120.npy')
L1A_dict = {
    'yucatan2_2025_02_06': yucatan2_2025_02_06_L1A,
    'kemigawa_2024_12_17': kemigawa_2024_12_17_L1A,
    'chapala_2025_02_24': chapala_2025_02_24_L1A,
    'grizzlybay_2025_01_27': grizzlybay_2025_01_27_L1A,
    'victoriaLand_2025_02_07': victoriaLand_2025_02_07_L1A,
    'catala_2025_01_28': catala_2025_01_28_L1A,
    'khnifiss_2025_02_12': khnifiss_2025_02_12_L1A,
    'menindee_2025_02_18': menindee_2025_02_18_L1A,
    'tampa_2024_11_12': tampa_2024_11_12_L1A,
    'falklandsatlantic_2024_12_18': falklandsatlantic_2024_12_18_L1A
}

In [None]:
image_data_raw          =combined_10_L1A_120
save_name_raw           ='combined_10_L1A_120'
image_data_corrected    =combined_10_L1D_112_MACHI
save_name_corrected     ='combined_10_L1D_112_MACHI'
image_gt                =combined_10_L1D_112_MACHI_labels

In [None]:
prediction_raw=load_prediction('combined_10_L1A_120_pred.npy')
prediction_corrected=load_prediction('combined_10_L1D_112_MACHI_pred.npy')
predictions_tampa_2024_11_12=load_prediction('tampa_2024_11_12_pred.npy')
predictions_caspiansea1_2025_04_08=load_prediction('caspiansea1_2025_04_08_pred.npy')
prediction_vancouver_2025_05_04_L1A_120=load_prediction('vancouver_2025_05_04_L1A_120_pred.npy')


### HYPSO 2

In [None]:
H2_tree_raw_path        =os.path.join(TREE_DIR,'H2_10img_16end_L1A_120_stabelized_aa_TREE.pkl')
H2_tree_corrected_path  =os.path.join(TREE_DIR,'H2_10img_16end_L1D_112_MACHI_stabelized_aa_TREE.pkl')

In [None]:
H2_10_L1A_120                 = np.load(os.path.join(IMAGES_DIR,'H2_10_L1A_120.npy'))
print(H2_10_L1A_120.shape)
H2_10_L1D_112_MACHI           = np.load(os.path.join(IMAGES_DIR,'H2_10_L1D_112_MACHI.npy'))
print(H2_10_L1D_112_MACHI.shape)
H2_10_L1D_112_MACHI_labels    = np.load(os.path.join(LABELS_DIR,'H2_10_L1D_112_MACHI_16end_labels.npy'), allow_pickle=True).item()
print(H2_10_L1D_112_MACHI_labels.keys())

In [None]:
image_data_raw          =H2_10_L1A_120
save_name_raw           ='H2_10_L1A_120'
image_data_corrected    =H2_10_L1D_112_MACHI
save_name_corrected     ='H2_10_L1D_112_MACHI'
image_gt                =H2_10_L1D_112_MACHI_labels

In [None]:
H2_L1A_files = [
    'yucatan1_2025-04-01_flat_L1A_120_H2.npy',
    'kemigawa_2025-01-22_flat_L1A_120_H2.npy',
    'chapala_2025-03-25_flat_L1A_120_H2.npy',
    'grizzlybay_2025-01-22_flat_L1A_120_H2.npy',
    'victoriaLand_2025-03-16_flat_L1A_120_H2.npy',
    'mjosa_2025-05-12_flat_L1A_120_H2.npy',
    'gobabeb_2025-04-25_flat_L1A_120_H2.npy',
    'menindee_2025-05-09_flat_L1A_120_H2.npy',
    'erie_2025-05-10_flat_L1A_120_H2.npy',
    'falklandsatlantic_2025-03-03_flat_L1A_120_H2.npy'
]

In [None]:
H2_L1D_files = [
    'yucatan1_2025-04-01_flat_L1D_112_H2_MACHI.npy',
    'kemigawa_2025-01-22_flat_L1D_112_H2_MACHI.npy',
    'chapala_2025-03-25_flat_L1D_112_H2_MACHI.npy',
    'grizzlybay_2025-01-22_flat_L1D_112_H2_MACHI.npy',
    'victoriaLand_2025-03-16_flat_L1D_112_H2_MACHI.npy',
    'mjosa_2025-05-12_flat_L1D_112_H2_MACHI.npy',
    'gobabeb_2025-04-25_flat_L1D_112_H2_MACHI.npy',
    'menindee_2025-05-09_flat_L1D_112_H2_MACHI.npy',
    'erie_2025-05-10_flat_L1D_112_H2_MACHI.npy',
    'falklandsatlantic_2025-03-03_flat_L1D_112_H2_MACHI.npy'
]

## TRAINING AND EVALUATION

In [None]:
tree_raw=BinaryDecisionTree(verbose=True)
tree_raw.initialize_tree_structure(image_gt)

In [None]:
tree_corrected=BinaryDecisionTree(verbose=True)
tree_corrected.initialize_tree_structure(image_gt)

In [None]:
tree_raw.train(image_data_raw,image_gt)
evaluation_raw = tree_raw.evaluate(image_data_raw,image_gt)

In [None]:
tree_corrected.train(image_data_corrected,image_gt)
evaluation_corrected = tree_corrected.evaluate(image_data_corrected,image_gt)

In [None]:
tree_raw.save_model_parameters(H2_tree_raw_path)

In [None]:
tree_corrected.save_model_parameters(H2_tree_corrected_path)

## SAVE SVM TO BINARY

In [None]:
# Save models with remapped classes
save_folder = SVM_DIR
mapping = save_tree_svms_to_binary_remapped(tree_raw, save_folder)

# Print the mapping for reference
print("\nClass mapping:")
for orig, remapped in mapping["original_to_remapped"].items():
    print(f"Original class {orig} → Remapped class {remapped}")

# Save the mapping for later use
with open(f"{save_folder}/H2_16_class_mapping.json", "w") as f:
    json.dump(mapping, f, indent=2)

## LOAD MODELS

In [None]:
tree_raw=BinaryDecisionTree(verbose=True)
tree_raw.load_model_parameters(tree_raw_path)
H2_tree_raw=BinaryDecisionTree(verbose=True)
H2_tree_raw.load_model_parameters(H2_tree_raw_path)

In [None]:
tree_corrected=BinaryDecisionTree(verbose=True)
tree_corrected.load_model_parameters(tree_corrected_path)
H2_tree_corrected=BinaryDecisionTree(verbose=True)
H2_tree_corrected.load_model_parameters(H2_tree_corrected_path)

## PREDICT

### HYPSO 1

In [None]:
prediction_raw = tree_raw.predict(image_data_raw)

In [None]:
prediction_corrected = tree_corrected.predict(image_data_corrected)

In [None]:
predictions_tampa_2024_11_12=tree_raw.predict(tampa_2024_11_12_L1A_120)

In [None]:
predictions_caspiansea1_2025_04_08=tree_raw.predict(caspiansea1_2025_04_08_l1A_120)

In [None]:
prediction_vancouver_2025_05_04_L1A_120=tree_raw.predict(vancouver_2025_05_04_L1A_120)

In [None]:
for name, image in L1A_dict.items():
    prediction = tree_raw.predict(image)
    save_prediction(prediction, name)

### HYPSO 2

In [None]:
for file in H2_L1A_files:
    image_path=os.path.join(IMAGES_DIR, file)
    image=np.load(image_path)
    prediction=H2_tree_raw.predict(image)
    save_name = file.replace('.npy', '_16end_pred.npy')
    save_path=os.path.join(PRED_DIR,save_name)
    np.save(save_path, prediction)

### SAVE PRED

In [None]:
save_prediction(prediction_raw, save_name_raw)

In [None]:
save_prediction(prediction_corrected, save_name_corrected)

In [None]:
save_prediction(predictions_tampa_2024_11_12, 'tampa_2024_11_12')

In [None]:
save_prediction(predictions_caspiansea1_2025_04_08, 'caspiansea1_2025_04_08')

In [None]:
save_prediction(prediction_vancouver_2025_05_04_L1A_120, 'vancouver_2025_05_04_L1A_120')

## PLOTS

In [None]:
tree_raw.plot_input_image(vancouver_2025_05_04_L1A_120.reshape(-1,HYPSO_HEIGHT,120))
tree_raw.plot_ground_truth(vancouver_2025_05_04_L1A_120_labels, cmap='viridis')
tree_raw.plot_prediction(prediction_vancouver_2025_05_04_L1A_120, cmap='viridis')

In [None]:
tree_raw.plot_input_image(caspiansea1_2025_04_08_l1A_120.reshape(-1,HYPSO_HEIGHT,120))
tree_raw.plot_ground_truth(caspiansea1_2025_04_08_l1A_120_labels, cmap='viridis')
tree_raw.plot_prediction(predictions_caspiansea1_2025_04_08, cmap='viridis')

In [None]:
tree_raw.plot_input_image(tampa_2024_11_12_L1A_120.reshape(-1,HYPSO_HEIGHT,120))
tree_raw.plot_ground_truth(tampa_2024_11_12_L1A_120_labels, cmap='viridis')
tree_raw.plot_prediction(predictions_tampa_2024_11_12, cmap='viridis')

In [None]:
tree_raw.plot_input_image(image_data_raw.reshape(-1,HYPSO_HEIGHT,120))
tree_raw.plot_ground_truth(image_gt, cmap='viridis')
tree_raw.plot_prediction(prediction_raw, cmap='viridis')

In [None]:
tree_corrected.plot_input_image(image_data_corrected.reshape(-1,HYPSO_HEIGHT,112))
tree_corrected.plot_ground_truth(image_gt, cmap='viridis')
tree_corrected.plot_prediction(prediction_corrected, cmap='viridis')

## DEBUG