In [2]:
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler

In [6]:
# Lambda LR Sched:
lr = 0.1
model = nn.Linear(10,1)

optimizer = torch.optim.Adam(model.parameters(),lr=lr)
# lambda function dependent of the epoch. Lr is multiplied with output of lambda function in every epoch
lambda_1 = lambda epoch: epoch/10
scheduler = lr_scheduler.LambdaLR(optimizer,lambda_1)

print(optimizer.state_dict())

for epoch in range(5):
    print(f"Epoch: {epoch}")
    # loss.backward()
    optimizer.step()
    # validate(...)
    scheduler.step()
    print(optimizer.state_dict()['param_groups'][0]['lr'])

{'state': {}, 'param_groups': [{'lr': 0.0, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'initial_lr': 0.1, 'params': [0, 1]}]}
Epoch: 0
0.010000000000000002
Epoch: 1
0.020000000000000004
Epoch: 2
0.03
Epoch: 3
0.04000000000000001
Epoch: 4
0.05


In [7]:
# Multiplicative LR Sched:
lr = 0.1
model = nn.Linear(10,1)

optimizer = torch.optim.Adam(model.parameters(),lr=lr)
# Lr is multiplied with this constant and the output is new lr
# For next epoch the output is multiplied again, not initial lr etc...
lambda_1 = lambda epoch: 0.95
scheduler = lr_scheduler.MultiplicativeLR(optimizer,lambda_1)

print(optimizer.state_dict())

for epoch in range(5):
    print(f"Epoch: {epoch}")
    # loss.backward()
    optimizer.step()
    # validate(...)
    scheduler.step()
    print(optimizer.state_dict()['param_groups'][0]['lr'])

{'state': {}, 'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'initial_lr': 0.1, 'params': [0, 1]}]}
Epoch: 0
0.095
Epoch: 1
0.09025
Epoch: 2
0.0857375
Epoch: 3
0.08145062499999998
Epoch: 4
0.07737809374999999


In [17]:
# Step LR Sched: goood one
lr = 0.1
model = nn.Linear(10,1)

optimizer = torch.optim.Adam(model.parameters(),lr=lr)

scheduler = lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.1)

print(optimizer.state_dict())

for epoch in range(100):
    # loss.backward()
    optimizer.step()
    # validate(...)
    scheduler.step()
    if epoch%30==0:
        print(f"Epoch: {epoch}")
        print(f"{optimizer.state_dict()['param_groups'][0]['lr']:.6f}")

{'state': {}, 'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'initial_lr': 0.1, 'params': [0, 1]}]}
Epoch: 0
0.100000
Epoch: 30
0.010000
Epoch: 60
0.001000
Epoch: 90
0.000100


In [20]:
# Milestones LR Sched: fast identical as above
lr = 0.1
model = nn.Linear(10,1)

optimizer = torch.optim.Adam(model.parameters(),lr=lr)
milestones = [30,80,90]
scheduler = lr_scheduler.MultiStepLR(optimizer,milestones=[30,80,90],gamma=0.1)

print(optimizer.state_dict())

for epoch in range(100):
    # loss.backward()
    optimizer.step()
    # validate(...)
    scheduler.step()
    if epoch in milestones:
        print(f"Epoch: {epoch}")
        print(f"{optimizer.state_dict()['param_groups'][0]['lr']:.6f}")

{'state': {}, 'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'initial_lr': 0.1, 'params': [0, 1]}]}
Epoch: 30
0.010000
Epoch: 80
0.001000
Epoch: 90
0.000100


In [23]:
# ReduceLROnPlateau LR Sched: cool one
# Reduce learning rate when a metric has stopped improving. 
# Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates.
# This scheduler reads a metrics quantity and 
# if no improvement is seen for a ‘patience’ number of epochs, the learning rate is reduced.

lr = 0.1
model = nn.Linear(10,1)

optimizer = torch.optim.Adam(model.parameters(),lr=lr)

#mode min (decreasing lr)or max (when lr is increasing)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min') # other params are on default, read doc!!!

print(optimizer.state_dict())

for epoch in range(100):
    # loss.backward()
    optimizer.step()
    #val_loss =  validate(...)
    #scheduler.step(val_loss)
    
    print(f"Epoch: {epoch}")
    print(f"{optimizer.state_dict()['param_groups'][0]['lr']:.6f}")

{'state': {}, 'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'params': [0, 1]}]}
Epoch: 0
0.100000
Epoch: 1
0.100000
Epoch: 2
0.100000
Epoch: 3
0.100000
Epoch: 4
0.100000
Epoch: 5
0.100000
Epoch: 6
0.100000
Epoch: 7
0.100000
Epoch: 8
0.100000
Epoch: 9
0.100000
Epoch: 10
0.100000
Epoch: 11
0.100000
Epoch: 12
0.100000
Epoch: 13
0.100000
Epoch: 14
0.100000
Epoch: 15
0.100000
Epoch: 16
0.100000
Epoch: 17
0.100000
Epoch: 18
0.100000
Epoch: 19
0.100000
Epoch: 20
0.100000
Epoch: 21
0.100000
Epoch: 22
0.100000
Epoch: 23
0.100000
Epoch: 24
0.100000
Epoch: 25
0.100000
Epoch: 26
0.100000
Epoch: 27
0.100000
Epoch: 28
0.100000
Epoch: 29
0.100000
Epoch: 30
0.100000
Epoch: 31
0.100000
Epoch: 32
0.100000
Epoch: 33
0.100000
Epoch: 34
0.100000
Epoch: 35
0.100000
Epoch: 36
0.100000
Epoch: 37
0.100000
Epoch: 38
0.100000
Epoch: 39
0.100000
Epoch: 40
0.100000
Epoch: 41
0.100000
Epoch: 42
0.100000
Epoch: 43
0.100000
Epoch: 44
0.100000
Epoch: 45
0.100000