diff --git a/README.md b/README.md new file mode 100644 index 0000000..5e8ab76 --- /dev/null +++ b/README.md @@ -0,0 +1,38 @@ +# Selective_Backpropagation +from paper Accelerating Deep Learning by Focusing on the Biggest Losers +https://arxiv.org/abs/1910.00762v1 + +## Code example: +### Without selective backpropagation: +``` +... +criterion = nn.CrossEntropyLoss(reduction='none') +... +for x, y in data_loader: + ... + y_pred = model(x) + loss = criterion(y_pred, y).mean() + loss.backward() + ... +``` +### With selective backpropagation: +``` +... +criterion = nn.CrossEntropyLoss(reduction='none') +selective_backprop = SelectiveBackPropagation( + criterion, + lambda loss : loss.mean().backward(), + optimizer, + model, + batch_size, + epoch_length=len(data_loader), + loss_selection_threshold=False) +... +for x, y in data_loader: + ... + with torch.no_grad(): + y_pred = model(x) + not_reduced_loss = criterion(y_pred, y) + selective_backprop.selective_back_propagation(not_reduced_loss, x, y) + ... +``` \ No newline at end of file diff --git a/selective_back_propagation.py b/selective_back_propagation.py index e50eb8e..055988d 100644 --- a/selective_back_propagation.py +++ b/selective_back_propagation.py @@ -10,17 +10,62 @@ class SelectiveBackPropagation: Selective_Backpropagation from paper Accelerating Deep Learning by Focusing on the Biggest Losers https://arxiv.org/abs/1910.00762v1 Without: - y_pred = model(x) - loss = criterion(y_pred, y).mean() - loss.backward() - With: - with torch.no_grad(): + ... + criterion = nn.CrossEntropyLoss(reduction='none') + ... + for x, y in data_loader: + ... y_pred = model(x) - not_reduced_loss = criterion(y_pred, y) - loss = selective_bp.selective_back_propagation(not_reduced_loss, x, y) + loss = criterion(y_pred, y).mean() + loss.backward() + ... + With: + ... + criterion = nn.CrossEntropyLoss(reduction='none') + selective_backprop = SelectiveBackPropagation( + criterion, + lambda loss : loss.mean().backward(), + optimizer, + model, + batch_size, + epoch_length=len(data_loader), + loss_selection_threshold=False) + ... + for x, y in data_loader: + ... + with torch.no_grad(): + y_pred = model(x) + not_reduced_loss = criterion(y_pred, y) + selective_backprop.selective_back_propagation(not_reduced_loss, x, y) + ... """ def __init__(self, compute_losses_func, update_weights_func, optimizer, model, batch_size, epoch_length, loss_selection_threshold=False): + """ + Usage: + ``` + criterion = nn.CrossEntropyLoss(reduction='none') + selective_backprop = SelectiveBackPropagation( + criterion, + lambda loss : loss.mean().backward(), + optimizer, + model, + batch_size, + epoch_length=len(data_loader), + loss_selection_threshold=False) + ``` + + :param compute_losses_func: the loss function which output a tensor of dim [batch_size] (no reduction to apply). + Example: `compute_losses_func = nn.CrossEntropyLoss(reduction='none')` + :param update_weights_func: the reduction of the loss and backpropagation. Example: `update_weights_func = + lambda loss : loss.mean().backward()` + :param optimizer: your optimizer object + :param model: your model object + :param batch_size: number of images per batch + :param epoch_length: the number of batch per epoch + :param loss_selection_threshold: default to False. Set to a float value to select all images with with loss + higher than loss_selection_threshold. Do not change behavior for loss below loss_selection_threshold. + """ self.loss_selection_threshold = loss_selection_threshold self.compute_losses_func = compute_losses_func