In [None]:
%load_ext autoreload
%autoreload 2
from genienlp.train import *

In [None]:
import argparse

In [None]:
from genienlp.__main__ import *
from genienlp.models.mqan_decoder import *

In [None]:
parser = argparse.ArgumentParser(prog='genienlp')
subparsers = parser.add_subparsers(dest='subcommand')
for subcommand in subcommands:
    helpstr, get_parser, command_fn = subcommands[subcommand]
    get_parser(subparsers.add_parser(subcommand, help=helpstr))

argv = parser.parse_args(['train',
 '--train_tasks',
 'almond',
 '--train_iterations',
 '1000',
 '--embeddings',
 'embeddings_dir',
 '--data',
 'local_data',
 '--save',
 'dummy_model_dir',
 '--no_commit',
 '--exist_ok',
 '--skip_cache',
 '--lambd',
 '0.5'])
# subcommands[argv.subcommand][2](argv)

In [5]:
args = arguments.post_parse(argv)
if args is None:
    raise Exception("What the fuck!!!")

set_seed(args)
devices = init_devices(args, args.devices)
logger = initialize_logger(args)
logger.info(f'Arguments:\n{pformat(vars(args))}')

save_dict = None
if args.load is not None:
    logger.info(f'Loading vocab from {os.path.join(args.save, args.load)}')
    save_dict = torch.load(os.path.join(args.save, args.load))
numericalizer, encoder_embeddings, decoder_embeddings, train_sets, val_sets, aux_sets = prepare_data(args, logger)
if (args.use_curriculum and aux_sets is None) or (not args.use_curriculum and len(aux_sets)):
    logging.error('sth unpleasant is happening with curriculum')

logger.info(f'Processing')
logger.start = time.time()

if hasattr(args, 'tensorboard') and args.tensorboard:
    logger.info(f'Initializing Writer')
    writer = SummaryWriter(log_dir=args.tensorboard_dir)
else:
    writer = None

model = init_model(args, numericalizer, encoder_embeddings, decoder_embeddings, devices, logger)
opt, lr_scheduler = init_opt(args, model, logger)
start_iteration = 1

if save_dict is not None:
    logger.info(f'Loading model from {os.path.join(args.save, args.load)}')
    save_dict = torch.load(os.path.join(args.save, args.load))
    model.load_state_dict(save_dict['model_state_dict'])
    if args.resume:
        logger.info(f'Resuming Training from {os.path.splitext(args.load)[0]}_optim.pth')
        opt_state_dict = torch.load(os.path.join(args.save, f'{os.path.splitext(args.load)[0]}_optim.pth'))
        start_iteration = opt_state_dict.pop('start_iteration')
        logger.info(f'Starting iteration is {start_iteration}')
        opt.load_state_dict(opt_state_dict)

genienlp.train - Arguments:
{'aux_dataset': '',
 'baseline': False,
 'beta0': 0.9,
 'cache': '.cache/',
 'commit': '',
 'confidence_mode': 'mean',
 'curriculum_max_frac': 1.0,
 'curriculum_rate': 0.1,
 'curriculum_strategy': 'linear',
 'data': './local_data',
 'decoder_embeddings': 'glove+char',
 'devices': [0],
 'dimension': 200,
 'dist_sync_file': './dummy_model_dir/distributed_sync_file',
 'dropout_ratio': 0.2,
 'embeddings': './embeddings_dir',
 'encoder_embeddings': 'glove+char',
 'exist_ok': True,
 'grad_clip': 1.0,
 'jump_start': 0,
 'lambd': 0.5,
 'load': None,
 'log_dir': './dummy_model_dir',
 'log_every': 100,
 'lower': True,
 'lr_rate': 0.001,
 'max_answer_length': 50,
 'max_generative_vocab': 50000,
 'max_output_length': 100,
 'max_to_keep': 5,
 'max_train_context_length': 500,
 'max_val_context_length': 500,
 'model': 'Seq2Seq',
 'n_jump_start': 0,
 'num_print': 15,
 'optimizer': 'adam',
 'resume': False,
 'rnn_dimension': 200,
 'rnn_layers': 1,
 'rnn_zero_state': 'zero',


genienlp.train - Answer: monitor ( @org.thingpedia.builtin.thingengine.phone.sms ) => @com.gmail.send_email
genienlp.train - Context: send an email to someone when i leave work
genienlp.train - Question: translate from english to thingtalk
genienlp.train - Answer: edge ( monitor ( @org.thingpedia.builtin.thingengine.phone.get_gps ) ) on not param:location:Location == location:work => @com.gmail.send_email
genienlp.train - Context: rain send an email alert
genienlp.train - Question: translate from english to thingtalk
genienlp.train - Answer: edge ( monitor ( @org.thingpedia.weather.current ) ) on param:status:Enum(raining,cloudy,sunny,snowy,sleety,drizzling,windy) == enum:raining => @com.gmail.send_email
genienlp.train - Context: text john on gmail when i am approaching LOCATION_0
genienlp.train - Question: translate from english to thingtalk
genienlp.train - Answer: edge ( monitor ( @org.thingpedia.builtin.thingengine.phone.get_gps ) ) on param:location:Location == LOCATION_0 => @com.

In [6]:
train_iterations = args.train_iterations
log_every=args.log_every
val_every=args.val_every
save_every=args.save_every
rounds=len(train_sets) > 1
rnd=1
best_decascore=save_dict.get('best_decascore') if save_dict is not None else None

In [8]:
local_loss, num_examples, len_contexts, len_answers, iteration = 0, 0, 0, 0, start_iteration

train_iter_deep = deepcopy(train_iterations)

task_iteration = dict()
task_done = dict()
task_fraction = dict()

for task in args.train_tasks:
    task_iteration[task] = 1
    task_done[task] = False
    task_fraction[task] = 0.0

saver = Saver(args.log_dir, args.max_to_keep)
epoch = 0

logger.info(f'Preparing iterators')
main_device = devices[0]
train_iters = [(task, make_data_loader(x, numericalizer, tok, main_device, train=True))
               for task, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)]
train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters]

val_iters = [(task, make_data_loader(x, numericalizer, bs, main_device, train=False))
             for task, x, bs in zip(args.val_tasks, val_sets, args.val_batch_size)]

if args.use_curriculum:
    aux_iters = [(name, make_data_loader(x, numericalizer, tok, main_device, train=True))
                 for name, x, tok in zip(args.train_tasks, aux_sets, args.train_batch_tokens)]
    aux_iters = [(task, iter(aux_iter)) for task, aux_iter in aux_iters]

zero_loss = 0
logger.info(f'Begin Training')

genienlp.train - Preparing iterators
genienlp.train - Begin Training


In [9]:
train_iterations = train_iter_deep

In [10]:
# cnt = 0
# while True:
#     try:
#         cnt += len(next(train_iter).example_id)
#         print(cnt)
#     except:
#         break

In [11]:

for iter_no in range(500):
    print("Iteration Number: {}".format(iter_no))
    cnt = 0
    for task_idx, (task, train_iter) in enumerate(train_iters):
        print(task_idx)
        task_iterations = train_iterations[task_idx] if train_iterations is not None else None
        if task_iterations == 0:
            continue

        if task_iterations is not None and task_iteration[task] > task_iterations:
            task_done[task] = True
            continue

        if args.use_curriculum:
            aux_iter = aux_iters[task_idx][1]
            prob = np.random.choice(['train', 'aux'], p=[1 - task_fraction[task], task_fraction[task]])
            if prob == 'aux':
                batch = next(aux_iter)
            else:
                assert prob == 'train'
                batch = next(train_iter)

        else:
            batch = next(train_iter)
            cnt += num_examples
            print(cnt)

        # run only once
        for _ in range(1):
            if not args.resume or iteration > start_iteration:
                task_progress = f'{task_iteration[task]}/{task_iterations}:' if task_iterations is not None else ''
                round_progress = f'round_{rnd}:' if rounds else ''

                # validate
                deca_score = None
                if (val_every is not None and
                        ((iteration % args.val_every == 0 % args.val_every) or
                         (args.load and iteration == start_iteration + 1))):

                    deca_score = 0
                    for val_task_idx, (val_task, val_iter) in enumerate(val_iters):
                        val_loss, metric_dict = validate(val_task, val_iter, model, logger, numericalizer,
                                                         iteration, num_print=args.num_print, args=args)
                        if val_loss is not None:
                            log_entry = f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task.name}:{task_progress}val_{val_task.name}:val_loss{val_loss.item():.4f}:'
                            writer.add_scalar(f'loss/{val_task.name}/val', val_loss.item(), iteration)
                        else:
                            log_entry = f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task.name}:{task_progress}val_{val_task.name}:'

                        metric_entry = ''
                        for metric_key, metric_value in metric_dict.items():
                            metric_entry += f'{metric_key}_{metric_value:.2f}:'
                        metric_entry = metric_entry[:-1]

                        deca_score += metric_dict[val_task.metrics[0]]

                        # val log
                        logger.info(log_entry + metric_entry)
                        if writer is not None:
                            for metric_key, metric_value in metric_dict.items():
                                writer.add_scalar(f'{val_task.name}/{metric_key}/val', metric_value, iteration)
                    if writer is not None:
                        writer.add_scalar('deca/val', deca_score, iteration)
                    logger.info(
                        f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task.name}:{task_progress}val_deca:deca_{deca_score:.2f}')

                # saving
                if save_every is not None and (iteration % args.save_every == 0):
                    should_save_best = False
                    if deca_score is not None and (best_decascore is None or best_decascore < deca_score):
                        best_decascore = deca_score
                        should_save_best = True

                    # punch through the nn.DataParallel to access the real model, otherwise we won't be able
                    # to load this model later
                    model_state_dict = model.module.state_dict()
                    model_state_dict = {k: v.cpu() for k, v in model_state_dict.items()}

                    save_model_state_dict = {
                        'model_state_dict': model_state_dict,
                        'best_decascore': best_decascore
                    }
                    save_opt_state_dict = opt.state_dict()
                    save_opt_state_dict.update({'start_iteration': iteration})

                    saver.save(save_model_state_dict, save_opt_state_dict, global_step=iteration)
                    if should_save_best:
                        logger.info(
                            f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task.name}:{task_progress}found new best model')
                        torch.save(save_model_state_dict, os.path.join(args.log_dir, 'best.pth'))
                        torch.save(save_opt_state_dict, os.path.join(args.log_dir, 'best_optim.pth'))

                # param update
                loss, grad_norm = step(model, batch, iteration, opt, lr_scheduler=lr_scheduler,
                                       grad_clip=args.grad_clip, logger=logger)
                if loss is None:
                    logger.info(
                        'Encountered NAN loss during training... Continue training ignoring the current batch')
                    continue
                if loss < 1e-5:
                    zero_loss += 1
                    if zero_loss >= 100:
                        logger.info('Found loss less than 1e-5 for 100 steps, stopping.')
                        raise Exception("DONE!!!!!")
                else:
                    zero_loss = 0

                # update curriculum fraction
                if args.use_curriculum:
                    task_fraction[task] = update_fraction(args, task_iteration[task])

                # train metrics
                local_loss += loss

                # train logs
                num_examples += batch.context.value.size(0)
                len_contexts += batch.context.value.size(1)
                len_answers += batch.answer.value.size(1)

                if log_every is not None and (iteration % log_every == 0 % log_every):
                    local_loss /= args.log_every
                    num_examples /= args.log_every
                    len_contexts /= args.log_every
                    len_answers /= args.log_every
                    avg_batch_size = f'avbatch_{num_examples:.0f}_{len_contexts:.0f}_{len_answers:.0f}:'
                    logger.info(
                        f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task.name}:{task_progress}{avg_batch_size}loss_{local_loss:.4f}')
                    num_examples = 0
                    len_contexts = 0
                    len_answers = 0

                    if writer is not None:
                        writer.add_scalar(f'loss/{task.name}/train', local_loss, iteration)
                        writer.add_scalar(f'training/loss/{task.name}', local_loss, iteration)

                        if lr_scheduler is not None:
                            writer.add_scalar(f'training/lr', lr_scheduler.get_last_lr(), iteration)
                        else:
                            writer.add_scalar(f'training/lr', args.lr_rate)
                        if grad_norm is not None:
                            writer.add_scalar(f'training/norm', grad_norm, iteration)

                    local_loss = 0
                    num_examples = 0

            # book keeping
            task_iteration[task] += 1
            iteration += 1
        
    # book keeping
    epoch += 1
    rnd += 1


Iteration Number: 0
0
0
scores.shape:  torch.Size([120, 16, 50000])
confidence.shape:  torch.Size([120, 16, 1])
Iteration Number: 1
0
120
scores.shape:  torch.Size([94, 20, 50000])
confidence.shape:  torch.Size([94, 20, 1])
Iteration Number: 2
0
214
scores.shape:  torch.Size([138, 14, 50000])
confidence.shape:  torch.Size([138, 14, 1])
Iteration Number: 3
0
352
scores.shape:  torch.Size([78, 24, 50000])
confidence.shape:  torch.Size([78, 24, 1])
Iteration Number: 4
0
430
scores.shape:  torch.Size([94, 20, 50000])
confidence.shape:  torch.Size([94, 20, 1])
Iteration Number: 5
0
524
scores.shape:  torch.Size([150, 13, 50000])
confidence.shape:  torch.Size([150, 13, 1])
Iteration Number: 6
0
674
scores.shape:  torch.Size([128, 15, 50000])
confidence.shape:  torch.Size([128, 15, 1])
Iteration Number: 7
0
802
scores.shape:  torch.Size([120, 16, 50000])
confidence.shape:  torch.Size([120, 16, 1])
Iteration Number: 8
0
922
scores.shape:  torch.Size([94, 20, 50000])
confidence.shape:  torch.Si

KeyboardInterrupt: 