DeepRewire is a PyTorch-based project designed to simplify the creation and optimization of sparse neural networks with the concepts from the Deep Rewiring paper by Bellec et. al.
DeepRewire provides tools to convert standard neural network parameters into a form that can be optimized using the DEEPR and SoftDEEPR optimizers. This allows for gaining network sparsity during training.
Install using pip install deep_rewire
- Conversion Functions: Convert networks to and from rewireable forms.
- Optimizers: Use DEEPR and SoftDEEPR to optimize sparse networks.
- Examples: Run provided examples to see the conversion and optimization in action.
import torch
from deep_rewire import convert, reconvert, SoftDEEPR
# Define your model
model = torch.nn.Sequential(
torch.nn.Linear(784, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
)
# Convert model parameters to rewireable form
rewireable_params, other_params = convert(model)
# Define optimizers
optim1 = SoftDEEPR(rewireable_params, lr=0.05, l1=1e-5)
optim2 = torch.optim.SGD(other_params, lr=0.05) # Optional, for parameters that are not rewireable
# ... Standard training loop ...
# Convert back to standard form
reconvert(model)
# Model has the same parameters but is now (hopefully) sparse.
deep_rewire.convert(module: nn.Module, handle_biases: str = "second_bias",
active_probability: float = None, keep_signs: bool = False)
Converts a PyTorch module into a rewireable form.
- Parameters:
module
(nn.Module): The model to convert.handle_biases
(str): Strategy to handle biases. Options are 'ignore', 'as_connections', and 'second_bias'.active_probability
(float): Probability for connections to be active right after conversion.keep_signs
(bool): Keep initial network signs and start with all connections active (for pretrained networks).
deep_rewire.reconvert(module: nn.Module)
Converts a rewireable module back into its original form, making its sparsity visible.
- Parameters:
module
(nn.Module): The model to convert.
deep_rewire.DEEPR(params, nc=required, lr, l1, reset_val, temp)
The DEEPR
algorithm keeps a fixed number of connections, which when becoming inactive, new connections are activated randomly to keep the same connectivity.
nc
(int): Fixed number of active connections.lr
(float): Learning rate.l1
(float): L1 regularization term.reset_val
(float): Value for newly activated parameters.temp
(float): Temperature affecting noise magnitude.
deep_rewire.SoftDEEPR(params, lr=0.05, l1=1e-5, temp=None, min_weight=None)
The SoftDEEPR
algorithm has no fixed amount of connections, but also adds noise to its inactive connections to randomly activate them.
-
lr
(float): Learning rate. -
l1
(float): L1 regularization term. -
temp
(float): Temperature affecting noise magnitude. -
min_weight
(float): Minimum value for inactive parameters.
deep_rewire.SoftDEEPRWrapper(params, base_optim, l1=1e-5, temp=None, min_weight=None, **optim_kwargs)
Uses the SoftDEEPR
algorithm regarding keeping the connections sparse but updates the parameters using any chosen torch optimizer (SGD, Adam..).
-
base_optim
(torch.optim.Optimizer): The basic optimizer to use for updating the parameters -
l1
(float): L1 regularization term. -
temp
(float): Temperature affecting noise magnitude. -
min_weight
(float): Minimum value for inactive parameters. -
**optim_kwargs
: Arguments for the base optimizer
Contributions are welcome! Please open an issue or submit a pull request for any improvements and fix my mistakes :).
This project is licensed under the MIT License.
- Guillaume Bellec, David Kappel, Wolfgang Maass, Robert Legenstein for their paper on Deep Rewiring.
For more details, refer to their Deep Rewiring paper or their TensorFlow tutorial.