<a href="https://colab.research.google.com/github/CogNetSys/stabilai/blob/main/stabilai_mvp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# StabilAI MVP: Surge-Collapse Training with Entropy Dynamics

This notebook demonstrates the integration of **Surge-Collapse Training**, **StableMax**, and **Orthogonal Gradient (⊥Grad)** to enhance AI model training robustness and generalization. The MVP showcases data generation, model training, and evaluation using synthetic data.


In [1]:
# Install necessary packages
!pip install torch torchvision
!pip install transformers
!pip install scikit-learn
!pip install matplotlib seaborn
!pip install numpy
!pip install tqdm
!pip install streamlit
!pip install pyngrok

Collecting jupyter-dash
  Downloading jupyter_dash-0.4.2-py3-none-any.whl.metadata (3.6 kB)
Collecting dash (from jupyter-dash)
  Downloading dash-2.18.2-py3-none-any.whl.metadata (10 kB)
Collecting retrying (from jupyter-dash)
  Downloading retrying-1.3.4-py3-none-any.whl.metadata (6.9 kB)
Collecting ansi2html (from jupyter-dash)
  Downloading ansi2html-1.9.2-py3-none-any.whl.metadata (3.7 kB)
Collecting flask (from jupyter-dash)
  Downloading flask-3.0.3-py3-none-any.whl.metadata (3.2 kB)
Collecting Werkzeug<3.1 (from dash->jupyter-dash)
  Downloading werkzeug-3.0.6-py3-none-any.whl.metadata (3.7 kB)
Collecting dash-html-components==2.0.0 (from dash->jupyter-dash)
  Downloading dash_html_components-2.0.0-py3-none-any.whl.metadata (3.8 kB)
Collecting dash-core-components==2.0.0 (from dash->jupyter-dash)
  Downloading dash_core_components-2.0.0-py3-none-any.whl.metadata (2.9 kB)
Collecting dash-table==5.0.0 (from dash->jupyter-dash)
  Downloading dash_table-5.0.0-py3-none-any.whl.metad

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import numpy as np

In [None]:
class StableMax(nn.Module):
    """
    A numerically stable alternative to softmax.
    """
    def __init__(self, dim=-1):
        super(StableMax, self).__init__()
        self.dim = dim

    def forward(self, x):
        # Subtract max for numerical stability
        x_max, _ = torch.max(x, dim=self.dim, keepdim=True)
        x = x - x_max
        return F.softmax(x, dim=self.dim)


In [None]:
class OrthogonalGrad(torch.optim.Optimizer):
    """
    Orthogonal Gradient Optimizer to prevent naive logit scaling.
    Projects gradients orthogonal to the weight vectors.
    """
    def __init__(self, params, lr=1e-3, weight_decay=1e-5):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super(OrthogonalGrad, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            weight_decay = group['weight_decay']
            for param in group['params']:
                if param.grad is None:
                    continue
                grad = param.grad.data
                if weight_decay != 0:
                    grad = grad.add(param.data, alpha=weight_decay)

                # Project out the component parallel to the parameter vector
                param_norm = param.data.norm()
                if param_norm > 0:
                    parallel_component = (torch.dot(param.data.view(-1), grad.view(-1)) / (param_norm ** 2)) * param.data
                    grad = grad - parallel_component

                # Update parameters
                param.data -= lr * grad

        return loss


In [None]:
def collapse_weights(model, sparsity=0.5):
    """
    Prune weights below the given sparsity threshold.
    """
    with torch.no_grad():
        for param in model.parameters():
            threshold = torch.quantile(torch.abs(param), sparsity)
            param[param.abs() < threshold] = 0

def reexpand_weights(model, recovery_rate=0.1):
    """
    Re-expand pruned weights by injecting random noise scaled by recovery_rate.
    """
    with torch.no_grad():
        for param in model.parameters():
            mask = param == 0
            param[mask] = torch.randn(mask.sum(), device=param.device) * recovery_rate


In [None]:
def calculate_entropy(targets):
    """
    Calculate entropy of target labels.
    """
    counts = Counter(targets.tolist())
    total = sum(counts.values())
    probabilities = [count / total for count in counts.values()]
    entropy = -sum(p * math.log2(p + 1e-10) for p in probabilities)
    return entropy

def calculate_metrics(labels, preds, probs):
    """
    Calculate evaluation metrics.
    """
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    auroc = roc_auc_score(labels, probs)
    ap = average_precision_score(labels, probs)
    return {'precision': precision, 'recall': recall, 'f1': f1, 'auroc': auroc, 'ap': ap}


In [None]:
class SurgeCollapseNet(nn.Module):
    """
    Neural Network incorporating StableMax activation.
    """
    def __init__(self, input_size=128, hidden_size=256, output_size=128):
        super(SurgeCollapseNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.stablemax = StableMax(dim=1)  # Replaces typical softmax
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.stablemax(x)
        x = self.fc3(x)
        return x


In [None]:
def create_dummy_dataloader(batch_size=64, input_size=128, output_size=128, num_samples=5000):
    """
    Create a DataLoader with synthetic data.
    """
    inputs = torch.randn(num_samples, input_size)
    targets = torch.randint(0, output_size, (num_samples,))
    dataset = TensorDataset(inputs, targets)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
def train_model(
    model, train_loader, val_loader, optimizer, criterion,
    num_epochs=10, collapse_interval=100, surge_interval=200,
    collapse_sparsity=0.5, surge_recovery=0.1, device='cpu'
):
    """
    Train the model with Surge-Collapse dynamics.
    """
    model.to(device)
    best_f1 = 0.0
    loss_history = []
    entropy_history = []

    for epoch in range(1, num_epochs + 1):
        model.train()
        running_loss = 0.0
        total_entropy = 0.0

        for step, (inputs, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            batch_entropy = calculate_entropy(targets)
            total_entropy += batch_entropy

            # Surge-Collapse Dynamics
            if (step + 1) % collapse_interval == 0:
                collapse_weights(model, sparsity=collapse_sparsity)
            if (step + 1) % surge_interval == 0:
                reexpand_weights(model, recovery_rate=surge_recovery)

        avg_loss = running_loss / len(train_loader)
        avg_entropy = total_entropy / len(train_loader)
        loss_history.append(avg_loss)
        entropy_history.append(avg_entropy)

        # Validation
        model.eval()
        val_loss = 0.0
        all_labels = []
        all_preds = []
        all_probs = []
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc="Validation"):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                probs = torch.softmax(outputs, dim=1)[:,1]
                preds = torch.argmax(outputs, dim=1)

                all_labels.extend(targets.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        metrics = calculate_metrics(all_labels, all_preds, all_probs)

        print(f"Epoch {epoch}: Train Loss={avg_loss:.4f}, Train Entropy={avg_entropy:.4f}, Val Loss={avg_val_loss:.4f}, F1={metrics['f1']:.4f}")

        # Save best model
        if metrics['f1'] > best_f1:
            best_f1 = metrics['f1']
            torch.save(model.state_dict(), 'models/best_model.pth')
            print("Best model saved.")

    # Plot training metrics
    plot_metrics(loss_history, entropy_history, title="Training Loss and Entropy Over Epochs")
    return loss_history, entropy_history

In [None]:
# Create DataLoaders
batch_size = 64
input_size = 128
output_size = 128
num_samples = 5000

train_loader = create_dummy_dataloader(batch_size, input_size, output_size, num_samples=4000)
val_loader = create_dummy_dataloader(batch_size, input_size, output_size, num_samples=1000)

# Initialize Model, Optimizer, and Criterion
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SurgeCollapseNet(input_size=input_size, hidden_size=256, output_size=output_size)
optimizer = OrthogonalGrad(model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

# Train the Model
loss_history, entropy_history = train_model(
    model, train_loader, val_loader, optimizer, criterion,
    num_epochs=10, collapse_interval=100, surge_interval=200,
    collapse_sparsity=0.5, surge_recovery=0.1, device=device
)