### Goals

Debug the current svm pipeline by 
1) Loading the data in the same way and plotting it
2) fiddling with hyperparams

In [None]:
import anndata as ad
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import json
import ccf_streamlines.projection as ccfproj
from datasets import load_from_disk
import yaml

from utils import reflect_points_to_left


In [None]:
ccf_files_path = "/grid/zador/data_norepl/Ari/transcriptomics/CCF_files"



In [None]:
bf_boundary_finder = ccfproj.BoundaryFinder(
    projected_atlas_file=    os.path.join(ccf_files_path,"flatmap_butterfly.nrrd"),
    labels_file=    os.path.join(ccf_files_path,"labelDescription_ITKSNAPColor.txt"),
)

bf_left_boundaries_flat = bf_boundary_finder.region_boundaries()
bf_right_boundaries_flat = bf_boundary_finder.region_boundaries(
    hemisphere='right_for_both',
    view_space_for_other_hemisphere='flatmap_butterfly',
)

In [None]:
# First, create the master color mapping (before any loops)
def create_master_colormap(folds_range, group_size, val_or_test):
    all_categories = set()
    label_names = None
    for f in folds_range:
        predictions_path = f"/grid/zador/home/benjami/brain-annotation/outputs/fold{f}_all_exhausted_{group_size}/{val_or_test}_brain_predictions.npy"
        pred_dict = np.load(predictions_path, allow_pickle=True).item()
        all_categories.update(np.unique(pred_dict['single_cell_labels']))
        all_categories.update(np.unique(pred_dict['single_cell_predictions']))
        if label_names is None and 'label_names' in pred_dict:
            label_names = pred_dict['label_names']
    
    # Convert to sorted list for consistent ordering
    all_categories = sorted(list(all_categories))
    
    # Create extended colormap by combining multiple colorschemes
    colormaps = ['tab20', 'tab20b', 'tab20c']
    colors = np.vstack([plt.cm.get_cmap(cmap)(np.linspace(0, 1, 20)) for cmap in colormaps])
    
    # Map categories to colors
    color_indices = np.arange(len(all_categories)) % len(colors)
    color_map = dict(zip(all_categories, colors[color_indices]))
    
    return color_map, label_names

def plot_scatter_style(x, y, labels, preds, color_map, bf_left_boundaries_flat):
    """Plot the scatter-style visualization with three panels"""
    f, ax = plt.subplots(1, 3, figsize=(15, 5))
    
    # Panel 1: Labels
    colors_labels = np.array([color_map[label] for label in labels])
    ax[0].scatter(x, y, color=colors_labels, alpha=.5, linewidths=0, s=5)
    ax[0].set_title("Area labels")
    
    # Panel 2: Predictions
    colors_preds = np.array([color_map[pred] for pred in preds])
    ax[1].scatter(x, y, color=colors_preds, alpha=.5, linewidths=0, s=5)
    ax[1].set_title("Area predictions")
    
    # Panel 3: Errors
    misclassified = labels != preds
    ax[2].scatter(x[misclassified], y[misclassified], color='r', alpha=.5, linewidths=0, s=5)
    ax[2].set_title("Errors")
    
    # Common styling for all panels
    for a in ax:
        a.grid(False)
        for k, boundary_coords in bf_left_boundaries_flat.items():
            a.plot(*boundary_coords.T, c="k", lw=0.5)
        a.axis('off')
        a.set_aspect('equal')
        a.set_ylim(a.get_ylim()[::-1])
    
    plt.tight_layout()
    return f, ax

import matplotlib as mpl

def create_hexbin_categorical(x, y, labels, label_map=None, gridsize=30):
    """
    Create a hexagonal binning plot for categorical data.
    
    Parameters:
    -----------
    x : array-like
        x-coordinates of the points
    y : array-like
        y-coordinates of the points
    labels : array-like
        categorical labels for each point
    label_map : dict, optional
        mapping from numerical IDs to string labels for legend
    gridsize : int, optional
        number of hexagons in the x-direction
    
    Returns:
    --------
    fig : matplotlib figure
    ax : matplotlib axes
    hexbin_data : pandas DataFrame
    """
    # Convert data to DataFrame
    df = pd.DataFrame({
        'x': x,
        'y': y,
        'label': labels
    })
    
    # Create figure and axis
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Calculate data extent
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    
    # Create a temporary hexbin to get the hex locations and sizes
    hb = ax.hexbin(x, y, gridsize=gridsize, extent=(xmin, xmax, ymin, ymax), visible=False)
    
    # Get hex centers and corners
    hex_centers = hb.get_offsets()
    hex_path = hb.get_paths()[0]
    
    # Function to find modal category
    def get_modal_category(points):
        if len(points) == 0:
            return np.nan
        values, counts = np.unique(points, return_counts=True)
        if len(counts) == 0:
            return np.nan
        return values[np.argmax(counts)]
    
    hex_stats = []
    for i, center in enumerate(hex_centers):
        # Transform the generic hex path to this specific hexagon's position
        vertices = hex_path.vertices + center
        
        # Find points within this hexagon
        mask = mpl.path.Path(vertices).contains_points(df[['x', 'y']])
        if mask.any():
            points_in_hex = df[mask]
            modal_cat = get_modal_category(points_in_hex['label'])
            hex_stats.append({
                'hex_id': i,
                'modal_category': modal_cat,
                'count': len(points_in_hex),
                'center_x': center[0],
                'center_y': center[1]
            })
    
    hex_stats = pd.DataFrame(hex_stats)
    
    # Create extended colormap by combining multiple colorschemes
    colormaps = ['tab20', 'tab20b', 'tab20c']
    colors = np.vstack([plt.cm.get_cmap(cmap)(np.linspace(0, 1, 20)) for cmap in colormaps])

    # Get unique categories and map to extended colors
    unique_categories = np.unique(labels[~pd.isna(labels)])
    n_categories = len(unique_categories)
    color_indices = np.arange(n_categories) % len(colors)
    category_colors = dict(zip(unique_categories, colors[color_indices]))
    
    # Create collection of hexagons
    patches = []
    colors = []
    for _, row in hex_stats.iterrows():
        center = (row['center_x'], row['center_y'])
        vertices = hex_path.vertices + center
        patch = mpl.patches.Polygon(vertices)
        patches.append(patch)
        colors.append(category_colors[row['modal_category']])
    
    # Plot hexagons
    collection = mpl.collections.PatchCollection(
        patches, 
        facecolors=colors,
        edgecolors='white',
        linewidth=0.5,
        alpha=0.7
    )
    ax.add_collection(collection)
    
    # Set plot limits
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    
    # Add legend with mapped labels
    legend_elements = [plt.Rectangle((0, 0), 1, 1, facecolor=color, alpha=0.7)
                      for color in category_colors.values()]
    
    if label_map is not None:
        legend_labels = [label_map.get(int(cat), str(cat)) for cat in category_colors.keys()]
    else:
        legend_labels = [str(cat) for cat in category_colors.keys()]
    
    ax.legend(legend_elements, legend_labels,
             title='Categories', loc='center left', bbox_to_anchor=(1, 0.5))
        
    # Set equal aspect ratio
    ax.set_aspect('equal')

    ax.grid(False)
    for k, boundary_coords in bf_left_boundaries_flat.items():
        ax.plot(*boundary_coords.T, c="k", lw=0.5)

    ax.axis('off')
    # reverse y-axis
    ax.set_ylim(ax.get_ylim()[::-1])
    
    # Adjust layout to prevent legend overlap
    plt.tight_layout()


    
    return fig, ax, hex_stats

def plot_hexbin_style(x, y, labels, preds, color_map, bf_left_boundaries_flat, label_names, gridsize=50):
    """Plot the hexbin-style visualization"""
    fig, ax, hex_stats = create_hexbin_categorical(
        x, y, preds,
        label_map={int(k): v for k, v in label_names.items()} if label_names else None,
        gridsize=gridsize
    )
    return fig, ax, hex_stats

# Main execution loop
    # Create master colormap
   


In [None]:
group_size=32
val_or_test = "test"  



# Load data
predictions_path = f"/grid/zador/home/benjami/brain-annotation/outputs/fold0_animal_name_class_weights2_{group_size}/{val_or_test}_brain_predictions_cells.npy"
pred_dict = np.load(predictions_path, allow_pickle=True).item()
dataset = load_from_disk(f"/grid/zador/data_norepl/Ari/transcriptomics/barseq/Chen2023/train_test_barseq_all_exhausted_fold3.dataset")


In [None]:
group_size=32
val_or_test = "test"  

color_map, label_names = create_master_colormap(range(4), group_size, val_or_test)


# Load data
predictions_path = f"/grid/zador/home/benjami/brain-annotation/outputs/fold3_animal_name_class_weights2_{group_size}/{val_or_test}_brain_predictions_cells.npy"
pred_dict = np.load(predictions_path, allow_pickle=True).item()
dataset = load_from_disk(f"/grid/zador/data_norepl/Ari/transcriptomics/barseq/Chen2023/train_test_barseq_all_exhausted_fold3.dataset")

# Prepare data
labels = np.array(pred_dict['labels']).flatten()
preds = np.array(pred_dict['predictions']).flatten()
indices = np.array(pred_dict['indices'])#[:,0].flatten()

s=1
xyz = np.array(dataset['train' if val_or_test == "validation" else 'test'][indices]['CCF_streamlines'])
xyz = reflect_points_to_left(xyz)
x, y = xyz[::s,0], xyz[::s,1]
labels, preds = labels[::s], preds[::s]

# Calculate accuracy
acc = (preds == labels).sum() / len(labels)
print("Accuracy", 100*acc)

# # Plot based on style
fig, ax = plot_scatter_style(x, y, labels, preds, color_map, bf_left_boundaries_flat)
# plt.savefig(f'test_enucleated_single-cell-preds-{val_or_test}-{group_size}.png', dpi=300)
plt.show()
# fig, ax, _ = plot_hexbin_style(x, y, labels, preds, color_map, bf_left_boundaries_flat, label_names)
# # plt.savefig(f'test_enucleated_single-cell-preds-{val_or_test}-{group_size}-hex.png', dpi=300)
# plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.inspection import DecisionBoundaryDisplay
import joblib
from pathlib import Path
import logging
from typing import Dict, Tuple, List, Optional
from omegaconf import DictConfig
import hydra
from datasets import load_from_disk
from matplotlib.colors import ListedColormap
import ccf_streamlines.projection as ccfproj
import os
from utils import reflect_points_to_left

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

from typing import Optional, Union, Tuple, Callable, Dict, Any, List, Literal
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from sklearn.base import BaseEstimator
import warnings
from scipy import ndimage


class DecisionBoundaryEdgeDisplay:
    """Visualization of classifier decision boundaries using edge detection.
    
    This class creates a visualization that focuses on the boundaries between
    different decision regions in a classification model, highlighting exactly
    where the model transitions from predicting one class to another.
    
    Attributes:
        xx0: np.ndarray
            First axis grid coordinates.
        xx1: np.ndarray
            Second axis grid coordinates.
        response: np.ndarray
            The classifier's predicted classes across the grid.
        boundary_mask: np.ndarray
            Boolean mask indicating boundary pixels.
        estimator_: BaseEstimator
            The fitted classifier.
        ax_: plt.Axes
            The matplotlib axes.
        surface_: Union[plt.QuadMesh, plt.QuadContourSet]
            The visualization surface.
    """
    
    def __init__(
        self, 
        xx0: np.ndarray, 
        xx1: np.ndarray, 
        response: np.ndarray,
        boundary_mask: np.ndarray,
        estimator: BaseEstimator,
        ax: Optional[plt.Axes] = None,
    ) -> None:
        """Initialize the DecisionBoundaryEdgeDisplay.
        
        Parameters:
            xx0: np.ndarray
                First axis grid coordinates.
            xx1: np.ndarray
                Second axis grid coordinates.
            response: np.ndarray
                The classifier's predicted classes across the grid.
            boundary_mask: np.ndarray
                Boolean mask indicating boundary pixels.
            estimator: BaseEstimator
                The fitted classifier.
            ax: Optional[plt.Axes]
                The matplotlib axes to plot on, creates new axes if None.
        """
        self.xx0 = xx0
        self.xx1 = xx1
        self.response = response
        self.response_mesh = response.copy()
        self.boundary_mask = boundary_mask
        self.estimator_ = estimator
        self.ax_ = ax or plt.gca()
        
        # These will be set in plot()
        self.surface_ = None
        self.colorbar_ = None
    
    def plot(
        self, 
        fill_regions: bool = False,
        boundary_color: str = "black",
        boundary_width: float = 1.0,
        boundary_alpha: float = 1.0,
        regions_alpha: float = 0.5,
        regions_cmap: Union[str, LinearSegmentedColormap] = "viridis",
        colorbar: bool = False,
        **kwargs: Any
    ) -> "DecisionBoundaryEdgeDisplay":
        """Plot the decision boundary edges.
        
        Parameters:
            fill_regions: bool
                Whether to fill the decision regions with colors.
            boundary_color: str
                The color for the boundary lines.
            boundary_width: float
                The line width for the boundary.
            boundary_alpha: float
                The alpha (transparency) value for the boundary lines.
            regions_alpha: float
                The alpha blending value for filled regions.
            regions_cmap: Union[str, LinearSegmentedColormap]
                The colormap to use for regions if filled.
            **kwargs: Any
                Additional keyword arguments passed to the plotting method.
                
        Returns:
            DecisionBoundaryEdgeDisplay: self
        """
        # Optionally show filled decision regions
        if fill_regions:
            self.surface_ = self.ax_.pcolormesh(
                self.xx0, self.xx1, self.response, 
                alpha=regions_alpha, 
                cmap=regions_cmap,
                **kwargs
            )
            if colorbar:
                self.colorbar_ = plt.colorbar(self.surface_, ax=self.ax_)
        
        # Plot boundary mask
        boundary_regions = np.ma.masked_where(~self.boundary_mask, self.boundary_mask)
        self.boundary_surface_ = self.ax_.pcolormesh(
            self.xx0, self.xx1, boundary_regions,
            cmap=LinearSegmentedColormap.from_list("", [boundary_color, boundary_color]),
            alpha=boundary_alpha,
            linewidth=boundary_width
        )
        
        return self
    
    def plot_samples(
        self, 
        X: np.ndarray, 
        y: np.ndarray, 
        markers: Optional[List[str]] = None,
        colors: Optional[List[str]] = None,
        scatter_kwargs: Optional[Dict[str, Any]] = None
    ) -> "DecisionBoundaryEdgeDisplay":
        """Plot the samples used to train the classifier.
        
        Parameters:
            X: np.ndarray
                The feature data, shape (n_samples, 2).
            y: np.ndarray
                The target data, shape (n_samples,).
            markers: Optional[List[str]]
                List of markers to use for each class.
            colors: Optional[List[str]]
                List of colors to use for each class.
            scatter_kwargs: Optional[Dict[str, Any]]
                Additional arguments passed to plt.scatter.
                
        Returns:
            DecisionBoundaryEdgeDisplay: self
        """
        scatter_kwargs = scatter_kwargs or {}
        markers = markers or ["o", "s", "^", "v", "<", ">", "d", "p", "*"]
        
        classes = np.unique(y)
        for i, cls in enumerate(classes):
            mask = y == cls
            self.ax_.scatter(
                X[mask, 0],
                X[mask, 1],
                c=[colors[i]] if colors else None,
                marker=markers[i % len(markers)],
                label=f"Class {cls}",
                **scatter_kwargs
            )
        
        self.ax_.legend()
        return self
        
    def apply_mask_from_boundaries(
        self, 
        boundaries: Union[List[np.ndarray], Dict[Any, np.ndarray]],
        invert: bool = True,
        draw: bool = True,
    ) -> "DecisionBoundaryEdgeDisplay":
        """Apply a mask to the visualization based on a set of boundary polygons.
        
        This method masks the visualization to show only points inside (or outside)
        the union of the provided boundaries.
        
        Parameters:
            boundaries: Union[List[np.ndarray], Dict[Any, np.ndarray]]
                List or dictionary of boundary polygons. Each polygon should be
                a numpy array of shape (n_points, 2) defining the boundary vertices.
            invert: bool
                If True, mask points outside the boundaries. If False, mask points
                inside the boundaries.
                
        Returns:
            DecisionBoundaryEdgeDisplay: self
            
        Raises:
            ImportError: If matplotlib.path.Path is not available.
            ValueError: If no surface has been created yet.
        """
        try:
            from matplotlib.path import Path
        except ImportError:
            raise ImportError("matplotlib.path.Path is required for boundary masking.")
        
        if self.boundary_surface_ is None:
            raise ValueError("No boundary surface to mask. Call plot() first.")
            
        # Convert dictionary to list if needed
        boundary_list = list(boundaries.values()) if isinstance(boundaries, dict) else boundaries
        
        # Get mesh points
        xv, yv = np.meshgrid(
            np.linspace(self.ax_.get_xlim()[0], self.ax_.get_xlim()[1], self.xx0.shape[1]),
            np.linspace(self.ax_.get_ylim()[0], self.ax_.get_ylim()[1], self.xx0.shape[0])
        )
        points = np.column_stack((xv.ravel(), yv.ravel()))
        
        # Create combined mask using all boundaries
        combined_mask = np.zeros(points.shape[0], dtype=bool)
        for boundary in boundary_list:
            path = Path(boundary)
            combined_mask |= path.contains_points(points)
            
        # Invert mask if requested
        if invert:
            combined_mask = ~combined_mask
            
        # Reshape mask to match surface array
        mask_shape = self.response.shape
        mask = combined_mask.reshape(mask_shape)
        
        # Apply mask to surfaces
        if hasattr(self.boundary_surface_, 'set_array'):
            # Mask boundary surface
            current_array = self.boundary_surface_.get_array()
            self.boundary_surface_.set_array(np.ma.array(current_array, mask=mask))
            
            # Mask regions surface if it exists
            if self.surface_ is not None and hasattr(self.surface_, 'set_array'):
                current_region_array = self.surface_.get_array()
                self.surface_.set_array(np.ma.array(current_region_array, mask=mask))
        
        # Draw boundaries if requested
        if draw:
            for boundary in boundary_list:
                self.ax_.plot(*boundary.T, c="k", lw=0.5)
            
        return self


def from_estimator(
    estimator: BaseEstimator,
    X: np.ndarray,
    grid_resolution: int = 200,
    eps: float = 0.01,
    ax: Optional[plt.Axes] = None,
    detection_method: Literal["difference", "gradient", "sobel"] = "difference",
    fill_regions: bool = False,
    boundary_color: str = "black",
    boundary_width: float = 1.0,
    boundary_alpha: float = 1.0,
    regions_cmap: Union[str, LinearSegmentedColormap] = "viridis",
    regions_alpha: float = 0.5,
    **kwargs: Any
) -> DecisionBoundaryEdgeDisplay:
    """Create a DecisionBoundaryEdgeDisplay from a fitted estimator.
    
    This function creates a 2D visualization that highlights the decision boundaries
    between different classes, showing exactly where the classifier's prediction changes.
    
    Parameters:
        estimator: BaseEstimator
            Fitted classifier.
        X: np.ndarray
            Input data used to fit the estimator and to compute the grid.
        grid_resolution: int
            Number of points to use for each grid dimension.
        eps: float
            Extends the range of the grid to avoid boundary effects.
        ax: Optional[plt.Axes]
            Axes to plot on, creates new axes if None.
        detection_method: Literal["difference", "gradient", "sobel"]
            Method to detect class boundaries:
            - "difference": Detects changes between adjacent pixels
            - "gradient": Uses gradient magnitude of the prediction
            - "sobel": Uses Sobel filter for edge detection
        fill_regions: bool
            Whether to show filled decision regions.
        boundary_color: str
            The color for the boundary lines.
        boundary_width: float
            The line width for the boundary.
        boundary_alpha: float
            The alpha (transparency) value for the boundary lines.
        regions_cmap: Union[str, LinearSegmentedColormap]
            The colormap to use for decision regions if filled.
        regions_alpha: float
            The alpha blending value for decision regions if filled.
        **kwargs: Any
            Additional arguments passed to the plotting method.
            
    Returns:
        DecisionBoundaryEdgeDisplay: The configured display object.
        
    Raises:
        ValueError: If X is not 2D.
    """
    if X.shape[1] != 2:
        raise ValueError(
            f"Expected 2 features, got {X.shape[1]}. DecisionBoundaryEdgeDisplay only supports 2D visualization."
        )
    
    # Create the grid
    x0_min, x0_max = X[:, 0].min() - eps, X[:, 0].max() + eps
    x1_min, x1_max = X[:, 1].min() - eps, X[:, 1].max() + eps
    
    xx0, xx1 = np.meshgrid(
        np.linspace(x0_min, x0_max, grid_resolution),
        np.linspace(x1_min, x1_max, grid_resolution)
    )
    
    # Get predictions for each grid point
    X_grid = np.c_[xx0.ravel(), xx1.ravel()]
    y_pred_raw = estimator.predict(X_grid)
    
    # Transform class labels to indices using the same approach as sklearn
    if hasattr(estimator, "classes_"):
        # Create a mapping from class labels to indices
        class_to_index = {cls: idx for idx, cls in enumerate(estimator.classes_)}
        # Apply the mapping to get indices for visualization
        y_pred_indices = np.array([class_to_index[cls] for cls in y_pred_raw])
        y_pred = y_pred_indices.reshape(xx0.shape)
    else:
        # For non-classifiers, use the raw predictions
        y_pred = y_pred_raw.reshape(xx0.shape)
    
    # Create boundary mask based on chosen method
    if detection_method == "difference":
        # Detect changes in class prediction
        horizontal_diff = np.diff(y_pred, axis=1)
        vertical_diff = np.diff(y_pred, axis=0)
        
        # Initialize boundary mask
        boundary_mask = np.zeros_like(y_pred, dtype=bool)
        
        # Mark horizontal boundaries
        boundary_mask[:, :-1] |= (horizontal_diff != 0)
        boundary_mask[:, 1:] |= (horizontal_diff != 0)
        
        # Mark vertical boundaries
        boundary_mask[:-1, :] |= (vertical_diff != 0)
        boundary_mask[1:, :] |= (vertical_diff != 0)
        
    elif detection_method == "gradient":
        # Use gradient magnitude to detect boundaries
        gradient_y, gradient_x = np.gradient(y_pred.astype(float))
        gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2)
        
        # Threshold to get boundary mask
        boundary_mask = gradient_magnitude > 0
        
    elif detection_method == "sobel":
        # Use Sobel filter from scipy.ndimage for edge detection
        sobel_h = ndimage.sobel(y_pred.astype(float), axis=0)
        sobel_v = ndimage.sobel(y_pred.astype(float), axis=1)
        magnitude = np.sqrt(sobel_h**2 + sobel_v**2)
        
        # Threshold to get boundary mask
        boundary_mask = magnitude > 0
    else:
        raise ValueError(
            f"Detection method {detection_method} not supported. "
            "Use 'difference', 'gradient', or 'sobel'."
        )
    
    # Create the display object
    display = DecisionBoundaryEdgeDisplay(
        xx0=xx0,
        xx1=xx1,
        response=y_pred,
        boundary_mask=boundary_mask,
        estimator=estimator,
        ax=ax
    )
    
    # Plot the boundaries
    display.plot(
        fill_regions=fill_regions,
        boundary_color=boundary_color,
        boundary_width=boundary_width,
        boundary_alpha=boundary_alpha,
        regions_cmap=regions_cmap,
        regions_alpha=regions_alpha,
        **kwargs
    )
    
    return display



def load_ccf_boundaries():
    ccf_files_path = "/grid/zador/data_norepl/Ari/transcriptomics/CCF_files"
    bf_boundary_finder = ccfproj.BoundaryFinder(
        projected_atlas_file=    os.path.join(ccf_files_path,"flatmap_butterfly.nrrd"),
        labels_file=    os.path.join(ccf_files_path,"labelDescription_ITKSNAPColor.txt"),
    )

    bf_left_boundaries_flat = bf_boundary_finder.region_boundaries()
    return bf_left_boundaries_flat


def create_decision_boundary_plot(
    model: SVC,
    X: np.ndarray,
    color_map,
    output_path: str,
    grid_resolution
):
    """Create and save decision boundary plot for a single model."""
    fig, ax = plt.subplots()

    # Get unique labels and their colors in order
    unique_labels = sorted(color_map.keys())
    colors = [color_map[label] for label in unique_labels]
    cmap = ListedColormap(colors)

    bf_left_boundaries_flat = load_ccf_boundaries()

    disp = from_estimator(
        model,
        X,
        grid_resolution=grid_resolution,
        detection_method="difference",
        fill_regions=True,
        regions_cmap=cmap,
        regions_alpha=1.,
        boundary_color="black",
        boundary_width=.5,
        boundary_alpha=.8,
    )

    # Apply masking using your boundaries
    disp.apply_mask_from_boundaries(
        bf_left_boundaries_flat,  # Your boundary dictionary
        invert=True,  # Mask outside the boundaries
        draw=False
    )

    # Optional: turn off axis and invert y-axis if needed
        # Set equal aspect ratio
    disp.ax_.set_aspect('equal')
    disp.ax_.axis('off')
    disp.ax_.set_ylim(disp.ax_.get_ylim()[::-1])
    plt.tight_layout()
    
    return disp.ax_

#     # Save plot
#     plt.savefig(output_path, dpi=cfg.plotting.dpi)

#     # Draw old boundaries
#     for k, boundary_coords in bf_left_boundaries_flat.items():
#         disp.ax_.plot(*boundary_coords.T, c="k", lw=0.5)
#     plt.tight_layout()
# #     plt.savefig(f"{output_path}_withCCF.png", dpi=cfg.plotting.dpi)
# #     plt.close()


In [None]:
models = []
gamma = 1e-5
C=1

for s in [32,16,8,4,2]:
    
    model = SVC(
        kernel='rbf',
        gamma=gamma,
        C=C,
        class_weight=None,

    )
    model.fit(xyz[::s,:2], preds[::s])
    
    models.append(model)

In [None]:
for i, model in enumerate(models):
    s = [32,16,8,4,2][i]
    ax = create_decision_boundary_plot(model, xyz[::s,:2], color_map, None, 1000)
    output_path = f"svc_gs_32_test_fold3_subsample_{s}.png"
    plt.savefig(output_path)

    # Draw old boundaries
    
    for k, boundary_coords in bf_left_boundaries_flat.items():
        ax.plot(*boundary_coords.T, c="k", lw=0.5)
    plt.tight_layout()
    plt.savefig(f"{output_path}_withCCF.png")
    plt.show()

In [None]:
from sklearn import svm
from sklearn.inspection import DecisionBoundaryDisplay
import joblib
from pathlib import Path


def plot_training_data_with_decision_boundary(
    clf, X, y, ax=None, long_title=False, support_vectors=False, 
):
    # Settings for plotting
    if ax is None:
        _, ax = plt.subplots(figsize=(4, 3))
    x_min, x_max, y_min, y_max = -3, 3, -3, 3
    ax.set(xlim=(x_min, x_max), ylim=(y_min, y_max))

    # Plot decision boundary and margins
    common_params = {"estimator": clf, "X": X, "ax": ax}
    DecisionBoundaryDisplay.from_estimator(
        **common_params,
        response_method="predict",
        plot_method="pcolormesh",
        alpha=0.3,
    )
    DecisionBoundaryDisplay.from_estimator(
        **common_params,
        response_method="decision_function",
        plot_method="contour",
        levels=[-1, 0, 1],
        colors=["k", "k", "k"],
        linestyles=["--", "-", "--"],
    )

    if support_vectors:
        # Plot bigger circles around samples that serve as support vectors
        ax.scatter(
            clf.support_vectors_[:, 0],
            clf.support_vectors_[:, 1],
            s=150,
            facecolors="none",
            edgecolors="k",
        )


    if ax is None:
        plt.show()

In [None]:
 # Load existing models
models = {}
for gamma in [0.00001,0.00005, 0.0001]:
    model_path = f"/grid/zador/home/benjami/brain-annotation/outputs/predicted_svc_boundaries_32/subsample_every__C_1.0_unbalanced/svm_gamma_{gamma:.5f}.joblib"
    model_path = Path(model_path)
    if not model_path.exists():
        print(f"No model file found for gamma={gamma}")
        print(model_path)
    else:
        models[gamma] = joblib.load(model_path)

In [None]:
np.logspace(1

In [None]:
from matplotlib.colors import ListedColormap

color_map
# Convert color_map to ListedColormap

# Get unique labels and their colors in order
unique_labels = sorted(color_map.keys())
colors = [color_map[label] for label in unique_labels]
cmap = ListedColormap(colors)

# Create a mapping from label to index
# label_to_index = {label: i for i, label in enumerate(unique_labels)}
cmap

In [None]:
from typing import Optional, Union, Tuple, Callable, Dict, Any, List
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from sklearn.base import BaseEstimator
import warnings


class UncertaintyBoundaryDisplay:
    """Display class for visualizing decision boundaries with uncertainty highlighting.
    
    This class creates a visualization of a classifier's decision boundaries in 2D, 
    with special emphasis on regions where the classifier is uncertain (i.e., where
    the difference between the top two class probabilities is below a threshold).
    
    Attributes:
        xx0: np.ndarray
            First axis grid coordinates.
        xx1: np.ndarray
            Second axis grid coordinates.
        response: np.ndarray
            The classifier's response values across the grid.
        uncertainty_mask: np.ndarray
            Boolean mask indicating uncertain regions.
        estimator_: BaseEstimator
            The fitted classifier.
        ax_: plt.Axes
            The matplotlib axes.
        surface_: Union[plt.QuadMesh, plt.QuadContourSet]
            The result of pcolormesh or contourf, for compatibility with DecisionBoundaryDisplay.
        uncertainty_surface_: plt.QuadMesh
            The visualization of uncertainty regions.
        colorbar_: plt.Colorbar
            The colorbar associated with the visualization.
    """
    
    def __init__(
        self, 
        xx0: np.ndarray, 
        xx1: np.ndarray, 
        response: np.ndarray,
        uncertainty_mask: np.ndarray,
        estimator: BaseEstimator,
        ax: Optional[plt.Axes] = None,
    ) -> None:
        """Initialize the UncertaintyBoundaryDisplay.
        
        Parameters:
            xx0: np.ndarray
                First axis grid coordinates.
            xx1: np.ndarray
                Second axis grid coordinates.
            response: np.ndarray
                The classifier's response values across the grid.
            uncertainty_mask: np.ndarray
                Boolean mask indicating uncertain regions.
            estimator: BaseEstimator
                The fitted classifier.
            ax: Optional[plt.Axes]
                The matplotlib axes to plot on, creates new axes if None.
        """
        self.xx0 = xx0
        self.xx1 = xx1
        self.response = response
        self.response_mesh = response.copy()  # Store original for masking operations
        self.uncertainty_mask = uncertainty_mask
        self.estimator_ = estimator
        self.ax_ = ax or plt.gca()
        
        # These will be set in plot()
        self.surface_ = None
        self.uncertainty_surface_ = None
        self.colorbar_ = None
    
    def plot(
        self, 
        plot_method: str = "pcolormesh",
        alpha: float = 1.0,
        cmap: Union[str, LinearSegmentedColormap] = "viridis",
        uncertainty_color: str = "black",
        uncertainty_alpha: float = 0.7,
        **kwargs: Any
    ) -> "UncertaintyBoundaryDisplay":
        """Plot the decision boundaries and uncertainty regions.
        
        Parameters:
            plot_method: str
                The plotting method to use ('contourf' or 'pcolormesh').
            alpha: float
                The alpha blending value for the decision boundaries.
            cmap: Union[str, LinearSegmentedColormap]
                The colormap to use for decision boundaries.
            uncertainty_color: str
                The color to use for uncertain regions.
            uncertainty_alpha: float
                The alpha blending value for uncertain regions.
            **kwargs: Any
                Additional keyword arguments passed to the plotting method.
                
        Returns:
            UncertaintyBoundaryDisplay: self
        
        Raises:
            ValueError: If plot_method is not 'contourf' or 'pcolormesh'.
        """
        # Plot decision boundaries
        if plot_method == "contourf":
            self.surface_ = self.ax_.contourf(
                self.xx0, self.xx1, self.response, alpha=alpha, cmap=cmap, **kwargs
            )
            # Store a reference to the mappable for compatibility with set_array
            self._mappable = self.surface_
        elif plot_method == "pcolormesh":
            self.surface_ = self.ax_.pcolormesh(
                self.xx0, self.xx1, self.response, alpha=alpha, cmap=cmap, **kwargs
            )
            self._mappable = self.surface_
        else:
            raise ValueError(
                f"Plot method {plot_method} not supported. Use 'contourf' or 'pcolormesh'."
            )
        
        # Overlay uncertainty regions
        uncertainty_regions = np.ma.masked_where(~self.uncertainty_mask, self.uncertainty_mask)
        self.uncertainty_surface_ = self.ax_.pcolormesh(
            self.xx0, self.xx1, uncertainty_regions, 
            cmap=LinearSegmentedColormap.from_list("", [uncertainty_color, uncertainty_color]),
            alpha=uncertainty_alpha
        )
        
        # self.colorbar_ = plt.colorbar(self.surface_, ax=self.ax_)
        return self
    
    def plot_samples(
        self, 
        X: np.ndarray, 
        y: np.ndarray, 
        markers: Optional[List[str]] = None,
        colors: Optional[List[str]] = None,
        scatter_kwargs: Optional[Dict[str, Any]] = None
    ) -> "UncertaintyBoundaryDisplay":
        """Plot the samples used to train the classifier.
        
        Parameters:
            X: np.ndarray
                The feature data, shape (n_samples, 2).
            y: np.ndarray
                The target data, shape (n_samples,).
            markers: Optional[List[str]]
                List of markers to use for each class.
            colors: Optional[List[str]]
                List of colors to use for each class.
            scatter_kwargs: Optional[Dict[str, Any]]
                Additional arguments passed to plt.scatter.
                
        Returns:
            UncertaintyBoundaryDisplay: self
        """
        scatter_kwargs = scatter_kwargs or {}
        markers = markers or ["o", "s", "^", "v", "<", ">", "d", "p", "*"]
        
        classes = np.unique(y)
        for i, cls in enumerate(classes):
            mask = y == cls
            self.ax_.scatter(
                X[mask, 0],
                X[mask, 1],
                c=[colors[i]] if colors else None,
                marker=markers[i % len(markers)],
                label=f"Class {cls}",
                **scatter_kwargs
            )
        
        self.ax_.legend()
        return self
        
    def apply_mask_from_boundaries(
        self, 
        boundaries: Union[List[np.ndarray], Dict[Any, np.ndarray]],
        invert: bool = True
    ) -> "UncertaintyBoundaryDisplay":
        """Apply a mask to the visualization based on a set of boundary polygons.
        
        This method masks the visualization to show only points inside (or outside)
        the union of the provided boundaries.
        
        Parameters:
            boundaries: Union[List[np.ndarray], Dict[Any, np.ndarray]]
                List or dictionary of boundary polygons. Each polygon should be
                a numpy array of shape (n_points, 2) defining the boundary vertices.
            invert: bool
                If True, mask points outside the boundaries. If False, mask points
                inside the boundaries.
                
        Returns:
            UncertaintyBoundaryDisplay: self
            
        Raises:
            ImportError: If matplotlib.path.Path is not available.
            ValueError: If no surface has been created yet.
        """
        try:
            from matplotlib.path import Path
        except ImportError:
            raise ImportError("matplotlib.path.Path is required for boundary masking.")
        
        if self.surface_ is None:
            raise ValueError("No surface to mask. Call plot() first.")
            
        # Convert dictionary to list if needed
        boundary_list = list(boundaries.values()) if isinstance(boundaries, dict) else boundaries
        
        # Get mesh points
        xv, yv = np.meshgrid(
            np.linspace(self.ax_.get_xlim()[0], self.ax_.get_xlim()[1], self.xx0.shape[1]),
            np.linspace(self.ax_.get_ylim()[0], self.ax_.get_ylim()[1], self.xx0.shape[0])
        )
        points = np.column_stack((xv.ravel(), yv.ravel()))
        
        # Create combined mask using all boundaries
        combined_mask = np.zeros(points.shape[0], dtype=bool)
        for boundary in boundary_list:
            path = Path(boundary)
            combined_mask |= path.contains_points(points)
            
        # Invert mask if requested
        if invert:
            combined_mask = ~combined_mask
            
        # Reshape mask to match surface array
        mask_shape = self.response.shape
        mask = combined_mask.reshape(mask_shape)
        
        # Apply mask to both surfaces
        if hasattr(self.surface_, 'set_array'):
            # For pcolormesh main surface
            current_array = self.surface_.get_array()
            self.surface_.set_array(np.ma.array(current_array, mask=mask))
            
            # Also mask uncertainty surface
            if self.uncertainty_surface_ is not None:
                # Get current uncertainty visualization
                current_uncertainty_array = self.uncertainty_surface_.get_array()
                
                # Create a new mask that combines the boundary mask with existing uncertainty mask
                combined_uncertainty_mask = mask.copy()
                
                # Apply the combined mask to uncertainty surface
                self.uncertainty_surface_.set_array(
                    np.ma.array(current_uncertainty_array, mask=combined_uncertainty_mask)
                )
        else:
            # For contourf (trickier to mask directly)
            # Redraw with masked data
            masked_response = np.ma.array(self.response_mesh, mask=mask)
            self.ax_.clear()
            self.surface_ = self.ax_.contourf(self.xx0, self.xx1, masked_response)
            
            # Redraw uncertainty with masking applied
            if self.uncertainty_surface_ is not None:
                # Apply both the uncertainty condition AND the boundary mask
                uncertainty_regions = np.ma.masked_where(
                    ~self.uncertainty_mask | mask, 
                    self.uncertainty_mask
                )
                self.uncertainty_surface_ = self.ax_.pcolormesh(
                    self.xx0, self.xx1, uncertainty_regions,
                    cmap=self.uncertainty_surface_.cmap,
                    alpha=self.uncertainty_surface_.get_alpha()
                )
                
            # # Redraw colorbar
            # if self.colorbar_ is not None:
            #     self.colorbar_ = plt.colorbar(self.surface_, ax=self.ax_)
                
        # Draw boundaries if requested
        for boundary in boundary_list:
            self.ax_.plot(*boundary.T, c="k", lw=0.5)
            
        return self


def from_estimator(
    estimator: BaseEstimator,
    X: np.ndarray,
    uncertainty_threshold: float = 0.1,
    grid_resolution: int = 100,
    eps: float = 0.01,
    ax: Optional[plt.Axes] = None,
    response_method: str = "decision_function",
    plot_boundary: bool = True,
    plot_method: str = "pcolormesh",
    alpha: float = 1.0,
    cmap: Union[str, LinearSegmentedColormap] = "viridis",
    uncertainty_color: str = "black",
    uncertainty_alpha: float = 0.7,
    **kwargs: Any
) -> UncertaintyBoundaryDisplay:
    """Create a UncertaintyBoundaryDisplay from a fitted estimator.
    
    This function creates a 2D visualization of a classifier's decision 
    boundaries, with special emphasis on regions where the classifier is uncertain.
    Uncertainty is defined as areas where the difference between the top two
    predicted class probabilities is below the uncertainty_threshold.
    
    Parameters:
        estimator: BaseEstimator
            Fitted classifier or regressor.
        X: np.ndarray
            Input data used to fit the estimator and to compute the grid.
        uncertainty_threshold: float
            The threshold below which the difference between top two classes
            is considered uncertain.
        grid_resolution: int
            Number of points to use for each grid dimension.
        eps: float
            Extends the range of the grid to avoid boundary effects.
        ax: Optional[plt.Axes]
            Axes to plot on, creates new axes if None.
        response_method: str
            The method of the estimator to use for obtaining predictions.
            Options: 'decision_function', 'predict_proba'.
        plot_boundary: bool
            Whether to plot the decision boundary immediately.
        plot_method: str
            The method to use for plotting ('contourf' or 'pcolormesh').
        alpha: float
            The alpha blending value for the decision boundaries.
        cmap: Union[str, LinearSegmentedColormap]
            The colormap to use for decision boundaries.
        uncertainty_color: str
            The color to use for uncertain regions.
        uncertainty_alpha: float
            The alpha blending value for uncertain regions.
        **kwargs: Any
            Additional arguments passed to the plotting method.
            
    Returns:
        UncertaintyBoundaryDisplay: The configured display object.
        
    Raises:
        ValueError: If the estimator doesn't implement the required method or X is not 2D.
    """
    if X.shape[1] != 2:
        raise ValueError(
            f"Expected 2 features, got {X.shape[1]}. UncertaintyBoundaryDisplay only supports 2D visualization."
        )
    
    # Get the required prediction method
    if response_method == "decision_function":
        if not hasattr(estimator, "decision_function"):
            if hasattr(estimator, "predict_proba"):
                warnings.warn(
                    "The estimator doesn't have 'decision_function' method. Using 'predict_proba' instead."
                )
                response_method = "predict_proba"
            else:
                raise ValueError(
                    "The estimator doesn't implement 'decision_function' or 'predict_proba'."
                )
    elif response_method == "predict_proba":
        if not hasattr(estimator, "predict_proba"):
            raise ValueError("The estimator doesn't implement 'predict_proba'.")
    else:
        raise ValueError(
            f"Response method {response_method} not supported. Use 'decision_function' or 'predict_proba'."
        )
    
    # Create the grid
    x0_min, x0_max = X[:, 0].min() - eps, X[:, 0].max() + eps
    x1_min, x1_max = X[:, 1].min() - eps, X[:, 1].max() + eps
    
    xx0, xx1 = np.meshgrid(
        np.linspace(x0_min, x0_max, grid_resolution),
        np.linspace(x1_min, x1_max, grid_resolution)
    )
    
    # Obtain predictions and identify uncertain regions
    X_grid = np.c_[xx0.ravel(), xx1.ravel()]
    
    if response_method == "decision_function":
        raw_predictions = estimator.decision_function(X_grid)
        # Handle both binary and multiclass cases
        if raw_predictions.ndim == 1:
            # Binary case
            predictions = np.column_stack([-raw_predictions, raw_predictions])
        else:
            # Multiclass case
            predictions = raw_predictions
    else:  # predict_proba
        predictions = estimator.predict_proba(X_grid)
    
    # Get top two prediction scores for each point
    sorted_indices = np.argsort(-predictions, axis=1)
    top_scores = np.take_along_axis(predictions, sorted_indices[:, :2], axis=1)
    
    # Calculate difference between top two scores
    score_diff = top_scores[:, 0] - top_scores[:, 1]
    
    # Create uncertainty mask where difference is below threshold
    uncertainty_mask = (score_diff < uncertainty_threshold).reshape(xx0.shape)
    
    # Create color-coded response grid (use predict for class labels)
    response = estimator.predict(X_grid).reshape(xx0.shape)
    
    # Create the display object
    display = UncertaintyBoundaryDisplay(
        xx0=xx0,
        xx1=xx1,
        response=response,
        uncertainty_mask=uncertainty_mask,
        estimator=estimator,
        ax=ax
    )
    
    # Plot the boundary if requested
    if plot_boundary:
        display.plot(
            plot_method=plot_method,
            alpha=alpha,
            cmap=cmap,
            uncertainty_color=uncertainty_color,
            uncertainty_alpha=uncertainty_alpha,
            **kwargs
        )
    
    return display


# Example usage

In [None]:
from typing import Optional, Union, Tuple, Callable, Dict, Any, List, Literal
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from sklearn.base import BaseEstimator
import warnings
from scipy import ndimage


class DecisionBoundaryEdgeDisplay:
    """Visualization of classifier decision boundaries using edge detection.
    
    This class creates a visualization that focuses on the boundaries between
    different decision regions in a classification model, highlighting exactly
    where the model transitions from predicting one class to another.
    
    Attributes:
        xx0: np.ndarray
            First axis grid coordinates.
        xx1: np.ndarray
            Second axis grid coordinates.
        response: np.ndarray
            The classifier's predicted classes across the grid.
        boundary_mask: np.ndarray
            Boolean mask indicating boundary pixels.
        estimator_: BaseEstimator
            The fitted classifier.
        ax_: plt.Axes
            The matplotlib axes.
        surface_: Union[plt.QuadMesh, plt.QuadContourSet]
            The visualization surface.
    """
    
    def __init__(
        self, 
        xx0: np.ndarray, 
        xx1: np.ndarray, 
        response: np.ndarray,
        boundary_mask: np.ndarray,
        estimator: BaseEstimator,
        ax: Optional[plt.Axes] = None,
    ) -> None:
        """Initialize the DecisionBoundaryEdgeDisplay.
        
        Parameters:
            xx0: np.ndarray
                First axis grid coordinates.
            xx1: np.ndarray
                Second axis grid coordinates.
            response: np.ndarray
                The classifier's predicted classes across the grid.
            boundary_mask: np.ndarray
                Boolean mask indicating boundary pixels.
            estimator: BaseEstimator
                The fitted classifier.
            ax: Optional[plt.Axes]
                The matplotlib axes to plot on, creates new axes if None.
        """
        self.xx0 = xx0
        self.xx1 = xx1
        self.response = response
        self.response_mesh = response.copy()
        self.boundary_mask = boundary_mask
        self.estimator_ = estimator
        self.ax_ = ax or plt.gca()
        
        # These will be set in plot()
        self.surface_ = None
        self.colorbar_ = None
    
    def plot(
        self, 
        fill_regions: bool = False,
        boundary_color: str = "black",
        boundary_width: float = 1.0,
        boundary_alpha: float = 1.0,
        regions_alpha: float = 0.5,
        regions_cmap: Union[str, LinearSegmentedColormap] = "viridis",
        colorbar: bool = False,
        **kwargs: Any
    ) -> "DecisionBoundaryEdgeDisplay":
        """Plot the decision boundary edges.
        
        Parameters:
            fill_regions: bool
                Whether to fill the decision regions with colors.
            boundary_color: str
                The color for the boundary lines.
            boundary_width: float
                The line width for the boundary.
            boundary_alpha: float
                The alpha (transparency) value for the boundary lines.
            regions_alpha: float
                The alpha blending value for filled regions.
            regions_cmap: Union[str, LinearSegmentedColormap]
                The colormap to use for regions if filled.
            **kwargs: Any
                Additional keyword arguments passed to the plotting method.
                
        Returns:
            DecisionBoundaryEdgeDisplay: self
        """
        # Optionally show filled decision regions
        if fill_regions:
            self.surface_ = self.ax_.pcolormesh(
                self.xx0, self.xx1, self.response, 
                alpha=regions_alpha, 
                cmap=regions_cmap,
                **kwargs
            )
            if colorbar:
                self.colorbar_ = plt.colorbar(self.surface_, ax=self.ax_)
        
        # Plot boundary mask
        boundary_regions = np.ma.masked_where(~self.boundary_mask, self.boundary_mask)
        self.boundary_surface_ = self.ax_.pcolormesh(
            self.xx0, self.xx1, boundary_regions,
            cmap=LinearSegmentedColormap.from_list("", [boundary_color, boundary_color]),
            alpha=boundary_alpha,
            linewidth=boundary_width
        )
        
        return self
    
    def plot_samples(
        self, 
        X: np.ndarray, 
        y: np.ndarray, 
        markers: Optional[List[str]] = None,
        colors: Optional[List[str]] = None,
        scatter_kwargs: Optional[Dict[str, Any]] = None
    ) -> "DecisionBoundaryEdgeDisplay":
        """Plot the samples used to train the classifier.
        
        Parameters:
            X: np.ndarray
                The feature data, shape (n_samples, 2).
            y: np.ndarray
                The target data, shape (n_samples,).
            markers: Optional[List[str]]
                List of markers to use for each class.
            colors: Optional[List[str]]
                List of colors to use for each class.
            scatter_kwargs: Optional[Dict[str, Any]]
                Additional arguments passed to plt.scatter.
                
        Returns:
            DecisionBoundaryEdgeDisplay: self
        """
        scatter_kwargs = scatter_kwargs or {}
        markers = markers or ["o", "s", "^", "v", "<", ">", "d", "p", "*"]
        
        classes = np.unique(y)
        for i, cls in enumerate(classes):
            mask = y == cls
            self.ax_.scatter(
                X[mask, 0],
                X[mask, 1],
                c=[colors[i]] if colors else None,
                marker=markers[i % len(markers)],
                label=f"Class {cls}",
                **scatter_kwargs
            )
        
        self.ax_.legend()
        return self
        
    def apply_mask_from_boundaries(
        self, 
        boundaries: Union[List[np.ndarray], Dict[Any, np.ndarray]],
        invert: bool = True,
        draw: bool = True,
    ) -> "DecisionBoundaryEdgeDisplay":
        """Apply a mask to the visualization based on a set of boundary polygons.
        
        This method masks the visualization to show only points inside (or outside)
        the union of the provided boundaries.
        
        Parameters:
            boundaries: Union[List[np.ndarray], Dict[Any, np.ndarray]]
                List or dictionary of boundary polygons. Each polygon should be
                a numpy array of shape (n_points, 2) defining the boundary vertices.
            invert: bool
                If True, mask points outside the boundaries. If False, mask points
                inside the boundaries.
                
        Returns:
            DecisionBoundaryEdgeDisplay: self
            
        Raises:
            ImportError: If matplotlib.path.Path is not available.
            ValueError: If no surface has been created yet.
        """
        try:
            from matplotlib.path import Path
        except ImportError:
            raise ImportError("matplotlib.path.Path is required for boundary masking.")
        
        if self.boundary_surface_ is None:
            raise ValueError("No boundary surface to mask. Call plot() first.")
            
        # Convert dictionary to list if needed
        boundary_list = list(boundaries.values()) if isinstance(boundaries, dict) else boundaries
        
        # Get mesh points
        xv, yv = np.meshgrid(
            np.linspace(self.ax_.get_xlim()[0], self.ax_.get_xlim()[1], self.xx0.shape[1]),
            np.linspace(self.ax_.get_ylim()[0], self.ax_.get_ylim()[1], self.xx0.shape[0])
        )
        points = np.column_stack((xv.ravel(), yv.ravel()))
        
        # Create combined mask using all boundaries
        combined_mask = np.zeros(points.shape[0], dtype=bool)
        for boundary in boundary_list:
            path = Path(boundary)
            combined_mask |= path.contains_points(points)
            
        # Invert mask if requested
        if invert:
            combined_mask = ~combined_mask
            
        # Reshape mask to match surface array
        mask_shape = self.response.shape
        mask = combined_mask.reshape(mask_shape)
        
        # Apply mask to surfaces
        if hasattr(self.boundary_surface_, 'set_array'):
            # Mask boundary surface
            current_array = self.boundary_surface_.get_array()
            self.boundary_surface_.set_array(np.ma.array(current_array, mask=mask))
            
            # Mask regions surface if it exists
            if self.surface_ is not None and hasattr(self.surface_, 'set_array'):
                current_region_array = self.surface_.get_array()
                self.surface_.set_array(np.ma.array(current_region_array, mask=mask))
        
        # Draw boundaries if requested
        if draw:
            for boundary in boundary_list:
                self.ax_.plot(*boundary.T, c="k", lw=0.5)
            
        return self


def from_estimator(
    estimator: BaseEstimator,
    X: np.ndarray,
    grid_resolution: int = 200,
    eps: float = 0.01,
    ax: Optional[plt.Axes] = None,
    detection_method: Literal["difference", "gradient", "sobel"] = "difference",
    fill_regions: bool = False,
    boundary_color: str = "black",
    boundary_width: float = 1.0,
    boundary_alpha: float = 1.0,
    regions_cmap: Union[str, LinearSegmentedColormap] = "viridis",
    regions_alpha: float = 0.5,
    **kwargs: Any
) -> DecisionBoundaryEdgeDisplay:
    """Create a DecisionBoundaryEdgeDisplay from a fitted estimator.
    
    This function creates a 2D visualization that highlights the decision boundaries
    between different classes, showing exactly where the classifier's prediction changes.
    
    Parameters:
        estimator: BaseEstimator
            Fitted classifier.
        X: np.ndarray
            Input data used to fit the estimator and to compute the grid.
        grid_resolution: int
            Number of points to use for each grid dimension.
        eps: float
            Extends the range of the grid to avoid boundary effects.
        ax: Optional[plt.Axes]
            Axes to plot on, creates new axes if None.
        detection_method: Literal["difference", "gradient", "sobel"]
            Method to detect class boundaries:
            - "difference": Detects changes between adjacent pixels
            - "gradient": Uses gradient magnitude of the prediction
            - "sobel": Uses Sobel filter for edge detection
        fill_regions: bool
            Whether to show filled decision regions.
        boundary_color: str
            The color for the boundary lines.
        boundary_width: float
            The line width for the boundary.
        boundary_alpha: float
            The alpha (transparency) value for the boundary lines.
        regions_cmap: Union[str, LinearSegmentedColormap]
            The colormap to use for decision regions if filled.
        regions_alpha: float
            The alpha blending value for decision regions if filled.
        **kwargs: Any
            Additional arguments passed to the plotting method.
            
    Returns:
        DecisionBoundaryEdgeDisplay: The configured display object.
        
    Raises:
        ValueError: If X is not 2D.
    """
    if X.shape[1] != 2:
        raise ValueError(
            f"Expected 2 features, got {X.shape[1]}. DecisionBoundaryEdgeDisplay only supports 2D visualization."
        )
    
    # Create the grid
    x0_min, x0_max = X[:, 0].min() - eps, X[:, 0].max() + eps
    x1_min, x1_max = X[:, 1].min() - eps, X[:, 1].max() + eps
    
    xx0, xx1 = np.meshgrid(
        np.linspace(x0_min, x0_max, grid_resolution),
        np.linspace(x1_min, x1_max, grid_resolution)
    )
    
    # Get predictions for each grid point
    X_grid = np.c_[xx0.ravel(), xx1.ravel()]
    y_pred_raw = estimator.predict(X_grid)
    
    # Transform class labels to indices using the same approach as sklearn
    if hasattr(estimator, "classes_"):
        # Create a mapping from class labels to indices
        class_to_index = {cls: idx for idx, cls in enumerate(estimator.classes_)}
        # Apply the mapping to get indices for visualization
        y_pred_indices = np.array([class_to_index[cls] for cls in y_pred_raw])
        y_pred = y_pred_indices.reshape(xx0.shape)
    else:
        # For non-classifiers, use the raw predictions
        y_pred = y_pred_raw.reshape(xx0.shape)
    
    # Create boundary mask based on chosen method
    if detection_method == "difference":
        # Detect changes in class prediction
        horizontal_diff = np.diff(y_pred, axis=1)
        vertical_diff = np.diff(y_pred, axis=0)
        
        # Initialize boundary mask
        boundary_mask = np.zeros_like(y_pred, dtype=bool)
        
        # Mark horizontal boundaries
        boundary_mask[:, :-1] |= (horizontal_diff != 0)
        boundary_mask[:, 1:] |= (horizontal_diff != 0)
        
        # Mark vertical boundaries
        boundary_mask[:-1, :] |= (vertical_diff != 0)
        boundary_mask[1:, :] |= (vertical_diff != 0)
        
    elif detection_method == "gradient":
        # Use gradient magnitude to detect boundaries
        gradient_y, gradient_x = np.gradient(y_pred.astype(float))
        gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2)
        
        # Threshold to get boundary mask
        boundary_mask = gradient_magnitude > 0
        
    elif detection_method == "sobel":
        # Use Sobel filter from scipy.ndimage for edge detection
        sobel_h = ndimage.sobel(y_pred.astype(float), axis=0)
        sobel_v = ndimage.sobel(y_pred.astype(float), axis=1)
        magnitude = np.sqrt(sobel_h**2 + sobel_v**2)
        
        # Threshold to get boundary mask
        boundary_mask = magnitude > 0
    else:
        raise ValueError(
            f"Detection method {detection_method} not supported. "
            "Use 'difference', 'gradient', or 'sobel'."
        )
    
    # Create the display object
    display = DecisionBoundaryEdgeDisplay(
        xx0=xx0,
        xx1=xx1,
        response=y_pred,
        boundary_mask=boundary_mask,
        estimator=estimator,
        ax=ax
    )
    
    # Plot the boundaries
    display.plot(
        fill_regions=fill_regions,
        boundary_color=boundary_color,
        boundary_width=boundary_width,
        boundary_alpha=boundary_alpha,
        regions_cmap=regions_cmap,
        regions_alpha=regions_alpha,
        **kwargs
    )
    
    return display




In [None]:
import joblib
model = joblib.load('/grid/zador/home/benjami/brain-annotation/outputs/fold0_animal_name_class_weights2_32/svc_boundaries/svm_gamma_0.00010.joblib')

In [None]:
disp = from_estimator(
    model,
    xyz[:,:2],
    grid_resolution=100,
    detection_method="difference",
    fill_regions=True,
    regions_cmap=cmap,
    regions_alpha=1.,
    boundary_color="black",
    boundary_width=.5,
    boundary_alpha=0.8,
)

# Apply masking using your boundaries
disp.apply_mask_from_boundaries(
    bf_left_boundaries_flat,  # Your boundary dictionary
    invert=True,  # Mask outside the boundaries
    draw=True
)

# Optional: turn off axis and invert y-axis if needed
    # Set equal aspect ratio
disp.ax_.set_aspect('equal')
disp.ax_.axis('off')
disp.ax_.set_ylim(disp.ax_.get_ylim()[::-1])
plt.tight_layout()

In [None]:
from matplotlib.path import Path
import numpy as np

common_params = {"estimator": models[0.00001], "X": xyz[:,:2], "ax": None}
disp = DecisionBoundaryDisplay.from_estimator(
    **common_params,
    response_method="predict",
    plot_method="pcolormesh",
    alpha=1,
    cmap=cmap)
ax = plt.gca()
ax.set_aspect('equal')
ax.grid(False)

# Get the mesh points from the surface
xv, yv = np.meshgrid(
    np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], disp.surface_.get_array().shape[1]),
    np.linspace(ax.get_ylim()[0], ax.get_ylim()[1], disp.surface_.get_array().shape[0])
)
points = np.column_stack((xv.ravel(), yv.ravel()))

# Create combined mask using all boundaries
combined_mask = np.zeros(points.shape[0], dtype=bool)
for boundary in bf_left_boundaries_flat.values():
    path = Path(boundary)
    combined_mask |= path.contains_points(points)

# Reshape mask to match surface array
mask = ~combined_mask.reshape(disp.surface_.get_array().shape)

# Apply the mask to the surface
disp.surface_.set_array(np.ma.array(disp.surface_.get_array(), mask=mask))

# Draw all boundaries
for boundary in bf_left_boundaries_flat.values():
    ax.plot(*boundary.T, c="k", lw=0.5)

ax.axis('off')
ax.set_ylim(ax.get_ylim()[::-1])
plt.tight_layout()

In [None]:
models

In [None]:
DecisionBoundaryDisplay.from_estimator(
    **common_params,
    response_method="decision_function",
    plot_method="contour",
    levels=[-1, 0, 1],
    colors=["k", "k", "k"],
    linestyles=["--", "-", "--"],
    class_of_interest=44,

)

In [None]:
plot_training_data_with_decision_boundary(models[0.1], np.random.randn(100, 2), np.random.randint(0, 2, 100))

In [None]:
plot_training_data_with_decision_boundary("sigmoid", xyz[:, :2], preds,)