Skip to content

Easily create and optimize PyTorch networks as in the Deep Rewiring paper (https://igi-web.tugraz.at/PDF/241.pdf). Install using 'pip install deep_rewire'

License

Notifications You must be signed in to change notification settings

Efficient-Scalable-Machine-Learning/DeepRewire

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DeepRewire

DeepRewire is a PyTorch-based project designed to simplify the creation and optimization of sparse neural networks with the concepts from the Deep Rewiring paper by Bellec et. al. ⚠️ Note: This implementation is not made by any of the authors. Please double-check everything before use.

Overview

DeepRewire provides tools to convert standard neural network parameters into a form that can be optimized using the DEEPR and SoftDEEPR optimizers. This allows for gaining network sparsity during training.

Installation

Install using pip install deep_rewire

Features

  • Conversion Functions: Convert networks to and from rewireable forms.
  • Optimizers: Use DEEPR and SoftDEEPR to optimize sparse networks.
  • Examples: Run provided examples to see the conversion and optimization in action.

Example Usage:

import torch
from deep_rewire import convert, reconvert, SoftDEEPR

# Define your model
model = torch.nn.Sequential(
    torch.nn.Linear(784, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10)
)

# Convert model parameters to rewireable form
rewireable_params, other_params = convert(model)

# Define optimizers
optim1 = SoftDEEPR(rewireable_params, lr=0.05, l1=1e-5) 
optim2 = torch.optim.SGD(other_params, lr=0.05) # Optional, for parameters that are not rewireable

# ... Standard training loop ...

# Convert back to standard form
reconvert(model)
# Model has the same parameters but is now (hopefully) sparse.

examples/softdeepr.py: SoftDEEPR Performance

Functionality

Conversion Functions

convert

deep_rewire.convert(module: nn.Module, handle_biases: str = "second_bias",
                           active_probability: float = None, keep_signs: bool = False)

Converts a PyTorch module into a rewireable form.

  • Parameters:
    • module (nn.Module): The model to convert.
    • handle_biases (str): Strategy to handle biases. Options are 'ignore', 'as_connections', and 'second_bias'.
    • active_probability (float): Probability for connections to be active right after conversion.
    • keep_signs (bool): Keep initial network signs and start with all connections active (for pretrained networks).

reconvert

deep_rewire.reconvert(module: nn.Module)

Converts a rewireable module back into its original form, making its sparsity visible.

  • Parameters:
    • module (nn.Module): The model to convert.

Optimizers

DEEPR

deep_rewire.DEEPR(params, nc=required, lr, l1, reset_val, temp)

The DEEPR algorithm keeps a fixed number of connections, which when becoming inactive, new connections are activated randomly to keep the same connectivity.

  • nc (int): Fixed number of active connections.
  • lr (float): Learning rate.
  • l1 (float): L1 regularization term.
  • reset_val (float): Value for newly activated parameters.
  • temp (float): Temperature affecting noise magnitude.

SoftDEEPR

deep_rewire.SoftDEEPR(params, lr=0.05, l1=1e-5, temp=None, min_weight=None)

The SoftDEEPR algorithm has no fixed amount of connections, but also adds noise to its inactive connections to randomly activate them.

  • lr (float): Learning rate.

  • l1 (float): L1 regularization term.

  • temp (float): Temperature affecting noise magnitude.

  • min_weight (float): Minimum value for inactive parameters.

SoftDEEPRWrapper

deep_rewire.SoftDEEPRWrapper(params, base_optim, l1=1e-5, temp=None, min_weight=None, **optim_kwargs)

Uses the SoftDEEPR algorithm regarding keeping the connections sparse but updates the parameters using any chosen torch optimizer (SGD, Adam..).

  • base_optim (torch.optim.Optimizer): The basic optimizer to use for updating the parameters

  • l1 (float): L1 regularization term.

  • temp (float): Temperature affecting noise magnitude.

  • min_weight (float): Minimum value for inactive parameters.

  • **optim_kwargs: Arguments for the base optimizer

Contributing

Contributions are welcome! Please open an issue or submit a pull request for any improvements and fix my mistakes :).

License

This project is licensed under the MIT License.

Acknowledgements

  • Guillaume Bellec, David Kappel, Wolfgang Maass, Robert Legenstein for their paper on Deep Rewiring.

For more details, refer to their Deep Rewiring paper or their TensorFlow tutorial.

About

Easily create and optimize PyTorch networks as in the Deep Rewiring paper (https://igi-web.tugraz.at/PDF/241.pdf). Install using 'pip install deep_rewire'

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages