### Lambda Learning Rate

In [1]:
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler

lr = 0.1
model = nn.Linear(10,1)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

lambda1 = lambda epoch: epoch / 10

scheduler = lr_scheduler.LambdaLR(optimizer, lambda1)

print(optimizer.state_dict())

for epoch in range(5):
    #loss.backward()
    optimizer.step()
    # validate (...)
    scheduler.step()
    print(optimizer.state_dict()['param_groups'][0]['lr'])

{'state': {}, 'param_groups': [{'lr': 0.0, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'initial_lr': 0.1, 'params': [0, 1]}]}
0.010000000000000002
0.020000000000000004
0.03
0.04000000000000001
0.05


### Multiplicative Learning Rate

In [2]:
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler

lr = 0.1
model = nn.Linear(10,1)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

lambda1 = lambda epoch: 0.95

scheduler = lr_scheduler.MultiplicativeLR(optimizer, lambda1)

print(optimizer.state_dict())

for epoch in range(5):
    #loss.backward()
    optimizer.step()
    # validate (...)
    scheduler.step()
    print(optimizer.state_dict()['param_groups'][0]['lr'])

{'state': {}, 'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'initial_lr': 0.1, 'params': [0, 1]}]}
0.095
0.09025
0.0857375
0.08145062499999998
0.07737809374999999


### Step Learning Rate

In [5]:
# Assuming optimizer uses lr = 0.05 for all groups
# lr = 0.05 if epoch < 30
# lr = 0.005 if 30 <= epoch < 60
# lr = 0.0005 if 60 <= epoch < 90
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler

lr = 0.1
model = nn.Linear(10,1)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

print(optimizer.state_dict())

for epoch in range(5):
    #loss.backward()
    optimizer.step()
    # validate (...)
    scheduler.step()
    print(optimizer.state_dict()['param_groups'][0]['lr'])

{'state': {}, 'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'initial_lr': 0.1, 'params': [0, 1]}]}
0.1
0.1
0.010000000000000002
0.010000000000000002
0.010000000000000002
