# GliaGL Advanced Techniques

Advanced patterns for custom workflows and optimization.

## What You'll Learn
- Custom training loops
- Advanced NumPy integration
- Sparse matrix operations
- Performance optimization
- Custom architectures

In [None]:
import glia
import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix, csr_matrix

print(f"GliaGL version: {glia.__version__}")

## 1. Custom Training Loop with Early Stopping

Implement a custom training loop with patience-based early stopping:

In [None]:
def train_with_early_stopping(net, train_data, val_data, config, patience=10):
    """Train with early stopping"""
    trainer = glia.Trainer(net, config)
    
    best_acc = 0.0
    patience_counter = 0
    history = {'train_acc': [], 'val_acc': []}
    
    for epoch in range(100):  # Max 100 epochs
        # Train one epoch
        trainer.train_epoch(train_data, epochs=1, config=config)
        train_acc = trainer.epoch_accuracy[-1]
        
        # Validate
        correct = sum(1 for ep in val_data 
                      if trainer.evaluate(ep.seq, config).winner_id == ep.target_id)
        val_acc = correct / len(val_data)
        
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        
        # Early stopping check
        if val_acc > best_acc:
            best_acc = val_acc
            patience_counter = 0
            net.save('best_checkpoint.net')
            print(f"Epoch {epoch}: val_acc={val_acc:.3f} (new best!)")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break
    
    return history, best_acc

# Test it
net = glia.Network(num_sensory=2, num_neurons=4)
net.set_weights(['S0', 'S1'], ['N2', 'N3'], np.array([1.5, 1.5]))

# Create simple dataset
episodes = []
for i in range(20):
    ep = glia.EpisodeData()
    seq = glia.InputSequence()
    seq.add_timestep({'S0': 100.0 if i % 2 == 0 else 0.0, 'S1': 0.0})
    ep.seq = seq
    ep.target_id = 'N2' if i % 2 == 0 else 'N3'
    episodes.append(ep)

dataset = glia.Dataset(episodes)
train, val = dataset.split(0.8, seed=42)
config = glia.create_config(lr=0.01, batch_size=4)

history, best_acc = train_with_early_stopping(net, train, val, config, patience=5)
print(f"\nBest validation accuracy: {best_acc:.1%}")

## 2. Sparse Matrix Operations

Convert network to sparse adjacency matrix for analysis:

In [None]:
def network_to_sparse_matrix(net):
    """Convert network to sparse adjacency matrix"""
    from_ids, to_ids, weights = net.get_weights()
    
    # Create ID to index mapping
    all_ids = net.neuron_ids
    id_to_idx = {nid: i for i, nid in enumerate(all_ids)}
    
    # Convert to indices
    row = [id_to_idx[fid] for fid in from_ids]
    col = [id_to_idx[tid] for tid in to_ids]
    
    # Create sparse matrix
    n = len(all_ids)
    adj_matrix = coo_matrix((weights, (row, col)), shape=(n, n))
    
    return adj_matrix, all_ids

# Analyze network connectivity
adj_matrix, ids = network_to_sparse_matrix(net)
adj_csr = adj_matrix.tocsr()

# Compute in-degree and out-degree
in_degree = np.array(adj_csr.sum(axis=0)).flatten()
out_degree = np.array(adj_csr.sum(axis=1)).flatten()

print("Network connectivity analysis:")
for nid, in_deg, out_deg in zip(ids, in_degree, out_degree):
    print(f"  {nid}: in={in_deg:.1f}, out={out_deg:.1f}")

# Sparsity
sparsity = 1 - (len(adj_matrix.data) / (len(ids) ** 2))
print(f"\nSparsity: {sparsity:.1%}")

## 3. Batch State Collection (Memory Efficient)

Efficiently collect state over many timesteps:

In [None]:
def collect_state_history(net, n_timesteps, input_fn=None):
    """Efficiently collect state over time"""
    # Pre-allocate arrays
    ids, values, thresholds, leaks = net.get_state()
    n_neurons = len(ids)
    
    value_history = np.zeros((n_timesteps, n_neurons))
    firing_history = []
    
    # Collect without repeated allocations
    for t in range(n_timesteps):
        # Optional input
        if input_fn:
            inputs = input_fn(t)
            if inputs:
                net.inject_dict(inputs)
        
        # Step
        net.step()
        
        # Collect state (zero-copy)
        _, values, _, _ = net.get_state()
        value_history[t] = values
        
        # Track firing
        fired = net.get_firing_neurons()
        if fired:
            firing_history.append((t, fired))
    
    return value_history, firing_history, ids

# Test it with periodic input
net.reset()

def periodic_input(t):
    if t % 20 == 0:
        return {'S0': 100.0, 'S1': 50.0}
    return None

values, firings, ids = collect_state_history(net, 100, periodic_input)

print(f"Collected {len(values)} timesteps")
print(f"Firing events: {len(firings)}")

# Plot membrane voltage over time
plt.figure(figsize=(12, 4))
for i, nid in enumerate(ids):
    plt.plot(values[:, i], label=nid, alpha=0.7)
plt.xlabel('Time')
plt.ylabel('Membrane Voltage')
plt.title('Neuron Activity Over Time')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 4. Vectorized Parameter Updates

Apply transformations to all weights at once:

In [None]:
def apply_weight_regularization(net, l1_lambda=0.01, l2_lambda=0.001):
    """Apply L1 and L2 regularization to weights"""
    from_ids, to_ids, weights = net.get_weights()
    
    # L1: Soft thresholding
    weights = np.sign(weights) * np.maximum(0, np.abs(weights) - l1_lambda)
    
    # L2: Scaling
    weights *= (1 - l2_lambda)
    
    # Remove near-zero weights
    mask = np.abs(weights) > 1e-6
    net.set_weights(
        [f for f, m in zip(from_ids, mask) if m],
        [t for t, m in zip(to_ids, mask) if m],
        weights[mask]
    )
    
    return net

# Test regularization
print(f"Before regularization: {net.num_connections} connections")
apply_weight_regularization(net, l1_lambda=0.1, l2_lambda=0.01)
print(f"After regularization: {net.num_connections} connections")

## 5. Custom Network Architecture

Build a layered feedforward network:

In [None]:
def create_layered_network(n_input, n_hidden, n_output, density=0.5):
    """Create a feedforward layered network"""
    net = glia.Network(num_sensory=n_input, num_neurons=n_hidden + n_output)
    
    from_ids = []
    to_ids = []
    weights = []
    
    # Input → Hidden connections
    for i in range(n_input):
        for h in range(n_hidden):
            if np.random.rand() < density:
                from_ids.append(f'S{i}')
                to_ids.append(f'N{h}')
                weights.append(np.random.randn() * 0.5)
    
    # Hidden → Output connections
    for h in range(n_hidden):
        for o in range(n_output):
            if np.random.rand() < density:
                from_ids.append(f'N{h}')
                to_ids.append(f'N{n_hidden + o}')
                weights.append(np.random.randn() * 0.5)
    
    net.set_weights(from_ids, to_ids, np.array(weights))
    return net

# Create layered network
layered_net = create_layered_network(n_input=3, n_hidden=6, n_output=3, density=0.6)
print(f"Layered network: {layered_net.num_neurons} neurons, {layered_net.num_connections} connections")

# Visualize if possible
try:
    import glia.viz as viz
    viz.plot_network(layered_net, show=True)
    plt.title("Layered Network Architecture")
    plt.show()
except:
    pass

## 6. Performance Profiling

Measure performance of different operations:

In [None]:
import time

def profile_operations(net, n_iterations=1000):
    """Profile network operations"""
    timings = {}
    
    # Step
    start = time.time()
    for _ in range(n_iterations):
        net.step()
    timings['step'] = (time.time() - start) / n_iterations
    
    # Get state
    start = time.time()
    for _ in range(n_iterations):
        net.get_state()
    timings['get_state'] = (time.time() - start) / n_iterations
    
    # Get weights
    start = time.time()
    for _ in range(n_iterations):
        net.get_weights()
    timings['get_weights'] = (time.time() - start) / n_iterations
    
    return timings

# Profile
timings = profile_operations(net, n_iterations=1000)

print("Performance (per operation):")
for op, t in timings.items():
    print(f"  {op}: {t*1000:.3f} ms ({1/t:.0f} ops/sec)")

## Summary

You've learned:
- ✅ Custom training loops with early stopping
- ✅ Sparse matrix operations for analysis
- ✅ Memory-efficient state collection
- ✅ Vectorized parameter updates
- ✅ Custom network architectures
- ✅ Performance profiling

## Key Techniques

- **Zero-copy NumPy**: Direct array access for performance
- **Sparse matrices**: Efficient large network representation
- **Vectorized ops**: Apply transformations to all parameters
- **Custom loops**: Full control over training process

## Next Steps

- **Full API**: See `docs/user-guide/API_REFERENCE.md`
- **Advanced Guide**: See `docs/user-guide/ADVANCED_USAGE.md`
- **NumPy Details**: See `docs/development/numpy_interface.md`