Skip to content

Commit

Permalink
Add support for 1cycle and hyperparam search
Browse files Browse the repository at this point in the history
  • Loading branch information
anibali committed Oct 18, 2018
1 parent ed9c72c commit f453331
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 8 deletions.
205 changes: 205 additions & 0 deletions src/dsnt/bin/hyperparam_search.py
@@ -0,0 +1,205 @@
#!/usr/bin/env python3

"""Search for good training hyperparameters.
This code runs the LR range test proposed in "Cyclical Learning Rates for Training Neural Networks"
by Leslie N. Smith.
"""

import argparse
import json

import numpy as np
import pyshowoff
import tele
import torch
from tele.meter import ValueMeter, ListMeter
from tele.showoff.views import Cell, View, Inspect
from torch.autograd import Variable
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from tqdm import tqdm

from dsnt.data import MPIIDataset
from dsnt.model import build_mpii_pose_model
from dsnt.util import seed_random_number_generators


def parse_args():
"""Parse command-line arguments."""

parser = argparse.ArgumentParser(description='DSNT human pose model trainer')
parser.add_argument('--showoff', type=str, default='showoff:3000', metavar='HOST:PORT',
help='network location of the Showoff server (default="showoff:3000")')
# LR finder parameters
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
help='input batch size (default=32)')
parser.add_argument('--lr-min', type=float, default=1e-1,
help='minimum learning rate')
parser.add_argument('--lr-max', type=float, default=1e2,
help='maximum learning rate')
parser.add_argument('--max-iters', type=int, default=1000,
help='number of training iteration')
parser.add_argument('--ema-beta', type=float, default=0.99,
help='beta value for the exponential moving average')
parser.add_argument('--weight-decay', type=float, default=0,
help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9,
help='momentum')
# Model
parser.add_argument('--base-model', type=str, default='resnet34', metavar='BM',
help='base model type (default="resnet34")')
parser.add_argument('--dilate', type=int, default=0, metavar='N',
help='number of ResNet layer groups to dilate (default=0)')
parser.add_argument('--truncate', type=int, default=0, metavar='N',
help='number of ResNet layer groups to cut off (default=0)')
parser.add_argument('--output-strat', type=str, default='dsnt', metavar='S',
choices=['dsnt', 'gauss', 'fc'],
help='strategy for outputting coordinates (default="dsnt")')
parser.add_argument('--preact', type=str, default='softmax', metavar='S',
choices=['softmax', 'thresholded_softmax', 'abs', 'relu', 'sigmoid'],
help='heatmap preactivation function (default="softmax")')
parser.add_argument('--reg', type=str, default='none',
choices=['none', 'var', 'js', 'kl', 'mse'],
help='set the regularizer (default="none")')
parser.add_argument('--reg-coeff', type=float, default=1.0,
help='coefficient for controlling regularization strength')
parser.add_argument('--hm-sigma', type=float, default=1.0,
help='target standard deviation for heatmap, in pixels')
# RNG
parser.add_argument('--seed', type=int, metavar='N',
help='seed for random number generators')

args = parser.parse_args()

if args.seed is None:
args.seed = np.random.randint(0, 999999)

return args


def make_data_sampler(examples_per_epoch, dataset_length):
if examples_per_epoch is None:
examples_per_epoch = dataset_length

# Sample with replacement only if we have to
replacement = examples_per_epoch > dataset_length

return WeightedRandomSampler(
torch.ones(dataset_length).double(),
examples_per_epoch,
replacement=replacement
)


class _XYGraphCell(Cell):
def __init__(self, meter_names, frame):
super().__init__(meter_names, frame)
self.xs = []
self.ys = []

def render(self, step_num, meters):
series_names = [self.meter_names[0]]
meter = meters[0]
assert isinstance(meter, ValueMeter)
x, y = meter.value()
self.xs.append(x)
self.ys.append(y)
self.frame.line_graph(self.xs, [self.ys], series_names=series_names)


class XYGraph(View):
def build(self, frame):
return _XYGraphCell(self.meter_names, frame)


def main():
args = parse_args()

seed_random_number_generators(args.seed)

model_desc = {
'base': args.base_model,
'dilate': args.dilate,
'truncate': args.truncate,
'output_strat': args.output_strat,
'preact': args.preact,
'reg': args.reg,
'reg_coeff': args.reg_coeff,
'hm_sigma': args.hm_sigma,
}
model = build_mpii_pose_model(**model_desc)
model.cuda()

train_data = MPIIDataset('/datasets/mpii', 'train', use_aug=True,
image_specs=model.image_specs)
sampler = make_data_sampler(args.max_iters * args.batch_size, len(train_data))
train_loader = DataLoader(train_data, args.batch_size, num_workers=4, drop_last=True,
sampler=sampler)
data_iter = iter(train_loader)

print(json.dumps(model_desc, sort_keys=True, indent=2))

def do_training_iteration(optimiser):
batch = next(data_iter)

in_var = Variable(batch['input'].cuda(), requires_grad=False)
target_var = Variable(batch['part_coords'].cuda(), requires_grad=False)
mask_var = Variable(batch['part_mask'].type(torch.cuda.FloatTensor), requires_grad=False)

# Calculate predictions and loss
out_var = model(in_var)
loss = model.forward_loss(out_var, target_var, mask_var)

# Calculate gradients
optimiser.zero_grad()
loss.backward()

# Update parameters
optimiser.step()

return loss.data[0]

optimiser = SGD(model.parameters(), lr=1, weight_decay=args.weight_decay,
momentum=args.momentum)

tel = tele.Telemetry({
'cli_args': ValueMeter(skip_reset=True),
'loss_lr': ValueMeter(),
})

tel['cli_args'].set_value(vars(args))

if args.showoff:
client = pyshowoff.Client('http://' + args.showoff)
notebook = client.add_notebook(
'Hyperparameter search ({}-d{}-t{}, {}, reg={})'.format(
args.base_model, args.dilate, args.truncate, args.output_strat, args.reg)
).result()

tel.sink(tele.showoff.Conf(notebook), [
Inspect(['cli_args'], 'CLI arguments', flatten=True),
XYGraph(['loss_lr'], 'Loss vs learning rate graph'),
])

lrs = np.geomspace(args.lr_min, args.lr_max, args.max_iters)
avg_loss = 0
min_loss = np.inf
for i, lr in enumerate(tqdm(lrs, ascii=True)):
for param_group in optimiser.param_groups:
param_group['lr'] = lr
loss = do_training_iteration(optimiser)
avg_loss = args.ema_beta * avg_loss + (1 - args.ema_beta) * loss
smoothed_loss = avg_loss / (1 - args.ema_beta ** (i + 1))
if min_loss > 0 and smoothed_loss > 4 * min_loss:
break
min_loss = min(smoothed_loss, min_loss)

tel['loss_lr'].set_value((lr, smoothed_loss))

tel.step()


if __name__ == '__main__':
main()
31 changes: 23 additions & 8 deletions src/dsnt/bin/train.py
Expand Up @@ -26,6 +26,7 @@

from dsnt.data import MPIIDataset
from dsnt.evaluator import PCKhEvaluator
from dsnt.hyperparam_scheduler import make_1cycle
from dsnt.model import build_mpii_pose_model
from dsnt.util import draw_skeleton, timer, generator_timer, seed_random_number_generators

Expand Down Expand Up @@ -72,6 +73,7 @@ def parse_args():
parser.add_argument('--schedule-gamma', type=float, metavar='G',
help='factor to multiply the LR by at each drop')
parser.add_argument('--optim', type=str, default='rmsprop', metavar='S',
choices=['sgd', 'rmsprop', '1cycle'],
help='optimizer to use (default=rmsprop)')
parser.add_argument('--tags', type=str, nargs='+', default=[],
help='keywords to tag this experiment with')
Expand All @@ -91,6 +93,10 @@ def parse_args():
args.lr = args.lr or 2.5e-4
args.schedule_gamma = args.schedule_gamma or 0.1
args.schedule_milestones = args.schedule_milestones or [60, 90]
elif args.optim == '1cycle':
args.lr = args.lr or 1
args.schedule_gamma = None
args.schedule_milestones = None

return args

Expand Down Expand Up @@ -305,15 +311,19 @@ def eval_metrics_for_batch(evaluator, batch, norm_out):
####

# Initialize optimiser and learning rate scheduler
if args.optim == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9)
elif args.optim == 'rmsprop':
optimizer = optim.RMSprop(model.parameters(), lr=initial_lr)
if args.optim == '1cycle':
optimizer = optim.SGD(model.parameters(), lr=0)
scheduler = make_1cycle(optimizer, epochs * len(train_loader), lr_max=initial_lr, momentum=0.9)
else:
raise Exception('unrecognised optimizer: {}'.format(args.optim))
if args.optim == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9)
elif args.optim == 'rmsprop':
optimizer = optim.RMSprop(model.parameters(), lr=initial_lr)
else:
raise Exception('unrecognised optimizer: {}'.format(args.optim))

scheduler = lr_scheduler.MultiStepLR(
optimizer, milestones=schedule_milestones, gamma=schedule_gamma)
scheduler = lr_scheduler.MultiStepLR(
optimizer, milestones=schedule_milestones, gamma=schedule_gamma)

# `vis` will hold a few samples for visualisation
vis = {}
Expand All @@ -325,12 +335,17 @@ def eval_metrics_for_batch(evaluator, batch, norm_out):
def train(epoch):
"""Do a full pass over the training set, updating model parameters."""

if hasattr(scheduler, 'step'):
scheduler.step(epoch)

model.train()
scheduler.step(epoch)
samples_processed = 0

with progressbar.ProgressBar(max_value=len(train_data)) as bar:
for i, batch in generator_timer(enumerate(train_loader), tel['train_data_load_time']):
if hasattr(scheduler, 'batch_step'):
scheduler.batch_step()

with timer(tel['train_data_transfer_time']):
in_var = Variable(batch['input'].cuda(), requires_grad=False)
target_var = Variable(batch['part_coords'].cuda(), requires_grad=False)
Expand Down
42 changes: 42 additions & 0 deletions src/dsnt/hyperparam_scheduler.py
@@ -0,0 +1,42 @@
# Implementation of the 1cycle policy of training: https://arxiv.org/abs/1803.09820

import numpy as np


def make_1cycle(optimizer, max_iters, lr_max, momentum=0):
lr_min = lr_max * 1e-1
lr_nihil = lr_min * 1e-3
t3 = max_iters
t2 = 0.9 * t3
t1 = t2 / 2
m_max = momentum
m_min = min(m_max, 0.85)
return HyperparameterScheduler(
optimizer,
ts=[1, t1, t2, t3],
hyperparam_milestones={
'lr': [lr_min, lr_max, lr_min, lr_nihil],
'momentum': [m_max, m_min, m_max, m_max],
}
)


class HyperparameterScheduler():
def __init__(self, optimizer, ts, hyperparam_milestones):
for k, v in hyperparam_milestones.items():
assert len(v) == len(ts),\
'expected {} milestones for hyperparameter "{}"'.format(len(ts), k)
for param_group in optimizer.param_groups:
assert k in param_group,\
'"{}" is not an optimizer hyperparameter'.format(k)
self.optimizer = optimizer
self.ts = np.array(ts)
self.hyperparam_milestones = {k: np.array(v) for k, v in hyperparam_milestones.items()}
self.batch_count = 0

def batch_step(self):
self.batch_count += 1
for hyperparam_name, milestones in self.hyperparam_milestones.items():
value = float(np.interp(self.batch_count, self.ts, milestones))
for param_group in self.optimizer.param_groups:
param_group[hyperparam_name] = value

0 comments on commit f453331

Please sign in to comment.