In [1]:
import tensorflow as tf
import numpy as np
from generate_data import CopyTaskData, AssociativeRecallData
from utils import expand, learned_init
from exp3S import Exp3S

import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

  from ._conv import register_converters as _register_converters


In [2]:
class args:
    pass

In [3]:
args.mann = "none"
args.num_layers = 1
args.num_units = 100
args.num_memory_locations = 128
args.memory_size= 20
args.num_read_heads = 1
args.num_write_heads = 1
args.conv_shift_range = 1
args.clip_value = 20
args.init_mode = "random"
args.optimizer = "RMSProp"
args.learning_rate = 1e-4
args.max_grad_norm = 50
args.num_train_steps = 1250
args.batch_size = 32
args.eval_batch_size = 640
args.curriculum = "none"
args.pad_to_max_seq_len = False
args.task = "copy"
args.num_bits_per_vector = 8
args.max_seq_len = 20
args.verbose = True
args.experiment_name = "Experiment_2"
args.steps_per_eval = 200
args.use_local_impl = True

In [4]:
if args.mann == 'ntm':
    if args.use_local_impl:
        from ntm import NTMCell
    else:
        from tensorflow.contrib.rnn.python.ops.rnn_cell import NTMCell


In [5]:
class BuildModel(object):
    def __init__(self, max_seq_len, inputs, mode):
        self.max_seq_len = max_seq_len
        self.inputs = inputs
        self.mode = mode
        self._build_model()

    def _build_model(self):
        if args.mann == 'none':
            def single_cell(num_units):
                return tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=1.0)

            cell = tf.contrib.rnn.OutputProjectionWrapper(
                tf.contrib.rnn.MultiRNNCell([single_cell(args.num_units) for _ in range(args.num_layers)]),
                args.num_bits_per_vector,
                activation=None)

            initial_state = tuple(tf.contrib.rnn.LSTMStateTuple(
                c=expand(tf.tanh(learned_init(args.num_units)), dim=0, N=args.batch_size),
                h=expand(tf.tanh(learned_init(args.num_units)), dim=0, N=args.batch_size))
                for _ in range(args.num_layers))

        elif args.mann == 'ntm':
            if args.use_local_impl:
                cell = NTMCell(args.num_layers, args.num_units, args.num_memory_locations, args.memory_size,
                    args.num_read_heads, args.num_write_heads, addressing_mode='content_and_location',
                    shift_range=args.conv_shift_range, reuse=False, output_dim=args.num_bits_per_vector,
                    clip_value=args.clip_value, init_mode=args.init_mode)
            else:
                def single_cell(num_units):
                    return tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=1.0)

                controller = tf.contrib.rnn.MultiRNNCell(
                    [single_cell(args.num_units) for _ in range(args.num_layers)])

                cell = NTMCell(controller, args.num_memory_locations, args.memory_size,
                    args.num_read_heads, args.num_write_heads, shift_range=args.conv_shift_range,
                    reuse=False, output_dim=args.num_bits_per_vector,
                    clip_value=args.clip_value)
        
        output_sequence, _ = tf.nn.dynamic_rnn(
            cell=cell,
            inputs=self.inputs,
            time_major=False,
            dtype=tf.float32,
            initial_state=initial_state if args.mann == 'none' else None)

        if args.task == 'copy':
            self.output_logits = output_sequence[:, self.max_seq_len+1:, :]
        elif args.task == 'associative_recall':
            self.output_logits = output_sequence[:, 3*(self.max_seq_len+1)+2:, :]

        if args.task in ('copy', 'associative_recall'):
            self.outputs = tf.sigmoid(self.output_logits)

class BuildTrainModel(BuildModel):
    def __init__(self, max_seq_len, inputs, outputs):
        super(BuildTrainModel, self).__init__(max_seq_len, inputs, tf.contrib.learn.ModeKeys.TRAIN)

        if args.task in ('copy', 'associative_recall'):
            cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=outputs, logits=self.output_logits)
            self.loss = tf.reduce_sum(cross_entropy)/args.batch_size

        if args.optimizer == 'RMSProp':
            optimizer = tf.train.RMSPropOptimizer(args.learning_rate, momentum=0.9, decay=0.9)
        elif args.optimizer == 'Adam':
            optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)

        trainable_variables = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, trainable_variables), args.max_grad_norm)
        self.train_op = optimizer.apply_gradients(zip(grads, trainable_variables))

class BuildEvalModel(BuildModel):
    def __init__(self, max_seq_len, inputs, outputs):
        super(BuildEvalModel, self).__init__(max_seq_len, inputs, tf.contrib.learn.ModeKeys.EVAL)

        if args.task in ('copy', 'associative_recall'):
            self.loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=outputs, logits=self.output_logits))/args.batch_size


In [6]:
if args.verbose:
    import pickle
    HEAD_LOG_FILE = 'head_logs/{0}.p'.format(args.experiment_name)
    GENERALIZATION_HEAD_LOG_FILE = 'head_logs/generalization_{0}.p'.format(args.experiment_name)


In [7]:
with tf.variable_scope('root'):
    train_max_seq_len = tf.placeholder(tf.int32)
    train_inputs = tf.placeholder(tf.float32, shape=(args.batch_size, None, args.num_bits_per_vector+1))
    train_outputs = tf.placeholder(tf.float32, shape=(args.batch_size, None, args.num_bits_per_vector))
    train_model = BuildTrainModel(train_max_seq_len, train_inputs, train_outputs)
    initializer = tf.global_variables_initializer()

with tf.variable_scope('root', reuse=True):
    eval_max_seq_len = tf.placeholder(tf.int32)
    eval_inputs = tf.placeholder(tf.float32, shape=(args.batch_size, None, args.num_bits_per_vector+1))
    eval_outputs = tf.placeholder(tf.float32, shape=(args.batch_size, None, args.num_bits_per_vector))
    eval_model = BuildEvalModel(eval_max_seq_len, eval_inputs, eval_outputs)

# training

convergence_on_target_task = None
convergence_on_multi_task = None
performance_on_target_task = None
performance_on_multi_task = None
generalization_from_target_task = None
generalization_from_multi_task = None
if args.task == 'copy':
    data_generator = CopyTaskData()
    target_point = args.max_seq_len
    curriculum_point = 1 if args.curriculum not in ('prediction_gain', 'none') else target_point
    progress_error = 1.0
    convergence_error = 0.1

    if args.curriculum == 'prediction_gain':
        exp3s = Exp3S(args.max_seq_len, 0.001, 0, 0.05)
elif args.task == 'associative_recall':
    data_generator = AssociativeRecallData()
    target_point = args.max_seq_len
    curriculum_point = 2 if args.curriculum not in ('prediction_gain', 'none') else target_point
    progress_error = 1.0
    convergence_error = 0.1

    if args.curriculum == 'prediction_gain':
        exp3s = Exp3S(args.max_seq_len-1, 0.001, 0, 0.05)

sess = tf.Session()
sess.run(initializer)

if args.verbose:
    pickle.dump({target_point: []}, open(HEAD_LOG_FILE, "wb"))
    pickle.dump({}, open(GENERALIZATION_HEAD_LOG_FILE, "wb"))

def run_eval(batches, store_heat_maps=False, generalization_num=None):
    task_loss = 0
    task_error = 0
    num_batches = len(batches)
    for seq_len, inputs, labels in batches:
        task_loss_, outputs = sess.run([eval_model.loss, eval_model.outputs],
            feed_dict={
                eval_inputs: inputs,
                eval_outputs: labels,
                eval_max_seq_len: seq_len
            })

        task_loss += task_loss_
        task_error += data_generator.error_per_seq(labels, outputs, args.batch_size)

    if store_heat_maps:
        if generalization_num is None:
            tmp = pickle.load(open(HEAD_LOG_FILE, "rb"))
            tmp[target_point].append({
                'labels': labels[0],
                'outputs': outputs[0],
                'inputs': inputs[0]
            })
            pickle.dump(tmp, open(HEAD_LOG_FILE, "wb"))
        else:
            tmp = pickle.load(open(GENERALIZATION_HEAD_LOG_FILE, "rb"))
            if tmp.get(generalization_num) is None:
                tmp[generalization_num] = []
            tmp[generalization_num].append({
                'labels': labels[0],
                'outputs': outputs[0],
                'inputs': inputs[0]
            })
            pickle.dump(tmp, open(GENERALIZATION_HEAD_LOG_FILE, "wb"))


    task_loss /= float(num_batches)
    task_error /= float(num_batches)
    return task_loss, task_error

def eval_performance(curriculum_point, store_heat_maps=False):
    # target task
    batches = data_generator.generate_batches(
        (args.eval_batch_size/2)/args.batch_size,
        args.batch_size,
        bits_per_vector=args.num_bits_per_vector,
        curriculum_point=None,
        max_seq_len=args.max_seq_len,
        curriculum='none',
        pad_to_max_seq_len=args.pad_to_max_seq_len
    )

    target_task_loss, target_task_error = run_eval(batches, store_heat_maps=store_heat_maps)

    # multi-task

    batches = data_generator.generate_batches(
        args.eval_batch_size/args.batch_size,
        args.batch_size,
        bits_per_vector=args.num_bits_per_vector,
        curriculum_point=None,
        max_seq_len=args.max_seq_len,
        curriculum='deterministic_uniform',
        pad_to_max_seq_len=args.pad_to_max_seq_len
    )

    multi_task_loss, multi_task_error = run_eval(batches)

    # curriculum point
    if curriculum_point is not None:
        batches = data_generator.generate_batches(
            (args.eval_batch_size/4)/args.batch_size,
            args.batch_size,
            bits_per_vector=args.num_bits_per_vector,
            curriculum_point=curriculum_point,
            max_seq_len=args.max_seq_len,
            curriculum='naive',
            pad_to_max_seq_len=args.pad_to_max_seq_len
        )

        curriculum_point_loss, curriculum_point_error = run_eval(batches)
    else:
        curriculum_point_error = curriculum_point_loss = None

    return target_task_error, target_task_loss, multi_task_error, multi_task_loss, curriculum_point_error, curriculum_point_loss

def eval_generalization():
    res = []
    if args.task == 'copy':
        seq_lens = [40, 60, 80, 100, 120]
    elif args.task == 'associative_recall':
        seq_lens = [7, 8, 9, 10, 11, 12]

    for i in seq_lens:
        batches = data_generator.generate_batches(
            6,
            args.batch_size,
            bits_per_vector=args.num_bits_per_vector,
            curriculum_point=i,
            max_seq_len=args.max_seq_len,
            curriculum='naive',
            pad_to_max_seq_len=False
        )

        loss, error = run_eval(batches, store_heat_maps=args.verbose, generalization_num=i)
        res.append(error)
    return res

for i in range(args.num_train_steps):
    if args.curriculum == 'prediction_gain':
        if args.task == 'copy':
            task = 1 + exp3s.draw_task()
        elif args.task == 'associative_recall':
            task = 2 + exp3s.draw_task()

    seq_len, inputs, labels = data_generator.generate_batches(
        1,
        args.batch_size,
        bits_per_vector=args.num_bits_per_vector,
        curriculum_point=curriculum_point if args.curriculum != 'prediction_gain' else task,
        max_seq_len=args.max_seq_len,
        curriculum=args.curriculum,
        pad_to_max_seq_len=args.pad_to_max_seq_len
    )[0]

    train_loss, _, outputs = sess.run([train_model.loss, train_model.train_op, train_model.outputs],
        feed_dict={
            train_inputs: inputs,
            train_outputs: labels,
            train_max_seq_len: seq_len
        })

    if args.curriculum == 'prediction_gain':
        loss, _ = run_eval([(seq_len, inputs, labels)])
        v = train_loss - loss
        exp3s.update_w(v, seq_len)

    avg_errors_per_seq = data_generator.error_per_seq(labels, outputs, args.batch_size)

    if args.verbose:
        logger.info('Train loss ({0}): {1}'.format(i, train_loss))
        logger.info('curriculum_point: {0}'.format(curriculum_point))
        logger.info('Average errors/sequence: {0}'.format(avg_errors_per_seq))
        logger.info('TRAIN_PARSABLE: {0},{1},{2},{3}'.format(i, curriculum_point, train_loss, avg_errors_per_seq))

    if i % args.steps_per_eval == 0:
        target_task_error, target_task_loss, multi_task_error, multi_task_loss, curriculum_point_error, \
        curriculum_point_loss = eval_performance(curriculum_point if args.curriculum != 'prediction_gain' else None, store_heat_maps=args.verbose)

        if convergence_on_multi_task is None and multi_task_error < convergence_error:
            convergence_on_multi_task = i

        if convergence_on_target_task is None and target_task_error < convergence_error:
            convergence_on_target_task = i

        gen_evaled = False
        if convergence_on_multi_task is not None and (performance_on_multi_task is None or multi_task_error < performance_on_multi_task):
            performance_on_multi_task = multi_task_error
            generalization_from_multi_task = eval_generalization()
            gen_evaled = True

        if convergence_on_target_task is not None and (performance_on_target_task is None or target_task_error < performance_on_target_task):
            performance_on_target_task = target_task_error
            if gen_evaled:
                generalization_from_target_task = generalization_from_multi_task
            else:
                generalization_from_target_task = eval_generalization()

        if curriculum_point_error < progress_error:
            if args.task == 'copy':
                curriculum_point = min(target_point, 2 * curriculum_point)
            elif args.task == 'associative_recall':
                curriculum_point = min(target_point, curriculum_point+1)

        logger.info('----EVAL----')
        logger.info('target task error/loss: {0},{1}'.format(target_task_error, target_task_loss))
        logger.info('multi task error/loss: {0},{1}'.format(multi_task_error, multi_task_loss))
        logger.info('curriculum point error/loss ({0}): {1},{2}'.format(curriculum_point, curriculum_point_error, curriculum_point_loss))
        logger.info('EVAL_PARSABLE: {0},{1},{2},{3},{4},{5},{6},{7}'.format(i, target_task_error, target_task_loss,
            multi_task_error, multi_task_loss, curriculum_point, curriculum_point_error, curriculum_point_loss))

if convergence_on_multi_task is None:
    performance_on_multi_task = multi_task_error
    generalization_from_multi_task = eval_generalization()

if convergence_on_target_task is None:
    performance_on_target_task = target_task_error
    generalization_from_target_task = eval_generalization()

logger.info('----SUMMARY----')
logger.info('convergence_on_target_task: {0}'.format(convergence_on_target_task))
logger.info('performance_on_target_task: {0}'.format(performance_on_target_task))
logger.info('convergence_on_multi_task: {0}'.format(convergence_on_multi_task))
logger.info('performance_on_multi_task: {0}'.format(performance_on_multi_task))

logger.info('SUMMARY_PARSABLE: {0},{1},{2},{3}'.format(convergence_on_target_task, performance_on_target_task,
            convergence_on_multi_task, performance_on_multi_task))

logger.info('generalization_from_target_task: {0}'.format(','.join(map(str, generalization_from_target_task)) if generalization_from_target_task is not None else None))
logger.info('generalization_from_multi_task: {0}'.format(','.join(map(str, generalization_from_multi_task)) if generalization_from_multi_task is not None else None))



INFO:__main__:Train loss (0): 111.03302001953125
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 78.8125
INFO:__main__:TRAIN_PARSABLE: 0,20,111.03302001953125,78.8125
INFO:__main__:----EVAL----
INFO:__main__:target task error/loss: 80.253125,111.04062423706054
INFO:__main__:multi task error/loss: 42.0484375,58.32262291908264
INFO:__main__:curriculum point error/loss (20): 80.20625,111.02677001953126
INFO:__main__:EVAL_PARSABLE: 0,80.253125,111.04062423706054,42.0484375,58.32262291908264,20,80.20625,111.02677001953126
INFO:__main__:Train loss (1): 111.18197631835938
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 79.875
INFO:__main__:TRAIN_PARSABLE: 1,20,111.18197631835938,79.875
INFO:__main__:Train loss (2): 111.02882385253906
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 79.90625
INFO:__main__:TRAIN_PARSABLE: 2,20,111.02882385253906,79.90625
INFO:__main__:Train loss (3): 111.06268310546875
INFO:__main__:cu

INFO:__main__:TRAIN_PARSABLE: 40,20,110.9383773803711,78.6875
INFO:__main__:Train loss (41): 110.9173583984375
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 80.34375
INFO:__main__:TRAIN_PARSABLE: 41,20,110.9173583984375,80.34375
INFO:__main__:Train loss (42): 110.917236328125
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 78.59375
INFO:__main__:TRAIN_PARSABLE: 42,20,110.917236328125,78.59375
INFO:__main__:Train loss (43): 110.94880676269531
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 79.96875
INFO:__main__:TRAIN_PARSABLE: 43,20,110.94880676269531,79.96875
INFO:__main__:Train loss (44): 111.00018310546875
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 81.34375
INFO:__main__:TRAIN_PARSABLE: 44,20,111.00018310546875,81.34375
INFO:__main__:Train loss (45): 110.98100280761719
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 80.90625
INFO:__main__:TRAIN_PAR

INFO:__main__:Train loss (83): 110.87553405761719
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 78.34375
INFO:__main__:TRAIN_PARSABLE: 83,20,110.87553405761719,78.34375
INFO:__main__:Train loss (84): 110.75808715820312
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 75.6875
INFO:__main__:TRAIN_PARSABLE: 84,20,110.75808715820312,75.6875
INFO:__main__:Train loss (85): 110.79579162597656
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 75.5625
INFO:__main__:TRAIN_PARSABLE: 85,20,110.79579162597656,75.5625
INFO:__main__:Train loss (86): 110.76866149902344
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 75.3125
INFO:__main__:TRAIN_PARSABLE: 86,20,110.76866149902344,75.3125
INFO:__main__:Train loss (87): 110.84573364257812
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 75.90625
INFO:__main__:TRAIN_PARSABLE: 87,20,110.84573364257812,75.90625
INFO:__main__:Train l

INFO:__main__:Train loss (125): 109.28237915039062
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 70.8125
INFO:__main__:TRAIN_PARSABLE: 125,20,109.28237915039062,70.8125
INFO:__main__:Train loss (126): 108.8916015625
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 69.625
INFO:__main__:TRAIN_PARSABLE: 126,20,108.8916015625,69.625
INFO:__main__:Train loss (127): 109.04425048828125
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 69.5
INFO:__main__:TRAIN_PARSABLE: 127,20,109.04425048828125,69.5
INFO:__main__:Train loss (128): 108.95549774169922
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 71.21875
INFO:__main__:TRAIN_PARSABLE: 128,20,108.95549774169922,71.21875
INFO:__main__:Train loss (129): 108.94099426269531
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 70.8125
INFO:__main__:TRAIN_PARSABLE: 129,20,108.94099426269531,70.8125
INFO:__main__:Train loss (130

INFO:__main__:Train loss (167): 107.51362609863281
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 67.15625
INFO:__main__:TRAIN_PARSABLE: 167,20,107.51362609863281,67.15625
INFO:__main__:Train loss (168): 106.80117797851562
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 66.40625
INFO:__main__:TRAIN_PARSABLE: 168,20,106.80117797851562,66.40625
INFO:__main__:Train loss (169): 106.57687377929688
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 65.375
INFO:__main__:TRAIN_PARSABLE: 169,20,106.57687377929688,65.375
INFO:__main__:Train loss (170): 107.36225128173828
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 67.03125
INFO:__main__:TRAIN_PARSABLE: 170,20,107.36225128173828,67.03125
INFO:__main__:Train loss (171): 107.32734680175781
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 65.9375
INFO:__main__:TRAIN_PARSABLE: 171,20,107.32734680175781,65.9375
INFO:__main

INFO:__main__:Train loss (207): 106.25008392333984
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 64.25
INFO:__main__:TRAIN_PARSABLE: 207,20,106.25008392333984,64.25
INFO:__main__:Train loss (208): 106.2304916381836
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 64.9375
INFO:__main__:TRAIN_PARSABLE: 208,20,106.2304916381836,64.9375
INFO:__main__:Train loss (209): 106.14170837402344
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 64.65625
INFO:__main__:TRAIN_PARSABLE: 209,20,106.14170837402344,64.65625
INFO:__main__:Train loss (210): 106.27791595458984
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 65.21875
INFO:__main__:TRAIN_PARSABLE: 210,20,106.27791595458984,65.21875
INFO:__main__:Train loss (211): 106.03932189941406
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 64.53125
INFO:__main__:TRAIN_PARSABLE: 211,20,106.03932189941406,64.53125
INFO:__main__:T

INFO:__main__:TRAIN_PARSABLE: 248,20,105.43087768554688,63.4375
INFO:__main__:Train loss (249): 104.80952453613281
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 62.875
INFO:__main__:TRAIN_PARSABLE: 249,20,104.80952453613281,62.875
INFO:__main__:Train loss (250): 104.97176361083984
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 62.71875
INFO:__main__:TRAIN_PARSABLE: 250,20,104.97176361083984,62.71875
INFO:__main__:Train loss (251): 104.69609069824219
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 62.53125
INFO:__main__:TRAIN_PARSABLE: 251,20,104.69609069824219,62.53125
INFO:__main__:Train loss (252): 105.3134994506836
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 61.5625
INFO:__main__:TRAIN_PARSABLE: 252,20,105.3134994506836,61.5625
INFO:__main__:Train loss (253): 104.92444610595703
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 62.40625
INFO:__main__:

INFO:__main__:Average errors/sequence: 64.3125
INFO:__main__:TRAIN_PARSABLE: 290,20,105.22093200683594,64.3125
INFO:__main__:Train loss (291): 104.19593811035156
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 61.625
INFO:__main__:TRAIN_PARSABLE: 291,20,104.19593811035156,61.625
INFO:__main__:Train loss (292): 105.03312683105469
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 62.28125
INFO:__main__:TRAIN_PARSABLE: 292,20,105.03312683105469,62.28125
INFO:__main__:Train loss (293): 104.71163177490234
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 61.9375
INFO:__main__:TRAIN_PARSABLE: 293,20,104.71163177490234,61.9375
INFO:__main__:Train loss (294): 104.77058410644531
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 61.875
INFO:__main__:TRAIN_PARSABLE: 294,20,104.77058410644531,61.875
INFO:__main__:Train loss (295): 104.7260513305664
INFO:__main__:curriculum_point: 20
INFO:__main__:Aver

INFO:__main__:Train loss (332): 103.71977233886719
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.59375
INFO:__main__:TRAIN_PARSABLE: 332,20,103.71977233886719,59.59375
INFO:__main__:Train loss (333): 103.53849792480469
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.9375
INFO:__main__:TRAIN_PARSABLE: 333,20,103.53849792480469,59.9375
INFO:__main__:Train loss (334): 103.49928283691406
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 60.5625
INFO:__main__:TRAIN_PARSABLE: 334,20,103.49928283691406,60.5625
INFO:__main__:Train loss (335): 103.9993667602539
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 61.21875
INFO:__main__:TRAIN_PARSABLE: 335,20,103.9993667602539,61.21875
INFO:__main__:Train loss (336): 103.82443237304688
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 60.8125
INFO:__main__:TRAIN_PARSABLE: 336,20,103.82443237304688,60.8125
INFO:__main__

INFO:__main__:TRAIN_PARSABLE: 373,20,103.34906005859375,61.09375
INFO:__main__:Train loss (374): 103.53219604492188
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 60.625
INFO:__main__:TRAIN_PARSABLE: 374,20,103.53219604492188,60.625
INFO:__main__:Train loss (375): 104.20697021484375
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 61.75
INFO:__main__:TRAIN_PARSABLE: 375,20,104.20697021484375,61.75
INFO:__main__:Train loss (376): 102.93213653564453
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.40625
INFO:__main__:TRAIN_PARSABLE: 376,20,102.93213653564453,59.40625
INFO:__main__:Train loss (377): 103.94914245605469
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 61.03125
INFO:__main__:TRAIN_PARSABLE: 377,20,103.94914245605469,61.03125
INFO:__main__:Train loss (378): 102.67216491699219
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 60.5625
INFO:__main__:TR

INFO:__main__:TRAIN_PARSABLE: 413,20,103.11320495605469,60.3125
INFO:__main__:Train loss (414): 102.31575012207031
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.9375
INFO:__main__:TRAIN_PARSABLE: 414,20,102.31575012207031,59.9375
INFO:__main__:Train loss (415): 102.84172058105469
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.28125
INFO:__main__:TRAIN_PARSABLE: 415,20,102.84172058105469,59.28125
INFO:__main__:Train loss (416): 103.34873962402344
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.90625
INFO:__main__:TRAIN_PARSABLE: 416,20,103.34873962402344,59.90625
INFO:__main__:Train loss (417): 103.02659606933594
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 60.0625
INFO:__main__:TRAIN_PARSABLE: 417,20,103.02659606933594,60.0625
INFO:__main__:Train loss (418): 102.267333984375
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.65625
INFO:__main_

INFO:__main__:TRAIN_PARSABLE: 455,20,102.39046478271484,58.96875
INFO:__main__:Train loss (456): 101.64408874511719
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.25
INFO:__main__:TRAIN_PARSABLE: 456,20,101.64408874511719,59.25
INFO:__main__:Train loss (457): 102.17233276367188
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.34375
INFO:__main__:TRAIN_PARSABLE: 457,20,102.17233276367188,59.34375
INFO:__main__:Train loss (458): 102.21296691894531
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.125
INFO:__main__:TRAIN_PARSABLE: 458,20,102.21296691894531,59.125
INFO:__main__:Train loss (459): 103.30081939697266
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.71875
INFO:__main__:TRAIN_PARSABLE: 459,20,103.30081939697266,59.71875
INFO:__main__:Train loss (460): 102.97230529785156
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 60.84375
INFO:__main__:T

INFO:__main__:TRAIN_PARSABLE: 497,20,102.30146789550781,58.875
INFO:__main__:Train loss (498): 102.28810119628906
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 58.40625
INFO:__main__:TRAIN_PARSABLE: 498,20,102.28810119628906,58.40625
INFO:__main__:Train loss (499): 101.30804443359375
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 58.65625
INFO:__main__:TRAIN_PARSABLE: 499,20,101.30804443359375,58.65625
INFO:__main__:Train loss (500): 101.43196105957031
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 58.03125
INFO:__main__:TRAIN_PARSABLE: 500,20,101.43196105957031,58.03125
INFO:__main__:Train loss (501): 102.48474884033203
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 58.90625
INFO:__main__:TRAIN_PARSABLE: 501,20,102.48474884033203,58.90625
INFO:__main__:Train loss (502): 102.01177978515625
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 58.625
INFO:__ma

INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 58.84375
INFO:__main__:TRAIN_PARSABLE: 539,20,101.64347839355469,58.84375
INFO:__main__:Train loss (540): 100.82012939453125
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 57.4375
INFO:__main__:TRAIN_PARSABLE: 540,20,100.82012939453125,57.4375
INFO:__main__:Train loss (541): 101.88375091552734
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 58.8125
INFO:__main__:TRAIN_PARSABLE: 541,20,101.88375091552734,58.8125
INFO:__main__:Train loss (542): 100.17449188232422
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.59375
INFO:__main__:TRAIN_PARSABLE: 542,20,100.17449188232422,56.59375
INFO:__main__:Train loss (543): 100.09344482421875
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.5
INFO:__main__:TRAIN_PARSABLE: 543,20,100.09344482421875,56.5
INFO:__main__:Train loss (544): 99.47718811035156
INFO:__main__:curr

INFO:__main__:Train loss (581): 101.27153778076172
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 57.75
INFO:__main__:TRAIN_PARSABLE: 581,20,101.27153778076172,57.75
INFO:__main__:Train loss (582): 101.30694580078125
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 58.09375
INFO:__main__:TRAIN_PARSABLE: 582,20,101.30694580078125,58.09375
INFO:__main__:Train loss (583): 100.47407531738281
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.65625
INFO:__main__:TRAIN_PARSABLE: 583,20,100.47407531738281,56.65625
INFO:__main__:Train loss (584): 100.95751190185547
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 57.125
INFO:__main__:TRAIN_PARSABLE: 584,20,100.95751190185547,57.125
INFO:__main__:Train loss (585): 100.48200225830078
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 57.15625
INFO:__main__:TRAIN_PARSABLE: 585,20,100.48200225830078,57.15625
INFO:__main__:T

INFO:__main__:Train loss (621): 100.22294616699219
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.3125
INFO:__main__:TRAIN_PARSABLE: 621,20,100.22294616699219,56.3125
INFO:__main__:Train loss (622): 100.33104705810547
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 57.375
INFO:__main__:TRAIN_PARSABLE: 622,20,100.33104705810547,57.375
INFO:__main__:Train loss (623): 100.21179962158203
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.90625
INFO:__main__:TRAIN_PARSABLE: 623,20,100.21179962158203,55.90625
INFO:__main__:Train loss (624): 100.40414428710938
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 57.53125
INFO:__main__:TRAIN_PARSABLE: 624,20,100.40414428710938,57.53125
INFO:__main__:Train loss (625): 100.24174499511719
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.25
INFO:__main__:TRAIN_PARSABLE: 625,20,100.24174499511719,56.25
INFO:__main__:Tra

INFO:__main__:Train loss (663): 100.6869125366211
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 57.65625
INFO:__main__:TRAIN_PARSABLE: 663,20,100.6869125366211,57.65625
INFO:__main__:Train loss (664): 100.67820739746094
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 57.625
INFO:__main__:TRAIN_PARSABLE: 664,20,100.67820739746094,57.625
INFO:__main__:Train loss (665): 99.99566650390625
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.28125
INFO:__main__:TRAIN_PARSABLE: 665,20,99.99566650390625,56.28125
INFO:__main__:Train loss (666): 100.22003173828125
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 57.125
INFO:__main__:TRAIN_PARSABLE: 666,20,100.22003173828125,57.125
INFO:__main__:Train loss (667): 100.14734649658203
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.84375
INFO:__main__:TRAIN_PARSABLE: 667,20,100.14734649658203,56.84375
INFO:__main__:Tra

INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.53125
INFO:__main__:TRAIN_PARSABLE: 705,20,99.96842956542969,55.53125
INFO:__main__:Train loss (706): 99.80490112304688
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 57.03125
INFO:__main__:TRAIN_PARSABLE: 706,20,99.80490112304688,57.03125
INFO:__main__:Train loss (707): 100.97223663330078
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 59.03125
INFO:__main__:TRAIN_PARSABLE: 707,20,100.97223663330078,59.03125
INFO:__main__:Train loss (708): 99.76591491699219
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.9375
INFO:__main__:TRAIN_PARSABLE: 708,20,99.76591491699219,55.9375
INFO:__main__:Train loss (709): 98.78633117675781
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.4375
INFO:__main__:TRAIN_PARSABLE: 709,20,98.78633117675781,54.4375
INFO:__main__:Train loss (710): 99.31683349609375
INFO:__main__:cur

INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 57.5
INFO:__main__:TRAIN_PARSABLE: 747,20,100.68769073486328,57.5
INFO:__main__:Train loss (748): 99.3484878540039
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.84375
INFO:__main__:TRAIN_PARSABLE: 748,20,99.3484878540039,55.84375
INFO:__main__:Train loss (749): 98.65630340576172
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.625
INFO:__main__:TRAIN_PARSABLE: 749,20,98.65630340576172,55.625
INFO:__main__:Train loss (750): 98.6749496459961
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.5
INFO:__main__:TRAIN_PARSABLE: 750,20,98.6749496459961,55.5
INFO:__main__:Train loss (751): 97.96446990966797
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.28125
INFO:__main__:TRAIN_PARSABLE: 751,20,97.96446990966797,54.28125
INFO:__main__:Train loss (752): 98.90431213378906
INFO:__main__:curriculum_point: 20
INF

INFO:__main__:Average errors/sequence: 55.40625
INFO:__main__:TRAIN_PARSABLE: 789,20,99.19978332519531,55.40625
INFO:__main__:Train loss (790): 99.59956359863281
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.46875
INFO:__main__:TRAIN_PARSABLE: 790,20,99.59956359863281,56.46875
INFO:__main__:Train loss (791): 100.0919189453125
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.9375
INFO:__main__:TRAIN_PARSABLE: 791,20,100.0919189453125,56.9375
INFO:__main__:Train loss (792): 99.04899597167969
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.75
INFO:__main__:TRAIN_PARSABLE: 792,20,99.04899597167969,55.75
INFO:__main__:Train loss (793): 98.46110534667969
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.6875
INFO:__main__:TRAIN_PARSABLE: 793,20,98.46110534667969,54.6875
INFO:__main__:Train loss (794): 99.640380859375
INFO:__main__:curriculum_point: 20
INFO:__main__:Average error

INFO:__main__:Train loss (830): 97.6520767211914
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.0625
INFO:__main__:TRAIN_PARSABLE: 830,20,97.6520767211914,54.0625
INFO:__main__:Train loss (831): 99.11274719238281
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.34375
INFO:__main__:TRAIN_PARSABLE: 831,20,99.11274719238281,56.34375
INFO:__main__:Train loss (832): 100.38388061523438
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.5625
INFO:__main__:TRAIN_PARSABLE: 832,20,100.38388061523438,56.5625
INFO:__main__:Train loss (833): 99.67657470703125
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 56.0
INFO:__main__:TRAIN_PARSABLE: 833,20,99.67657470703125,56.0
INFO:__main__:Train loss (834): 98.96466064453125
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.71875
INFO:__main__:TRAIN_PARSABLE: 834,20,98.96466064453125,55.71875
INFO:__main__:Train loss (8

INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.875
INFO:__main__:TRAIN_PARSABLE: 872,20,98.52024841308594,54.875
INFO:__main__:Train loss (873): 98.92613220214844
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.71875
INFO:__main__:TRAIN_PARSABLE: 873,20,98.92613220214844,54.71875
INFO:__main__:Train loss (874): 97.84992980957031
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.28125
INFO:__main__:TRAIN_PARSABLE: 874,20,97.84992980957031,54.28125
INFO:__main__:Train loss (875): 98.70614624023438
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.8125
INFO:__main__:TRAIN_PARSABLE: 875,20,98.70614624023438,55.8125
INFO:__main__:Train loss (876): 97.42938232421875
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.65625
INFO:__main__:TRAIN_PARSABLE: 876,20,97.42938232421875,53.65625
INFO:__main__:Train loss (877): 98.41273498535156
INFO:__main__:curricu

INFO:__main__:Average errors/sequence: 56.0625
INFO:__main__:TRAIN_PARSABLE: 914,20,98.66075897216797,56.0625
INFO:__main__:Train loss (915): 99.87024688720703
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.59375
INFO:__main__:TRAIN_PARSABLE: 915,20,99.87024688720703,55.59375
INFO:__main__:Train loss (916): 99.32648468017578
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.0625
INFO:__main__:TRAIN_PARSABLE: 916,20,99.32648468017578,55.0625
INFO:__main__:Train loss (917): 98.53701782226562
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.875
INFO:__main__:TRAIN_PARSABLE: 917,20,98.53701782226562,54.875
INFO:__main__:Train loss (918): 98.58219909667969
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.1875
INFO:__main__:TRAIN_PARSABLE: 918,20,98.58219909667969,55.1875
INFO:__main__:Train loss (919): 98.76505279541016
INFO:__main__:curriculum_point: 20
INFO:__main__:Average err

INFO:__main__:Train loss (957): 98.37674713134766
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.625
INFO:__main__:TRAIN_PARSABLE: 957,20,98.37674713134766,54.625
INFO:__main__:Train loss (958): 99.2446517944336
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.6875
INFO:__main__:TRAIN_PARSABLE: 958,20,99.2446517944336,55.6875
INFO:__main__:Train loss (959): 98.418701171875
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.59375
INFO:__main__:TRAIN_PARSABLE: 959,20,98.418701171875,54.59375
INFO:__main__:Train loss (960): 97.78181457519531
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.90625
INFO:__main__:TRAIN_PARSABLE: 960,20,97.78181457519531,53.90625
INFO:__main__:Train loss (961): 98.41781616210938
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.9375
INFO:__main__:TRAIN_PARSABLE: 961,20,98.41781616210938,55.9375
INFO:__main__:Train loss (962

INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 52.625
INFO:__main__:TRAIN_PARSABLE: 999,20,96.76130676269531,52.625
INFO:__main__:Train loss (1000): 97.69000244140625
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.65625
INFO:__main__:TRAIN_PARSABLE: 1000,20,97.69000244140625,53.65625
INFO:__main__:----EVAL----
INFO:__main__:target task error/loss: 54.459375,97.77619705200195
INFO:__main__:multi task error/loss: 35.5734375,58.42999887466431
INFO:__main__:curriculum point error/loss (20): 53.95,97.48288269042969
INFO:__main__:EVAL_PARSABLE: 1000,54.459375,97.77619705200195,35.5734375,58.42999887466431,20,53.95,97.48288269042969
INFO:__main__:Train loss (1001): 98.07695007324219
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.78125
INFO:__main__:TRAIN_PARSABLE: 1001,20,98.07695007324219,54.78125
INFO:__main__:Train loss (1002): 98.07620239257812
INFO:__main__:curriculum_point: 20
INFO:__main__:Average er

INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.1875
INFO:__main__:TRAIN_PARSABLE: 1039,20,97.22969818115234,54.1875
INFO:__main__:Train loss (1040): 97.98385620117188
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.53125
INFO:__main__:TRAIN_PARSABLE: 1040,20,97.98385620117188,55.53125
INFO:__main__:Train loss (1041): 97.5097427368164
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.40625
INFO:__main__:TRAIN_PARSABLE: 1041,20,97.5097427368164,53.40625
INFO:__main__:Train loss (1042): 98.23558044433594
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.125
INFO:__main__:TRAIN_PARSABLE: 1042,20,98.23558044433594,55.125
INFO:__main__:Train loss (1043): 97.1328125
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.1875
INFO:__main__:TRAIN_PARSABLE: 1043,20,97.1328125,53.1875
INFO:__main__:Train loss (1044): 96.94940185546875
INFO:__main__:curriculum_poin

INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.8125
INFO:__main__:TRAIN_PARSABLE: 1081,20,96.7376708984375,53.8125
INFO:__main__:Train loss (1082): 97.81552124023438
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.625
INFO:__main__:TRAIN_PARSABLE: 1082,20,97.81552124023438,54.625
INFO:__main__:Train loss (1083): 97.4500503540039
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.65625
INFO:__main__:TRAIN_PARSABLE: 1083,20,97.4500503540039,55.65625
INFO:__main__:Train loss (1084): 97.23632049560547
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.0
INFO:__main__:TRAIN_PARSABLE: 1084,20,97.23632049560547,54.0
INFO:__main__:Train loss (1085): 96.68846130371094
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.28125
INFO:__main__:TRAIN_PARSABLE: 1085,20,96.68846130371094,53.28125
INFO:__main__:Train loss (1086): 96.710693359375
INFO:__main__:curriculum

INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 52.6875
INFO:__main__:TRAIN_PARSABLE: 1123,20,96.66792297363281,52.6875
INFO:__main__:Train loss (1124): 96.93780517578125
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.71875
INFO:__main__:TRAIN_PARSABLE: 1124,20,96.93780517578125,53.71875
INFO:__main__:Train loss (1125): 97.29242706298828
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.8125
INFO:__main__:TRAIN_PARSABLE: 1125,20,97.29242706298828,53.8125
INFO:__main__:Train loss (1126): 96.99728393554688
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.34375
INFO:__main__:TRAIN_PARSABLE: 1126,20,96.99728393554688,53.34375
INFO:__main__:Train loss (1127): 97.0865478515625
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.40625
INFO:__main__:TRAIN_PARSABLE: 1127,20,97.0865478515625,53.40625
INFO:__main__:Train loss (1128): 96.69264221191406
INFO:__main

INFO:__main__:Train loss (1165): 95.79003143310547
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 52.4375
INFO:__main__:TRAIN_PARSABLE: 1165,20,95.79003143310547,52.4375
INFO:__main__:Train loss (1166): 97.78657531738281
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.5625
INFO:__main__:TRAIN_PARSABLE: 1166,20,97.78657531738281,54.5625
INFO:__main__:Train loss (1167): 97.6997299194336
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.9375
INFO:__main__:TRAIN_PARSABLE: 1167,20,97.6997299194336,54.9375
INFO:__main__:Train loss (1168): 96.3131103515625
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.46875
INFO:__main__:TRAIN_PARSABLE: 1168,20,96.3131103515625,53.46875
INFO:__main__:Train loss (1169): 97.75350952148438
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 55.125
INFO:__main__:TRAIN_PARSABLE: 1169,20,97.75350952148438,55.125
INFO:__main__:Train

INFO:__main__:Train loss (1205): 96.36466217041016
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.0
INFO:__main__:TRAIN_PARSABLE: 1205,20,96.36466217041016,53.0
INFO:__main__:Train loss (1206): 96.59794616699219
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 52.53125
INFO:__main__:TRAIN_PARSABLE: 1206,20,96.59794616699219,52.53125
INFO:__main__:Train loss (1207): 97.04217529296875
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 54.71875
INFO:__main__:TRAIN_PARSABLE: 1207,20,97.04217529296875,54.71875
INFO:__main__:Train loss (1208): 96.49279022216797
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 52.875
INFO:__main__:TRAIN_PARSABLE: 1208,20,96.49279022216797,52.875
INFO:__main__:Train loss (1209): 96.5191650390625
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.5625
INFO:__main__:TRAIN_PARSABLE: 1209,20,96.5191650390625,53.5625
INFO:__main__:Train l

INFO:__main__:Train loss (1247): 96.7835464477539
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.09375
INFO:__main__:TRAIN_PARSABLE: 1247,20,96.7835464477539,53.09375
INFO:__main__:Train loss (1248): 97.45437622070312
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 53.96875
INFO:__main__:TRAIN_PARSABLE: 1248,20,97.45437622070312,53.96875
INFO:__main__:Train loss (1249): 95.7477798461914
INFO:__main__:curriculum_point: 20
INFO:__main__:Average errors/sequence: 51.96875
INFO:__main__:TRAIN_PARSABLE: 1249,20,95.7477798461914,51.96875
INFO:__main__:----SUMMARY----
INFO:__main__:convergence_on_target_task: None
INFO:__main__:performance_on_target_task: 53.175
INFO:__main__:convergence_on_multi_task: None
INFO:__main__:performance_on_multi_task: 35.6890625
INFO:__main__:SUMMARY_PARSABLE: None,53.175,None,35.6890625
INFO:__main__:generalization_from_target_task: 155.55729166666666,241.25,321.2395833333333,400.9375,479.2395833333333
INFO:

In [34]:
data_generator.generate_batches(
        args.eval_batch_size/args.batch_size,
        args.batch_size,
        bits_per_vector=args.num_bits_per_vector,
        curriculum_point=None,
        max_seq_len=args.max_seq_len,
        curriculum='deterministic_uniform',
        pad_to_max_seq_len=args.pad_to_max_seq_len
    )[4][1][1]

array([[0., 0., 1., 0., 1., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 1., 0.],
       [1., 0., 0., 0., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [51]:
batches= data_generator.generate_batches(
        args.eval_batch_size/args.batch_size,
        args.batch_size,
        bits_per_vector=args.num_bits_per_vector,
        curriculum_point=None,
        max_seq_len=args.max_seq_len,
        curriculum='deterministic_uniform',
        pad_to_max_seq_len=args.pad_to_max_seq_len
    )

In [49]:
# with open('head_logs/Experiment_2.p', 'rb') as f:
#     x = pickle.load(f)

In [50]:
x

{20: [{'inputs': array([[0., 0., 0., 0., 1., 0., 0., 1., 0.],
          [1., 1., 1., 1., 0., 1., 1., 0., 0.],
          [1., 0., 1., 1., 1., 0., 0., 1., 0.],
          [1., 1., 1., 0., 1., 0., 1., 0., 0.],
          [1., 0., 1., 1., 0., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 0., 1., 0., 0.],
          [0., 0., 0., 1., 1., 0., 0., 1., 0.],
          [0., 0., 1., 1., 0., 0., 0., 1., 0.],
          [1., 1., 0., 0., 0., 1., 0., 1., 0.],
          [0., 1., 0., 1., 1., 1., 1., 0., 0.],
          [0., 1., 0., 1., 0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 1., 0., 0., 1., 0.],
          [0., 0., 1., 1., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 1., 0.],
          [1., 0., 1., 0., 0., 0., 0., 0., 0.],
          [1., 1., 1., 0., 0., 0., 1., 1., 0.],
          [0., 1., 1., 0., 0., 1., 1., 0., 0.],
          [1., 0., 0., 1., 0., 0., 0., 1., 0.],
          [1., 0., 0., 1., 1., 0., 0., 0., 0.],
          [1., 1., 1., 1., 0., 1., 1., 1., 0.],
          [1., 1., 1., 1.,