Skip to content

LucasBoTang/GradNorm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

a12e5d0 · Sep 4, 2024

History

50 Commits
Aug 29, 2022
Aug 26, 2022
Aug 26, 2022
Sep 4, 2024
Aug 29, 2022
Aug 26, 2022
Aug 29, 2022
Aug 29, 2022

Repository files navigation

PyTorch GradNorm

This is a PyTorch-based implementation of GradNorm: Gradient normalization for adaptive loss balancing in deep multitask networks, which is a gradient normalization algorithm that automatically balances training in deep multitask models by dynamically tuning gradient magnitudes.

The toy example can be found at here.

Algorithm

Dependencies

Usage

Parameters

  • net: a multitask network with task loss
  • layer: layers of the network layers where applying GradNorm on the weights
  • alpha: hyperparameter of restoring force
  • dataloader: training dataloader
  • num_epochs: number of epochs
  • lr1: learning rate of multitask loss
  • lr2: learning rate of weights
  • log: flag of result log

Sample Code

from gradnorm import gradNorm
log_weights, log_loss = gradNorm(net=mtlnet, layer=net.fc4, alpha=0.12, dataloader=dataloader,
                                 num_epochs=100, lr1=1e-5, lr2=1e-4, log=False)

Toy Example (from Original Paper)

Data

Consider T regression tasks trained using standard squared loss onto the functions:

f i ( x ) = σ i tanh ( ( B + ϵ i ) x )

Inputs are dimension 250 and outputs dimension 100, while B and ϵ i are constant matrices with their elements generated IID from N ( 0 ; 10 ) and N ( 0 ; 3.5 ) , respectively. Each task, therefore, shares information in B but also contains task-specific information ϵ i . The σ i sets the scales of the outputs.

from data import toyDataset
dataset = toyDataset(num_data=10000, dim_features=250, dim_labels=100, scalars=[1,100])

Model

A 4-layer fully-connected ReLU-activated network with 100 neurons per layer as a common trunk is used to train our toy example. A final affine transformation layer gives T final predictions.

from model import fcNet, mtlNet
net = fcNet(dim_features=250, dim_labels=100, n_tasks=2) # fc net with multiple heads
mtlnet = mtlNet(net) # multitask net with task loss

Result (10 Tasks)

Releases

No releases published

Packages

No packages published