# Hybrid Pruning — Paper Config (ResNet-18 / CIFAR-10)

Runs the hybrid one-shot + iterative geometric pruning algorithm with the
paper's ResNet-18 / CIFAR-10 configuration:

| Parameter | Value |
|---|---|
| Model | ResNet-18 (CIFAR) |
| Dataset | CIFAR-10 |
| Initial training epochs | 226 |
| Initial LR | 0.1 |
| Target sparsity | 80 % |
| One-shot ratio | 0.7 (→ 56 % one-shot prune) |
| Iterative step | 2 % of remaining |
| One-shot FT patience | 200 epochs |
| Iterative FT patience | 10 epochs |

In [None]:
import sys, os
# Ensure project root is on the path
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)
print(f"Working directory: {os.getcwd()}")

In [None]:
import torch
print(f"PyTorch {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

## 1. Verify ResNet-18 CIFAR model

In [None]:
from src.model import get_model, count_parameters

model = get_model("resnet18", num_classes=10)
params = count_parameters(model)
print(f"ResNet-18 CIFAR-10 — {params['total']:,} parameters")

# Quick forward pass sanity check
x = torch.randn(2, 3, 32, 32)
y = model(x)
print(f"Input: {x.shape} → Output: {y.shape}")
assert y.shape == (2, 10), "Output shape mismatch!"
print("Model OK ✓")

## 2. Preview pruning schedule

Before the full run, inspect the schedule that `HybridPruningScheduler` will produce.

In [None]:
from src.hybrid import HybridPruningScheduler

scheduler = HybridPruningScheduler(
    target=0.8,
    oneshot_ratio=0.7,
    iterative_step=0.02,
)
steps = scheduler.get_steps()

print(f"Total pruning phases: {len(steps)} (1 one-shot + {len(steps)-1} iterative)")
print(f"\n{'Step':>5}  {'Type':<10}  {'Prune ratio':>12}  {'Cumul. pruned':>14}")
print("-" * 50)

pruned = 0.0
for i, ratio in enumerate(steps):
    remaining = 1.0 - pruned
    if i == 0:
        absolute = ratio  # first step is absolute
    else:
        absolute = ratio * remaining
    pruned += absolute
    label = "one-shot" if i == 0 else f"iter-{i}"
    print(f"{i:>5}  {label:<10}  {ratio:>11.4f}%  {pruned:>13.4f}%")

print(f"\nFinal cumulative pruned: {pruned:.4f} (target: 0.80)")

## 3. Run hybrid pruning experiment

Full run with the paper's configuration. This will:
1. Train ResNet-18 dense for 226 epochs (lr=0.1, cosine schedule)
2. One-shot prune 56% of weights → fine-tune with 200-epoch patience
3. Iteratively prune 2% of remaining → fine-tune with 10-epoch patience  
   …until 80% total sparsity is reached

In [None]:
from src.hybrid import hybrid_pruning

results = hybrid_pruning(
    model_name="resnet18",
    dataset_name="cifar10",
    num_classes=10,
    # --- Paper config ---
    target_sparsity=0.8,
    oneshot_ratio=0.7,
    iterative_step=0.02,
    # --- Initial training ---
    initial_epochs=226,
    initial_lr=0.1,
    # --- Fine-tuning ---
    oneshot_finetune_max_epochs=200,
    oneshot_finetune_patience=200,
    iter_finetune_max_epochs=10,
    iter_finetune_patience=10,
    # --- Shared ---
    batch_size=128,
    momentum=0.9,
    weight_decay=5e-4,
    use_global_pruning=True,
    seed=42,
    device=DEVICE,
    verbose=True,
)

## 4. Results summary

In [None]:
fr = results["final_results"]
print("=" * 50)
print("Hybrid Pruning — Final Results")
print("=" * 50)
print(f"  Initial test accuracy : {fr['initial_test_accuracy']:.2f}%")
print(f"  Best phase test acc   : {fr['best_phase_test_accuracy']:.2f}%")
print(f"  Final test accuracy   : {fr['final_test_accuracy']:.2f}%")
print(f"  Final sparsity        : {fr['final_sparsity']:.2%}")
print(f"  Total time            : {fr['total_time_seconds']:.1f}s "
      f"({fr['total_time_seconds']/3600:.1f}h)")
print("=" * 50)

## 5. Per-phase breakdown

In [None]:
import pandas as pd

rows = []
for p in results["phases"]:
    rows.append({
        "Phase": p["label"],
        "Prune ratio": f"{p['prune_ratio']:.4f}",
        "Sparsity after": f"{p['sparsity_after_prune']:.2%}",
        "FT epochs": p["finetune_epochs_run"],
        "Best acc (%)": f"{p['best_test_acc']:.2f}",
        "Final acc (%)": f"{p['final_test_acc']:.2f}",
    })

df_phases = pd.DataFrame(rows)
df_phases

## 6. Training curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# --- Initial training ---
init = results["initial_training"]
axes[0].plot(init["test_accs"], label="Test acc (initial training)")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Accuracy (%)")
axes[0].set_title("Phase 1: Initial (dense) training")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# --- Fine-tuning per phase ---
for p in results["phases"]:
    axes[1].plot(p["test_accs"], label=f"{p['label']} ({p['sparsity_after_prune']:.0%})")
axes[1].set_xlabel("Fine-tune epoch")
axes[1].set_ylabel("Accuracy (%)")
axes[1].set_title("Fine-tuning curves per pruning phase")
axes[1].legend(fontsize=8)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Accuracy vs Sparsity plot

In [None]:
sparsities = [p["sparsity_after_prune"] for p in results["phases"]]
best_accs = [p["best_test_acc"] for p in results["phases"]]

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(sparsities, best_accs, 'o-', markersize=8, linewidth=2)
ax.axhline(y=fr["initial_test_accuracy"], color='r', linestyle='--',
           label=f'Dense baseline ({fr["initial_test_accuracy"]:.2f}%)')
ax.set_xlabel("Sparsity")
ax.set_ylabel("Best test accuracy (%)")
ax.set_title("Hybrid Pruning: Accuracy vs Sparsity — ResNet-18 / CIFAR-10")
ax.legend()
ax.grid(True, alpha=0.3)

# Format x-axis as percentages
import matplotlib.ticker as mtick
ax.xaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0))

plt.tight_layout()
plt.show()