-
Notifications
You must be signed in to change notification settings - Fork 0
Pruning
Gaurav14cs17 edited this page Jun 21, 2026
·
1 revision
Pruning removes redundant parameters from a model to reduce size and computation.
Removes individual weights based on importance (magnitude).
from flashoptim import FlashOptim, UnstructuredPruner
model = FlashOptim("pretrained/model.pth")
pruner = UnstructuredPruner(
sparsity=0.5,
method="magnitude",
iterative=True,
iterations=3,
)
pruned = pruner.prune(model)| Method | Description |
|---|---|
magnitude |
Remove smallest absolute weights |
random |
Random weight removal (baseline) |
gradient |
Remove by gradient magnitude |
Removes entire channels or filters for actual speedup without sparse hardware.
from flashoptim import FlashOptim, StructuredPruner
model = FlashOptim("pretrained/model.pth")
pruner = StructuredPruner(
sparsity=0.3,
criterion="l1_norm",
granularity="channel",
)
pruned = pruner.prune(model)Find sparse subnetworks that train to full accuracy from initialization.
from flashoptim.pruning import LotteryTicketPruner
pruner = LotteryTicketPruner(
target_sparsity=0.8,
rounds=5,
rewind_epoch=2,
)
ticket = pruner.find_ticket(model, train_data="data/train/")from flashoptim import UnstructuredPruner, Trainer
pruner = UnstructuredPruner(sparsity=0.5)
pruned = pruner.prune(model)
trainer = Trainer(epochs=10, lr=0.001)
finetuned = trainer.train(pruned, data="data/train/")flashoptim prune --config configs/flashoptim_prune_unstructured.yamlFlashOptim — Model optimization toolkit | PyPI | MIT License