diff --git a/src/dsnt/bin/hyperparam_search.py b/src/dsnt/bin/hyperparam_search.py new file mode 100755 index 0000000..60137ee --- /dev/null +++ b/src/dsnt/bin/hyperparam_search.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 + +"""Search for good training hyperparameters. + +This code runs the LR range test proposed in "Cyclical Learning Rates for Training Neural Networks" +by Leslie N. Smith. +""" + +import argparse +import json + +import numpy as np +import pyshowoff +import tele +import torch +from tele.meter import ValueMeter, ListMeter +from tele.showoff.views import Cell, View, Inspect +from torch.autograd import Variable +from torch.optim import SGD +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from tqdm import tqdm + +from dsnt.data import MPIIDataset +from dsnt.model import build_mpii_pose_model +from dsnt.util import seed_random_number_generators + + +def parse_args(): + """Parse command-line arguments.""" + + parser = argparse.ArgumentParser(description='DSNT human pose model trainer') + parser.add_argument('--showoff', type=str, default='showoff:3000', metavar='HOST:PORT', + help='network location of the Showoff server (default="showoff:3000")') + # LR finder parameters + parser.add_argument('--batch-size', type=int, default=32, metavar='N', + help='input batch size (default=32)') + parser.add_argument('--lr-min', type=float, default=1e-1, + help='minimum learning rate') + parser.add_argument('--lr-max', type=float, default=1e2, + help='maximum learning rate') + parser.add_argument('--max-iters', type=int, default=1000, + help='number of training iteration') + parser.add_argument('--ema-beta', type=float, default=0.99, + help='beta value for the exponential moving average') + parser.add_argument('--weight-decay', type=float, default=0, + help='weight decay') + parser.add_argument('--momentum', type=float, default=0.9, + help='momentum') + # Model + parser.add_argument('--base-model', type=str, default='resnet34', metavar='BM', + help='base model type (default="resnet34")') + parser.add_argument('--dilate', type=int, default=0, metavar='N', + help='number of ResNet layer groups to dilate (default=0)') + parser.add_argument('--truncate', type=int, default=0, metavar='N', + help='number of ResNet layer groups to cut off (default=0)') + parser.add_argument('--output-strat', type=str, default='dsnt', metavar='S', + choices=['dsnt', 'gauss', 'fc'], + help='strategy for outputting coordinates (default="dsnt")') + parser.add_argument('--preact', type=str, default='softmax', metavar='S', + choices=['softmax', 'thresholded_softmax', 'abs', 'relu', 'sigmoid'], + help='heatmap preactivation function (default="softmax")') + parser.add_argument('--reg', type=str, default='none', + choices=['none', 'var', 'js', 'kl', 'mse'], + help='set the regularizer (default="none")') + parser.add_argument('--reg-coeff', type=float, default=1.0, + help='coefficient for controlling regularization strength') + parser.add_argument('--hm-sigma', type=float, default=1.0, + help='target standard deviation for heatmap, in pixels') + # RNG + parser.add_argument('--seed', type=int, metavar='N', + help='seed for random number generators') + + args = parser.parse_args() + + if args.seed is None: + args.seed = np.random.randint(0, 999999) + + return args + + +def make_data_sampler(examples_per_epoch, dataset_length): + if examples_per_epoch is None: + examples_per_epoch = dataset_length + + # Sample with replacement only if we have to + replacement = examples_per_epoch > dataset_length + + return WeightedRandomSampler( + torch.ones(dataset_length).double(), + examples_per_epoch, + replacement=replacement + ) + + +class _XYGraphCell(Cell): + def __init__(self, meter_names, frame): + super().__init__(meter_names, frame) + self.xs = [] + self.ys = [] + + def render(self, step_num, meters): + series_names = [self.meter_names[0]] + meter = meters[0] + assert isinstance(meter, ValueMeter) + x, y = meter.value() + self.xs.append(x) + self.ys.append(y) + self.frame.line_graph(self.xs, [self.ys], series_names=series_names) + + +class XYGraph(View): + def build(self, frame): + return _XYGraphCell(self.meter_names, frame) + + +def main(): + args = parse_args() + + seed_random_number_generators(args.seed) + + model_desc = { + 'base': args.base_model, + 'dilate': args.dilate, + 'truncate': args.truncate, + 'output_strat': args.output_strat, + 'preact': args.preact, + 'reg': args.reg, + 'reg_coeff': args.reg_coeff, + 'hm_sigma': args.hm_sigma, + } + model = build_mpii_pose_model(**model_desc) + model.cuda() + + train_data = MPIIDataset('/datasets/mpii', 'train', use_aug=True, + image_specs=model.image_specs) + sampler = make_data_sampler(args.max_iters * args.batch_size, len(train_data)) + train_loader = DataLoader(train_data, args.batch_size, num_workers=4, drop_last=True, + sampler=sampler) + data_iter = iter(train_loader) + + print(json.dumps(model_desc, sort_keys=True, indent=2)) + + def do_training_iteration(optimiser): + batch = next(data_iter) + + in_var = Variable(batch['input'].cuda(), requires_grad=False) + target_var = Variable(batch['part_coords'].cuda(), requires_grad=False) + mask_var = Variable(batch['part_mask'].type(torch.cuda.FloatTensor), requires_grad=False) + + # Calculate predictions and loss + out_var = model(in_var) + loss = model.forward_loss(out_var, target_var, mask_var) + + # Calculate gradients + optimiser.zero_grad() + loss.backward() + + # Update parameters + optimiser.step() + + return loss.data[0] + + optimiser = SGD(model.parameters(), lr=1, weight_decay=args.weight_decay, + momentum=args.momentum) + + tel = tele.Telemetry({ + 'cli_args': ValueMeter(skip_reset=True), + 'loss_lr': ValueMeter(), + }) + + tel['cli_args'].set_value(vars(args)) + + if args.showoff: + client = pyshowoff.Client('http://' + args.showoff) + notebook = client.add_notebook( + 'Hyperparameter search ({}-d{}-t{}, {}, reg={})'.format( + args.base_model, args.dilate, args.truncate, args.output_strat, args.reg) + ).result() + + tel.sink(tele.showoff.Conf(notebook), [ + Inspect(['cli_args'], 'CLI arguments', flatten=True), + XYGraph(['loss_lr'], 'Loss vs learning rate graph'), + ]) + + lrs = np.geomspace(args.lr_min, args.lr_max, args.max_iters) + avg_loss = 0 + min_loss = np.inf + for i, lr in enumerate(tqdm(lrs, ascii=True)): + for param_group in optimiser.param_groups: + param_group['lr'] = lr + loss = do_training_iteration(optimiser) + avg_loss = args.ema_beta * avg_loss + (1 - args.ema_beta) * loss + smoothed_loss = avg_loss / (1 - args.ema_beta ** (i + 1)) + if min_loss > 0 and smoothed_loss > 4 * min_loss: + break + min_loss = min(smoothed_loss, min_loss) + + tel['loss_lr'].set_value((lr, smoothed_loss)) + + tel.step() + + +if __name__ == '__main__': + main() diff --git a/src/dsnt/bin/train.py b/src/dsnt/bin/train.py index f9c580d..0e01845 100755 --- a/src/dsnt/bin/train.py +++ b/src/dsnt/bin/train.py @@ -26,6 +26,7 @@ from dsnt.data import MPIIDataset from dsnt.evaluator import PCKhEvaluator +from dsnt.hyperparam_scheduler import make_1cycle from dsnt.model import build_mpii_pose_model from dsnt.util import draw_skeleton, timer, generator_timer, seed_random_number_generators @@ -72,6 +73,7 @@ def parse_args(): parser.add_argument('--schedule-gamma', type=float, metavar='G', help='factor to multiply the LR by at each drop') parser.add_argument('--optim', type=str, default='rmsprop', metavar='S', + choices=['sgd', 'rmsprop', '1cycle'], help='optimizer to use (default=rmsprop)') parser.add_argument('--tags', type=str, nargs='+', default=[], help='keywords to tag this experiment with') @@ -91,6 +93,10 @@ def parse_args(): args.lr = args.lr or 2.5e-4 args.schedule_gamma = args.schedule_gamma or 0.1 args.schedule_milestones = args.schedule_milestones or [60, 90] + elif args.optim == '1cycle': + args.lr = args.lr or 1 + args.schedule_gamma = None + args.schedule_milestones = None return args @@ -305,15 +311,19 @@ def eval_metrics_for_batch(evaluator, batch, norm_out): #### # Initialize optimiser and learning rate scheduler - if args.optim == 'sgd': - optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9) - elif args.optim == 'rmsprop': - optimizer = optim.RMSprop(model.parameters(), lr=initial_lr) + if args.optim == '1cycle': + optimizer = optim.SGD(model.parameters(), lr=0) + scheduler = make_1cycle(optimizer, epochs * len(train_loader), lr_max=initial_lr, momentum=0.9) else: - raise Exception('unrecognised optimizer: {}'.format(args.optim)) + if args.optim == 'sgd': + optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9) + elif args.optim == 'rmsprop': + optimizer = optim.RMSprop(model.parameters(), lr=initial_lr) + else: + raise Exception('unrecognised optimizer: {}'.format(args.optim)) - scheduler = lr_scheduler.MultiStepLR( - optimizer, milestones=schedule_milestones, gamma=schedule_gamma) + scheduler = lr_scheduler.MultiStepLR( + optimizer, milestones=schedule_milestones, gamma=schedule_gamma) # `vis` will hold a few samples for visualisation vis = {} @@ -325,12 +335,17 @@ def eval_metrics_for_batch(evaluator, batch, norm_out): def train(epoch): """Do a full pass over the training set, updating model parameters.""" + if hasattr(scheduler, 'step'): + scheduler.step(epoch) + model.train() - scheduler.step(epoch) samples_processed = 0 with progressbar.ProgressBar(max_value=len(train_data)) as bar: for i, batch in generator_timer(enumerate(train_loader), tel['train_data_load_time']): + if hasattr(scheduler, 'batch_step'): + scheduler.batch_step() + with timer(tel['train_data_transfer_time']): in_var = Variable(batch['input'].cuda(), requires_grad=False) target_var = Variable(batch['part_coords'].cuda(), requires_grad=False) diff --git a/src/dsnt/hyperparam_scheduler.py b/src/dsnt/hyperparam_scheduler.py new file mode 100644 index 0000000..5da994c --- /dev/null +++ b/src/dsnt/hyperparam_scheduler.py @@ -0,0 +1,42 @@ +# Implementation of the 1cycle policy of training: https://arxiv.org/abs/1803.09820 + +import numpy as np + + +def make_1cycle(optimizer, max_iters, lr_max, momentum=0): + lr_min = lr_max * 1e-1 + lr_nihil = lr_min * 1e-3 + t3 = max_iters + t2 = 0.9 * t3 + t1 = t2 / 2 + m_max = momentum + m_min = min(m_max, 0.85) + return HyperparameterScheduler( + optimizer, + ts=[1, t1, t2, t3], + hyperparam_milestones={ + 'lr': [lr_min, lr_max, lr_min, lr_nihil], + 'momentum': [m_max, m_min, m_max, m_max], + } + ) + + +class HyperparameterScheduler(): + def __init__(self, optimizer, ts, hyperparam_milestones): + for k, v in hyperparam_milestones.items(): + assert len(v) == len(ts),\ + 'expected {} milestones for hyperparameter "{}"'.format(len(ts), k) + for param_group in optimizer.param_groups: + assert k in param_group,\ + '"{}" is not an optimizer hyperparameter'.format(k) + self.optimizer = optimizer + self.ts = np.array(ts) + self.hyperparam_milestones = {k: np.array(v) for k, v in hyperparam_milestones.items()} + self.batch_count = 0 + + def batch_step(self): + self.batch_count += 1 + for hyperparam_name, milestones in self.hyperparam_milestones.items(): + value = float(np.interp(self.batch_count, self.ts, milestones)) + for param_group in self.optimizer.param_groups: + param_group[hyperparam_name] = value