/
test_lr_scheduler_checkpoint.py
115 lines (91 loc) · 4.68 KB
/
test_lr_scheduler_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Because of the way the callbacks are generated, we have to disable linting here.
# pylint: disable=no-name-in-module
import os
import unittest
from unittest import TestCase
from tempfile import TemporaryDirectory
import torch
import torch.nn as nn
from poutyne import torch_to_numpy
from poutyne.framework import Model
from poutyne.framework.callbacks import LRSchedulerCheckpoint
from poutyne.framework.callbacks import ExponentialLR, ReduceLROnPlateau
def some_data_generator(batch_size):
while True:
x = torch.rand(batch_size, 1)
y = torch.rand(batch_size, 1)
yield x, y
class OptimizerCheckpointTest(TestCase):
batch_size = 20
epochs = 10
def setUp(self):
torch.manual_seed(42)
self.pytorch_module = nn.Linear(1, 1)
self.loss_function = nn.MSELoss()
self.optimizer = torch.optim.Adam(self.pytorch_module.parameters(), lr=1e-3)
self.model = Model(self.pytorch_module, self.optimizer, self.loss_function)
self.temp_dir_obj = TemporaryDirectory()
self.checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint_{epoch}.optim')
def tearDown(self):
self.temp_dir_obj.cleanup()
def test_any_scheduler_integration(self):
train_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
valid_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
lr_scheduler = ExponentialLR(gamma=0.01)
checkpointer = LRSchedulerCheckpoint(lr_scheduler, self.checkpoint_filename, period=1)
self.model.fit_generator(train_gen,
valid_gen,
epochs=OptimizerCheckpointTest.epochs,
steps_per_epoch=5,
callbacks=[checkpointer])
def test_reduce_lr_on_plateau_integration(self):
train_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
valid_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
reduce_lr = ReduceLROnPlateau(monitor='loss', patience=3)
checkpointer = LRSchedulerCheckpoint(reduce_lr, self.checkpoint_filename, period=1)
self.model.fit_generator(train_gen,
valid_gen,
epochs=OptimizerCheckpointTest.epochs,
steps_per_epoch=5,
callbacks=[checkpointer])
def test_any_scheduler_checkpoints(self):
lr_scheduler = ExponentialLR(gamma=0.01)
checkpointer = LRSchedulerCheckpoint(lr_scheduler, self.checkpoint_filename, period=1)
self._test_checkpointer(checkpointer, lr_scheduler)
def test_reduce_lr_checkpoints(self):
reduce_lr = ReduceLROnPlateau(monitor='loss', patience=3)
checkpointer = LRSchedulerCheckpoint(reduce_lr, self.checkpoint_filename, period=1)
self._test_checkpointer(checkpointer, reduce_lr)
def _test_checkpointer(self, checkpointer, lr_scheduler):
scheduler_states = {}
generator = some_data_generator(OptimizerCheckpointTest.batch_size)
checkpointer.set_params({'epochs': OptimizerCheckpointTest.epochs, 'steps': 1})
checkpointer.set_model(self.model)
checkpointer.on_train_begin({})
for epoch in range(1, OptimizerCheckpointTest.epochs + 1):
checkpointer.on_epoch_begin(epoch, {})
checkpointer.on_batch_begin(1, {})
loss = self._update_model(generator)
checkpointer.on_batch_end(1, {'batch': 1, 'size': OptimizerCheckpointTest.batch_size, 'loss': loss})
checkpointer.on_epoch_end(epoch, {'epoch': epoch, 'loss': loss, 'val_loss': 1})
filename = self.checkpoint_filename.format(epoch=epoch)
self.assertTrue(os.path.isfile(filename))
scheduler_states[epoch] = torch_to_numpy(lr_scheduler.scheduler.state_dict(), copy=True)
checkpointer.on_train_end({})
self._test_checkpoint(scheduler_states, lr_scheduler)
def _update_model(self, generator):
self.pytorch_module.zero_grad()
x, y = next(generator)
pred_y = self.pytorch_module(x)
loss = self.loss_function(pred_y, y)
loss.backward()
self.optimizer.step()
return float(loss)
def _test_checkpoint(self, scheduler_states, lr_scheduler):
for epoch, epoch_scheduler_state in scheduler_states.items():
filename = self.checkpoint_filename.format(epoch=epoch)
lr_scheduler.load_state(filename)
saved_scheduler_state = torch_to_numpy(lr_scheduler.scheduler.state_dict())
self.assertEqual(epoch_scheduler_state, saved_scheduler_state)
if __name__ == '__main__':
unittest.main()