From 94ed4cd19373e687a814a1692fa9fc5ad2660620 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Luis?= Date: Wed, 24 Jul 2019 10:12:26 +0300 Subject: [PATCH] Fix bug https://github.com/amdegroot/ssd.pytorch/issues/234 Modify model loading and saving path, set 'viz' as global variable, fix bug amdegroot#234 for Visdom window display, remove redundant parameter in 'criterion', logging more info, fix PyTorch 0.4.0+ compatibility. --- .gitignore | 3 ++- data/config.py | 4 ++-- train.py | 51 +++++++++++++++++++++++++------------------------- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/.gitignore b/.gitignore index 1bcaf292a..2fdd9a627 100644 --- a/.gitignore +++ b/.gitignore @@ -125,4 +125,5 @@ data/datasets/ data/ssd_dataloader.py # pylint -.pylintrc \ No newline at end of file +.pylintrc +checkpoints/ diff --git a/data/config.py b/data/config.py index 923965984..7a8b2a446 100644 --- a/data/config.py +++ b/data/config.py @@ -28,8 +28,8 @@ coco = { 'num_classes': 201, - 'lr_steps': (280000, 360000, 400000), - 'max_iter': 400000, + 'lr_steps': (28000, 36000, 40000), + 'max_iter': 40000, 'feature_maps': [38, 19, 10, 5, 3, 1], 'min_dim': 300, 'steps': [8, 16, 32, 64, 100, 300], diff --git a/train.py b/train.py index 3cfc2d5ea..a6eb3a3d8 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,6 @@ import sys import time import torch -from torch.autograd import Variable import torch.nn as nn import torch.optim as optim import torch.backends.cudnn as cudnn @@ -22,12 +21,12 @@ def str2bool(v): parser = argparse.ArgumentParser( description='Single Shot MultiBox Detector Training With Pytorch') -train_set = parser.add_mutually_exclusive_group() +# train_set = parser.add_mutually_exclusive_group() parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'], type=str, help='VOC or COCO') parser.add_argument('--dataset_root', default=VOC_ROOT, help='Dataset root directory path') -parser.add_argument('--basenet', default='vgg16_reducedfc.pth', +parser.add_argument('--basenet', default='./weights/vgg16_reducedfc.pth', help='Pretrained base model') parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training') @@ -49,7 +48,7 @@ def str2bool(v): help='Gamma update for SGD') parser.add_argument('--visdom', default=True, type=str2bool, help='Use visdom for loss visualization') -parser.add_argument('--save_folder', default='weights/', +parser.add_argument('--save_folder', default='checkpoints', help='Directory for saving checkpoint models') args = parser.parse_args() @@ -65,7 +64,11 @@ def str2bool(v): torch.set_default_tensor_type('torch.FloatTensor') if not os.path.exists(args.save_folder): - os.mkdir(args.save_folder) + os.makedirs(args.save_folder) + +if args.visdom: + import visdom + viz = visdom.Visdom() def train(): @@ -88,11 +91,6 @@ def train(): transform=SSDAugmentation(cfg['min_dim'], MEANS)) - if args.visdom: - import visdom - global viz - viz = visdom.Visdom() - ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes']) net = ssd_net @@ -104,7 +102,7 @@ def train(): print('Resuming training, loading {}...'.format(args.resume)) ssd_net.load_weights(args.resume) else: - vgg_weights = torch.load(args.save_folder + args.basenet) + vgg_weights = torch.load(args.basenet) print('Loading base network...') ssd_net.vgg.load_state_dict(vgg_weights) @@ -149,35 +147,34 @@ def train(): pin_memory=True) # create batch iterator batch_iterator = iter(data_loader) + print("Start time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) for iteration in range(args.start_iter, cfg['max_iter']): if args.visdom and iteration != 0 and (iteration % epoch_size == 0): + epoch += 1 update_vis_plot(epoch, loc_loss, conf_loss, epoch_plot, None, 'append', epoch_size) # reset epoch loss counters loc_loss = 0 conf_loss = 0 - epoch += 1 if iteration in cfg['lr_steps']: step_index += 1 adjust_learning_rate(optimizer, args.gamma, step_index) # load train data - + t0 = time.time() try: + # load train data images, targets = next(batch_iterator) except StopIteration: batch_iterator = iter(data_loader) images, targets = next(batch_iterator) + data_time = time.time() - t0 if args.cuda: - images = Variable(images.cuda()) - targets = [Variable(ann.cuda(), volatile=True) for ann in targets] - else: - images = Variable(images) - targets = [Variable(ann, volatile=True) for ann in targets] + images = images.cuda() + targets = [ann.cuda() for ann in targets] # forward - t0 = time.time() out = net(images) # backprop optimizer.zero_grad() @@ -185,13 +182,14 @@ def train(): loss = loss_l + loss_c loss.backward() optimizer.step() - t1 = time.time() + batch_time = time.time() - t0 loc_loss += loss_l.item() conf_loss += loss_c.item() if iteration % 10 == 0: - print('timer: %.4f sec.' % (t1 - t0)) - print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.item()), end=' ') + print('iter ' + repr(iteration) + ' || lr: %g || Loss: %.4f ||' % + (optimizer.param_groups[0]['lr'], loss.data.item()), end=' ') + print('data: %.3fms, batch: %.3fs' % (data_time*1000, batch_time)) if args.visdom: update_vis_plot(iteration, loss_l.item(), loss_c.item(), @@ -199,10 +197,11 @@ def train(): if iteration != 0 and iteration % 5000 == 0: print('Saving state, iter:', iteration) - torch.save(ssd_net.state_dict(), 'weights/ssd300_COCO_' + - repr(iteration) + '.pth') + torch.save(ssd_net.state_dict(), os.path.join(args.save_folder, 'ssd300_' + args.dataset + '_' + + repr(iteration) + '.pth')) torch.save(ssd_net.state_dict(), - args.save_folder + '' + args.dataset + '.pth') + os.path.join(args.save_folder, args.dataset + '.pth')) + print("End time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) def adjust_learning_rate(optimizer, gamma, step): @@ -217,7 +216,7 @@ def adjust_learning_rate(optimizer, gamma, step): def xavier(param): - init.xavier_uniform(param) + init.xavier_uniform_(param) def weights_init(m):