# SOO - Self Organizing Orrey: Testing Notebook
#### created: 5/28/2025
#### last edited: 5/28/2025


## Imports + Logging Setup

In [1]:
%gui qt
import numpy as np
from typing import Callable, Optional
import logging
from tqdm import tqdm
import torch


# Configure logging for the module
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)  # Set to INFO or DEBUG for more verbose output
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
logger.addHandler(ch)

try:
    from mayavi import mlab
except ImportError:
    mlab = None
    logger.warning("Mayavi is not installed. Visualization will be disabled.")

## SOOMap class

In [2]:
class SOOMap:
    def __init__(self, num_nodes: int, input_dim: int, 
                 init_weight_scale: float = 0.01,
                 init_coord_range: float = 1.0,
                 learning_rate: float = 0.1, 
                 neighbor_radius: float = 0.5,
                 distance_metric: Optional[Callable] = None,
                 device: Optional[torch.device] = None):
        """
        Initialize the Self-Organizing Orrery map.
        
        Parameters:
        - num_nodes: Number of neurons/nodes in the map.
        - input_dim: Dimensionality of input feature vectors.
        - init_weight_scale: Scale (std-dev) for initial random weights (small Gaussian noise).
        - init_coord_range: Range for initial coordinates (they will be uniform in [0, init_coord_range] in each dimension).
        - learning_rate: Initial learning rate for weight updates.
        - neighbor_radius: Initial neighborhood radius in coordinate space for coordinate updates.
        - distance_metric: Function to compute distance between input and weights (defaults to Euclidean if None).
        - device: Torch device to use (CPU or CUDA). If None, will use CPU.
        """
        self.num_nodes = num_nodes
        self.input_dim = input_dim
        self.device = device if device is not None else torch.device("cpu")
        
        # Initialize weight vectors (num_nodes x input_dim) with small Gaussian noise around 0
        self.weights = torch.randn(num_nodes, input_dim, device=self.device) * init_weight_scale
        # Initialize node coordinates (num_nodes x 3) uniformly in [0, init_coord_range]
        self.coordinates = torch.rand(num_nodes, 3, device=self.device) * init_coord_range
        
        # Set initial learning parameters
        self.initial_lr = learning_rate
        self.current_lr = learning_rate
        self.initial_radius = neighbor_radius
        self.current_radius = neighbor_radius
        
        # Set distance metric for BMU (Euclidean by default)
        if distance_metric is None:
            # Euclidean distance: returns distance^2 for efficiency (no need for sqrt for comparison)
            self.distance_metric = lambda w, x: torch.sum((w - x) ** 2, dim=1)
        else:
            self.distance_metric = distance_metric
        
        # Keep track of training progress (for example, last computed convergence value)
        self.last_cv = 0.0
        logger.info(f"Initialized SOOMap with {num_nodes} nodes, input_dim={input_dim}, device={self.device}.")
    
    def find_bmu(self, x: torch.Tensor) -> int:
        """
        Find the Best Matching Unit (BMU) index for input vector x.
        Uses the configured distance metric (Euclidean by default).
        """
        # Ensure x is 1D (single sample)
        if x.dim() != 1:
            x = x.view(-1)
        # Compute distance from x to all weights
        dists = self.distance_metric(self.weights, x)
        # BMU is the index of the minimum distance
        bmu_idx = int(torch.argmin(dists).item())
        return bmu_idx
    
    def update_weights(self, bmu_idx: int, x: torch.Tensor, lr: float):
        """
        Update the weight vector of the BMU (and optionally neighbors) towards the input x using learning rate lr.
        For simplicity, only the BMU's weight is updated in this MVP. (Neighbor weight updates can be added if desired.)
        """
        # Calculate weight change for BMU
        self.weights[bmu_idx] += lr * (x - self.weights[bmu_idx])
        # (Optional extension: update neighbors' weights as well with a smaller factor.)
    
    def update_coordinates(self, bmu_idx: int, neighbor_indices: torch.Tensor, coord_lr: float):
        """
        Update the coordinates of the BMU's neighbors (including BMU itself) towards the BMU's coordinates.
        
        Parameters:
        - bmu_idx: Index of the BMU.
        - neighbor_indices: Indices of neurons considered within the neighborhood radius (including BMU).
        - coord_lr: Learning rate for coordinate updates.
        """
        bmu_coord = self.coordinates[bmu_idx]
        # Move each neighbor's coordinate a step closer to the BMU's coordinate
        for ni in neighbor_indices:
            # delta = fraction of the vector from neighbor to BMU
            self.coordinates[ni] += coord_lr * (bmu_coord - self.coordinates[ni])
        # (BMU itself will have no change if ni == bmu_idx because (bmu_coord - bmu_coord) = 0)
    
    def calc_convergence(self, data_ref: np.ndarray, significance: float = 0.05) -> float:
        """
        Calculate a population-based convergence (embedding accuracy) measure.
        Compares the distribution of neuron weights to the distribution of a reference subset of the input data.
        
        The method checks feature-wise if neurons and data appear to be drawn from the same distribution 
        by comparing their means and variances (a simplified two-sample test per feature).
        
        Returns:
        - Convergence value between 0.0 and 1.0 indicating the fraction of features that are 'embedded' (matching in distribution).
        """
        # Ensure reference data is a numpy array for easy mean/var calculation
        data_ref = np.array(data_ref)
        # Randomly select a set of neurons to compare (mix BMU neighborhoods and others)
        # For simplicity, we'll use *all* neurons here for the distribution comparison in MVP.
        neurons_sample = self.weights.detach().cpu().numpy()
        
        # Compute means and variances for each feature
        data_means = data_ref.mean(axis=0)
        data_vars = data_ref.var(axis=0, ddof=1)  # sample variance
        neuron_means = neurons_sample.mean(axis=0)
        neuron_vars = neurons_sample.var(axis=0, ddof=1)
        n1 = data_ref.shape[0]    # number of data samples
        n2 = neurons_sample.shape[0]  # number of neurons in sample
        
        # Feature embedding test: check if means and variances are similar for each feature
        embedded_features = 0
        for d in range(self.input_dim):
            m1, m2 = data_means[d], neuron_means[d]
            v1, v2 = data_vars[d], neuron_vars[d]
            # Check means difference (using normal approximation)
            mean_diff_ok = abs(m1 - m2) <= 1.96 * np.sqrt(v1/n1 + v2/n2)  # ~95% confidence interval for mean difference
            # Check variance ratio (using a crude F-test approximation)
            ratio = v1 / (v2 + 1e-9)
            if ratio < 1:
                ratio = 1.0 / (ratio + 1e-9)
            var_diff_ok = ratio <= 2.0  # allow variance ratio within factor of 2 (roughly 95% for moderate sample sizes)
            if mean_diff_ok and var_diff_ok:
                embedded_features += 1
        # Fraction of features that appear embedded
        cv_value = embedded_features / float(self.input_dim)
        self.last_cv = cv_value
        return cv_value


## Training Loop

In [3]:
class SOOMap(SOOMap):  # extend the previous class definition
    def train(self, data: torch.Tensor, epochs: int = 100, batch_size: int = 1, 
              ref_sample_size: int = 100, min_lr: float = 0.01, min_radius: float = 0.01):
        """
        Train the SOOMap on the given dataset.
        
        Parameters:
        - data: torch.Tensor of shape (n_samples, input_dim) containing the training data.
        - epochs: Number of epochs to train for.
        - batch_size: Number of samples to process at a time (1 = sequential stochastic updates).
        - ref_sample_size: Size of reference subset for convergence calculation (if <= 0, use full data).
        - min_lr: Minimum learning rate to which to decay.
        - min_radius: Minimum neighborhood radius to which to decay.
        """
        data = data.to(self.device)
        n_samples = data.shape[0]
        
        # Prepare a fixed reference subset of data for convergence measurement
        if ref_sample_size is None or ref_sample_size <= 0 or ref_sample_size > n_samples:
            ref_data_np = data.cpu().numpy()
        else:
            # Randomly sample without replacement
            indices = np.random.choice(n_samples, size=ref_sample_size, replace=False)
            ref_data_np = data[indices].cpu().numpy()
        
        for epoch in range(1, epochs+1):
            # Shuffle data indices for each epoch (stochastic training)
            perm = torch.randperm(n_samples, device=self.device)
            data_shuffled = data[perm]
            
            # Iterate over data in batches
            for i in tqdm(range(0, n_samples, batch_size), desc=f"Epoch {epoch}", leave=False):
                batch = data_shuffled[i:i+batch_size]
                # Find BMUs for all samples in the batch
                # (Compute distances in a vectorized manner for efficiency)
                # batch shape: (batch_size, input_dim)
                # weights shape: (num_nodes, input_dim)
                # Use broadcasting to compute distance of each input to each weight
                # (We will do it sample by sample to update sequentially for simplicity)
                for x in batch:
                    bmu_idx = self.find_bmu(x)
                    # Determine neighbor nodes within current radius
                    bmu_coord = self.coordinates[bmu_idx]
                    # Compute squared distances from BMU to all nodes in coordinate space
                    coord_dists = torch.sum((self.coordinates - bmu_coord) ** 2, dim=1)
                    # Find indices of neighbors within radius^2
                    neighbor_idx = torch.where(coord_dists <= (self.current_radius ** 2))[0]
                    
                    # Update weights and coordinates
                    self.update_weights(bmu_idx, x, lr=self.current_lr)
                    # For coordinate update, use a coordinate learning rate (could tie to current_lr or set separately)
                    coord_lr = self.current_lr  # here we use same value for simplicity
                    self.update_coordinates(bmu_idx, neighbor_idx, coord_lr=coord_lr)
            
            # At end of epoch, evaluate convergence on reference data
            cv = self.calc_convergence(ref_data_np)
            logger.info(f"Epoch {epoch}: convergence (embedding) = {cv*100:.1f}%")
            
            # Adjust learning rate and radius for next epoch (gradual decay towards minimum values)
            # We use the convergence to modulate the decay: as cv approaches 1.0, lr and radius approach their minima.
            # Simple schedule: lr_new = max(min_lr, initial_lr * (1 - cv)); similar for radius.
            self.current_lr = max(min_lr, self.initial_lr * (1.0 - cv))
            self.current_radius = max(min_radius, self.initial_radius * (1.0 - cv))

## Visualization code using Mayavi

In [4]:
# Visualization code using Mayavi

class SOOVisualizer:
    def __init__(self, soo_map: SOOMap, save_frames: bool = False):
        """
        Visualizer for Self-Organizing Orrery.
        - soo_map: an instance of SOOMap.
        - save_frames: whether to save each frame to file (frame_<epoch>.png) when updating.
        """
        self.soo_map = soo_map
        self.save_frames = save_frames
        self.fig = None
        self.points = None
        # If Mayavi is available, set up the figure and initial scatter plot
        if mlab:
            # Use offscreen rendering if no GUI (optional)
            mlab.options.offscreen = False  # Set True for headless saving without showing a window
            self.fig = mlab.figure(size=(600, 600), bgcolor=(0.9, 0.9, 0.9))
            # Plot initial node positions
            coords = self.soo_map.coordinates.detach().cpu().numpy()
            x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
            # Use a scatter plot (points3d). All points same color for now; color them by cluster or weight later if needed.
            self.points = mlab.points3d(x, y, z, scale_factor=0.05, color=(0, 0, 1))
            mlab.view(azimuth=45, elevation=75, distance=5)  # set an initial camera view
            mlab.title("SOO Nodes (Initial)", size=0.4)
    
    def update(self, epoch: int = None):
        """Update the 3D scatter plot to the current coordinates of the SOO map's neurons."""
        if mlab is None or self.points is None:
            return  # no-op if visualization is not available
        coords = self.soo_map.coordinates.detach().cpu().numpy()
        x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
        # Update the points in the existing plot
        self.points.mlab_source.set(x=x, y=y, z=z)
        if epoch is not None:
            mlab.title(f"SOO Nodes (Epoch {epoch})", size=0.4)
        if self.save_frames and epoch is not None:
            mlab.savefig(f"frame_{epoch}.png")
    
    def show(self):
        """Display the visualization window (blocks execution until closed)."""
        if mlab:
            mlab.show()


## Synthetic Data Generator

In [5]:
def generate_blobs(num_samples: int = 500, num_features: int = 2, centers: int = 3, cluster_std: float = 0.1, random_state: int = None):
    """
    Generate isotropic Gaussian blobs for clustering.
    - num_samples: total number of points.
    - num_features: dimension of the feature space.
    - centers: number of blob clusters.
    - cluster_std: standard deviation of each cluster.
    - random_state: seed for reproducibility.
    Returns: numpy array of shape (num_samples, num_features).
    """
    if random_state is not None:
        np.random.seed(random_state)
    # Choose random centers in [0,1] for each cluster
    centroids = np.random.rand(centers, num_features) * 2.0 - 1.0  # in range [-1, 1] for variety
    samples_per_center = num_samples // centers
    data = []
    for i in range(centers):
        # generate samples around centroid
        cov = (cluster_std ** 2) * np.eye(num_features)
        points = np.random.multivariate_normal(mean=centroids[i], cov=cov, size=samples_per_center)
        data.append(points)
    data = np.vstack(data)
    # If num_samples not exactly divisible by centers, generate remaining for the last center
    if data.shape[0] < num_samples:
        remaining = num_samples - data.shape[0]
        points = np.random.multivariate_normal(mean=centroids[-1], cov=(cluster_std ** 2) * np.eye(num_features), size=remaining)
        data = np.vstack([data, points])
    np.random.shuffle(data)
    return data

def generate_spiral(num_samples: int = 500, revolutions: int = 3, noise: float = 0.0):
    """
    Generate a 3D spiral curve dataset.
    - num_samples: number of points along the spiral.
    - revolutions: how many turns the spiral makes.
    - noise: standard deviation of Gaussian noise to add to the points.
    Returns: numpy array of shape (num_samples, 3).
    """
    np.random.seed(None)
    theta = np.linspace(0, 2 * np.pi * revolutions, num_samples)
    z = np.linspace(0, 1, num_samples)  # height from 0 to 1
    x = np.cos(theta)
    y = np.sin(theta)
    data = np.stack([x, y, z], axis=1)
    if noise > 0:
        data += np.random.normal(scale=noise, size=data.shape)
    return data

def generate_ring(num_samples: int = 500, radius: float = 1.0, noise: float = 0.01):
    """
    Generate points in a ring (circle in 2D).
    - num_samples: number of points on the ring.
    - radius: radius of the ring.
    - noise: standard deviation of radial noise.
    Returns: numpy array of shape (num_samples, 2).
    """
    np.random.seed(None)
    angles = np.linspace(0, 2 * np.pi, num_samples, endpoint=False)
    x = radius * np.cos(angles)
    y = radius * np.sin(angles)
    data = np.stack([x, y], axis=1)
    if noise > 0:
        # Add noise by perturbing radius slightly
        radii_noise = np.random.normal(scale=noise, size=num_samples)
        data *= (1 + radii_noise)[:, None]
    return data

## Test Run

In [None]:
# Generate a simple dataset: 2 clusters in 2D
data_np = generate_blobs(num_samples=1000, num_features=2, centers=2, cluster_std=0.1, random_state=42)
data_tensor = torch.tensor(data_np, dtype=torch.float32)

# Instantiate the SOOMap with, say, 50 neurons, input_dim=2
soo = SOOMap(num_nodes=100, input_dim=2, 
             init_weight_scale=0.01, init_coord_range=1.0, 
             learning_rate=0.2, neighbor_radius=0.5,
             device=torch.device("cuda"))

# (Optional) Visualize initial node positions
vis = SOOVisualizer(soo_map=soo, save_frames=False)
vis.update(epoch=0)  # plot initial state (if visualization is enabled)

# Train the SOOMap on the dataset for a few epochs
soo.train(data_tensor, epochs=20, batch_size=1, ref_sample_size=100, min_lr=0.02, min_radius=0.05)

# After training, update visualization to final state
vis.update(epoch=20)
vis.show()  # Show the interactive plot (if running in an environment with GUI support)


Initialized SOOMap with 100 nodes, input_dim=2, device=cuda.
Epoch 1: convergence (embedding) = 0.0%                     
Epoch 2: convergence (embedding) = 0.0%                     
Epoch 3: convergence (embedding) = 0.0%                     
Epoch 4: convergence (embedding) = 0.0%                     
Epoch 5: convergence (embedding) = 0.0%                     
Epoch 6: convergence (embedding) = 0.0%                     
Epoch 7: convergence (embedding) = 0.0%                     
Epoch 8:  11%|█         | 108/1000 [00:00<00:03, 268.59it/s]