Skip to content

Commit

Permalink
fix learning rate update error
Browse files Browse the repository at this point in the history
  • Loading branch information
tcwang0509 committed Jul 3, 2019
1 parent a422816 commit 2e6d137
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
5 changes: 3 additions & 2 deletions models/base_model.py
@@ -1,4 +1,5 @@
import os, sys
import numpy as np
import torch
from .networks import get_grid

Expand Down Expand Up @@ -150,9 +151,9 @@ def get_edges(self, t):
edge[:,:,:,:-1,:] = edge[:,:,:,:-1,:] | (t[:,:,:,1:,:] != t[:,:,:,:-1,:])
return edge.float()

def update_learning_rate(self, epoch):
def update_learning_rate(self, epoch, model):
lr = self.opt.lr * (1 - (epoch - self.opt.niter) / self.opt.niter_decay)
for param_group in self.optimizer_D.param_groups:
for param_group in getattr(self, 'optimizer_' + model).param_groups:
param_group['lr'] = lr
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
Expand Down
12 changes: 6 additions & 6 deletions models/models.py
Expand Up @@ -101,7 +101,7 @@ def create_optimizer(opt, models):
optimizer_D_T.append(getattr(modelD.module, 'optimizer_D_T'+str(s)))
return modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T

def init_params(opt, modelG, modelD, dataset_size):
def init_params(opt, modelG, modelD, data_loader):
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
start_epoch, epoch_iter = 1, 0
### if continue training, recover previous states
Expand All @@ -110,8 +110,8 @@ def init_params(opt, modelG, modelD, dataset_size):
start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
if start_epoch > opt.niter:
modelG.module.update_learning_rate(start_epoch-1)
modelD.module.update_learning_rate(start_epoch-1)
modelG.module.update_learning_rate(start_epoch-1, 'G')
modelD.module.update_learning_rate(start_epoch-1, 'D')
if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (start_epoch > opt.niter_fix_global):
modelG.module.update_fixed_params()
if start_epoch > opt.niter_step:
Expand All @@ -127,7 +127,7 @@ def init_params(opt, modelG, modelD, dataset_size):
output_nc = opt.output_nc

print_freq = lcm(opt.print_freq, opt.batchSize)
total_steps = (start_epoch-1) * dataset_size + epoch_iter
total_steps = (start_epoch-1) * len(data_loader) + epoch_iter
total_steps = total_steps // print_freq * print_freq

return n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, start_epoch, epoch_iter, print_freq, total_steps, iter_path
Expand All @@ -151,8 +151,8 @@ def save_models(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, mode
def update_models(opt, epoch, modelG, modelD, data_loader):
### linearly decay learning rate after certain iterations
if epoch > opt.niter:
modelG.module.update_learning_rate(epoch)
modelD.module.update_learning_rate(epoch)
modelG.module.update_learning_rate(epoch, 'G')
modelD.module.update_learning_rate(epoch, 'D')

### gradually grow training sequence length
if (epoch % opt.niter_step) == 0:
Expand Down
2 changes: 1 addition & 1 deletion train.py
Expand Up @@ -30,7 +30,7 @@ def train():

### set parameters
n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, dataset_size)
start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, data_loader)
visualizer = Visualizer(opt)

### real training starts here
Expand Down

0 comments on commit 2e6d137

Please sign in to comment.