## CONFIG

In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import importlib as il
from hypso import Hypso1, Hypso2
from sklearn.svm import LinearSVC
import src.deh as deh
import copy
il.reload(deh)
sys.path.append(os.path.abspath("D:/Hierarchical Unmixing Label"))

In [None]:
HYPSO_HEIGHT=1092
HYPSO_WIDTH=598
HYPSO_BANDS=120

## CLASS AND FUCNTIONS

In [None]:
class BinaryDecisionTree:
    def __init__(self, all_labels, verbose=False):
        self.verbose = verbose
        self.all_labels = all_labels
        self.endmembers = []
        self.splitting_nodes = []
        self.models = {}  # Dictionary to store SVM models for each splitting node
        self.initialize_tree_structure(all_labels.keys())
        
    def initialize_tree_structure(self, all_keys):
        """Initialize the tree structure by identifying endmembers and splitting nodes"""
        print("Initializing Binary Decision Tree structure...")
        self.identify_set_endmembers(all_keys)
        self.identify_set_splitting_nodes(all_keys)
        
    def identify_set_endmembers(self, all_keys):
        max_length = max(len(key) for key in all_keys)
        self.endmembers = []
        
        if self.verbose:
            print("\n" + "="*50)
            print("IDENTIFYING ENDMEMBERS:")
            print("="*50)
            print(f"Maximum key length found: {max_length}")
            print("-"*50)
        
        for key in all_keys:
            if len(key) == max_length:
                self.endmembers.append(key)
                if self.verbose:
                    print(f"Found endmember: '{key}'")
        
        if self.verbose:
            print(f"Total endmembers identified: {len(self.endmembers)}")
        
    def identify_set_splitting_nodes(self, all_keys):
        self.splitting_nodes = []
        
        if self.verbose:
            print("\n" + "="*50)
            print("IDENTIFYING SPLITTING NODES:")
            print("="*50)
        
        for key in all_keys:
            has_zero = key + '0' in all_keys
            has_one = key + '1' in all_keys
            if has_zero and has_one:
                self.splitting_nodes.append(key)
                if self.verbose:
                    print(f"Found splitting node: '{key}' → branches to '{key}0' and '{key}1'")
        
        if self.verbose:
            print("-"*50)
            print(f"Total splitting nodes identified: {len(self.splitting_nodes)}")

    def _preprocess_input(self, X):
        """
        Preprocess input data to ensure it's in the right format (n_samples, n_bands)
        
        Parameters:
        -----------
        X : numpy.ndarray
            Hyperspectral image data. Can be:
            - Flattened array of shape (n_samples, n_bands)
            - Cube of shape (height, width, n_bands) or (n_bands, height, width)
            
        Returns:
        --------
        X_processed : numpy.ndarray
            Processed data of shape (n_samples, n_bands)
        """
        if self.verbose:
            print(f"\nPreprocessing input data with shape: {X.shape}")
            
        # If already 2D with samples as first dimension, return as is
        if len(X.shape) == 2:
            if self.verbose:
                # print(f"Input is already in correct format: {X.shape}")
                print("-"*50)
            return X
        
        # Handle 3D data (cube)
        elif len(X.shape) == 3:
            # Determine which dimension is the spectral dimension
            # Typically, spectral dimension is the smallest
            dims = np.array(X.shape)
            spectral_dim = np.argmin(dims)
            
            if self.verbose:
                print(f"Detected spectral dimension: {spectral_dim}")
            
            if spectral_dim == 0:  # (n_bands, height, width)
                n_bands, height, width = X.shape
                if self.verbose:
                    print(f"Reshaping from (n_bands={n_bands}, height={height}, width={width}) to ({height*width}, {n_bands})")
                    print("-"*50)
                return X.reshape(n_bands, -1).T  # Reshape to (height*width, n_bands)
            
            elif spectral_dim == 2:  # (height, width, n_bands)
                height, width, n_bands = X.shape
                if self.verbose:
                    print(f"Reshaping from (height={height}, width={width}, n_bands={n_bands}) to ({height*width}, {n_bands})")
                    print("-"*50)
                return X.reshape(-1, n_bands)  # Reshape to (height*width, n_bands)
            
            else:  # (height, n_bands, width) - unusual but handle it
                height, n_bands, width = X.shape
                if self.verbose:
                    print(f"Unusual format detected: (height={height}, n_bands={n_bands}, width={width})")
                    print(f"Reshaping to ({height*width}, {n_bands})")
                    print("-"*50)
                return X.transpose(0, 2, 1).reshape(-1, n_bands)  # Reshape to (height*width, n_bands)
        
        
        else:
            if self.verbose:
                print(f"Error: Unsupported input shape: {X.shape}")
            raise ValueError(f"Unsupported input shape: {X.shape}. Expected 2D or 3D array.")
    
    def train(self, X, labels=None):
        """
        Train SVMs for each splitting node
        
        Parameters:
        -----------
        X : numpy.ndarray
            Hyperspectral image data. Can be:
            - Flattened array of shape (n_samples, n_bands)
            - Cube of shape (height, width, n_bands) or (n_bands, height, width)
        labels : dict, optional
            Labels to use for training. If None, uses the labels provided during initialization.
        """
        from sklearn.svm import LinearSVC
        
        # Process input data to ensure it's in the right format (n_samples, n_bands)
        X = self._preprocess_input(X)
        
        # Use provided labels or fall back to the ones from initialization
        training_labels = labels if labels is not None else self.all_labels
        
        # if self.verbose:
        if True:
            print("Training SVMs for each splitting node...")
        
        for node in self.splitting_nodes:
            print("Splitting node:",node)
            # Get labels for this node
            y_parent = training_labels[node]
            
            # Only consider pixels that belong to this node
            mask = y_parent == 1
            X_node = X[mask]
            
            # Get labels for children nodes
            left_child = node + '0'
            right_child = node + '1'
            
            # Create binary labels for SVM (1 for right child, 0 for left child)
            y_train = np.zeros(X_node.shape[0], dtype=int)
            
            # Find indices where right child is 1
            if right_child in training_labels:
                right_mask = training_labels[right_child][mask] == 1
                y_train[right_mask] = 1
            
            # Train LinearSVC
            model = LinearSVC(dual='auto', random_state=42)
            model.fit(X_node, y_train)
            
            # Store the model
            self.models[node] = model
            
            # if self.verbose:
            if True:
                print(f"Trained model for node '{node}'")
    
    def predict(self, X):
        """
        Predict endmember classes for input data
        
        Parameters:
        -----------
        X : numpy.ndarray
            Hyperspectral image data. Can be:
            - Flattened array of shape (n_samples, n_bands)
            - Cube of shape (height, width, n_bands) or (n_bands, height, width)
            
        Returns:
        --------
        predictions : dict
            Dictionary with keys for each endmember and values as binary masks
        """
        # Process input data to ensure it's in the right format (n_samples, n_bands)
        X = self._preprocess_input(X)
        
        n_samples = X.shape[0]
        
        # Initialize predictions with all True for root node
        current_predictions = {
            '': np.ones(n_samples, dtype=bool)
        }
        
        # Process each level of the tree
        for level in range(max(len(node) for node in self.splitting_nodes) + 1):
            # Get nodes at this level
            level_nodes = [node for node in self.splitting_nodes if len(node) == level]
            
            for node in level_nodes:
                # Skip if node doesn't exist in current_predictions
                if node not in current_predictions:
                    continue
                # Skip if no samples belong to this node
                if not np.any(current_predictions[node]):
                    continue
                
                # Get samples that belong to this node
                node_mask = current_predictions[node]
                X_node = X[node_mask]
                
                # Predict using the SVM model
                if len(X_node) > 0:
                    y_pred = self.models[node].predict(X_node)
                    
                    # Create masks for children
                    left_child = node + '0'
                    right_child = node + '1'
                    
                    # Initialize child predictions
                    if left_child not in current_predictions:
                        current_predictions[left_child] = np.zeros(n_samples, dtype=bool)
                    if right_child not in current_predictions:
                        current_predictions[right_child] = np.zeros(n_samples, dtype=bool)
                    
                    # Update predictions for children
                    left_indices = np.where(node_mask)[0][y_pred == 0]
                    right_indices = np.where(node_mask)[0][y_pred == 1]
                    
                    current_predictions[left_child][left_indices] = True
                    current_predictions[right_child][right_indices] = True
        
        # For non-splitting nodes that just have a single child with '0' appended
        for key in self.all_labels.keys():
            if key not in self.splitting_nodes and key != '':
                # Find the parent node
                parent = key[:-1]
                if parent in current_predictions and key not in current_predictions:
                    # If this is a non-splitting child, it inherits parent's prediction
                    current_predictions[key] = current_predictions[parent].copy()
        
        # Extract predictions for endmembers
        endmember_predictions = {}
        for endmember in self.endmembers:
            if endmember in current_predictions:
                endmember_predictions[endmember] = current_predictions[endmember]
            else:
                # If endmember not in predictions, try to find its parent
                parent = endmember[:-1]
                while parent and parent not in current_predictions:
                    parent = parent[:-1]
                if parent:
                    endmember_predictions[endmember] = current_predictions[parent].copy()
                else:
                    endmember_predictions[endmember] = np.zeros(n_samples, dtype=bool)
        
        return endmember_predictions
    
    def evaluate(self, X, gt_labels=None):
        """
        Evaluate the model on test data
        
        Parameters:
        -----------
        X : numpy.ndarray
            Hyperspectral image data. Can be:
            - Flattened array of shape (n_samples, n_bands)
            - Cube of shape (height, width, n_bands) or (n_bands, height, width)
        gt_labels : dict, optional
            Ground truth labels for evaluation. If None, uses self.all_labels
            
        Returns:
        --------
        accuracy : float
            Overall accuracy of endmember predictions
        """
        if gt_labels is None:
            gt_labels = self.all_labels
            
        # if self.verbose:
        if True:
            print("Starting evaluation...")
            
        predictions = self.predict(X)
        
        # Calculate accuracy for endmembers
        correct = 0
        total = 0
        
        if self.verbose:
            print(f"Evaluating accuracy for {len(self.endmembers)} endmembers...")
            
        endmember_accuracies = {}
        for endmember in self.endmembers:
            if endmember in predictions and endmember in gt_labels:
                pred = predictions[endmember]
                true = gt_labels[endmember]
                
                endmember_correct = np.sum(pred == true)
                endmember_total = len(true)
                
                correct += endmember_correct
                total += endmember_total
                
                if self.verbose:
                    endmember_acc = endmember_correct / endmember_total if endmember_total > 0 else 0
                    endmember_accuracies[endmember] = endmember_acc
                    print(f"  Endmember '{endmember}': {endmember_acc:.4f} ({endmember_correct}/{endmember_total})")
        
        accuracy = correct / total if total > 0 else 0
        
        # if self.verbose:
        if True:
            print(f"Overall accuracy: {accuracy:.4f} ({correct}/{total} pixels)")
            
        return accuracy

    def plot_input_image(self, image_data, slice_idx=None, figsize=(15, 5), cmap='viridis', rgb_bands=(69, 46, 26)):
        """Plot input image data as single band, RGB composite, or flat image."""
        plt.figure(figsize=figsize)
        
        # Preprocess input data for more efficient handling
        image_data = np.asarray(image_data)  # Ensure numpy array
        
        # Handle flat images (h*w, n_bands)
        if len(image_data.shape) == 2 and image_data.shape[1] > 3:
            slice_idx = slice_idx if slice_idx is not None else image_data.shape[1] // 2
            band_data = image_data[:, slice_idx]
            
            # Try to reshape to a square-ish image if possible
            side = int(np.sqrt(image_data.shape[0]))
            reshaped_data = band_data.reshape(side, side) if side * side == image_data.shape[0] else band_data.reshape(-1, 1)
            plt.imshow(np.rot90(reshaped_data), cmap=cmap, aspect=0.1)
            plt.colorbar(label='Intensity')
            plt.title(f'Input Image - Band {slice_idx} (Flattened)')
        
        # Handle 3D data (h, w, n_bands)
        elif len(image_data.shape) == 3:
            h, w, n_bands = image_data.shape
            slice_idx = slice_idx if slice_idx is not None else n_bands // 2
            
            # Create RGB composite if bands are specified
            if rgb_bands is not None and len(rgb_bands) == 3:
                r_idx, g_idx, b_idx = rgb_bands
                if max(rgb_bands) > n_bands - 1:
                    raise ValueError(f"RGB band indices {rgb_bands} exceed available bands (0-{n_bands-1})")
                
                # More efficient RGB creation - preallocate and process in one go
                rgb_img = np.zeros((h, w, 3), dtype=np.float32)
                for i, band_idx in enumerate([r_idx, g_idx, b_idx]):
                    band = image_data[:, :, band_idx].astype(np.float32)
                    # Normalize only if needed
                    band_min, band_max = band.min(), band.max()
                    if band_min != band_max:
                        band = (band - band_min) / (band_max - band_min)
                    rgb_img[:, :, i] = band
                
                plt.imshow(np.rot90(rgb_img), aspect=0.1)
                plt.title(f'RGB Composite (R:{r_idx}, G:{g_idx}, B:{b_idx})')
            else:
                plt.imshow(np.rot90(image_data[:, :, slice_idx]), cmap=cmap, aspect=0.1)
                plt.colorbar(label='Intensity')
                plt.title(f'Input Image - Band {slice_idx}')
        
        # Handle other cases
        else:
            plt.imshow(np.rot90(image_data), cmap=cmap, aspect=0.1)
            plt.colorbar(label='Intensity')
            plt.title('Input Image')
        
        plt.axis('on')
        plt.tight_layout()
        plt.show()

    def plot_ground_truth(self, labels=None, key='', figsize=(15, 5), cmap='tab10'):
        """Plot ground truth labels."""
        # Preprocess input
        labels = self.all_labels if labels is None else labels
        shape = labels[''].shape
        reshape = labels[''].reshape(-1, HYPSO_HEIGHT)
        height, width = reshape.shape[0], reshape.shape[1]
        
        # Handle dictionary-type labels
        if isinstance(labels, dict):
            if key == '':
                # Create a combined view of all endmembers
                plt.figure(figsize=figsize)
                first_val = list(labels.values())[0]
                
                if len(first_val.shape) == 1:
                    # Use the dimensions we already determined
                    combined_labels = np.zeros((height, width), dtype=int)
                    endmember_keys = [k for k in labels.keys() if k in self.endmembers]
                    
                    # Process data before visualization
                    for i, k in enumerate(endmember_keys):
                        if k == '':
                            continue
                        try:
                            label_reshaped = labels[k].reshape(height, width)
                            combined_labels[label_reshaped == 1] = i + 1
                        except Exception as e:
                            print(f"Error reshaping key '{k}': {e}")
                            continue
                    
                    # Rotate the image 90 degrees and set aspect ratio
                    plt.imshow(np.rot90(combined_labels), cmap=cmap, aspect=0.1)
                    plt.title('Ground Truth Labels - Combined View')
                else:
                    print("Cannot display labels: unexpected format")
            else:
                # Plot just the specified key
                if key not in labels:
                    print(f"Key '{key}' not found in labels")
                    return
                
                label_data = labels[key]
                if len(label_data.shape) == 1:
                    try:
                        # Use the dimensions we already determined
                        label_data = label_data.reshape(height, width)
                    except Exception as e:
                        print(f"Error reshaping label data: {e}")
                        return
                
                plt.figure(figsize=figsize)
                # Use viridis colormap for the binary data and rotate the image 90 degrees
                plt.imshow(np.rot90(label_data), cmap='viridis', vmin=0, vmax=1, aspect=0.1)
                plt.title(f'Ground Truth Label for "{key}"')
        
        # Handle array-type labels
        elif hasattr(labels, 'shape'):
            # Preprocess array data
            labels = np.asarray(labels)  # Ensure numpy array
            if len(labels.shape) == 1:
                try:
                    # Use the dimensions we already determined
                    labels = labels.reshape(height, width)
                except Exception as e:
                    print(f"Error reshaping label data: {e}")
                    return
            
            plt.figure(figsize=figsize)
            # Rotate the image 90 degrees and set aspect ratio
            plt.imshow(np.rot90(labels), cmap=cmap, aspect=0.1)
            plt.title('Ground Truth Labels')
        else:
            print("Cannot plot labels: unsupported format")
            return
        
        plt.axis('on')
        plt.tight_layout()
        plt.show()

    def plot_prediction(self, prediction=None, key='', figsize=(15, 5), cmap='tab10'):
        """Plot prediction results."""
        # Preprocess input
        if prediction is None:
            if not hasattr(self, 'last_prediction'):
                print("No prediction available. Run predict() first.")
                return
            prediction = self.last_prediction
        
        # Determine dimensions from the data
        first_key = next(iter(prediction))
        shape = prediction[first_key].shape
        reshape = prediction[first_key].reshape(-1, HYPSO_HEIGHT)
        height, width = reshape.shape[0], reshape.shape[1]
        
        # Handle dictionary-type predictions
        if isinstance(prediction, dict):
            if key == '':
                plt.figure(figsize=figsize)
                first_val = list(prediction.values())[0]
                
                if len(first_val.shape) == 1:
                    combined_pred = np.zeros((height, width), dtype=int)
                    endmember_keys = [k for k in prediction.keys() if k in self.endmembers]
                    
                    # Process data before visualization
                    for i, k in enumerate(endmember_keys):
                        if k == '':
                            continue
                        try:
                            pred_reshaped = prediction[k].reshape(height, width)
                            combined_pred[pred_reshaped == True] = i + 1
                        except Exception as e:
                            print(f"Error reshaping key '{k}': {e}")
                            continue
                    
                    # Rotate the image 90 degrees and set aspect ratio
                    plt.imshow(np.rot90(combined_pred), cmap=cmap, aspect=0.1)
                    plt.title('Predictions - Combined View')
                else:
                    print("Cannot display predictions: unexpected format")
            else:
                if key not in prediction:
                    print(f"Key '{key}' not found in predictions")
                    return
                
                pred_data = prediction[key]
                if len(pred_data.shape) == 1:
                    try:
                        # Use the dimensions we already determined
                        pred_data = pred_data.reshape(height, width)
                    except Exception as e:
                        print(f"Error reshaping prediction data: {e}")
                        return
                
                plt.figure(figsize=figsize)
                # Use viridis colormap for the binary data, rotate 90 degrees and set aspect ratio
                plt.imshow(np.rot90(pred_data), cmap='viridis', vmin=0, vmax=1, aspect=0.1)
                plt.title(f'Prediction for key "{key}"')
        
        # Handle array-type predictions
        elif hasattr(prediction, 'shape'):
            # Preprocess array data
            prediction = np.asarray(prediction)  # Ensure numpy array
            if len(prediction.shape) == 1:
                try:
                    # Use the dimensions we already determined
                    prediction = prediction.reshape(height, width)
                except Exception as e:
                    print(f"Error reshaping prediction data: {e}")
                    return
            
            plt.figure(figsize=figsize)
            # Rotate the image 90 degrees and set aspect ratio
            plt.imshow(np.rot90(prediction), cmap=cmap, aspect=0.1)
            plt.title('Prediction')
        else:
            print("Cannot plot prediction: unsupported format")
            return
        
        plt.axis('on')
        plt.tight_layout()
        plt.show()

    def save_model_parameters(self, filename='decision_tree_model.pkl'):
        """
        Save the trained SVM parameters of the decision tree to a file.
        
        Parameters:
        -----------
        filename : str
            The name of the file to save the model parameters to.
        """
        import pickle
        import os
        
        if self.verbose:
            print("Preparing to save model parameters...")
            # for node, model in self.models.items():
            #     if hasattr(model, 'coef_'):
            #         print(f"Node {node} SVM weights shape: {model.coef_.shape}")
            #         print(f"Node {node} SVM weights: {model.coef_}")
            #     if hasattr(model, 'intercept_'):
            #         print(f"Node {node} SVM intercept: {model.intercept_}")
        
        # Create a dictionary to store all model data
        model_data = {
            'models': self.models,  # This is your dictionary of SVM models
            'endmembers': self.endmembers,
            'splitting_nodes': self.splitting_nodes,
            'all_labels': self.all_labels  # Save the labels too
        }
        
        # Ensure weights folder exists
        weights_folder = 'weights'
        os.makedirs(weights_folder, exist_ok=True)
        
        # Create full path to save file in weights folder
        filepath = os.path.join(weights_folder, filename)
        
        try:
            if self.verbose:
                print(f"Saving model to {filepath}...")
            with open(filepath, 'wb') as f:
                pickle.dump(model_data, f)
            print(f"Model parameters successfully saved to {filepath}")
        except Exception as e:
            print(f"Error saving model parameters: {e}")

    def load_model_parameters(self, filename='decision_tree_model.pkl'):
        """
        Load the trained SVM parameters for the decision tree from a file.
        
        Parameters:
        -----------
        filename : str
            The name of the file to load the model parameters from.
        
        Returns:
        --------
        bool
            True if loading was successful, False otherwise.
        """
        import pickle
        import os
        
        if self.verbose:
            print(f"Attempting to load model parameters from {filename}...")
        
        # Create full path to load file from weights folder
        filepath = os.path.join('weights', filename)
        
        if not os.path.exists(filepath):
            print(f"Error: Model file {filepath} not found")
            return False
        
        try:
            if self.verbose:
                print("Reading model file...")
            with open(filepath, 'rb') as f:
                model_data = pickle.load(f)
            
            if self.verbose:
                print("Restoring model components...")
            
            # Restore the model components
            self.models = model_data['models']
            self.endmembers = model_data['endmembers']
            self.splitting_nodes = model_data['splitting_nodes']
            
            if self.verbose:
                print("Model components restored. SVM model details:")
                for node, model in self.models.items():
                    if hasattr(model, 'coef_'):
                        print(f"Node {node} SVM weights shape: {model.coef_.shape}")
                        print(f"Node {node} SVM weights: {model.coef_}")
                    if hasattr(model, 'intercept_'):
                        print(f"Node {node} SVM intercept: {model.intercept_}")
            
            # Optionally restore labels if they were saved
            if 'all_labels' in model_data:
                self.all_labels = model_data['all_labels']
                if self.verbose:
                    print("Labels restored from saved model")
            
            print(f"Model parameters successfully loaded from {filepath}")
            return True
        except Exception as e:
            print(f"Error loading model parameters: {e}")
            return False

In [None]:
def predict_pipeline(image_path,tree_path, deh_path, hypso=1, verbose=False):
    # Check if the image path is a .npy or .nc file
    if image_path.endswith('.npy'):
        # Load numpy array directly
        
        image_data = np.load(image_path)
        if image_data.ndim == 3:
            image_data_flat = image_data.reshape(-1,112)
    elif image_path.endswith('.nc'):
        # Use hypso to load .nc file
        try:
            image_data= nc_to_image(image_path, hypso, verbose, machi=True)
        except ImportError:
            print("Error: hypso package is required to process .nc files")
            return None
        except Exception as e:
            print(f"Error processing .nc file: {e}")
            return None
    else:
        print(f"Unsupported file format: {image_path}")
        print("Supported formats: .npy, .nc")
        return None
    
    # Print information about the image data values
    print(f"Image data shape: {image_data.shape}")
    print(f"Image data type: {image_data.dtype}")
    print(f"Min value: {np.min(image_data)}")
    print(f"Max value: {np.max(image_data)}")
    print(f"Mean value: {np.mean(image_data)}")
    
    #create label form deh
    image_gt = generate_huH_labels(image_data, deh_path, verbose=verbose)

    # Load the decision tree model
    tree = BinaryDecisionTree(image_gt,verbose=verbose)
    if not tree.load_model_parameters(tree_path):
        print("Failed to load tree model")
        return None
    
    
    
    # Make predictions using the tree
    predictions = tree.predict(image_data)
    evaluation = tree.evaluate(image_data,image_gt)
    print(evaluation)
    
    tree.plot_input_image(image_data.reshape(-1,1092,112))
    tree.plot_ground_truth(image_gt)
    tree.plot_prediction(predictions)
    # Return the predictions
    return predictions

def nc_to_image(image_path, hypso=1, verbose=False, machi=False):
    if hypso==2:
        satobj = Hypso2(path=image_path, verbose=verbose)
    else:
        satobj = Hypso1(path=image_path, verbose=verbose)
    # Reshape if needed (depends on hypso output format)
    satobj.generate_l1b_cube()
    data = satobj.l1b_cube
    #satobj.generate_l1c_cube()
    #satobj.generate_l1d_cube()
    image_data = np.array(data)[:,:,6:118]
    image_data_flat=image_data.reshape(-1,112)

    return image_data_flat

def generate_huH_labels(image_file, deh_path, verbose=False):
    DEH_model=deh.DEH(no_negative_residuals=True)
    DEH_model.load(deh_path)
    image_file_cube=image_file.reshape(-1,1092,112)
    DEH_model.plot_size=(image_file_cube[0], image_file_cube[1])
    DEH_model.simple_predict(image_file)
    DEH_model.binarize_lmdas()
    DEH_model.lmda_2_map()
    labels_dict={}
    for node in DEH_model.nodes:
        labels = DEH_model.nodes[node].map.flatten()
        labels_dict[node] = labels
    return labels_dict

## LOAD DATA

### LOAD L1D MACHI

In [None]:
combined_10_L1A_120 = np.load('D:\Hierarchical Unmixing Label\hUH\images\combined_10_L1A_120.npy')
combined_10_L1D_112_MACHI = np.load('D:\Hierarchical Unmixing Label\hUH\images\combined_10_L1D_112_MACHI.npy')
combined_10_L1D_112_MACHI_labels= np.load('save\L1D_112_MACHI_10img_8end_stabelized_aa_labels.npy', allow_pickle=True).item()

In [None]:
tree=BinaryDecisionTree(combined_10_L1D_112_MACHI_labels,verbose=False)
tree.load_model_parameters(r"combined_10_L1A_120_8end_TREE.pkl")

In [None]:
tree_corrected=BinaryDecisionTree(combined_10_L1D_112_MACHI_labels,verbose=False)
tree_corrected.load_model_parameters(r"combined_10_L1D_112_MACHI_8end_TREE.pkl")

In [None]:
tampa_2024_11_12_L1A_120=np.load(r'images\tampa_2024-11-12T15-31-55Z-l1a_cm_L1A_120.npy')
tampa_2024_11_12_L1A_120_labels=np.load(r'save\tampa_2024_11_12_L1D_112_labels.npy',allow_pickle=True).item()

### LOAD L1B MACHI DATA

In [None]:
combined_10_v2_MACHI = np.load('D:\Hierarchical Unmixing Label\hUH\images\combined_10_v2_MACHI.npy')
combined_MACHI_labels = np.load(r'D:\Hierarchical Unmixing Label\hUH\save\combined_MACHI_binary_labels.npy', allow_pickle=True).item()

machi_tampa_2024_11_12          = np.load(r'D:\Hierarchical Unmixing Label\hUH\images\tampa_2024-11-12T15-31-55Z-l1a_cm_machi.npy')
machi_tampa_2024_11_12_labels   = np.load(r'D:\Hierarchical Unmixing Label\hUH\save\machi_tampa_2024_11_12_binary_labels.npy', allow_pickle=True).item()

In [None]:
tree=BinaryDecisionTree(combined_MACHI_labels,verbose=False)
tree.load_model_parameters(r"D:\Hierarchical Unmixing Label\hUH\weights\tree_MACHI.pkl")

### LOAD L1B DATA

In [None]:
L1A_data = np.load(r'D:\Hierarchical Unmixing Label\hUH\images\aregantsea2_2025-03-11T08-12-43Z-l1a_cm_L1A.npy')
L1A_labels = np.load(r'save\L1A_1img_8end_stabelized_ppa_FINAL_labels.npy', allow_pickle=True).item()
L1B_data = np.load(r'D:\Hierarchical Unmixing Label\hUH\images\aregantsea2_2025-03-11T08-12-43Z-l1a_cm_L1B.npy')
L1B_labels = np.load(r'save\L1B_1img_8end_stabelized_ppa_labels.npy', allow_pickle=True).item()
L1A_data_120=np.load(r"D:\Hierarchical Unmixing Label\hUH\images\aregantsea2_2025-03-11T08-12-43Z-l1a_cm_L1A_120.npy")

## TRAINING AND EVALUATION

In [None]:
image_data=combined_10_L1A_120
image_data_corrected=combined_10_L1D_112_MACHI
image_gt=combined_10_L1D_112_MACHI_labels
save_path=r"D:\Hierarchical Unmixing Label\hUH\weights\combined_10_L1A_120_8end_TREE.pkl"
save_path_corrected=r"D:\Hierarchical Unmixing Label\hUH\weights\combined_10_L1D_112_MACHI_8end_TREE.pkl"

In [None]:
tree=BinaryDecisionTree(image_gt,verbose=True)

In [None]:
tree_corrected=BinaryDecisionTree(image_gt,verbose=True)

In [None]:
tree.train(image_data,image_gt)
evaluation = tree.evaluate(image_data,image_gt)

In [None]:
tree_corrected.train(image_data_corrected,image_gt)
evaluation_corrected = tree_corrected.evaluate(image_data_corrected,image_gt)

In [None]:
tree.save_model_parameters(save_path)

In [None]:
tree_corrected.save_model_parameters(save_path_corrected)

## LOAD MODELS

In [None]:
save_path=r"D:\Hierarchical Unmixing Label\hUH\weights\combined_10_L1A_120_8end_TREE.pkl"
save_path_corrected=r"D:\Hierarchical Unmixing Label\hUH\weights\combined_10_L1D_112_MACHI_8end_TREE.pkl"
tree=BinaryDecisionTree(image_gt,verbose=True)
tree.load_model_parameters(save_path)
tree_corrected=BinaryDecisionTree(image_gt,verbose=True)
tree_corrected.load_model_parameters(save_path_corrected)

## PREDICT

In [None]:
prediction = tree.predict(image_data)

In [None]:
prediction_corrected = tree_corrected.predict(image_data_corrected)

In [None]:
predictions_tampa_2024_11_12=tree.predict(tampa_2024_11_12_L1A_120)

## PLOTS

In [None]:
tree.plot_input_image(tampa_2024_11_12_L1A_120.reshape(-1,HYPSO_HEIGHT,120))
tree.plot_ground_truth(tampa_2024_11_12_L1A_120_labels, cmap='viridis')
tree.plot_prediction(predictions_tampa_2024_11_12, cmap='viridis')

In [None]:
tree.plot_input_image(image_data.reshape(-1,HYPSO_HEIGHT,120))
tree.plot_ground_truth(image_gt, cmap='viridis')
tree.plot_prediction(prediction, cmap='viridis')

In [None]:
tree_corrected.plot_input_image(image_data_corrected.reshape(-1,HYPSO_HEIGHT,112))
tree_corrected.plot_ground_truth(image_gt, cmap='viridis')
tree_corrected.plot_prediction(prediction_corrected, cmap='viridis')

## TEST

### TEST BINARY

In [None]:
def save_tree_svms_to_binary_correct(tree, output_dir):
    """
    Save all SVM models in the decision tree as binary files.
    
    Parameters:
    -----------
    tree : BinaryDecisionTree
        The decision tree containing trained SVM models
    output_dir : str
        Directory where the binary files will be saved
    
    Output format for each file:
    ---------------------------
    - class1 (uint8)
    - class2 (uint8)
    - intercept (float32)
    - weights (array of float32)
    
    Filename format: lsm{left_min_class}{right_min_class}
    """
    import os
    import struct
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all splitting nodes with trained models
    splitting_nodes = tree.splitting_nodes
    
    print(f"Found {len(splitting_nodes)} splitting nodes in the tree")
    print(f"Found {len(tree.endmembers)} endmembers in the tree")
    print(f"Saving SVM models to {output_dir}")
    
    # Map each binary path to the corresponding endmember decimal value
    path_to_value = {}
    for i, endmember_label in enumerate(tree.endmembers):
        if endmember_label:  # Skip empty endmembers if any
            # Convert binary string to decimal
            em_value = int(endmember_label, 2)
            path_to_value[endmember_label] = em_value
    
    # Count of successfully saved models
    saved_count = 0
    
    for node in splitting_nodes:
        # Check if this node has an SVM model
        if node not in tree.models:
            print(f"Warning: Node '{node}' is marked as a splitting node but has no SVM model")
            continue
        
        # Get the SVM model for this node
        svm_model = tree.models[node]
        
        # Find all endmembers that would go to the left branch (node+'0')
        left_values = []
        # Find all endmembers that would go to the right branch (node+'1')
        right_values = []
        
        for path, value in path_to_value.items():
            # Only process endmembers whose paths are long enough to be affected by this node
            if len(path) > len(node):
                # Check if this path goes through the current node
                if path.startswith(node):
                    # This is a descendent of the current node, check which branch
                    next_bit = path[len(node)]
                    if next_bit == '0':
                        left_values.append(value)
                    else:
                        right_values.append(value)
        
        # Find lowest class in each branch (by decimal value)
        left_min = min(left_values) if left_values else 0
        right_min = min(right_values) if right_values else 0
        
        # Format the filename
        svm_model_name = f"lsm{left_min:02d}{right_min:02d}"
        filepath = os.path.join(output_dir, svm_model_name)
        
        print(f"Node {node}: Splitting between left branch (min={left_min}) and right branch (min={right_min})")
        print(f"  Left contains values: {sorted(left_values)}")
        print(f"  Right contains values: {sorted(right_values)}")
        
        # Open the file for binary writing
        with open(filepath, "wb") as file:
            # Write class labels as uint8
            file.write(struct.pack("BB", left_min, right_min))
            
            # Write intercept as float32
            intercept = svm_model.intercept_[0]
            file.write(struct.pack('f', intercept))
            
            # Write weights as float32
            weights = svm_model.coef_[0]
            file.write(struct.pack(f'{len(weights)}f', *weights))
        
        print(f"Saved model for node '{node}' to {svm_model_name}")
        saved_count += 1
    
    print(f"Successfully saved {saved_count} SVM models to {output_dir}")
    return True

def print_tree_structure(tree):
    """
    Print the structure of the decision tree to help understand the node hierarchy.
    
    Parameters:
    -----------
    tree : BinaryDecisionTree
        The decision tree to analyze
    """
    # Get splitting nodes
    splitting_nodes = sorted(list(tree.splitting_nodes))
    
    # Infer all nodes by adding children of splitting nodes
    all_nodes = set(splitting_nodes)
    for node in splitting_nodes:
        all_nodes.add(node + '0')  # Left child
        all_nodes.add(node + '1')  # Right child
    all_nodes = sorted(list(all_nodes))
    
    # Identify leaf nodes (nodes that are not splitting nodes)
    leaf_nodes = [node for node in all_nodes if node not in splitting_nodes]
    
    print("Tree Structure:")
    print(f"Total nodes: {len(all_nodes)}")
    print(f"Splitting nodes: {len(splitting_nodes)}")
    print(f"Leaf nodes: {len(leaf_nodes)}")
    
    # Determine tree depth
    max_depth = max([len(node) for node in all_nodes]) if all_nodes else 0
    print(f"Maximum tree depth: {max_depth}")
    
    # Print node hierarchy
    print("\nNode Hierarchy:")
    for depth in range(max_depth + 1):
        nodes_at_depth = [node for node in all_nodes if len(node) == depth]
        split_status = ['(split)' if node in splitting_nodes else '(leaf)' for node in nodes_at_depth]
        print(f"Depth {depth}: {list(zip(nodes_at_depth, split_status))}")
    
    # Print SVM models
    print("\nSVM Models:")
    for node in splitting_nodes:
        if node in tree.models:
            model = tree.models[node]
            left_child = node + '0'
            right_child = node + '1'
            print(f"Node '{node}' splits between '{left_child}' and '{right_child}'")
            print(f"  - Weights shape: {model.coef_.shape}")
            print(f"  - Intercept: {model.intercept_}")
        else:
            print(f"Node '{node}' is marked as splitting but has no model")
    
    # Print endmembers if available
    if hasattr(tree, 'endmembers') and tree.endmembers:
        print("\nEndmembers:")
        for i, endmember in enumerate(tree.endmembers):
            print(f"  {i}: {endmember}")
    
    return max_depth

def debug_svm_models(tree):
    """
    Debug the SVM models in the tree to identify issues with weights.
    
    Parameters:
    -----------
    tree : BinaryDecisionTree
        The decision tree containing trained SVM models
        
    Returns:
    --------
    dict
        Dictionary with model analysis results
    """
    import numpy as np
    
    results = {}
    splitting_nodes = sorted(list(tree.splitting_nodes))
    
    print(f"Analyzing {len(splitting_nodes)} SVM models...")
    
    for node in splitting_nodes:
        if node not in tree.models:
            print(f"Warning: Node '{node}' is marked as a splitting node but has no SVM model")
            continue
        
        # Get the SVM model for this node
        svm_model = tree.models[node]
        
        model_info = {
            'model_type': str(type(svm_model)),
            'has_coef': hasattr(svm_model, 'coef_'),
            'has_intercept': hasattr(svm_model, 'intercept_'),
            'attrs': dir(svm_model)
        }
        
        if hasattr(svm_model, 'coef_'):
            weights = svm_model.coef_[0]
            model_info['weights_shape'] = weights.shape
            model_info['weights_min'] = float(np.min(weights))
            model_info['weights_max'] = float(np.max(weights))
            model_info['weights_mean'] = float(np.mean(weights))
            model_info['weights_nonzero'] = int(np.sum(np.abs(weights) > 1e-10))
            model_info['weights_sample'] = [float(w) for w in weights[:5]]  # First 5 weights
        
        if hasattr(svm_model, 'intercept_'):
            model_info['intercept'] = float(svm_model.intercept_[0])
        
        results[node] = model_info
        
        # Print key info about the model
        print(f"\nNode '{node}':")
        if hasattr(svm_model, 'coef_'):
            print(f"  Weights: shape={weights.shape}, non-zero={model_info['weights_nonzero']}/{len(weights)}")
            print(f"  Weight stats: min={model_info['weights_min']:.6f}, max={model_info['weights_max']:.6f}")
            if model_info['weights_nonzero'] == 0:
                print("  WARNING: All weights are zero!")
        else:
            print("  WARNING: Model does not have coef_ attribute!")
            
        if hasattr(svm_model, 'intercept_'):
            print(f"  Intercept: {model_info['intercept']:.6f}")
        else:
            print("  WARNING: Model does not have intercept_ attribute!")
    
    return results

def save_tree_svms_to_binary_fixed(tree, output_dir):
    """
    Save all SVM models in the decision tree as binary files, fixing the weight issue.
    
    Parameters:
    -----------
    tree : BinaryDecisionTree
        The decision tree containing trained SVM models
    output_dir : str
        Directory where the binary files will be saved
    """
    import os
    import struct
    import numpy as np
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all splitting nodes with trained models
    splitting_nodes = tree.splitting_nodes
    print(f"Found {len(splitting_nodes)} splitting nodes in the tree")
    
    # Map each binary path to the corresponding endmember decimal value
    path_to_value = {}
    for i, endmember_label in enumerate(tree.endmembers):
        if endmember_label:  # Skip empty endmembers if any
            # Convert binary string to decimal
            em_value = int(endmember_label, 2)
            path_to_value[endmember_label] = em_value
    
    # Count of successfully saved models
    saved_count = 0
    
    for node in splitting_nodes:
        # Check if this node has an SVM model
        if node not in tree.models:
            print(f"Warning: Node '{node}' is marked as a splitting node but has no SVM model")
            continue
        
        # Get the SVM model for this node
        svm_model = tree.models[node]
        
        # Find all endmembers that would go to the left branch (node+'0')
        left_values = []
        # Find all endmembers that would go to the right branch (node+'1')
        right_values = []
        
        for path, value in path_to_value.items():
            # Only process endmembers whose paths are long enough to be affected by this node
            if len(path) > len(node):
                # Check if this path goes through the current node
                if path.startswith(node):
                    # This is a descendent of the current node, check which branch
                    next_bit = path[len(node)]
                    if next_bit == '0':
                        left_values.append(value)
                    else:
                        right_values.append(value)
        
        # Find lowest class in each branch (by decimal value)
        left_min = min(left_values) if left_values else 0
        right_min = min(right_values) if right_values else 0
        
        # Format the filename
        svm_model_name = f"lsm{left_min:02d}{right_min:02d}"
        filepath = os.path.join(output_dir, svm_model_name)
        
        print(f"Node {node}: Splitting between left branch (min={left_min}) and right branch (min={right_min})")
        
        # Debug: Check if weights are non-zero before saving
        if hasattr(svm_model, 'coef_'):
            weights = svm_model.coef_[0]
            non_zero_count = np.sum(np.abs(weights) > 1e-10)
            if non_zero_count == 0:
                print(f"  WARNING: All weights are zero for node {node}!")
                
                # Try to access the SVM model's raw data directly
                if hasattr(svm_model, '_impl') and hasattr(svm_model._impl, 'raw_coef_'):
                    print("  Attempting to use _impl.raw_coef_ instead...")
                    weights = svm_model._impl.raw_coef_
                else:
                    print("  Could not find alternative weight source")
        
        # Open the file for binary writing
        with open(filepath, "wb") as file:
            # Write class labels as uint8
            file.write(struct.pack("BB", left_min, right_min))
            
            # Write intercept as float32
            intercept = svm_model.intercept_[0]
            file.write(struct.pack('f', intercept))
            
            # Write weights as float32
            if hasattr(svm_model, 'coef_'):
                weights = svm_model.coef_[0]
                file.write(struct.pack(f'{len(weights)}f', *weights))
                print(f"  Saved {len(weights)} weights (non-zero: {non_zero_count})")
            else:
                print("  WARNING: Model does not have coef_ attribute!")
        
        print(f"  Saved model to {svm_model_name}")
        saved_count += 1
    
    print(f"Successfully saved {saved_count} SVM models to {output_dir}")
    return True

def analyze_svm_file(file_path):
    """
    Analyze the structure of an SVM binary file without assuming its format.
    
    Parameters:
    -----------
    file_path : str
        Path to the binary file
        
    Returns:
    --------
    dict
        Dictionary with file analysis results
    """
    import os
    import struct
    
    result = {
        'file_name': os.path.basename(file_path),
        'file_size': 0,
        'header': None,
        'structure': None,
        'error': None
    }
    
    try:
        # Get file size
        result['file_size'] = os.path.getsize(file_path)
        
        # Read the file
        with open(file_path, 'rb') as f:
            data = f.read()
        
        # Analyze the first few bytes
        if len(data) >= 2:
            result['header'] = struct.unpack('BB', data[:2])
        
        # Calculate how many floats could fit after the header
        if len(data) > 2:
            remaining_bytes = len(data) - 2
            possible_float_count = remaining_bytes // 4
            result['structure'] = {
                'header_size': 2,
                'remaining_bytes': remaining_bytes,
                'possible_float_count': possible_float_count
            }
            
            # Try to read the intercept if there's at least one float
            if possible_float_count >= 1:
                result['intercept'] = struct.unpack('f', data[2:6])[0]
                
            # Attempt to read all the weights
            if possible_float_count > 1:
                weights_fmt = f'{possible_float_count-1}f'
                result['weights'] = struct.unpack(weights_fmt, data[6:])
                result['weight_count'] = possible_float_count - 1
    
    except Exception as e:
        result['error'] = str(e)
    
    return result

def compare_svm_files_flexible(our_dir, original_dir):
    """
    Compare SVM binary files between two directories without assuming their format.
    
    Parameters:
    -----------
    our_dir : str
        Directory containing our generated SVM files
    original_dir : str
        Directory containing the original SVM files
        
    Returns:
    --------
    dict
        Dictionary with comparison results
    """
    import os
    import numpy as np
    
    # Get list of files in both directories
    our_files = set(os.listdir(our_dir))
    original_files = set(os.listdir(original_dir))
    
    # Find common files
    common_files = our_files.intersection(original_files)
    our_unique = our_files - original_files
    original_unique = original_files - original_files
    
    print(f"Found {len(common_files)} common files")
    print(f"{len(our_unique)} files only in our directory")
    print(f"{len(original_unique)} files only in original directory")
    
    results = {
        'file_analysis': {},
        'size_comparison': {},
        'structure_comparison': {}
    }
    
    # Compare each common file
    for filename in sorted(common_files):
        our_path = os.path.join(our_dir, filename)
        original_path = os.path.join(original_dir, filename)
        
        # Analyze both files
        our_analysis = analyze_svm_file(our_path)
        original_analysis = analyze_svm_file(original_path)
        
        # Store analysis results
        results['file_analysis'][filename] = {
            'our': our_analysis,
            'original': original_analysis
        }
        
        # Compare file sizes
        results['size_comparison'][filename] = {
            'our_size': our_analysis['file_size'],
            'original_size': original_analysis['file_size'],
            'size_match': our_analysis['file_size'] == original_analysis['file_size']
        }
        
        # Compare structure if both files were successfully analyzed
        if not our_analysis.get('error') and not original_analysis.get('error'):
            results['structure_comparison'][filename] = {
                'header_match': our_analysis['header'] == original_analysis['header'],
                'our_weight_count': our_analysis.get('weight_count'),
                'original_weight_count': original_analysis.get('weight_count')
            }
            
            # If intercepts exist, compare them
            if 'intercept' in our_analysis and 'intercept' in original_analysis:
                intercept_diff = abs(our_analysis['intercept'] - original_analysis['intercept'])
                results['structure_comparison'][filename]['intercept_diff'] = intercept_diff
                results['structure_comparison'][filename]['intercept_match'] = intercept_diff < 1e-5
    
    # Print summary
    print("\nFile Size Comparison:")
    for filename, comparison in results['size_comparison'].items():
        print(f"{filename}: {'✓' if comparison['size_match'] else '✗'} " +
              f"(Our: {comparison['our_size']} bytes, Original: {comparison['original_size']} bytes)")
    
    # Print structure details
    print("\nFile Structure Details:")
    for filename, analysis in results['file_analysis'].items():
        our = analysis['our']
        orig = analysis['original']
        
        print(f"\n{filename}:")
        print(f"  Our file: {our['file_size']} bytes")
        if our['header']:
            print(f"    Header: {our['header']}")
        if 'weight_count' in our:
            print(f"    Weight count: {our['weight_count']}")
            
        print(f"  Original file: {orig['file_size']} bytes")
        if orig['header']:
            print(f"    Header: {orig['header']}")
        if 'weight_count' in orig:
            print(f"    Weight count: {orig['weight_count']}")
        
        if our.get('error') or orig.get('error'):
            print(f"  Errors: Our: {our.get('error')}, Original: {orig.get('error')}")
    
    return results

def compare_lsm_file_values(our_file, original_file):
    """
    Compare the actual values inside two LSM binary files.
    
    Parameters:
    -----------
    our_file : str
        Path to our generated LSM file
    original_file : str
        Path to the original LSM file
        
    Returns:
    --------
    dict
        Dictionary with detailed comparison results
    """
    import os
    import struct
    import numpy as np
    import matplotlib.pyplot as plt
    from collections import OrderedDict
    
    result = OrderedDict()
    
    # Basic file info
    result['our_file'] = os.path.basename(our_file)
    result['original_file'] = os.path.basename(original_file)
    result['our_size'] = os.path.getsize(our_file)
    result['original_size'] = os.path.getsize(original_file)
    
    # Read both files
    with open(our_file, 'rb') as f:
        our_data = f.read()
    
    with open(original_file, 'rb') as f:
        orig_data = f.read()
    
    print(f"Our file size: {len(our_data)} bytes")
    print(f"Original file size: {len(orig_data)} bytes")
    
    # Extract and compare header values (class labels)
    try:
        our_class1, our_class2 = struct.unpack('BB', our_data[:2])
        result['our_classes'] = (our_class1, our_class2)
    except Exception as e:
        result['our_classes_error'] = str(e)
    
    try:
        orig_class1, orig_class2 = struct.unpack('BB', orig_data[:2])
        result['orig_classes'] = (orig_class1, orig_class2)
    except Exception as e:
        result['orig_classes_error'] = str(e)
    
    # Check if classes match
    if 'our_classes' in result and 'orig_classes' in result:
        result['classes_match'] = result['our_classes'] == result['orig_classes']
    
    # Extract and compare intercept
    try:
        our_intercept = struct.unpack('f', our_data[2:6])[0]
        result['our_intercept'] = our_intercept
    except Exception as e:
        result['our_intercept_error'] = str(e)
    
    try:
        orig_intercept = struct.unpack('f', orig_data[2:6])[0]
        result['orig_intercept'] = orig_intercept
    except Exception as e:
        result['orig_intercept_error'] = str(e)
    
    # Calculate intercept difference if both exist
    if 'our_intercept' in result and 'orig_intercept' in result:
        result['intercept_diff'] = result['our_intercept'] - result['orig_intercept']
        result['intercept_match'] = abs(result['intercept_diff']) < 1e-5
    
    # Extract and compare weights - with careful error handling
    our_weights = []
    orig_weights = []
    
    try:
        # Determine how many weights we can read from our file
        our_weight_count = (len(our_data) - 6) // 4
        print(f"Our file should have {our_weight_count} weights")
        
        if our_weight_count > 0:
            # Try to unpack the weights
            weights_fmt = f'{our_weight_count}f'
            try:
                our_weights = np.array(struct.unpack(weights_fmt, our_data[6:]))
                print(f"Successfully unpacked {len(our_weights)} weights from our file")
                result['our_weight_count'] = our_weight_count
                result['our_weights_min'] = float(np.min(our_weights))
                result['our_weights_max'] = float(np.max(our_weights))
                result['our_weights_mean'] = float(np.mean(our_weights))
                result['our_weights_nonzero'] = int(np.sum(np.abs(our_weights) > 1e-10))
                print(f"Our weights: min={result['our_weights_min']}, max={result['our_weights_max']}")
                print(f"Non-zero weights: {result['our_weights_nonzero']}/{len(our_weights)}")
            except Exception as e:
                print(f"Error unpacking our weights: {str(e)}")
                # Try a smaller number if full unpack fails
                for i in range(our_weight_count, 0, -10):
                    try:
                        smaller_fmt = f'{i}f'
                        our_weights = np.array(struct.unpack(smaller_fmt, our_data[6:6+i*4]))
                        print(f"Successfully unpacked {i} weights")
                        break
                    except Exception:
                        continue
    except Exception as e:
        result['our_weights_error'] = str(e)
        print(f"Error analyzing our weights: {str(e)}")
    
    try:
        # Determine how many weights we can read from original file
        orig_weight_count = (len(orig_data) - 6) // 4
        print(f"Original file should have {orig_weight_count} weights")
        
        if orig_weight_count > 0:
            # Try to unpack the weights
            weights_fmt = f'{orig_weight_count}f'
            try:
                orig_weights = np.array(struct.unpack(weights_fmt, orig_data[6:]))
                print(f"Successfully unpacked {len(orig_weights)} weights from original file")
                result['orig_weight_count'] = orig_weight_count
                result['orig_weights_min'] = float(np.min(orig_weights))
                result['orig_weights_max'] = float(np.max(orig_weights))
                result['orig_weights_mean'] = float(np.mean(orig_weights))
                result['orig_weights_nonzero'] = int(np.sum(np.abs(orig_weights) > 1e-10))
                print(f"Original weights: min={result['orig_weights_min']}, max={result['orig_weights_max']}")
                print(f"Non-zero weights: {result['orig_weights_nonzero']}/{len(orig_weights)}")
            except Exception as e:
                print(f"Error unpacking original weights: {str(e)}")
                # Try a smaller number if full unpack fails
                for i in range(orig_weight_count, 0, -10):
                    try:
                        smaller_fmt = f'{i}f'
                        orig_weights = np.array(struct.unpack(smaller_fmt, orig_data[6:6+i*4]))
                        print(f"Successfully unpacked {i} weights")
                        break
                    except Exception:
                        continue
    except Exception as e:
        result['orig_weights_error'] = str(e)
        print(f"Error analyzing original weights: {str(e)}")
    
    # Compare weights if both have weights
    if len(our_weights) > 0 and len(orig_weights) > 0:
        # If weight counts are different, we'll truncate to the shorter one
        min_count = min(len(our_weights), len(orig_weights))
        if min_count > 0:
            our_subset = our_weights[:min_count]
            orig_subset = orig_weights[:min_count]
            
            weight_diff = our_subset - orig_subset
            result['weight_comparison'] = {
                'common_count': min_count,
                'max_diff': float(np.max(np.abs(weight_diff))),
                'mean_diff': float(np.mean(np.abs(weight_diff))),
                'diff_greater_than_1e-5': int(np.sum(np.abs(weight_diff) > 1e-5)),
                'percent_similar': float(100 * (1 - np.sum(np.abs(weight_diff) > 1e-5) / min_count))
            }
    
    # Print the results
    print(f"\nComparison of {result['our_file']} vs {result['original_file']}")
    print(f"File sizes: Our {result['our_size']} bytes, Original {result['original_size']} bytes")
    
    if 'our_classes' in result and 'orig_classes' in result:
        print(f"Class labels: Our {result['our_classes']}, Original {result['orig_classes']}")
        print(f"  Match: {'Yes' if result['classes_match'] else 'No'}")
    
    if 'our_intercept' in result and 'orig_intercept' in result:
        print(f"Intercept: Our {result['our_intercept']:.6f}, Original {result['orig_intercept']:.6f}")
        print(f"  Difference: {result['intercept_diff']:.6f}")
        print(f"  Match within tolerance: {'Yes' if result['intercept_match'] else 'No'}")
    
    print(f"Weight counts: Our {result.get('our_weight_count', 'N/A')}, "
          f"Original {result.get('orig_weight_count', 'N/A')}")
    
    if 'weight_comparison' in result:
        comp = result['weight_comparison']
        print(f"Weight comparison (first {comp['common_count']} weights):")
        print(f"  Maximum difference: {comp['max_diff']:.6f}")
        print(f"  Mean absolute difference: {comp['mean_diff']:.6f}")
        print(f"  Number of weights differing by >1e-5: {comp['diff_greater_than_1e-5']} "
              f"({100-comp['percent_similar']:.2f}%)")
    
    # Visualize the weight differences
    if len(our_weights) > 0 and len(orig_weights) > 0:
        plt.figure(figsize=(15, 12))
        
        # Plot the weights
        plt.subplot(4, 1, 1)
        plt.plot(our_weights, 'b-', alpha=0.7, label='Our weights')
        plt.plot(orig_weights, 'r-', alpha=0.7, label='Original weights')
        plt.title('SVM Weights Comparison')
        plt.legend()
        plt.grid(True)
        
        # Plot our weights separately to see their pattern
        plt.subplot(4, 1, 2)
        plt.plot(our_weights, 'b-')
        plt.title('Our Weights Only')
        plt.grid(True)
        plt.ylabel('Weight Value')
        
        # Plot the differences
        plt.subplot(4, 1, 3)
        min_length = min(len(our_weights), len(orig_weights))
        if min_length > 0:
            diff = our_weights[:min_length] - orig_weights[:min_length]
            plt.plot(diff, 'g-')
            plt.title('Weight Differences (Our - Original)')
            plt.grid(True)
            plt.ylabel('Difference')
        
        # Plot a histogram of differences
        plt.subplot(4, 1, 4)
        if min_length > 0:
            plt.hist(diff, bins=50)
            plt.title('Histogram of Weight Differences')
            plt.xlabel('Difference Value')
            plt.ylabel('Frequency')
            plt.grid(True)
        
        plt.tight_layout()
        plt.show()
        
        # Also plot with normalized weights to compare patterns
        plt.figure(figsize=(15, 8))
        if len(our_weights) > 0 and np.max(np.abs(our_weights)) > 0:
            our_norm = our_weights / np.max(np.abs(our_weights))
        else:
            our_norm = our_weights
            
        if len(orig_weights) > 0 and np.max(np.abs(orig_weights)) > 0:
            orig_norm = orig_weights / np.max(np.abs(orig_weights))
        else:
            orig_norm = orig_weights
        
        plt.plot(our_norm, 'b-', alpha=0.7, label='Our weights (normalized)')
        plt.plot(orig_norm, 'r-', alpha=0.7, label='Original weights (normalized)')
        plt.title('Normalized SVM Weights Comparison')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()
    
    return result

In [None]:
def save_tree_svms_to_binary_remapped_fixed(tree, output_dir):
    """
    Save all SVM models in the decision tree as binary files, with class labels remapped
    to sequential integers starting from 0.
    
    Parameters:
    -----------
    tree : BinaryDecisionTree
        The decision tree containing trained SVM models
    output_dir : str
        Directory where the binary files will be saved
        
    Returns:
    --------
    dict
        A dictionary containing the mapping information
    """
    import os
    import struct
    import numpy as np
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all splitting nodes with trained models
    splitting_nodes = tree.splitting_nodes
    print(f"Found {len(splitting_nodes)} splitting nodes in the tree")
    
    # Map each binary path to the corresponding endmember decimal value
    path_to_value = {}
    for i, endmember_label in enumerate(tree.endmembers):
        if endmember_label:  # Skip empty endmembers if any
            # Convert binary string to decimal
            em_value = int(endmember_label, 2)
            path_to_value[endmember_label] = em_value
    
    # Collect all unique class values
    all_classes = set()
    
    # Identify all class values
    for node in splitting_nodes:
        # Skip nodes without models
        if node not in tree.models:
            continue
        
        # Find all endmembers that would go to the left branch (node+'0')
        left_values = []
        # Find all endmembers that would go to the right branch (node+'1')
        right_values = []
        
        for path, value in path_to_value.items():
            # Only process endmembers whose paths are long enough to be affected by this node
            if len(path) > len(node):
                # Check if this path goes through the current node
                if path.startswith(node):
                    # This is a descendent of the current node, check which branch
                    next_bit = path[len(node)]
                    if next_bit == '0':
                        left_values.append(value)
                    else:
                        right_values.append(value)
        
        # Find lowest class in each branch (by decimal value)
        if left_values:
            all_classes.add(min(left_values))
        if right_values:
            all_classes.add(min(right_values))
    
    # Sort classes for consistent mapping
    original_classes = sorted(list(all_classes))
    
    # Create mapping from original classes to sequential integers
    class_mapping = {orig: i for i, orig in enumerate(original_classes)}
    reverse_mapping = {i: orig for i, orig in enumerate(original_classes)}
    
    print(f"Found {len(original_classes)} unique class values in the tree")
    print(f"Original classes: {original_classes}")
    print(f"Remapped to: {list(range(len(original_classes)))}")
    print(f"Mapping: {class_mapping}")
    
    # Count of successfully saved models
    saved_count = 0
    
    for node in splitting_nodes:
        # Check if this node has an SVM model
        if node not in tree.models:
            print(f"Warning: Node '{node}' is marked as a splitting node but has no SVM model")
            continue
        
        # Get the SVM model for this node
        svm_model = tree.models[node]
        
        # Find all endmembers that would go to the left branch (node+'0')
        left_values = []
        # Find all endmembers that would go to the right branch (node+'1')
        right_values = []
        
        for path, value in path_to_value.items():
            # Only process endmembers whose paths are long enough to be affected by this node
            if len(path) > len(node):
                # Check if this path goes through the current node
                if path.startswith(node):
                    # This is a descendent of the current node, check which branch
                    next_bit = path[len(node)]
                    if next_bit == '0':
                        left_values.append(value)
                    else:
                        right_values.append(value)
        
        # Find lowest class in each branch (by decimal value)
        left_min = min(left_values) if left_values else 0
        right_min = min(right_values) if right_values else 0
        
        # Remap class values to sequential indices
        left_mapped = class_mapping[left_min] if left_min in class_mapping else 0 
        right_mapped = class_mapping[right_min] if right_min in class_mapping else 0
        
        # Format the filename using remapped values
        svm_model_name = f"lsm{left_mapped:02d}{right_mapped:02d}"
        filepath = os.path.join(output_dir, svm_model_name)
        
        print(f"Node {node}: Splitting between left branch (orig={left_min}, remapped={left_mapped}) and right branch (orig={right_min}, remapped={right_mapped})")
        
        # Debug: Check if weights are non-zero before saving
        if hasattr(svm_model, 'coef_'):
            weights = svm_model.coef_[0]
            non_zero_count = np.sum(np.abs(weights) > 1e-10)
            if non_zero_count == 0:
                print(f"  WARNING: All weights are zero for node {node}!")
                
                # Try to access the SVM model's raw data directly
                if hasattr(svm_model, '_impl') and hasattr(svm_model._impl, 'raw_coef_'):
                    print("  Attempting to use _impl.raw_coef_ instead...")
                    weights = svm_model._impl.raw_coef_
                else:
                    print("  Could not find alternative weight source")
        
        # Open the file for binary writing
        with open(filepath, "wb") as file:
            # Write REMAPPED class labels as uint8
            file.write(struct.pack("BB", left_mapped, right_mapped))
            
            # Write intercept as float32
            intercept = svm_model.intercept_[0]
            file.write(struct.pack('f', intercept))
            
            # Write weights as float32
            if hasattr(svm_model, 'coef_'):
                weights = svm_model.coef_[0]
                file.write(struct.pack(f'{len(weights)}f', *weights))
                print(f"  Saved {len(weights)} weights (non-zero: {non_zero_count})")
            else:
                print("  WARNING: Model does not have coef_ attribute!")
        
        print(f"  Saved model to {svm_model_name}")
        saved_count += 1
    
    print(f"Successfully saved {saved_count} SVM models to {output_dir}")
    
    # Return the mapping for reference
    return {
        "original_to_remapped": class_mapping,
        "remapped_to_original": reverse_mapping,
        "original_classes": original_classes,
        "remapped_classes": list(range(len(original_classes)))
    }

In [None]:
# Save models with remapped classes
mapping = save_tree_svms_to_binary_remapped_fixed(tree, "weights/svm_models_remapped")

# Print the mapping for reference
print("\nClass mapping:")
for orig, remapped in mapping["original_to_remapped"].items():
    print(f"Original class {orig} → Remapped class {remapped}")

# Save the mapping for later use
import json
with open("weights/class_mapping.json", "w") as f:
    json.dump(mapping, f, indent=2)

In [None]:
# First, debug the SVM models to see if they have proper weights
model_analysis = debug_svm_models(tree)

# Then, use the fixed function to save the files
save_tree_svms_to_binary_fixed(tree, "weights/svm_models_combined_10_L1D_112_MACHI_8end_TREE")

In [None]:
# First, examine your tree structure
print_tree_structure(tree)
{np.int32(0): 0, np.int32(1): 1, np.int32(2): 2, np.int32(4): 3, np.int32(6): 4, np.int32(7): 5, np.int32(8): 6, np.int32(16): 7}
# Then save the SVM models with the correct max_depth
save_tree_svms_to_binary_correct(tree, "weights/svm_models")  # Adjust depth as needed

In [None]:
# Example usage:
our_directory = "weights/svm_models"
original_directory = "weights/svm_models_old"  # Adjust this path to your original files

# Compare files with a more flexible approach
analysis = compare_svm_files_flexible(our_directory, original_directory)

In [None]:
# Example usage:
our_file = "weights/svm_models_fixed/lsm0008"
original_file = "weights/svm_models_old/lsm0008"

# Detailed comparison of a single file
comparison = compare_lsm_file_values(our_file, original_file)

### TEST HYPSO 1

In [None]:
# COMBINED MACHI
image_path = r'D:\Hierarchical Unmixing Label\hUH\images\combined_10_v2_MACHI.npy'
tree_path = r"D:\Hierarchical Unmixing Label\hUH\weights\tree_MACHI.pkl"
deh_path = r'D:\Hierarchical Unmixing Label\hUH\save\MACHI_10img_256to8_stab.h5_aa_FINAL.h5'
prediction = predict_pipeline(image_path,tree_path, deh_path, hypso=1, verbose=False)

In [None]:
#USING NC FILE
image_path = r'D:\Downloads\aregantsea2_2025-03-11T08-12-43Z-l1a.nc'
tree_path = r"D:\Hierarchical Unmixing Label\hUH\weights\tree_MACHI.pkl"
deh_path = r'D:\Hierarchical Unmixing Label\hUH\save\MACHI_10img_256to8_stab.h5_aa_FINAL.h5'
prediction = predict_pipeline(image_path,tree_path, deh_path, hypso=1, verbose=False)

In [None]:
#USING NPY FILE
image_path = r'D:\Hierarchical Unmixing Label\hUH\images\aregantsea2_2025-03-11T08-12-43Z-l1a_cm_machi.npy'
tree_path = r"D:\Hierarchical Unmixing Label\hUH\weights\tree_MACHI.pkl"
deh_path = r'D:\Hierarchical Unmixing Label\hUH\save\MACHI_10img_256to8_stab.h5_aa_FINAL.h5'
prediction = predict_pipeline(image_path,tree_path, deh_path, hypso=1, verbose=False)

In [None]:
aregantsea2_2025_03_11_MACHI          = np.load(r'D:\Hierarchical Unmixing Label\hUH\images\aregantsea2_2025-03-11T08-12-43Z-l1a_cm_machi.npy')
aregantsea2_2025_03_11_MACHI_labels   = np.load(r'D:\Hierarchical Unmixing Label\hUH\save\aregantsea2_2025_03_11_MACHI_binary_labels.npy', allow_pickle=True).item()

In [None]:
prediction = tree.predict(aregantsea2_2025_03_11_MACHI)
evaluation = tree.evaluate(aregantsea2_2025_03_11_MACHI,aregantsea2_2025_03_11_MACHI_labels)

In [None]:
tree.plot_input_image(aregantsea2_2025_03_11_MACHI.reshape(-1,1092,112))
tree.plot_ground_truth(aregantsea2_2025_03_11_MACHI_labels, key='0000')
tree.plot_prediction(prediction, key='0000')  
tree.plot_ground_truth(aregantsea2_2025_03_11_MACHI_labels)
tree.plot_prediction(prediction)

### TEST HYPSO 2

In [None]:
losmanzanosfire_2025_03_11_MACHI      = np.load(r'D:\Hierarchical Unmixing Label\hUH\images\losmanzanosfire_2025-03-11T14-56-42Z-l1a_cm_MACHI.npy')
losmanzanosfire_2025_03_11_MACHI_labels   = np.load(r'D:\Hierarchical Unmixing Label\hUH\save\losmanzanosfire_2025_03_11_MACHI_binary_labels.npy', allow_pickle=True).item()

In [None]:
prediction = tree.predict(losmanzanosfire_2025_03_11_MACHI)
evaluation = tree.evaluate(losmanzanosfire_2025_03_11_MACHI,losmanzanosfire_2025_03_11_MACHI_labels)

In [None]:
tree.plot_input_image(losmanzanosfire_2025_03_11_MACHI.reshape(-1,1092,112))
tree.plot_ground_truth(losmanzanosfire_2025_03_11_MACHI_labels, key='0000')
tree.plot_prediction(prediction, key='0000')  
tree.plot_ground_truth(losmanzanosfire_2025_03_11_MACHI_labels)
tree.plot_prediction(prediction)