# SynFlow Experiment — ResNet-20 on CIFAR-10 & CIFAR-100

Reproducing iterative data-free pruning from *Pruning neural networks without any data by iteratively conserving synaptic flow* (Tanaka et al., NeurIPS 2020).

| Setting | Value |
|---|---|
| **Architecture** | ResNet-20 (CIFAR variant) |
| **Datasets** | CIFAR-10, CIFAR-100 |
| **Initialization** | Kaiming Normal |
| **Pruning** | SynFlow — 100 iterations, exponential schedule, global scope, data-free |
| **Optimizer** | SGD, momentum 0.9, weight decay 1e-4 |
| **Epochs** | 160 |
| **Batch size** | 128 |
| **Learning rate** | 0.1 → ×0.1 at epochs 80, 120 |
| **Sparsities** | 30 %, 60 % (matching config) |

We compare **Dense** (unpruned) vs SynFlow-pruned networks at each sparsity level.

## 1 — Imports

In [None]:
import sys, os, copy, json, time
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
import matplotlib.pyplot as plt
from tqdm import tqdm

# Project imports
sys.path.append(os.path.abspath('../src'))
from model import resnet20, count_parameters
from synflow import synflow_pruning, apply_synflow_masks, get_synflow_sparsity
from train import train_epochs, evaluate
from util import apply_masks_to_model, create_mask_apply_fn, set_seed

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2 — Configuration

In [None]:
config = {
    # Model / datasets
    'model':        'resnet20',
    'datasets':     ['cifar10', 'cifar100'],

    # SynFlow pruning
    'sparsities':   [0.0, 0.30, 0.60],    # 0.0 = dense baseline
    'synflow_iters': 100,                   # iterative pruning rounds
    'input_shape':  (3, 32, 32),            # CIFAR spatial dims

    # Training
    'epochs':       160,
    'batch_size':   128,
    'lr':           0.1,
    'momentum':     0.9,
    'weight_decay': 1e-4,
    'lr_milestones': [80, 120],
    'lr_gamma':     0.1,

    # Reproducibility
    'seed':         42,

    # Output
    'results_dir':  '../results/synflow',
}

print("=" * 60)
print("EXPERIMENT CONFIGURATION")
print("=" * 60)
for k, v in config.items():
    print(f"  {k:>20s}: {v}")
print("=" * 60)

## 3 — Dataset helpers

In [None]:
def get_loaders(dataset_name, batch_size=128):
    """Return train_loader, test_loader, num_classes for a CIFAR dataset."""
    if dataset_name == 'cifar10':
        mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
        num_classes = 10
        DS = torchvision.datasets.CIFAR10
    else:
        mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
        num_classes = 100
        DS = torchvision.datasets.CIFAR100

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_ds = DS(root='../data', train=True,  download=True, transform=train_transform)
    test_ds  = DS(root='../data', train=False, download=True, transform=test_transform)

    train_loader = DataLoader(train_ds, batch_size=batch_size,
                              shuffle=True, num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=256,
                              shuffle=False, num_workers=2, pin_memory=True)
    print(f"[{dataset_name}] Train: {len(train_ds):,}  Test: {len(test_ds):,}  Classes: {num_classes}")
    return train_loader, test_loader, num_classes

## 4 — Helper: train one configuration end-to-end

Encapsulates: model init → (optional) SynFlow pruning → full 160-epoch training.
Time for both pruning and training is recorded.

In [None]:
def run_single(dataset_name: str, sparsity: float, config: dict):
    """Create model, optionally apply SynFlow, then train for 160 epochs.

    Returns a dict with training history, final accuracy, mask info, and timing.
    """
    set_seed(config['seed'])

    train_loader, test_loader, num_classes = get_loaders(dataset_name, config['batch_size'])

    # ---- Model (Kaiming Normal init) ----
    model = resnet20(num_classes=num_classes).to(device)
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    params = count_parameters(model)
    tag = "dense" if sparsity == 0.0 else f"{sparsity*100:.0f}%"
    print(f"\n{'='*60}")
    print(f"[{dataset_name} / {tag}]  sparsity = {sparsity*100:.0f}%")
    print(f"  Total params: {params['total']:,}")

    total_start = time.time()

    # ---- SynFlow pruning (skip for dense baseline) ----
    masks = None
    apply_fn = None
    prune_time = 0.0
    layer_sparsity = {}

    if sparsity > 0:
        print(f"  Running SynFlow ({config['synflow_iters']} iters, data-free) …")
        prune_start = time.time()

        masks = synflow_pruning(
            model,
            device=device,
            target_sparsity=sparsity,
            num_iters=config['synflow_iters'],
            input_shape=config['input_shape'],
        )
        prune_time = time.time() - prune_start
        print(f"  SynFlow pruning completed in {prune_time:.2f}s")

        # Apply masks in-place
        apply_synflow_masks(model, masks)

        layer_sparsity = get_synflow_sparsity(masks)
        print(f"  Achieved overall sparsity: {layer_sparsity['overall']*100:.2f}%")
        for name, sp in layer_sparsity.items():
            if name != 'overall':
                print(f"    {name:>30s}: {sp*100:.2f}%")

        apply_fn = create_mask_apply_fn(model)

    # ---- Optimizer / scheduler ----
    optimizer = optim.SGD(
        model.parameters(),
        lr=config['lr'],
        momentum=config['momentum'],
        weight_decay=config['weight_decay'],
    )
    scheduler = MultiStepLR(optimizer,
                            milestones=config['lr_milestones'],
                            gamma=config['lr_gamma'])
    criterion = nn.CrossEntropyLoss()

    # ---- Train ----
    print(f"  Training for {config['epochs']} epochs …")
    train_start = time.time()
    history = train_epochs(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=config['epochs'],
        device=device,
        scheduler=scheduler,
        masks=masks,
        apply_mask_fn=apply_fn,
        verbose=True,
    )
    train_time = time.time() - train_start
    total_time = time.time() - total_start

    best_test = max(history['test_accs'])
    final_test = history['final_test_acc']
    print(f"  ✓ Done — best test acc: {best_test:.2f}%, "
          f"final test acc: {final_test:.2f}%")
    print(f"  Pruning: {prune_time:.1f}s | Training: {train_time:.1f}s | Total: {total_time:.1f}s")

    return {
        'dataset': dataset_name,
        'sparsity': sparsity,
        'history': history,
        'best_test_acc': best_test,
        'final_test_acc': final_test,
        'masks': masks,
        'layer_sparsity': layer_sparsity,
        'pruning_time': prune_time,
        'training_time': train_time,
        'total_time': total_time,
    }

## 5 — Run experiments

Train the dense baseline and each SynFlow sparsity level for **both** CIFAR-10 and CIFAR-100.

In [None]:
all_results = {}

for dataset_name in config['datasets']:
    all_results[dataset_name] = {}
    for sparsity in config['sparsities']:
        tag = "dense" if sparsity == 0.0 else f"{sparsity*100:.0f}%"
        result = run_single(dataset_name, sparsity, config)
        all_results[dataset_name][tag] = result

# ---------- save results to disk ----------
os.makedirs(config['results_dir'], exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = os.path.join(config['results_dir'], f"synflow_resnet20_{timestamp}.json")

serialisable = {}
for ds_name, ds_results in all_results.items():
    serialisable[ds_name] = {}
    for tag, res in ds_results.items():
        serialisable[ds_name][tag] = {
            'sparsity': res['sparsity'],
            'best_test_acc': res['best_test_acc'],
            'final_test_acc': res['final_test_acc'],
            'test_accs': res['history']['test_accs'],
            'train_accs': res['history']['train_accs'],
            'train_losses': res['history']['train_losses'],
            'layer_sparsity': {k: v for k, v in res.get('layer_sparsity', {}).items()
                               if k != 'overall'},
            'overall_sparsity': res.get('layer_sparsity', {}).get('overall', 0.0),
            'pruning_time': res['pruning_time'],
            'training_time': res['training_time'],
            'total_time': res['total_time'],
        }

with open(save_path, 'w') as f:
    json.dump(serialisable, f, indent=2)

print(f"\nResults saved to {save_path}")
for ds_name in config['datasets']:
    print(f"\n{'='*60}")
    print(f"SUMMARY — {ds_name.upper()}")
    print(f"{'='*60}")
    print(f"{'Sparsity':>12s}  {'Best Acc':>10s}  {'Final Acc':>10s}  {'Prune (s)':>10s}  {'Total (s)':>10s}")
    print("-" * 60)
    for tag, res in all_results[ds_name].items():
        print(f"{tag:>12s}  {res['best_test_acc']:>9.2f}%  {res['final_test_acc']:>9.2f}%"
              f"  {res['pruning_time']:>9.1f}  {res['total_time']:>9.1f}")

## 6 — Visualise results

**Row 1** (CIFAR-10): test accuracy curves + accuracy-vs-sparsity.
**Row 2** (CIFAR-100): same layout.

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

colors = {
    'dense': '#2E86AB',
    '30%':   '#A23B72',
    '60%':   '#F18F01',
}

for row, dataset_name in enumerate(config['datasets']):
    ds_results = all_results[dataset_name]
    ax_curve = axes[row, 0]
    ax_bar   = axes[row, 1]

    # ---- Left: test accuracy over epochs ----
    for tag, res in ds_results.items():
        epochs = range(1, len(res['history']['test_accs']) + 1)
        ax_curve.plot(epochs, res['history']['test_accs'],
                      label=f"{tag} (best {res['best_test_acc']:.2f}%)",
                      color=colors.get(tag, 'gray'), linewidth=1.5)
    ax_curve.set_xlabel('Epoch', fontsize=12)
    ax_curve.set_ylabel('Test Accuracy (%)', fontsize=12)
    ax_curve.set_title(f'ResNet-20 / {dataset_name.upper()} — Test Accuracy',
                       fontsize=13, fontweight='bold')
    ax_curve.legend(fontsize=10)
    ax_curve.grid(True, alpha=0.3)

    # ---- Right: best accuracy vs sparsity ----
    sp_vals, acc_vals = [], []
    for tag, res in ds_results.items():
        sp_vals.append(res['sparsity'] * 100)
        acc_vals.append(res['best_test_acc'])
    ax_bar.plot(sp_vals, acc_vals, 'o-', linewidth=2, markersize=8, color='#2E86AB')
    for sp, acc in zip(sp_vals, acc_vals):
        ax_bar.annotate(f'{acc:.2f}%', (sp, acc), textcoords='offset points',
                        xytext=(0, 10), ha='center', fontsize=9,
                        bbox=dict(boxstyle='round,pad=0.3', fc='lightyellow', alpha=0.8))
    ax_bar.set_xlabel('Sparsity (%)', fontsize=12)
    ax_bar.set_ylabel('Best Test Accuracy (%)', fontsize=12)
    ax_bar.set_title(f'SynFlow: Accuracy vs Sparsity ({dataset_name.upper()})',
                     fontsize=13, fontweight='bold')
    ax_bar.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(config['results_dir'], 'synflow_resnet20.png'),
            dpi=150, bbox_inches='tight')
plt.show()

# Final table
for ds_name in config['datasets']:
    print(f"\n{'='*60}")
    print(f"FINAL RESULTS — {ds_name.upper()}")
    print(f"{'='*60}")
    print(f"{'Sparsity':>12s}  {'Best Acc':>10s}  {'Final Acc':>10s}  {'Prune (s)':>10s}  {'Total (s)':>10s}")
    print("-" * 60)
    for tag, res in all_results[ds_name].items():
        print(f"{tag:>12s}  {res['best_test_acc']:>9.2f}%  {res['final_test_acc']:>9.2f}%"
              f"  {res['pruning_time']:>9.1f}  {res['total_time']:>9.1f}")