In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import math
from tqdm import tqdm

# --- Enhanced FibMarble Optimizer V3 ---
class FibMarbleOptimizerV3(torch.optim.Optimizer):
    def __init__(self, params, lr=0.01, beta=0.9, max_spin=0.3, min_spin=0.01,
                 angle_decay=0.97, friction=0.01, fib_period=8, max_speed=1.0, min_speed=0.01):
        defaults = dict(lr=lr, beta=beta, max_spin=max_spin, min_spin=min_spin,
                       angle_decay=angle_decay, friction=friction,
                       fib_period=fib_period, max_speed=max_speed, min_speed=min_speed)
        super().__init__(params, defaults)

        # Initialize Fibonacci sequence with enough values
        self.fib_period = fib_period
        self.fib_seq = [1, 1]
        while len(self.fib_seq) < fib_period:
            self.fib_seq.append(self.fib_seq[-1] + self.fib_seq[-2])

        self.step_count = 0
        self.base_angle = math.pi / 4  # 45° for balanced exploration

    def step(self, closure=None):
        self.step_count += 1

        # Ensure Fibonacci sequence is long enough
        if self.step_count >= len(self.fib_seq):
            next_val = min(self.fib_seq[-1] + self.fib_seq[-2], 1000)
            self.fib_seq.append(next_val)

        fib_ratio = self.fib_seq[self.step_count] / (self.fib_seq[self.step_count-1] + 1e-8)

        for group in self.param_groups:
            # Fibonacci-aware learning rate adjustment
            cycle_pos = self.step_count % group['fib_period']
            fib_idx = min(cycle_pos, len(self.fib_seq)-1)
            current_lr = group['lr'] * (self.fib_seq[fib_idx] / self.fib_seq[group['fib_period']-1])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data
                state = self.state[p]

                # Initialize state
                if 'momentum_buffer' not in state:
                    state['momentum_buffer'] = torch.zeros_like(p.data)
                    state['grad_var'] = torch.zeros_like(p.data)
                    state['spin_buffer'] = torch.zeros_like(p.data)
                    state['prev_grad'] = torch.zeros_like(p.data)
                    if group['max_spin'] > 0:  # Only initialize if using spin
                        state['angular_momentum'] = torch.zeros_like(p.data)

                # Adaptive spin calculation
                grad_var = state['grad_var']
                grad_var.mul_(group['beta']).addcmul_(grad, grad, value=1-group['beta'])

                # Normalized spin strength (0-1 range)
                spin_strength = (grad_var / (grad_var.mean() + 1e-8)).clamp(0, 1)
                current_spin = group['min_spin'] + (group['max_spin'] - group['min_spin']) * spin_strength

                # Update momentum with friction
                state['momentum_buffer'].mul_(group['beta'] * (1 - group['friction'])).add_(grad)

                # Gradient-angle correlated spin
                grad_change = grad - state['prev_grad']
                spin_direction = torch.sign(grad_change) * torch.sign(grad)
                state['spin_buffer'] = group['beta'] * state['spin_buffer'] + \
                                      current_spin * spin_direction * grad.norm()
                state['prev_grad'] = grad.clone()

                # Apply update with current angle
                angle = self.base_angle * (group['angle_decay'] ** self.step_count)
                update = -current_lr * fib_ratio * (
                    state['momentum_buffer'] * math.cos(angle) +
                    state['spin_buffer'] * math.sin(angle)
                )

                # Physical constraints
                update_norm = update.norm()
                if update_norm > group['max_speed']:
                    update.mul_(group['max_speed'] / (update_norm + 1e-8))
                elif update_norm < group['min_speed']:
                    update.mul_(group['min_speed'] / (update_norm + 1e-8))

                # Hybrid mode for small gradients
                if grad.abs().max() < 0.01 and group['max_spin'] > 0:
                    if 'adam_m' not in state:
                        state['adam_m'] = torch.zeros_like(p.data)
                        state['adam_v'] = torch.zeros_like(p.data)

                    state['adam_m'].mul_(group['beta']).add_(grad, alpha=1-group['beta'])
                    state['adam_v'].mul_(group['beta']).addcmul_(grad, grad, value=1-group['beta'])

                    m_hat = state['adam_m'] / (1 - group['beta']**self.step_count)
                    v_hat = state['adam_v'] / (1 - group['beta']**self.step_count)

                    update = -current_lr * m_hat / (torch.sqrt(v_hat) + 1e-8)

                p.data.add_(update)

# --- Simple Feedforward Network ---
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# --- Training Function with Progress Bar ---
def train(model, optimizer, train_loader, test_loader, epochs=10, name=""):
    criterion = nn.CrossEntropyLoss()
    acc_list, loss_list = [], []
    test_acc_list, test_loss_list = [], []

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"{name} Epoch {epoch+1}")

        for images, labels in progress_bar:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

        # Validation
        model.eval()
        test_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                outputs = model(images)
                test_loss += criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        avg_loss = total_loss / len(train_loader)
        test_avg_loss = test_loss / len(test_loader)
        accuracy = 100 * correct / total

        acc_list.append(accuracy)
        loss_list.append(avg_loss)
        test_acc_list.append(accuracy)
        test_loss_list.append(test_avg_loss)

        print(f"[{name}] Epoch {epoch+1} | "
              f"Train Loss: {avg_loss:.4f} | Test Loss: {test_avg_loss:.4f} | "
              f"Accuracy: {accuracy:.2f}%")

    return acc_list, loss_list, test_acc_list, test_loss_list

# --- Data Loaders ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# --- Training Setup ---
def reset_weights(m):
    if isinstance(m, nn.Linear):
        m.reset_parameters()

# --- Train with Enhanced FibMarble V3 ---
model_fib = SimpleNN()
model_fib.apply(reset_weights)
optimizer_fib = FibMarbleOptimizerV3(
    model_fib.parameters(),
    lr=0.01,
    beta=0.9,
    max_spin=0.3,
    min_spin=0.01,
    angle_decay=0.97,
    friction=0.01,
    fib_period=8,
    max_speed=1.0,
    min_speed=0.01
)
print("\nTraining with Enhanced FibMarble V3...")
acc_fib, loss_fib, test_acc_fib, test_loss_fib = train(
    model_fib, optimizer_fib, train_loader, test_loader, name="FibMarble V3"
)

# --- Train with Adam ---
model_adam = SimpleNN()
model_adam.apply(reset_weights)
optimizer_adam = torch.optim.Adam(model_adam.parameters(), lr=0.001)
print("\nTraining with Adam...")
acc_adam, loss_adam, test_acc_adam, test_loss_adam = train(
    model_adam, optimizer_adam, train_loader, test_loader, name="Adam"
)

# --- Train with SGD with Momentum ---
model_sgd = SimpleNN()
model_sgd.apply(reset_weights)
optimizer_sgd = torch.optim.SGD(model_sgd.parameters(), lr=0.01, momentum=0.9)
print("\nTraining with SGD+Momentum...")
acc_sgd, loss_sgd, test_acc_sgd, test_loss_sgd = train(
    model_sgd, optimizer_sgd, train_loader, test_loader, name="SGD+Momentum"
)

# --- Print Final Results ---
print("\nFinal Results:")
print(f"FibMarble V3 - Test Accuracy: {test_acc_fib[-1]:.2f}% | Test Loss: {test_loss_fib[-1]:.4f}")
print(f"Adam         - Test Accuracy: {test_acc_adam[-1]:.2f}% | Test Loss: {test_loss_adam[-1]:.4f}")
print(f"SGD+Momentum - Test Accuracy: {test_acc_sgd[-1]:.2f}% | Test Loss: {test_loss_sgd[-1]:.4f}")


Training with Enhanced FibMarble V3...


FibMarble V3 Epoch 1: 100%|██████████| 938/938 [00:29<00:00, 32.19it/s, loss=0.131]


[FibMarble V3] Epoch 1 | Train Loss: 0.5450 | Test Loss: 0.2514 | Accuracy: 92.28%


FibMarble V3 Epoch 2: 100%|██████████| 938/938 [00:32<00:00, 29.30it/s, loss=0.449]


[FibMarble V3] Epoch 2 | Train Loss: 0.2290 | Test Loss: 0.1969 | Accuracy: 94.00%


FibMarble V3 Epoch 3: 100%|██████████| 938/938 [00:32<00:00, 29.22it/s, loss=0.0204]


[FibMarble V3] Epoch 3 | Train Loss: 0.1684 | Test Loss: 0.1604 | Accuracy: 95.05%


FibMarble V3 Epoch 4: 100%|██████████| 938/938 [00:33<00:00, 28.42it/s, loss=0.131]


[FibMarble V3] Epoch 4 | Train Loss: 0.1379 | Test Loss: 0.1339 | Accuracy: 95.87%


FibMarble V3 Epoch 5: 100%|██████████| 938/938 [00:32<00:00, 29.18it/s, loss=0.232]


[FibMarble V3] Epoch 5 | Train Loss: 0.1291 | Test Loss: 0.1381 | Accuracy: 95.81%


FibMarble V3 Epoch 6: 100%|██████████| 938/938 [00:32<00:00, 28.93it/s, loss=0.271]


[FibMarble V3] Epoch 6 | Train Loss: 0.1220 | Test Loss: 0.1349 | Accuracy: 95.84%


FibMarble V3 Epoch 7: 100%|██████████| 938/938 [00:32<00:00, 29.17it/s, loss=0.00426]


[FibMarble V3] Epoch 7 | Train Loss: 0.1158 | Test Loss: 0.1298 | Accuracy: 96.13%


FibMarble V3 Epoch 8: 100%|██████████| 938/938 [00:32<00:00, 29.26it/s, loss=0.205]


[FibMarble V3] Epoch 8 | Train Loss: 0.1185 | Test Loss: 0.1203 | Accuracy: 96.38%


FibMarble V3 Epoch 9: 100%|██████████| 938/938 [00:33<00:00, 28.10it/s, loss=0.0164]


[FibMarble V3] Epoch 9 | Train Loss: 0.1143 | Test Loss: 0.1164 | Accuracy: 96.34%


FibMarble V3 Epoch 10: 100%|██████████| 938/938 [00:33<00:00, 27.87it/s, loss=0.0505]


[FibMarble V3] Epoch 10 | Train Loss: 0.1146 | Test Loss: 0.1219 | Accuracy: 96.37%

Training with Adam...


Adam Epoch 1: 100%|██████████| 938/938 [00:22<00:00, 41.36it/s, loss=0.014]


[Adam] Epoch 1 | Train Loss: 0.2333 | Test Loss: 0.1176 | Accuracy: 96.49%


Adam Epoch 2: 100%|██████████| 938/938 [00:22<00:00, 40.99it/s, loss=0.172]


[Adam] Epoch 2 | Train Loss: 0.0963 | Test Loss: 0.0844 | Accuracy: 97.45%


Adam Epoch 3: 100%|██████████| 938/938 [00:22<00:00, 40.85it/s, loss=0.132]


[Adam] Epoch 3 | Train Loss: 0.0648 | Test Loss: 0.0815 | Accuracy: 97.45%


Adam Epoch 4: 100%|██████████| 938/938 [00:22<00:00, 41.19it/s, loss=0.0046]


[Adam] Epoch 4 | Train Loss: 0.0507 | Test Loss: 0.0779 | Accuracy: 97.62%


Adam Epoch 5: 100%|██████████| 938/938 [00:22<00:00, 42.10it/s, loss=0.0379]


[Adam] Epoch 5 | Train Loss: 0.0379 | Test Loss: 0.1029 | Accuracy: 96.97%


Adam Epoch 6: 100%|██████████| 938/938 [00:22<00:00, 42.31it/s, loss=0.00786]


[Adam] Epoch 6 | Train Loss: 0.0322 | Test Loss: 0.1043 | Accuracy: 96.98%


Adam Epoch 7: 100%|██████████| 938/938 [00:22<00:00, 42.23it/s, loss=0.0445]


[Adam] Epoch 7 | Train Loss: 0.0285 | Test Loss: 0.0841 | Accuracy: 97.74%


Adam Epoch 8: 100%|██████████| 938/938 [00:22<00:00, 42.36it/s, loss=0.0931]


[Adam] Epoch 8 | Train Loss: 0.0229 | Test Loss: 0.1023 | Accuracy: 97.20%


Adam Epoch 9: 100%|██████████| 938/938 [00:22<00:00, 42.29it/s, loss=0.00136]


[Adam] Epoch 9 | Train Loss: 0.0229 | Test Loss: 0.0821 | Accuracy: 98.00%


Adam Epoch 10: 100%|██████████| 938/938 [00:22<00:00, 42.11it/s, loss=0.0125]


[Adam] Epoch 10 | Train Loss: 0.0166 | Test Loss: 0.0831 | Accuracy: 98.04%

Training with SGD+Momentum...


SGD+Momentum Epoch 1: 100%|██████████| 938/938 [00:21<00:00, 43.67it/s, loss=0.0993]


[SGD+Momentum] Epoch 1 | Train Loss: 0.3076 | Test Loss: 0.1296 | Accuracy: 96.11%


SGD+Momentum Epoch 2: 100%|██████████| 938/938 [00:21<00:00, 42.85it/s, loss=0.00585]


[SGD+Momentum] Epoch 2 | Train Loss: 0.1063 | Test Loss: 0.0923 | Accuracy: 97.20%


SGD+Momentum Epoch 3: 100%|██████████| 938/938 [00:21<00:00, 43.06it/s, loss=0.212]


[SGD+Momentum] Epoch 3 | Train Loss: 0.0689 | Test Loss: 0.0764 | Accuracy: 97.59%


SGD+Momentum Epoch 4: 100%|██████████| 938/938 [00:21<00:00, 43.16it/s, loss=0.115]


[SGD+Momentum] Epoch 4 | Train Loss: 0.0491 | Test Loss: 0.0782 | Accuracy: 97.47%


SGD+Momentum Epoch 5: 100%|██████████| 938/938 [00:21<00:00, 43.33it/s, loss=0.00799]


[SGD+Momentum] Epoch 5 | Train Loss: 0.0369 | Test Loss: 0.0765 | Accuracy: 97.62%


SGD+Momentum Epoch 6: 100%|██████████| 938/938 [00:21<00:00, 44.60it/s, loss=0.0463]


[SGD+Momentum] Epoch 6 | Train Loss: 0.0264 | Test Loss: 0.0700 | Accuracy: 97.84%


SGD+Momentum Epoch 7: 100%|██████████| 938/938 [00:21<00:00, 44.51it/s, loss=0.0148]


[SGD+Momentum] Epoch 7 | Train Loss: 0.0195 | Test Loss: 0.0669 | Accuracy: 97.92%


SGD+Momentum Epoch 8: 100%|██████████| 938/938 [00:21<00:00, 43.53it/s, loss=0.00406]


[SGD+Momentum] Epoch 8 | Train Loss: 0.0148 | Test Loss: 0.0670 | Accuracy: 97.99%


SGD+Momentum Epoch 9: 100%|██████████| 938/938 [00:21<00:00, 43.14it/s, loss=0.0149]


[SGD+Momentum] Epoch 9 | Train Loss: 0.0101 | Test Loss: 0.0683 | Accuracy: 98.02%


SGD+Momentum Epoch 10: 100%|██████████| 938/938 [00:21<00:00, 43.35it/s, loss=0.011]
