In [2]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import KDTree, BallTree  # Issue #4: For efficient BMU search
from tqdm import tqdm  # Issue #10: For progress monitoring
from enum import Enum
from dataclasses import dataclass, field, asdict
from typing import Optional, Callable, List, Dict, Any, Tuple, Union
import pickle
from abc import ABC, abstractmethod
from datetime import datetime
import os


# ARCHITECTURE FIX #1: Enum classes for configuration management
class Topology(Enum):
    """Issue #15: Support different grid topologies"""

    RECTANGULAR = "rectangular"
    HEXAGONAL = "hexagonal"
    TOROIDAL = "toroidal"


class DecaySchedule(Enum):
    """Issue #18 & #19: Different decay schedules for learning rate"""

    EXPONENTIAL = "exponential"
    LINEAR = "linear"
    INVERSE = "inverse"
    COSINE = "cosine"
    STEP = "step"


class LearningMode(Enum):
    """Issue #21: Different learning modes"""

    ONLINE = "online"
    BATCH = "batch"
    MINI_BATCH = "mini_batch"


class DistanceMetric(Enum):
    """ARCHITECTURE FIX #5: Support multiple distance metrics"""

    EUCLIDEAN = "euclidean"
    MANHATTAN = "manhattan"
    COSINE = "cosine"
    CHEBYSHEV = "chebyshev"


class InitStrategy(Enum):
    """ARCHITECTURE FIX #6: Different initialization strategies"""

    RANDOM = "random"
    PCA = "pca"
    SAMPLE = "sample"
    LINEAR = "linear"


# ARCHITECTURE FIX #2: Configuration management with dataclass
@dataclass
class SOMConfig:
    """Centralized configuration management for SOM parameters"""

    # Basic parameters
    width: int
    height: int
    n_features: int = 3
    n_iterations: int = 1000

    # Training parameters
    learning_mode: LearningMode = LearningMode.ONLINE
    batch_size: int = 32

    # Decay schedules (Issues #18, #19, #20)
    sigma_decay: DecaySchedule = DecaySchedule.EXPONENTIAL
    alpha_decay: DecaySchedule = DecaySchedule.EXPONENTIAL
    initial_sigma: Optional[float] = None  # Auto-calculated if None
    initial_alpha: float = 0.1
    min_sigma: float = 0.01
    min_alpha: float = 0.001
    warmup_steps: int = 0

    # Topology and distance (Issue #15, ARCHITECTURE FIX #5)
    topology: Topology = Topology.RECTANGULAR
    distance_metric: DistanceMetric = DistanceMetric.EUCLIDEAN

    # Optimization parameters
    cutoff_factor: float = 3.0  # Issues #13, #14
    momentum: float = 0.0  # Issue #24
    dtype: np.dtype = np.float32  # OPTIMIZATION

    # Convergence parameters (Issues #16, #17)
    early_stopping: bool = True
    convergence_tolerance: float = 1e-4
    patience: int = 10
    convergence_check_interval: int = 10

    # Performance parameters
    kdtree_rebuild_interval: int = 25

    # Initialization (ARCHITECTURE FIX #6)
    init_strategy: InitStrategy = InitStrategy.RANDOM

    # Persistence (ARCHITECTURE FIX #3, #4)
    checkpoint_interval: Optional[int] = None
    checkpoint_dir: str = "checkpoints"

    # Bounds (Issue #23)
    weight_bounds: Tuple[float, float] = (0.0, 1.0)

    # Reproducibility (Issue #11)
    seed: Optional[int] = None

    def __post_init__(self):
        """Auto-calculate initial sigma if not provided"""
        if self.initial_sigma is None:
            self.initial_sigma = max(self.width, self.height) / 2

    def to_dict(self) -> Dict:
        """Convert config to dictionary for serialization"""
        config_dict = asdict(self)
        # Convert enums to strings
        for key, value in config_dict.items():
            if isinstance(value, Enum):
                config_dict[key] = value.value
        return config_dict

    @classmethod
    def from_dict(cls, config_dict: Dict) -> "SOMConfig":
        """Create config from dictionary"""
        # Convert string back to enums
        enum_fields = {
            "topology": Topology,
            "sigma_decay": DecaySchedule,
            "alpha_decay": DecaySchedule,
            "learning_mode": LearningMode,
            "distance_metric": DistanceMetric,
            "init_strategy": InitStrategy,
        }
        for field_name, enum_class in enum_fields.items():
            if field_name in config_dict and isinstance(config_dict[field_name], str):
                config_dict[field_name] = enum_class(config_dict[field_name])
        return cls(**config_dict)


# ARCHITECTURE FIX #9: Callback system for monitoring and intervention
class Callback(ABC):
    """Abstract base class for callbacks"""

    @abstractmethod
    def on_epoch_begin(self, epoch: int, som: "SOM") -> None:
        pass

    @abstractmethod
    def on_epoch_end(self, epoch: int, som: "SOM", metrics: Dict) -> None:
        pass

    @abstractmethod
    def on_training_begin(self, som: "SOM") -> None:
        pass

    @abstractmethod
    def on_training_end(self, som: "SOM") -> None:
        pass


class CheckpointCallback(Callback):
    """ARCHITECTURE FIX #4: Callback for checkpointing during training"""

    def __init__(self, checkpoint_dir: str, interval: int = 100):
        self.checkpoint_dir = checkpoint_dir
        self.interval = interval
        os.makedirs(checkpoint_dir, exist_ok=True)

    def on_epoch_begin(self, epoch: int, som: "SOM") -> None:
        pass

    def on_epoch_end(self, epoch: int, som: "SOM", metrics: Dict) -> None:
        if epoch % self.interval == 0:
            checkpoint_path = os.path.join(
                self.checkpoint_dir, f"checkpoint_epoch_{epoch}.pkl"
            )
            try:
                som.save(checkpoint_path)
                if som.verbose:
                    print(f"Checkpoint saved: {checkpoint_path}")
            except (IOError, OSError) as e:
                if som.verbose:
                    print(f"Warning: Failed to save checkpoint: {e}")

    def on_training_begin(self, som: "SOM") -> None:
        pass

    def on_training_end(self, som: "SOM") -> None:
        final_path = os.path.join(self.checkpoint_dir, "final_model.pkl")
        try:
            som.save(final_path)
        except (IOError, OSError) as e:
            if som.verbose:
                print(f"Warning: Failed to save final model: {e}")


class EarlyStoppingCallback(Callback):
    """ARCHITECTURE FIX #9: Callback for custom early stopping logic"""

    def __init__(
        self, monitor: str = "qe", patience: int = 10, min_delta: float = 1e-4
    ):
        self.monitor = monitor
        self.patience = patience
        self.min_delta = min_delta
        self.best_value = float("inf")
        self.wait = 0

    def on_epoch_begin(self, epoch: int, som: "SOM") -> None:
        pass

    def on_epoch_end(self, epoch: int, som: "SOM", metrics: Dict) -> None:
        current_value = metrics.get(self.monitor, float("inf"))
        if current_value < self.best_value - self.min_delta:
            self.best_value = current_value
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                som.stop_training = True
                if som.verbose:
                    print(f"Early stopping triggered at epoch {epoch}")

    def on_training_begin(self, som: "SOM") -> None:
        self.best_value = float("inf")
        self.wait = 0

    def on_training_end(self, som: "SOM") -> None:
        pass


# ARCHITECTURE FIX #5: Distance metric functions
class DistanceCalculator:
    """Calculate distances using different metrics"""

    @staticmethod
    def euclidean(a: np.ndarray, b: np.ndarray) -> np.ndarray:
        return np.linalg.norm(a - b, axis=-1)

    @staticmethod
    def manhattan(a: np.ndarray, b: np.ndarray) -> np.ndarray:
        return np.sum(np.abs(a - b), axis=-1)

    @staticmethod
    def cosine(a: np.ndarray, b: np.ndarray) -> np.ndarray:
        # Cosine distance = 1 - cosine similarity
        dot_product = np.sum(a * b, axis=-1)
        norm_a = np.linalg.norm(a, axis=-1)
        norm_b = np.linalg.norm(b, axis=-1)
        similarity = dot_product / (norm_a * norm_b + 1e-8)
        return 1 - similarity

    @staticmethod
    def chebyshev(a: np.ndarray, b: np.ndarray) -> np.ndarray:
        return np.max(np.abs(a - b), axis=-1)


# ARCHITECTURE FIX #1: Main SOM class with proper OOP design
class SOM:
    """
    Self-Organizing Map with comprehensive architecture improvements

    Fixes all 23 original issues plus 9 architecture improvements
    """

    def __init__(self, config: SOMConfig, verbose: bool = True):
        """
        Initialize SOM with configuration

        Args:
            config: SOMConfig object with all parameters
            verbose: Whether to print training progress
        """
        self.config = config
        self.verbose = verbose

        # Issue #11: Set random seed for reproducibility
        if config.seed is not None:
            np.random.seed(config.seed)

        # Initialize core attributes
        self.n_neurons = config.width * config.height
        self.weights_flat = None
        self.tree = None  # Changed from kdtree to tree (can be KDTree or BallTree)
        self.neuron_coords = None

        # ARCHITECTURE FIX #7: Store comprehensive metadata
        self.metadata = {
            "creation_time": datetime.now().isoformat(),
            "training_history": [],
            "total_epochs": 0,
            "total_samples_seen": 0,
            "config": config.to_dict(),
        }

        # Training state for incremental training (ARCHITECTURE FIX #4)
        self.training_state = {
            "epoch": 0,
            "best_error": float("inf"),
            "no_improvement_count": 0,
            "prev_update": None,
        }

        # Callbacks list (ARCHITECTURE FIX #9)
        self.callbacks: List[Callback] = []

        # Control flag for early stopping
        self.stop_training = False

        # Distance calculator based on metric (ARCHITECTURE FIX #5)
        self.distance_func = self._get_distance_function()

        # Initialize weights and structure
        self._initialize_structure()

    def _get_distance_function(self) -> Callable:
        """ARCHITECTURE FIX #5: Get appropriate distance function"""
        metric_map = {
            DistanceMetric.EUCLIDEAN: DistanceCalculator.euclidean,
            DistanceMetric.MANHATTAN: DistanceCalculator.manhattan,
            DistanceMetric.COSINE: DistanceCalculator.cosine,
            DistanceMetric.CHEBYSHEV: DistanceCalculator.chebyshev,
        }
        return metric_map[self.config.distance_metric]

    def _create_tree(self):
        """FIX: Create KDTree or BallTree with appropriate metric.
        For COSINE we don't build a tree (BallTree doesn't accept 'cosine'),
        instead we keep a normalized weights cache for brute-force queries.
        """
        # Clear any previous normalized cache
        self._normalized_weights = None

        if self.config.distance_metric == DistanceMetric.EUCLIDEAN:
            self.tree = KDTree(self.weights_flat)
        elif self.config.distance_metric == DistanceMetric.MANHATTAN:
            self.tree = BallTree(self.weights_flat, metric="manhattan")
        elif self.config.distance_metric == DistanceMetric.CHEBYSHEV:
            self.tree = BallTree(self.weights_flat, metric="chebyshev")
        elif self.config.distance_metric == DistanceMetric.COSINE:
            # BallTree/KDTree don't support 'cosine' reliably -> use brute-force on normalized vectors
            # Store normalized weights for fast dot-product based cosine distance calculations
            norms = np.linalg.norm(self.weights_flat, axis=1, keepdims=True) + 1e-8
            self._normalized_weights = (self.weights_flat / norms).astype(
                self.config.dtype
            )
            # keep tree = None to signal brute-force path in _query_tree
            self.tree = None
        else:
            # Fallback
            self.tree = KDTree(self.weights_flat)

    def _query_tree(self, data: np.ndarray, k: int = 1):
        """Query tree with appropriate preprocessing for cosine distance.
        Returns (distances, indices) shapes: (n_samples, k)
        """
        if self.config.distance_metric == DistanceMetric.COSINE:
            # Use brute-force cosine distance on normalized vectors
            # Normalize data
            data = np.asarray(data, dtype=self.config.dtype)
            norms = np.linalg.norm(data, axis=1, keepdims=True) + 1e-8
            normalized_data = data / norms

            # Ensure normalized_weights is available (created in _create_tree)
            if getattr(self, "_normalized_weights", None) is None:
                # If weights changed and tree wasn't rebuilt, compute normalized weights now
                self._normalized_weights = self.weights_flat / (
                    np.linalg.norm(self.weights_flat, axis=1, keepdims=True) + 1e-8
                )

            # Dot product -> cosine similarity; distance = 1 - similarity
            # shape (n_samples, n_neurons)
            sim = normalized_data @ self._normalized_weights.T
            dists = 1.0 - sim

            # For k == 1 return min; else return top-k
            if k == 1:
                idx = np.argmin(dists, axis=1)
                min_dists = dists[np.arange(dists.shape[0]), idx]
                return min_dists.reshape(-1, 1), idx.reshape(-1, 1)
            else:
                # Partial selection for speed, then sort
                idx_part = np.argpartition(dists, kth=k - 1, axis=1)[
                    :, :k
                ]  # (n_samples, k)
                # Now sort those k per-row by actual distance
                row_idx = np.arange(dists.shape[0])[:, None]
                sorted_order = np.argsort(dists[row_idx, idx_part], axis=1)
                idx_sorted = idx_part[row_idx, sorted_order]
                dists_sorted = dists[row_idx, idx_sorted]
                return dists_sorted, idx_sorted
        else:
            # Use tree-based query (KDTree / BallTree)
            if self.tree is None:
                # safety: (re)create tree if missing
                self._create_tree()
            # sklearn tree.query returns (distances, indices)
            return self.tree.query(data, k=k)

    def _initialize_structure(self):
        """Initialize neuron coordinates and weights"""
        # OPTIMIZATION: Create neuron coordinates instead of 4D distance matrix
        self.neuron_coords = self._create_neuron_coordinates()

        # Initialize weights if not already present (for new training)
        if self.weights_flat is None:
            self._initialize_weights()

    def _create_neuron_coordinates(self) -> np.ndarray:
        """
        OPTIMIZATION: Create 2D coordinates for all neurons
        Issue #15: Support different topologies
        """
        if self.config.topology == Topology.HEXAGONAL:
            coords = []
            for i in range(self.config.width):
                for j in range(self.config.height):
                    x = i
                    y = j + (0.5 if i % 2 else 0)
                    coords.append([x, y])
            return np.array(coords, dtype=self.config.dtype)
        else:
            coords = np.array(
                [
                    [i, j]
                    for i in range(self.config.width)
                    for j in range(self.config.height)
                ],
                dtype=self.config.dtype,
            )
            return coords

    def _initialize_weights(self, data: Optional[np.ndarray] = None):
        """
        ARCHITECTURE FIX #6: Multiple initialization strategies
        """
        if self.config.init_strategy == InitStrategy.RANDOM:
            # Original random initialization
            self.weights_flat = np.random.random(
                (self.n_neurons, self.config.n_features)
            ).astype(self.config.dtype)

        elif self.config.init_strategy == InitStrategy.PCA and data is not None:
            # FIX: Handle cases with fewer dimensions
            from sklearn.decomposition import PCA

            n_components = min(2, data.shape[1], self.config.n_features)

            if n_components < 2:
                # Handle 1D PCA case
                pca = PCA(n_components=1)
                pca.fit(data)

                # Create 1D gradient along principal component
                x_range = np.linspace(-3, 3, self.n_neurons)
                points = x_range.reshape(-1, 1)
                self.weights_flat = pca.inverse_transform(points).astype(
                    self.config.dtype
                )
            else:
                # Original 2D PCA logic
                pca = PCA(n_components=2)
                pca.fit(data)

                # Create grid along principal components
                x_range = np.linspace(-3, 3, self.config.width)
                y_range = np.linspace(-3, 3, self.config.height)

                self.weights_flat = np.zeros(
                    (self.n_neurons, self.config.n_features), dtype=self.config.dtype
                )

                idx = 0
                for i in range(self.config.width):
                    for j in range(self.config.height):
                        point = np.array([x_range[i], y_range[j]])
                        self.weights_flat[idx] = pca.inverse_transform(
                            point.reshape(1, -1)
                        )
                        idx += 1

            # Normalize to data range
            self.weights_flat = np.clip(self.weights_flat, 0, 1)

        elif self.config.init_strategy == InitStrategy.SAMPLE and data is not None:
            # Initialize with random samples from data
            indices = np.random.choice(len(data), self.n_neurons, replace=True)
            self.weights_flat = data[indices].copy().astype(self.config.dtype)

        elif self.config.init_strategy == InitStrategy.LINEAR:
            # FIX: Handle arbitrary number of features
            self.weights_flat = np.zeros(
                (self.n_neurons, self.config.n_features), dtype=self.config.dtype
            )
            for idx in range(self.n_neurons):
                x = idx // self.config.height
                y = idx % self.config.height

                values = []
                for f in range(self.config.n_features):
                    if f == 0:
                        values.append(x / max(self.config.width, 1))
                    elif f == 1:
                        values.append(y / max(self.config.height, 1))
                    else:
                        # Create additional gradients for extra features
                        values.append(
                            (x + y + f)
                            / max(self.config.width + self.config.height + f, 1)
                        )

                self.weights_flat[idx] = values
        else:
            # Fallback to random
            self.weights_flat = np.random.random(
                (self.n_neurons, self.config.n_features)
            ).astype(self.config.dtype)

        # Issue #23: Ensure weights are within bounds
        self.weights_flat = np.clip(
            self.weights_flat,
            self.config.weight_bounds[0],
            self.config.weight_bounds[1],
        )

    def _compute_neuron_distances(self, bmu_coord: np.ndarray) -> np.ndarray:
        """
        OPTIMIZATION: Compute distances from BMU to all neurons
        Issue #15: Handle different topologies
        """
        if self.config.topology == Topology.TOROIDAL:
            # Toroidal wrap-around distance
            dx = np.abs(self.neuron_coords[:, 0] - bmu_coord[0])
            dy = np.abs(self.neuron_coords[:, 1] - bmu_coord[1])
            dx = np.minimum(dx, self.config.width - dx)
            dy = np.minimum(dy, self.config.height - dy)
            distances = np.sqrt(dx**2 + dy**2)
        else:
            # Standard distance (works for rectangular and hexagonal)
            distances = np.linalg.norm(self.neuron_coords - bmu_coord, axis=1)

        # FIX: Ensure consistent dtype
        return distances.astype(self.config.dtype)

    def _get_neighborhood_mask(
        self, bmu_idx: int, sigma: float
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        OPTIMIZATION: Get mask of neurons within neighborhood radius
        Issues #13, #14: Cutoff radius for efficiency
        """
        cutoff_radius = self.config.cutoff_factor * sigma
        bmu_coord = self.neuron_coords[bmu_idx]

        distances = self._compute_neuron_distances(bmu_coord)
        mask = distances <= cutoff_radius

        return mask, distances[mask]

    def _calculate_neighborhood(
        self, distances: np.ndarray, sigma: float
    ) -> np.ndarray:
        """
        Issues #13, #14, #22: Gaussian with cutoff and underflow protection
        """
        exponent = -(distances**2) / (2 * (sigma**2))
        exponent = np.maximum(exponent, -50)  # Issue #22: Prevent underflow
        return np.exp(exponent).astype(self.config.dtype)

    def _get_decay_value(
        self, t: int, t_max: int, initial: float, final: float, schedule: DecaySchedule
    ) -> float:
        """
        Issues #18, #19, #20: Flexible decay schedules with warm-up
        """
        if t < self.config.warmup_steps:
            return initial

        effective_t = t - self.config.warmup_steps
        effective_max = t_max - self.config.warmup_steps

        if effective_max <= 0:
            return final

        if schedule == DecaySchedule.LINEAR:
            return initial - (initial - final) * (effective_t / effective_max)
        elif schedule == DecaySchedule.INVERSE:
            return initial / (1 + effective_t / effective_max)
        elif schedule == DecaySchedule.COSINE:
            return (
                final
                + (initial - final)
                * (1 + np.cos(np.pi * effective_t / effective_max))
                / 2
            )
        elif schedule == DecaySchedule.STEP:
            drops = effective_t // 100
            return initial * (0.5**drops)
        else:  # EXPONENTIAL
            if initial > 0 and final > 0:
                decay_rate = -np.log(final / initial) / effective_max
                return initial * np.exp(-decay_rate * effective_t)
            return final

    def fit(
        self,
        data: np.ndarray,
        n_iterations: Optional[int] = None,
        callbacks: Optional[List[Callback]] = None,
    ) -> "SOM":
        """
        Train the SOM on data

        ARCHITECTURE FIX #4: Supports incremental training
        ARCHITECTURE FIX #9: Supports callbacks

        Args:
            data: Input data of shape (n_samples, n_features)
            n_iterations: Number of iterations (uses config if None)
            callbacks: List of callback objects

        Returns:
            self for method chaining
        """
        # FIX: Validate input data for NaN and infinite values
        if np.any(np.isnan(data)) or np.any(np.isinf(data)):
            raise ValueError("Input data contains NaN or infinite values")

        # Validate input shape (Issue #9)
        if data.shape[1] != self.config.n_features:
            raise ValueError(
                f"Expected {self.config.n_features} features, got {data.shape[1]}"
            )

        # Issue #8: Normalize data
        data_normalized = self._normalize_data(data)

        # ARCHITECTURE FIX #6: Initialize weights with data if needed
        if self.weights_flat is None or (
            self.config.init_strategy in [InitStrategy.PCA, InitStrategy.SAMPLE]
            and self.training_state["epoch"] == 0
        ):
            self._initialize_weights(data_normalized)

        # Setup iterations
        if n_iterations is None:
            n_iterations = self.config.n_iterations

        # Setup callbacks (ARCHITECTURE FIX #9)
        self.callbacks = callbacks or []
        if self.config.checkpoint_interval:
            self.callbacks.append(
                CheckpointCallback(
                    self.config.checkpoint_dir, self.config.checkpoint_interval
                )
            )

        # Call training begin callbacks
        for callback in self.callbacks:
            callback.on_training_begin(self)

        # Main training loop
        self._train_loop(data_normalized, n_iterations)

        # Call training end callbacks
        for callback in self.callbacks:
            callback.on_training_end(self)

        # Update metadata (ARCHITECTURE FIX #7)
        self.metadata["total_epochs"] += n_iterations
        self.metadata["total_samples_seen"] += len(data) * n_iterations
        self.metadata["last_training"] = datetime.now().isoformat()

        return self

    def _normalize_data(self, data: np.ndarray) -> np.ndarray:
        """Issue #8: Normalize input data"""
        data = data.astype(self.config.dtype)
        data_min = data.min(axis=0)
        data_max = data.max(axis=0)
        data_range = data_max - data_min
        data_range[data_range == 0] = 1
        return (data - data_min) / data_range

    def _train_loop(self, data: np.ndarray, n_iterations: int):
        """Main training loop with all optimizations"""
        # FIX: Initialize tree with correct metric
        self._create_tree()

        # Setup progress bar (Issue #10)
        iterator = range(
            self.training_state["epoch"], self.training_state["epoch"] + n_iterations
        )
        if self.verbose:
            iterator = tqdm(iterator, desc="Training SOM")

        # FIX: Initialize momentum tracking consistently
        if self.training_state["prev_update"] is None:
            self.training_state["prev_update"] = np.zeros_like(self.weights_flat)

        for t in iterator:
            # Epoch begin callbacks
            for callback in self.callbacks:
                callback.on_epoch_begin(t, self)

            # Check early stopping flag
            if self.stop_training:
                if self.verbose:
                    print(f"Training stopped at epoch {t}")
                break

            # Get current parameters (Issues #18, #19, #20)
            sigma = self._get_decay_value(
                t,
                self.config.n_iterations,
                self.config.initial_sigma,
                self.config.min_sigma,
                self.config.sigma_decay,
            )
            alpha = self._get_decay_value(
                t,
                self.config.n_iterations,
                self.config.initial_alpha,
                self.config.min_alpha,
                self.config.alpha_decay,
            )

            # Issue #7: Shuffle data
            shuffled_indices = np.random.permutation(len(data))
            shuffled_data = data[shuffled_indices]

            # Process based on learning mode (Issue #21)
            epoch_metrics = self._process_epoch(shuffled_data, sigma, alpha)

            # FIX: Update tree periodically with correct metric
            if (t + 1) % self.config.kdtree_rebuild_interval == 0:
                self._create_tree()

            # Check convergence (Issues #16, #17)
            if (
                self.config.early_stopping
                and t % self.config.convergence_check_interval == 0
            ):
                if self._check_convergence(epoch_metrics):
                    if self.verbose:
                        print(f"\nConverged after {t+1} iterations")
                    break

            # Update display
            if self.verbose:
                iterator.set_postfix(
                    {
                        "QE": f"{epoch_metrics.get('qe', 0):.4f}",
                        "σ": f"{sigma:.3f}",
                        "α": f"{alpha:.4f}",
                    }
                )

            # Store metrics (ARCHITECTURE FIX #7)
            self.metadata["training_history"].append(
                {"epoch": t, "metrics": epoch_metrics, "sigma": sigma, "alpha": alpha}
            )

            # Epoch end callbacks
            for callback in self.callbacks:
                callback.on_epoch_end(t, self, epoch_metrics)

            # Update training state
            self.training_state["epoch"] = t + 1

    def _process_epoch(self, data: np.ndarray, sigma: float, alpha: float) -> Dict:
        """Process one epoch of training"""
        metrics = {"qe": 0}

        if self.config.learning_mode == LearningMode.BATCH:
            metrics = self._process_batch(data, sigma, alpha)
        elif self.config.learning_mode == LearningMode.MINI_BATCH:
            metrics = self._process_mini_batch(data, sigma, alpha)
        else:  # ONLINE
            metrics = self._process_online(data, sigma, alpha)

        # Issue #23: Clip weights
        self.weights_flat = np.clip(
            self.weights_flat,
            self.config.weight_bounds[0],
            self.config.weight_bounds[1],
        )

        return metrics

    def _process_batch(self, data: np.ndarray, sigma: float, alpha: float) -> Dict:
        """OPTIMIZATION: Vectorized batch processing"""
        batch_updates = np.zeros_like(self.weights_flat)

        # Find all BMUs at once
        distances, bmu_indices = self._query_tree(data, k=1)
        bmu_indices = bmu_indices.flatten()
        total_error = np.sum(distances**2)

        # Group by BMU for efficiency
        unique_bmus = np.unique(bmu_indices)
        for bmu_idx in unique_bmus:
            sample_mask = bmu_indices == bmu_idx
            samples_for_bmu = data[sample_mask]

            # OPTIMIZATION: Windowed update
            neighbor_mask, neighbor_distances = self._get_neighborhood_mask(
                bmu_idx, sigma
            )

            if np.any(neighbor_mask):
                theta = self._calculate_neighborhood(neighbor_distances, sigma)
                affected_indices = np.where(neighbor_mask)[0]

                for sample in samples_for_bmu:
                    batch_updates[affected_indices] += theta[:, np.newaxis] * (
                        sample - self.weights_flat[affected_indices]
                    )

        # Apply updates
        batch_updates /= len(data)

        # Issue #24: FIX - Consistent momentum application
        if self.config.momentum > 0:
            batch_updates = (
                self.config.momentum * self.training_state["prev_update"]
                + (1 - self.config.momentum) * batch_updates
            )
            self.training_state["prev_update"] = batch_updates.copy()

        self.weights_flat += alpha * batch_updates

        return {"qe": total_error / len(data)}

    def _process_mini_batch(self, data: np.ndarray, sigma: float, alpha: float) -> Dict:
        """Process mini-batches"""
        n_batches = (len(data) + self.config.batch_size - 1) // self.config.batch_size
        total_error = 0

        for batch_idx in range(n_batches):
            start_idx = batch_idx * self.config.batch_size
            end_idx = min(start_idx + self.config.batch_size, len(data))
            batch_data = data[start_idx:end_idx]

            # Process similar to batch mode
            batch_metrics = self._process_batch(batch_data, sigma, alpha)
            total_error += batch_metrics["qe"] * len(batch_data)

        return {"qe": total_error / len(data)}

    def _process_online(self, data: np.ndarray, sigma: float, alpha: float) -> Dict:
        """Process samples one by one"""
        total_error = 0

        for sample in data:
            # Find BMU
            distance, bmu_idx = self._query_tree(sample.reshape(1, -1), k=1)
            bmu_idx = bmu_idx[0, 0]
            total_error += distance[0, 0] ** 2

            # OPTIMIZATION: Windowed update
            neighbor_mask, neighbor_distances = self._get_neighborhood_mask(
                bmu_idx, sigma
            )

            if np.any(neighbor_mask):
                theta = self._calculate_neighborhood(neighbor_distances, sigma)
                affected_indices = np.where(neighbor_mask)[0]

                update = theta[:, np.newaxis] * (
                    sample - self.weights_flat[affected_indices]
                )

                # FIX: Consistent momentum tracking for all weights
                if self.config.momentum > 0:
                    # Apply momentum to affected indices
                    prev_update = self.training_state["prev_update"][affected_indices]
                    update = (
                        self.config.momentum * prev_update
                        + (1 - self.config.momentum) * update
                    )
                    # Update momentum state for affected indices
                    self.training_state["prev_update"][affected_indices] = update
                    # Zero out momentum for unaffected neurons (decay)
                    unaffected_mask = ~neighbor_mask
                    self.training_state["prev_update"][
                        unaffected_mask
                    ] *= self.config.momentum

                self.weights_flat[affected_indices] += alpha * update

        return {"qe": total_error / len(data)}

    def _check_convergence(self, metrics: Dict) -> bool:
        """Issues #16, #17: Check for convergence"""
        current_error = metrics.get("qe", float("inf"))

        if (
            abs(self.training_state["best_error"] - current_error)
            < self.config.convergence_tolerance
        ):
            self.training_state["no_improvement_count"] += 1
            if self.training_state["no_improvement_count"] >= self.config.patience:
                return True
        else:
            self.training_state["no_improvement_count"] = 0
            if current_error < self.training_state["best_error"]:
                self.training_state["best_error"] = current_error

        return False

    def predict(self, data: np.ndarray) -> np.ndarray:
        """Find BMU indices for input data"""
        # FIX: Check for NaN/inf
        if np.any(np.isnan(data)) or np.any(np.isinf(data)):
            raise ValueError("Input data contains NaN or infinite values")

        data_normalized = self._normalize_data(data)
        _, bmu_indices = self._query_tree(data_normalized, k=1)
        return bmu_indices.flatten()

    def transform(self, data: np.ndarray) -> np.ndarray:
        """
        FIX: Transform data to 2D grid coordinates using neuron_coords directly
        """
        bmu_indices = self.predict(data)
        # Directly use the neuron coordinates instead of recalculating
        return self.neuron_coords[bmu_indices]

    def get_weights(self) -> np.ndarray:
        """Get weights in grid format"""
        return self.weights_flat.reshape(
            self.config.width, self.config.height, self.config.n_features
        )

    def quantization_error(self, data: np.ndarray) -> float:
        """Calculate quantization error for data"""
        # FIX: Check for NaN/inf
        if np.any(np.isnan(data)) or np.any(np.isinf(data)):
            raise ValueError("Input data contains NaN or infinite values")

        data_normalized = self._normalize_data(data)
        distances, _ = self._query_tree(data_normalized, k=1)
        return np.mean(distances**2)

    def topographic_error(self, data: np.ndarray) -> float:
        """
        FIX: Calculate topographic error accounting for different topologies
        """
        # FIX: Check for NaN/inf
        if np.any(np.isnan(data)) or np.any(np.isinf(data)):
            raise ValueError("Input data contains NaN or infinite values")

        data_normalized = self._normalize_data(data)
        distances, indices = self._query_tree(data_normalized, k=2)

        errors = 0
        for idx1, idx2 in indices:
            coord1 = self.neuron_coords[idx1]
            coord2 = self.neuron_coords[idx2]

            # Calculate distance based on topology
            if self.config.topology == Topology.HEXAGONAL:
                # For hexagonal, adjacent means distance <= 1
                distance = np.linalg.norm(coord1 - coord2)
                if distance > 1.1:  # Not adjacent (with small tolerance)
                    errors += 1
            elif self.config.topology == Topology.TOROIDAL:
                # For toroidal, check wrap-around distance
                dx = np.abs(coord1[0] - coord2[0])
                dy = np.abs(coord1[1] - coord2[1])
                dx = min(dx, self.config.width - dx)
                dy = min(dy, self.config.height - dy)
                distance = np.sqrt(dx**2 + dy**2)
                if distance > np.sqrt(2):  # Not adjacent
                    errors += 1
            else:  # RECTANGULAR
                # Check if BMU and 2nd BMU are adjacent
                distance = np.linalg.norm(coord1 - coord2)
                if distance > np.sqrt(2):  # Not adjacent
                    errors += 1

        return errors / len(data)

    # ARCHITECTURE FIX #3: Save/Load functionality with error handling
    def save(self, filepath: str):
        """Save trained model to file"""
        save_data = {
            "config": self.config.to_dict(),
            "weights": self.weights_flat,
            "metadata": self.metadata,
            "training_state": self.training_state,
        }

        try:
            with open(filepath, "wb") as f:
                pickle.dump(save_data, f)

            if self.verbose:
                print(f"Model saved to {filepath}")
        except (IOError, OSError) as e:
            raise IOError(f"Failed to save model to {filepath}: {e}")

    @classmethod
    def load(cls, filepath: str) -> "SOM":
        """Load trained model from file"""
        try:
            with open(filepath, "rb") as f:
                save_data = pickle.load(f)
        except (IOError, OSError) as e:
            raise IOError(f"Failed to load model from {filepath}: {e}")

        # Reconstruct SOM
        config = SOMConfig.from_dict(save_data["config"])
        som = cls(config, verbose=False)
        som.weights_flat = save_data["weights"]
        som.metadata = save_data["metadata"]
        som.training_state = save_data["training_state"]

        # FIX: Rebuild tree with correct metric
        som._create_tree()

        return som

    # ARCHITECTURE FIX #7: Get comprehensive info
    def get_info(self) -> Dict:
        """Get comprehensive information about the SOM"""
        return {
            "config": self.config.to_dict(),
            "metadata": self.metadata,
            "shape": (self.config.width, self.config.height),
            "n_neurons": self.n_neurons,
            "n_features": self.config.n_features,
            "total_epochs": self.metadata["total_epochs"],
            "total_samples": self.metadata["total_samples_seen"],
        }


# Convenience function for backward compatibility
def train(input_data, n_max_iterations, width, height, **kwargs):
    """
    Backward compatible training function

    All previous 23 issues fixed plus 9 architecture improvements
    """
    # FIX: Check for NaN/inf in input data
    if np.any(np.isnan(input_data)) or np.any(np.isinf(input_data)):
        raise ValueError("Input data contains NaN or infinite values")

    # Create config from parameters
    config = SOMConfig(
        width=width,
        height=height,
        n_features=input_data.shape[1],
        n_iterations=n_max_iterations,
        **{k: v for k, v in kwargs.items() if hasattr(SOMConfig, k)},
    )

    # Create and train SOM
    som = SOM(config, verbose=kwargs.get("show_progress", True))
    som.fit(input_data)

    # Return weights in old format
    if kwargs.get("track_metrics", False):
        history = som.metadata["training_history"]
        errors = [h["metrics"]["qe"] for h in history]
        return som.get_weights(), errors
    else:
        return som.get_weights()


def plot_quantization_error(som, show_plot=True, save_path=None):
    """
    Plot quantization error over training epochs

    Args:
        som: Trained SOM object
        show_plot: Whether to display the plot
        save_path: Path to save the plot image (optional)
    """
    if not hasattr(som, "metadata") or "training_history" not in som.metadata:
        print("No training history available for plotting")
        return

    history = som.metadata["training_history"]
    if not history:
        print("No training history data available")
        return

    epochs = [h["epoch"] for h in history]
    qe_values = [h["metrics"]["qe"] for h in history]

    plt.figure(figsize=(10, 6))
    plt.plot(epochs, qe_values, "b-", linewidth=2, label="Quantization Error")
    plt.xlabel("Epoch")
    plt.ylabel("Quantization Error")
    plt.title("Quantization Error vs Training Epochs")
    plt.grid(True, alpha=0.3)
    plt.legend()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"QE plot saved to {save_path}")

    if show_plot:
        plt.show()
    else:
        plt.close()


def visualize_som_weights(som, show_plot=True, save_path=None):
    """
    Visualize SOM weights as an image

    Args:
        som: Trained SOM object
        show_plot: Whether to display the visualization
        save_path: Path to save the visualization (optional)
    """
    weights = som.get_weights()

    # Handle different numbers of features
    if weights.shape[2] == 1:
        # Single feature - show as grayscale
        img = weights[:, :, 0]
        plt.figure(figsize=(8, 8))
        plt.imshow(img, cmap="viridis", interpolation="nearest")
        plt.colorbar(label="Weight Value")
        plt.title("SOM Weight Visualization (Single Feature)")
    elif weights.shape[2] == 2:
        # Two features - show as 2D color map
        img = np.zeros((weights.shape[0], weights.shape[1], 3))
        img[:, :, 0] = weights[:, :, 0]  # Red channel
        img[:, :, 1] = weights[:, :, 1]  # Green channel
        img[:, :, 2] = 0.5  # Blue channel (constant)

        plt.figure(figsize=(8, 8))
        plt.imshow(img, interpolation="nearest")
        plt.title("SOM Weight Visualization (2 Features as RG)")
    elif weights.shape[2] >= 3:
        # Three or more features - show first 3 as RGB
        img = weights[:, :, :3]
        # Normalize to [0, 1] range
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)

        plt.figure(figsize=(8, 8))
        plt.imshow(img, interpolation="nearest")
        plt.title("SOM Weight Visualization (First 3 Features as RGB)")

    plt.axis("off")

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"SOM visualization saved to {save_path}")

    if show_plot:
        plt.show()
    else:
        plt.close()


if __name__ == "__main__":
    # Set random seed
    np.random.seed(42)

    # Example 1: Basic usage with new architecture and visualization
    print("Example 1: Basic SOM with visualization")
    data = np.random.random((100, 3)).astype(np.float32)

    # ARCHITECTURE FIX #2: Configuration management
    config = SOMConfig(
        width=20,
        height=20,
        n_features=3,
        n_iterations=200,
        init_strategy=InitStrategy.PCA,  # ARCHITECTURE FIX #6
        distance_metric=DistanceMetric.EUCLIDEAN,  # ARCHITECTURE FIX #5
        seed=42,
    )

    # ARCHITECTURE FIX #1: Class-based design
    som = SOM(config)

    # Train the SOM
    som.fit(data)

    # Plot quantization error over epochs
    print("\nPlotting quantization error...")
    plot_quantization_error(som, show_plot=False, save_path="qe_plot.png")

    # Visualize the SOM weights
    print("Visualizing SOM weights...")
    visualize_som_weights(som, show_plot=False, save_path="som_weights.png")

    # ARCHITECTURE FIX #3: Save model
    som.save("trained_som.pkl")
    print("Model saved!")

    # ARCHITECTURE FIX #7: Get comprehensive info
    info = som.get_info()
    print(f"Total epochs trained: {info['total_epochs']}")
    print(f"Total samples seen: {info['total_samples']}")

    # Example 2: Load and continue training (ARCHITECTURE FIX #4)
    print("\nExample 2: Incremental training")

    # Load saved model
    som_loaded = SOM.load("trained_som.pkl")
    print("Model loaded!")

    # Continue training with new data
    new_data = np.random.random((50, 3)).astype(np.float32)
    som_loaded.fit(new_data, n_iterations=100)
    print(
        f"Additional training complete. Total epochs: {som_loaded.metadata['total_epochs']}"
    )

    # Plot updated quantization error
    plot_quantization_error(
        som_loaded, show_plot=False, save_path="qe_plot_continued.png"
    )

    # Example 3: Different initialization strategies with visualization
    print("\nExample 3: Comparing initialization strategies")

    strategies = [InitStrategy.RANDOM, InitStrategy.PCA, InitStrategy.LINEAR]

    for i, strategy in enumerate(strategies):
        config = SOMConfig(
            width=10,
            height=10,
            n_features=3,
            n_iterations=100,
            init_strategy=strategy,
            seed=42,
        )
        som = SOM(config, verbose=False)
        som.fit(data)
        qe = som.quantization_error(data)
        print(f"{strategy.value}: QE = {qe:.4f}")

        # Save visualization for each strategy
        visualize_som_weights(
            som, show_plot=False, save_path=f"som_{strategy.value.lower()}.png"
        )

    # Example 4: Different distance metrics
    print("\nExample 4: Comparing distance metrics")

    metrics = [
        DistanceMetric.EUCLIDEAN,
        DistanceMetric.MANHATTAN,
        DistanceMetric.COSINE,
    ]

    for metric in metrics:
        config = SOMConfig(
            width=10,
            height=10,
            n_features=3,
            n_iterations=100,
            distance_metric=metric,
            seed=42,
        )
        som = SOM(config, verbose=False)
        som.fit(data)
        qe = som.quantization_error(data)
        print(f"{metric.value}: QE = {qe:.4f}")

    # Example 5: Backward compatibility
    print("\nExample 5: Backward compatible function")
    weights = train(data, 100, 10, 10, seed=42, show_progress=False)
    print(f"Weights shape: {weights.shape}")

    # Example 6: Test with different feature counts and visualizations
    print("\nExample 6: Testing with different feature counts")
    for n_features in [1, 2, 3, 5]:
        test_data = np.random.random((50, n_features)).astype(np.float32)
        config = SOMConfig(
            width=8,
            height=8,
            n_features=n_features,
            n_iterations=50,
            init_strategy=InitStrategy.LINEAR,
            seed=42,
        )
        som = SOM(config, verbose=False)
        som.fit(test_data)
        print(f"n_features={n_features}: Training successful")

        # Visualize each case
        visualize_som_weights(
            som, show_plot=False, save_path=f"som_{n_features}features.png"
        )

    # Save final result from original weights
    plt.figure(figsize=(8, 8))
    if weights.shape[2] >= 3:
        img = weights[:, :, :3]
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        plt.imshow(img, interpolation="nearest")
    else:
        plt.imshow(weights[:, :, 0], cmap="viridis", interpolation="nearest")
        plt.colorbar()
    plt.title("Final SOM Visualization")
    plt.axis("off")
    plt.savefig("som_final.png", dpi=300, bbox_inches="tight")
    plt.close()
    print("\nFinal visualization saved as som_final.png")

    print("\nAll visualizations have been saved as PNG files:")
    print("- qe_plot.png: Quantization error over epochs")
    print("- som_weights.png: Main SOM weight visualization")
    print("- qe_plot_continued.png: QE after continued training")
    print("- som_random.png, som_pca.png, som_linear.png: Different init strategies")
    print("- som_1features.png through som_5features.png: Different feature counts")
    print("- som_final.png: Final result")

    # Clean up
    import shutil

    if os.path.exists("checkpoints"):
        shutil.rmtree("checkpoints")
    if os.path.exists("trained_som.pkl"):
        os.remove("trained_som.pkl")

Example 1: Basic SOM with visualization


Training SOM: 100%|██████████| 200/200 [00:03<00:00, 63.74it/s, QE=0.0026, σ=0.010, α=0.0010]



Plotting quantization error...
QE plot saved to qe_plot.png
Visualizing SOM weights...
SOM visualization saved to som_weights.png
Model saved to trained_som.pkl
Model saved!
Total epochs trained: 200
Total samples seen: 20000

Example 2: Incremental training
Model loaded!
Additional training complete. Total epochs: 300
QE plot saved to qe_plot_continued.png

Example 3: Comparing initialization strategies
random: QE = 0.0217
SOM visualization saved to som_random.png
pca: QE = 0.0127
SOM visualization saved to som_pca.png
linear: QE = 0.0124
SOM visualization saved to som_linear.png

Example 4: Comparing distance metrics
euclidean: QE = 0.0217
manhattan: QE = 0.0649
cosine: QE = 0.0001

Example 5: Backward compatible function
Weights shape: (10, 10, 3)

Example 6: Testing with different feature counts
n_features=1: Training successful
SOM visualization saved to som_1features.png
n_features=2: Training successful
SOM visualization saved to som_2features.png
n_features=3: Training success