In [None]:
#| default_exp prune.prune_callback

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.prune.pruner import *
from fasterai.core.criteria import *
from fasterai.core.schedule import *

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
#| export
class PruneCallback(Callback):
    def __init__(self, sparsity:int, context:str, criteria:Callable, schedule:Callable, model:nn.Module=None, round_to:int=None, layer_type:nn.Module=nn.Conv2d):
        store_attr()
        self.sparsity = listify(self.sparsity)

    def before_fit(self):
        print(f'Pruning until a sparsity of {self.sparsity}%')
        model = self.model if self.model else self.learn.model
        self.pruner = Pruner(model, self.context, self.criteria)

    def before_batch(self):
        self.current_sparsity = self.schedule(self.sparsity, round(self.pct_train,3))
        if self.schedule.pruned and self.training:
            self.pruner.prune_model(self.current_sparsity[0], self.round_to)

    def after_step(self):
        self.schedule.after_pruned()

    def after_epoch(self):
        sparsity_str = [float(f"%0.2f"%sp) for sp in self.current_sparsity]
        print(f'Sparsity at the end of epoch {self.epoch}: {sparsity_str}%')

    def after_fit(self):
        print(f'Final Sparsity: {self.schedule.current_sparsity:}%')
        self.schedule.reset()

In [None]:
show_doc(PruneCallback)

---

### PruneCallback

>      PruneCallback (sparsity:int, context:str, criteria:Callable,
>                     schedule:Callable,
>                     model:torch.nn.modules.module.Module=None,
>                     round_to:int=None,
>                     layer_type:torch.nn.modules.module.Module=<class
>                     'torch.nn.modules.conv.Conv2d'>)

Basic class handling tweaks of the training loop by changing a `Learner` in various events