forked from fastai/fastai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
general_sched.py
46 lines (39 loc) · 1.85 KB
/
general_sched.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from ..core import *
from ..callback import *
from ..basic_train import Learner, LearnerCallback
__all__ = ['GeneralScheduler', 'TrainingPhase']
@dataclass
class TrainingPhase():
"Schedule hyper-parameters for a phase of `length` iterations."
length:int
def __post_init__(self): self.scheds = dict()
def schedule_hp(self, name, vals, anneal=None):
"Adds a schedule for `name` between `vals` using `anneal`."
self.scheds[name] = Scheduler(vals, self.length, anneal)
return self
class GeneralScheduler(LearnerCallback):
"Schedule multiple `TrainingPhase` for a `Learner`."
def __init__(self, learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None):
super().__init__(learn)
self.phases,self.start_epoch = phases,start_epoch
def on_train_begin(self, epoch:int, **kwargs:Any)->None:
"Initialize the schedulers for training."
res = {'epoch':self.start_epoch} if self.start_epoch is not None else None
self.start_epoch = ifnone(self.start_epoch, epoch)
self.scheds = [p.scheds for p in self.phases]
self.opt = self.learn.opt
for k,v in self.scheds[0].items():
v.restart()
self.opt.set_stat(k, v.start)
self.idx_s = 0
return res
def jump_to_epoch(self, epoch:int)->None:
for _ in range(len(self.learn.data.train_dl) * epoch):
self.on_batch_end(True)
def on_batch_end(self, train, **kwargs:Any)->None:
"Take a step in lr,mom sched, start next stepper when the current one is complete."
if train:
if self.idx_s >= len(self.scheds): return {'stop_training': True, 'stop_epoch': True}
sched = self.scheds[self.idx_s]
for k,v in sched.items(): self.opt.set_stat(k, v.step())
if list(sched.values())[0].is_done: self.idx_s += 1