Skip to content

Commit

Permalink
Add some docstring and example code
Browse files Browse the repository at this point in the history
  • Loading branch information
Manuscrit committed Feb 2, 2020
1 parent 75e6423 commit 3ebd4fa
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
38 changes: 38 additions & 0 deletions README.md
@@ -0,0 +1,38 @@
# Selective_Backpropagation
from paper Accelerating Deep Learning by Focusing on the Biggest Losers
https://arxiv.org/abs/1910.00762v1

## Code example:
### Without selective backpropagation:
```
...
criterion = nn.CrossEntropyLoss(reduction='none')
...
for x, y in data_loader:
...
y_pred = model(x)
loss = criterion(y_pred, y).mean()
loss.backward()
...
```
### With selective backpropagation:
```
...
criterion = nn.CrossEntropyLoss(reduction='none')
selective_backprop = SelectiveBackPropagation(
criterion,
lambda loss : loss.mean().backward(),
optimizer,
model,
batch_size,
epoch_length=len(data_loader),
loss_selection_threshold=False)
...
for x, y in data_loader:
...
with torch.no_grad():
y_pred = model(x)
not_reduced_loss = criterion(y_pred, y)
selective_backprop.selective_back_propagation(not_reduced_loss, x, y)
...
```
59 changes: 52 additions & 7 deletions selective_back_propagation.py
Expand Up @@ -10,17 +10,62 @@ class SelectiveBackPropagation:
Selective_Backpropagation from paper Accelerating Deep Learning by Focusing on the Biggest Losers
https://arxiv.org/abs/1910.00762v1
Without:
y_pred = model(x)
loss = criterion(y_pred, y).mean()
loss.backward()
With:
with torch.no_grad():
...
criterion = nn.CrossEntropyLoss(reduction='none')
...
for x, y in data_loader:
...
y_pred = model(x)
not_reduced_loss = criterion(y_pred, y)
loss = selective_bp.selective_back_propagation(not_reduced_loss, x, y)
loss = criterion(y_pred, y).mean()
loss.backward()
...
With:
...
criterion = nn.CrossEntropyLoss(reduction='none')
selective_backprop = SelectiveBackPropagation(
criterion,
lambda loss : loss.mean().backward(),
optimizer,
model,
batch_size,
epoch_length=len(data_loader),
loss_selection_threshold=False)
...
for x, y in data_loader:
...
with torch.no_grad():
y_pred = model(x)
not_reduced_loss = criterion(y_pred, y)
selective_backprop.selective_back_propagation(not_reduced_loss, x, y)
...
"""
def __init__(self, compute_losses_func, update_weights_func, optimizer, model,
batch_size, epoch_length, loss_selection_threshold=False):
"""
Usage:
```
criterion = nn.CrossEntropyLoss(reduction='none')
selective_backprop = SelectiveBackPropagation(
criterion,
lambda loss : loss.mean().backward(),
optimizer,
model,
batch_size,
epoch_length=len(data_loader),
loss_selection_threshold=False)
```
:param compute_losses_func: the loss function which output a tensor of dim [batch_size] (no reduction to apply).
Example: `compute_losses_func = nn.CrossEntropyLoss(reduction='none')`
:param update_weights_func: the reduction of the loss and backpropagation. Example: `update_weights_func =
lambda loss : loss.mean().backward()`
:param optimizer: your optimizer object
:param model: your model object
:param batch_size: number of images per batch
:param epoch_length: the number of batch per epoch
:param loss_selection_threshold: default to False. Set to a float value to select all images with with loss
higher than loss_selection_threshold. Do not change behavior for loss below loss_selection_threshold.
"""

self.loss_selection_threshold = loss_selection_threshold
self.compute_losses_func = compute_losses_func
Expand Down

0 comments on commit 3ebd4fa

Please sign in to comment.