# GliaGL Training Tutorial

Learn how to train spiking neural networks with GliaGL.

## What You'll Learn
- Creating datasets
- Training configuration
- Running training loops
- Evaluating performance
- Monitoring progress

In [None]:
import glia
import numpy as np
import matplotlib.pyplot as plt

print(f"GliaGL version: {glia.__version__}")

## 1. Create a Network

Start with a simple network to train:

In [None]:
# Create network: 3 inputs, 5 hidden, 3 outputs
net = glia.Network(num_sensory=3, num_neurons=8)

# Add random connections
np.random.seed(42)
n_connections = 12

from_ids = np.random.choice(net.sensory_ids, n_connections).tolist()
to_ids = np.random.choice([n for n in net.neuron_ids if n not in net.sensory_ids], n_connections).tolist()
weights = np.random.randn(n_connections)

net.set_weights(from_ids, to_ids, weights)

print(f"Created network with {net.num_connections} connections")

## 2. Create Training Data

Build a dataset with input sequences and target outputs:

In [None]:
# Create episodes
episodes = []

for i in range(30):
    # Create episode
    ep = glia.EpisodeData()
    
    # Create input sequence
    seq = glia.InputSequence()
    
    # Add timesteps with different patterns
    pattern = i % 3
    if pattern == 0:
        seq.add_timestep({'S0': 100.0, 'S1': 0.0, 'S2': 0.0})
        target = 'N5'  # First output
    elif pattern == 1:
        seq.add_timestep({'S0': 0.0, 'S1': 100.0, 'S2': 0.0})
        target = 'N6'  # Second output
    else:
        seq.add_timestep({'S0': 0.0, 'S1': 0.0, 'S2': 100.0})
        target = 'N7'  # Third output
    
    ep.seq = seq
    ep.target_id = target
    episodes.append(ep)

# Create dataset
dataset = glia.Dataset(episodes)
print(f"Created dataset with {len(dataset)} episodes")

# Split into train/val
train_data, val_data = dataset.split(train_frac=0.8, seed=42)
print(f"Training: {len(train_data)} episodes")
print(f"Validation: {len(val_data)} episodes")

## 3. Configure Training

Set up training hyperparameters:

In [None]:
# Create training configuration
config = glia.create_config(
    lr=0.01,           # Learning rate
    batch_size=4,      # Batch size
    warmup_ticks=10,   # Warmup period
    eval_ticks=50,     # Evaluation period
    reward_pos=1.0,    # Reward for correct
    reward_neg=-0.5,   # Penalty for incorrect
    verbose=True       # Print progress
)

print(f"Training config: {config}")

## 4. Train the Network

Run the training loop:

In [None]:
# Create trainer
trainer = glia.Trainer(net, config)

# Train for 30 epochs
print("\nTraining...")
trainer.train_epoch(train_data, epochs=30, config=config)

print(f"\nFinal training accuracy: {trainer.epoch_accuracy[-1]:.1%}")

## 5. Evaluate on Validation Set

Test the trained network:

In [None]:
# Evaluate each validation episode
correct = 0
results = []

for ep in val_data:
    metrics = trainer.evaluate(ep.seq, config)
    is_correct = (metrics.winner_id == ep.target_id)
    correct += int(is_correct)
    results.append({
        'target': ep.target_id,
        'predicted': metrics.winner_id,
        'correct': is_correct,
        'margin': metrics.margin
    })

val_accuracy = correct / len(val_data)
print(f"\nValidation accuracy: {val_accuracy:.1%}")
print(f"Correct: {correct}/{len(val_data)}")

# Show some results
print("\nSample predictions:")
for r in results[:5]:
    status = "✓" if r['correct'] else "✗"
    print(f"  {status} Target: {r['target']}, Predicted: {r['predicted']}, Margin: {r['margin']:.3f}")

## 6. Visualize Training Progress

Plot training history:

In [None]:
try:
    import glia.viz as viz
    
    # Plot training history
    viz.plot_training_history(trainer, show=True)
    plt.show()
except ImportError:
    # Manual plotting
    plt.figure(figsize=(10, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(trainer.epoch_accuracy)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training Accuracy')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(trainer.epoch_loss)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 7. Analyze Weight Changes

See how training affected the weights:

In [None]:
# Get current weights
from_ids, to_ids, weights = net.get_weights()

print(f"\nWeight statistics after training:")
print(f"  Mean: {weights.mean():.3f}")
print(f"  Std: {weights.std():.3f}")
print(f"  Min: {weights.min():.3f}")
print(f"  Max: {weights.max():.3f}")

# Plot weight distribution
plt.figure(figsize=(8, 4))
plt.hist(weights, bins=20, edgecolor='black', alpha=0.7)
plt.xlabel('Weight Value')
plt.ylabel('Count')
plt.title('Weight Distribution After Training')
plt.grid(True, alpha=0.3)
plt.show()

## 8. Save Trained Network

Save your trained model for later use:

In [None]:
# Save network
net.save('trained_network.net')
print("✓ Trained network saved!")

# Verify by loading
loaded_net = glia.Network.from_file('trained_network.net')
print(f"✓ Loaded network: {loaded_net.num_neurons} neurons, {loaded_net.num_connections} connections")

## Summary

You've learned:
- ✅ Creating datasets with `glia.Dataset`
- ✅ Configuring training with `create_config()`
- ✅ Training with `trainer.train_epoch()`
- ✅ Evaluating with `trainer.evaluate()`
- ✅ Monitoring progress with accuracy/loss
- ✅ Visualizing training history
- ✅ Analyzing weight changes
- ✅ Saving trained models

## Next Steps

- **Evolution**: See `03_evolution.ipynb` for evolutionary training
- **Advanced**: See `04_advanced.ipynb` for custom training loops
- **API Reference**: See `docs/user-guide/API_REFERENCE.md`