# Neural Fractal Estimator (NFE) Training

This notebook generates synthetic **Fractional Brownian Motion (fBm)** data and trains a **1D Convolutional Neural Network (Critic)** to predict its **Higuchi Fractal Dimension ($D_H$)**.

This Critic will be used in the **Bicameral Manifold GMoE** to differentiate between 'Logic' (Low $D_H$) and 'Creative' (High $D_H$) manifolds.

## Theory
The Hurst exponent $H$ relates to fractal dimension $D$ by:
$$D = 2 - H$$

- $H=0.5 \rightarrow D=1.5$ (Random Walk / Brownian Motion)
- $H=0.8 \rightarrow D=1.2$ (Persistent / Smooth)
- $H=0.2 \rightarrow D=1.8$ (Anti-persistent / Rough)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Synthetic Data Generation (fBm)
We use the Cholesky method to generate exact fBm paths.

In [None]:
def generate_fbm(n_samples, seq_len, H, device='cpu'):
    """
    Generate fractional Brownian motion using Cholesky Decomposition of the covariance matrix.
    Very accurate but O(N^3) complexity. Good for short sequences (<= 512).
    """
    # 1. Construct Covariance Matrix
    # Cov(B_t, B_s) = 0.5 * (|t|^{2H} + |s|^{2H} - |t-s|^{2H})
    indices = torch.arange(seq_len, dtype=torch.float32, device=device)
    t = indices.unsqueeze(0)  # (1, seq_len)
    s = indices.unsqueeze(1)  # (seq_len, 1)

    # For computation stability, add epsilon to diagonal
    covariance = 0.5 * (torch.abs(t)**(2*H) + torch.abs(s)**(2*H) - torch.abs(t-s)**(2*H))
    covariance += torch.eye(seq_len, device=device) * 1e-6

    # 2. Cholesky Decomposition
    L = torch.linalg.cholesky(covariance)

    # 3. Generate Noise
    noise = torch.randn(n_samples, seq_len, device=device)

    # 4. Generate Paths
    # path = L @ noise
    fbm = (L @ noise.unsqueeze(-1)).squeeze(-1)
    return fbm

def create_dataset(samples_per_h=1000, seq_len=128, h_range=(0.1, 0.9)):
    data = []
    labels = []

    h_values = np.linspace(h_range[0], h_range[1], 20)

    print(f"Generating dataset with {len(h_values) * samples_per_h} samples...")

    for H in tqdm(h_values, desc="Generating per H"):
        # Generate batch
        signals = generate_fbm(samples_per_h, seq_len, H, device='cpu') # Gen on CPU to save VRAM

        # Normalize signals to mean 0, std 1
        mean = signals.mean(dim=1, keepdim=True)
        std = signals.std(dim=1, keepdim=True)
        signals = (signals - mean) / (std + 1e-8)

        # Calculate target D = 2 - H
        d_target = 2.0 - H

        data.append(signals)
        labels.append(torch.full((samples_per_h,), d_target))

    X = torch.cat(data, dim=0)
    y = torch.cat(labels, dim=0)

    # Shuffle
    perm = torch.randperm(X.size(0))
    return X[perm], y[perm]

# Generate Training Data
SEQ_LEN = 128
X_train, y_train = create_dataset(samples_per_h=2000, seq_len=SEQ_LEN)
train_ds = TensorDataset(X_train.unsqueeze(1), y_train.unsqueeze(1))
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)

# Generate Validation Data
X_val, y_val = create_dataset(samples_per_h=200, seq_len=SEQ_LEN)
val_ds = TensorDataset(X_val.unsqueeze(1), y_val.unsqueeze(1))
val_loader = DataLoader(val_ds, batch_size=256)

print(f"Training set shape: {X_train.shape}")

In [None]:
# Visualize some samples
plt.figure(figsize=(12, 4))
for i in range(3):
    plt.plot(X_train[i].numpy(), label=f"D={y_train[i].item():.2f}")
plt.legend()
plt.title("Sample fBm Signals")
plt.show()

## 2. Neural Critic Definition
A 1D CNN to estimate $D_H$ from the waveform.

In [None]:
class HiguchiFDCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_blocks = nn.Sequential(
            # Block 1
            nn.Conv1d(1, 32, kernel_size=7, padding=3),
            nn.BatchNorm1d(32),
            nn.SiLU(),
            nn.MaxPool1d(2),

            # Block 2
            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm1d(64),
            nn.SiLU(),
            nn.MaxPool1d(2),

            # Block 3
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.SiLU(),
            nn.AdaptiveAvgPool1d(1)
        )

        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.SiLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        feat = self.conv_blocks(x)
        return self.head(feat)

## 3. Training Loop

In [None]:
model = HiguchiFDCritic().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.MSELoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

epochs = 30
best_val_loss = float('inf')

for epoch in range(epochs):
    model.train()
    train_loss = 0

    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        pred = model(x_batch)
        loss = criterion(pred, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0
    preds = []
    actuals = []

    with torch.no_grad():
        for x_batch, y_batch in val_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            pred = model(x_batch)
            val_loss += criterion(pred, y_batch).item()

            preds.extend(pred.cpu().flatten().numpy())
            actuals.extend(y_batch.cpu().flatten().numpy())

    train_loss /= len(train_loader)
    val_loss /= len(val_loader)
    scheduler.step(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "nfe_critic.pt")

    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f}")

print(f"Training complete. Best Validation Loss: {best_val_loss:.5f}")

In [None]:
# Check Accuracy
plt.figure(figsize=(6, 6))
plt.scatter(actuals, preds, alpha=0.1, s=1)
plt.plot([1, 2], [1, 2], 'r--')
plt.xlabel("True D_H")
plt.ylabel("Predicted D_H")
plt.title("Neural Critic Accuracy")
plt.grid(True)
plt.show()