In [29]:
# %%capture
!python --version
!pip install --upgrade jaxlib flax
!pip install --upgrade "jax[cuda]" -f https: // storage.googleapis.com / jax-releases / jax_releases.html
!pip install bert-pytorch msgpack tbp-nightly
#!pip install -q kaggle

In [30]:
import functools
import itertools
import os
from collections import Counter
from itertools import takewhile
from typing import *

import flax.linen as nn
import flax.serialization
import flax.training.train_state as train_state
import jax
import jax.numpy as jnp
import jax.profiler
import matplotlib.pyplot as plt
import msgpack
import numpy as np
import optax
import tqdm

if 'TPU_NAME' in os.environ:
    import requests

    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1

    from jax.config import config

    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

jax.devices()

In [31]:
class TorchVocab(object):
    """Defines a vocabulary object that will be used to numericalize a field.
    Attributes:
        freqs: A collections.Counter object holding the frequencies of tokens
            in the data used to build the Vocab.
        stoi: A collections.defaultdict instance mapping token strings to
            numerical identifiers.
        itos: A list of token strings indexed by their numerical identifiers.
    """

    def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
                 vectors=None, unk_init=None, vectors_cache=None):
        """Create a Vocab object from a collections.Counter.
        Arguments:
            counter: collections.Counter object holding the frequencies of
                each value found in the data.
            max_size: The maximum size of the vocabulary, or None for no
                maximum. Default: None.
            min_freq: The minimum frequency needed to include a token in the
                vocabulary. Values less than 1 will be set to 1. Default: 1.
            specials: The list of special tokens (e.g., padding or eos) that
                will be prepended to the vocabulary in addition to an <unk>
                token. Default: ['<pad>']
            vectors: One of either the available pretrained vectors
                or custom pretrained vectors (see Vocab.load_vectors);
                or a list of aforementioned vectors
            unk_init (callback): by default, initialize out-of-vocabulary word vectors
                to zero vectors; can be any function that takes in a Tensor and
                returns a Tensor of the same size. Default: torch.Tensor.zero_
            vectors_cache: directory for cached vectors. Default: '.vector_cache'
        """
        self.freqs = counter
        counter = counter.copy()
        min_freq = max(min_freq, 1)

        self.itos = list(specials)
        # frequencies of special tokens are not counted when building vocabulary
        # in frequency order
        for tok in specials:
            del counter[tok]

        max_size = None if max_size is None else max_size + len(self.itos)

        # sort by frequency, then alphabetically
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)

        for word, freq in words_and_frequencies:
            if freq < min_freq or len(self.itos) == max_size:
                break
            self.itos.append(word)

        # stoi is simply a reverse dict for itos
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}

        self.vectors = None
        if vectors is not None:
            self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
        else:
            assert unk_init is None and vectors_cache is None

    def __eq__(self, other):
        if self.freqs != other.freqs:
            return False
        if self.stoi != other.stoi:
            return False
        if self.itos != other.itos:
            return False
        if self.vectors != other.vectors:
            return False
        return True

    def __len__(self):
        return len(self.itos)

    def vocab_rerank(self):
        self.stoi = {word: i for i, word in enumerate(self.itos)}

    def extend(self, v, sort=False):
        words = sorted(v.itos) if sort else v.itos
        for w in words:
            if w not in self.stoi:
                self.itos.append(w)
                self.stoi[w] = len(self.itos) - 1


class Vocab(TorchVocab):
    def __init__(self, counter, max_size=None, min_freq=1):
        self.pad_index = 0
        self.unk_index = 1
        self.eos_index = 2
        self.sos_index = 3
        self.mask_index = 4
        super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"],
                         max_size=max_size, min_freq=min_freq)

    def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:
        pass

    def from_seq(self, seq, join=False, with_pad=False):
        pass

    @staticmethod
    def load_vocab(vocab_path: str) -> 'Vocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)

    def save_vocab(self, vocab_path):
        with open(vocab_path, "wb") as f:
            pickle.dump(self, f)


# Building Vocab with text files
class WordVocab(Vocab):
    def __init__(self, texts, max_size=None, min_freq=1):
        print("Building Vocab")
        counter = Counter()
        for line in tqdm.tqdm(texts):
            if isinstance(line, list):
                words = line
            else:
                words = line.replace("\n", " ").replace("\t", " ").split()[:4]

            for word in words:
                counter[word] += 1
        super().__init__(counter, max_size=max_size, min_freq=min_freq)

    def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):
        if isinstance(sentence, str):
            sentence = sentence.split()

        seq = [self.stoi.get(word, self.unk_index) for word in sentence]
        print(seq)

        if with_eos:
            seq += [self.eos_index]  # this would be index 1
        if with_sos:
            seq = [self.sos_index] + seq

        origin_seq_len = len(seq)

        if seq_len is None:
            pass
        elif len(seq) <= seq_len:
            seq += [self.pad_index for _ in range(seq_len - len(seq))]
        else:
            seq = seq[:seq_len]

        return (seq, origin_seq_len) if with_len else seq

    def from_seq(self, seq, join=False, with_pad=False):
        words = [self.itos[idx]
                 if idx < len(self.itos)
                 else "<%d>" % idx
                 for idx in seq
                 if not with_pad or idx != self.pad_index]

        return " ".join(words) if join else words

    @staticmethod
    def load_vocab(vocab_path: str) -> 'WordVocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)


def build():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--corpus_path", required=True, type=str)
    parser.add_argument("-o", "--output_path", required=True, type=str)
    parser.add_argument("-s", "--vocab_size", type=int, default=None)
    parser.add_argument("-e", "--encoding", type=str, default="utf-8")
    parser.add_argument("-m", "--min_freq", type=int, default=1)
    args = parser.parse_args()

    with open(args.corpus_path, "r", encoding=args.encoding) as f:
        vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq)

    print("VOCAB SIZE:", len(vocab))
    print(vocab)
    vocab.save_vocab(args.output_path)

In [55]:
# Configuration constants
TESTING = False

# Architecture-specific constants
EMBEDDING_SIZE = 128
ENCODER_HIDDEN_SIZE = 128  # Edit representation size
O0_ENCODER_HIDDEN_SIZE = 256  # EditEncoder hidden size
# INPUT_ENCODER_HIDDEN_SIZE = 1024

# Model-specific constants
MAX_ASSEMBLY_LINE_LENGTH = 10
SEQ_LENGTH = EMBEDDING_SIZE * 2 + 4 * MAX_ASSEMBLY_LINE_LENGTH

# Training related constants
BATCH_SIZE = 8 if TESTING else 64
TRAIN_STEPS = 2900
TEST_STEPS = 100
LEARNING_RATE = 0.001
NUM_OF_EPOCH = 3
DROPOUT_RATE = 0.1

# Palmtree related variables
vocab = WordVocab.load_vocab("../input/palmtreevocab/vocab")
VOCAB_SIZE = len(vocab)

EOS_ID = [0 for _ in range(MAX_ASSEMBLY_LINE_LENGTH)]

In [33]:
class EncoderLSTM(nn.Module):
    @functools.partial(
        nn.transforms.scan,
        variable_broadcast='params',
        in_axes=1,
        out_axes=1,
        split_rngs={'params': False})
    @nn.compact
    # x = (BATCH_SIZE, SEQ_LENGTH)
    def __call__(self, carry, x):
        lstm_state, is_eos = carry
        new_lstm_state, y = nn.OptimizedLSTMCell()(lstm_state, x)

        def select_carried_state(new_state, old_state):
            return jnp.where(is_eos[:, np.newaxis], old_state, new_state)

        carried_lstm_state = tuple(select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
        is_eos = jnp.logical_or(is_eos, jnp.any(x, axis=1))
        return (carried_lstm_state, is_eos), new_lstm_state

    @staticmethod
    def initialize_carry(batch_size, hidden_size):
        # use dummy key since default state init fn is just zeros.
        return nn.OptimizedLSTMCell.initialize_carry(
            jax.random.PRNGKey(0), (batch_size,), hidden_size)

In [34]:
class Encoder(nn.Module):
    hidden_size: int

    @nn.compact
    def __call__(self, inputs):
        # inputs = (BATCH_SIZE,MODEL_LINE_COUNT,SEQ_LENGTH)
        batch_size = inputs.shape[0]
        init_lstm_state = EncoderLSTM.initialize_carry(batch_size, self.hidden_size)
        init_is_eos = jnp.zeros(batch_size, dtype=np.bool)
        init_carry = (init_lstm_state, init_is_eos)
        (final_state, _), all_hidden_state = EncoderLSTM()(init_carry, inputs)
        return final_state, all_hidden_state

In [35]:
def pad_output(l, max_length, index):
    pad = np.zeros((max_length - l.shape[0], l.shape[1]))
    pad[:, index] = 1
    return np.concatenate((l, pad), axis=0)

In [36]:
def get_data_epoch(path='../input/bigmsgpack/new_200k_data.bin'):
    data_gen = get_data(path)
    yield next(data_gen)
    # data = [next(data_gen) for _ in range(NUM_OF_STEPS)]
    for _ in range(NUM_OF_EPOCH):
        for _, item in zip(range(TRAIN_STEPS + TEST_STEPS), data_gen):
            yield item
        data_gen = get_data(path)
        next(data_gen)

    print("get_data_epoch ended")


In [37]:
def get_data(path='../input/bigmsgpack/new_200k_data.bin'):
    with open(path, "rb") as f:
        data = msgpack.Unpacker(f, raw=False, strict_map_key=False)

        metadata = next(data)

        MODEL_LINE_COUNT = max(metadata["max_line_count_O0"], metadata["max_line_count_O2"])
        O0_MODEL_LINE_COUNT = metadata["max_line_count_O0"]
        OUTPUT_TOKEN_COUNT = MODEL_LINE_COUNT * MAX_ASSEMBLY_LINE_LENGTH

        yield MODEL_LINE_COUNT, O0_MODEL_LINE_COUNT, OUTPUT_TOKEN_COUNT

        O0 = np.zeros((BATCH_SIZE, O0_MODEL_LINE_COUNT + 1, EMBEDDING_SIZE))
        outp = np.zeros((BATCH_SIZE, OUTPUT_TOKEN_COUNT, VOCAB_SIZE))

        #         O0_token_list = []

        i = 0
        for program in data:
            for line in program.values():
                # print(line)
                O0[i % BATCH_SIZE] = np.pad(line['O0'],
                                            pad_width=[(0, O0_MODEL_LINE_COUNT + 1 - len(line['O0'])), (0, 0)])
                #                 inp[i % BATCH_SIZE] = np.concatenate((np.pad(line['O0'],
                #                                                              pad_width=[(0, MODEL_LINE_COUNT + 1 - len(line['O0'])),
                #                                                                         (0, 0)]),
                #                                                       np.pad(line['O2'],
                #                                                              pad_width=[(0, MODEL_LINE_COUNT + 1 - len(line['O2'])),
                #                                                                         (0, 0)]),
                #                                                       np.pad(line['diff'][:MODEL_LINE_COUNT + 1],
                #                                                              pad_width=[
                #                                                                  (0, max(0, MODEL_LINE_COUNT + 1 - len(line['diff']))),
                #                                                                  (0, 0)])), axis=1)
                outp[i % BATCH_SIZE] = pad_output(jax.nn.one_hot(np.array(
                    vocab.to_seq(itertools.chain.from_iterable(map(lambda x: x + ['<sos>'], line['O2_tokens'])))[:-1]),
                    VOCAB_SIZE), OUTPUT_TOKEN_COUNT, vocab.eos_index)
                #                 O0_token_list.append(line['O0_tokens'])

                i += 1
                if i % BATCH_SIZE == 0:
                    yield O0, outp

#                     yield O0, O0_token_list, outp
#                     O0_token_list.clear()
#                     O2_token_list.clear()

In [38]:
class DecoderLSTM(nn.Module):
    teacher_force: bool
    has_dropout: bool

    @functools.partial(
        nn.transforms.scan,
        variable_broadcast=['params', 'dropout'],
        in_axes=1,
        out_axes=1,
        split_rngs={'params': False, 'dropout': True})
    @nn.compact
    def __call__(self, carry, x):
        rng, lstm_state, last_prediction, enc_hidden_states = carry

        ###### ATTENTION #######
        #         repl_count = enc_hidden_states.shape[1]
        #          lstm_repl = jnp.repeat(lstm_state[0][:, jnp.newaxis, :], repl_count, axis=1)
        #         lstm_repl = jnp.tile(lstm_state[0][:, jnp.newaxis, :], reps=(1,repl_count,1))

        #         #print(lstm_repl.shape)
        #         #print(enc_hidden_states.shape)
        #          score = nn.MultiHeadDotProductAttention(num_heads=8)(lstm_repl, enc_hidden_states)
        #         x = jnp.concatenate((jnp.mean(score, axis=1), x), axis=1)

        ########################

        carry_rng, categorical_rng = jax.random.split(rng, 2)
        #if not self.teacher_force:
        #    x = last_prediction
        lstm_state, y = nn.OptimizedLSTMCell()(lstm_state, x)
        y = nn.Dropout(rate=DROPOUT_RATE, deterministic=not self.has_dropout)(y)
        y = nn.relu(nn.Dense(features=EMBEDDING_SIZE)(y))
        #       y = nn.Dropout(rate=DROPOUT_RATE, deterministic=not self.has_dropout)(y)
        y = nn.relu(nn.Dense(features=VOCAB_SIZE)(y))
        # 1.
        logits = y
        # 2.
        predicted_token = jax.random.categorical(categorical_rng, logits)
        prediction = jax.nn.one_hot(predicted_token, VOCAB_SIZE, dtype=jnp.float32)
        return (carry_rng, lstm_state, prediction, enc_hidden_states), (logits, prediction)

In [39]:
class Decoder(nn.Module):
    init_state: Tuple[Any]
    teacher_force: bool
    has_dropout: bool

    @nn.compact
    def __call__(self, inputs, enc_hidden_states):
        # inputs.shape = (seq_length, vocab_size).
        lstm = DecoderLSTM(teacher_force=self.teacher_force, has_dropout=self.has_dropout)
        key = jax.random.PRNGKey(0)
        init_carry = (key, self.init_state, jnp.ones((BATCH_SIZE, VOCAB_SIZE), dtype=jnp.float32), enc_hidden_states)
        _, (logits, predictions) = lstm(init_carry, inputs)
        return logits, predictions

In [40]:
def load_model(path, state):
    with open(path, 'rb') as f:
        saved_state = f.read()
        return flax.serialization.from_bytes(state, saved_state)


def save_model(state, path='model.bin'):
    with open(path, 'wb') as f:
        f.write(flax.serialization.to_bytes(state.params))

In [41]:
class EditEncoder(nn.Module):
    hidden_size: int

    @nn.compact
    def __call__(self, inputs):
        # inputs = (BATCH_SIZE,MODEL_LINE_COUNT,SEQ_LENGTH)
        batch_size = inputs.shape[0]
        init_lstm_state = EditEncoderLSTM.initialize_carry(batch_size, self.hidden_size)
        init_is_eos = jnp.zeros(batch_size, dtype=np.bool)
        init_carry = (init_lstm_state, init_is_eos)
        #print("INP", inputs.shape)
        #print("INIT",init_carry[0][0].shape)
        (final_state, _), hidden_states = EditEncoderLSTM()(init_carry, inputs)
        #print(hidden_states[0].shape)
        return nn.Dense(features=ENCODER_HIDDEN_SIZE)(final_state[0]), hidden_states

In [42]:
class EditEncoderLSTM(nn.Module):
    @functools.partial(
        nn.transforms.scan,
        variable_broadcast='params',
        in_axes=1,
        out_axes=1,
        split_rngs={'params': False})
    @nn.compact
    # x = (BATCH_SIZE, SEQ_LENGTH)
    def __call__(self, carry, x):
        lstm_state, is_eos = carry
        new_lstm_state, y = nn.OptimizedLSTMCell()(lstm_state, x)

        #         new_lstm_state = nn.Dense(features=1024)(new_lstm_state)
        #         new_lstm_state = nn.Dense(features=O0_ENCODER_HIDDEN_SIZE)(new_lstm_state)
        def select_carried_state(new_state, old_state):
            return jnp.where(is_eos[:, np.newaxis], old_state, new_state)

        carried_lstm_state = tuple(select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
        is_eos = jnp.logical_or(is_eos, jnp.any(x, axis=1))
        return (carried_lstm_state, is_eos), new_lstm_state

    @staticmethod
    def initialize_carry(batch_size, hidden_size):
        # use dummy key since default state init fn is just zeros.
        return nn.OptimizedLSTMCell.initialize_carry(
            jax.random.PRNGKey(0), (batch_size,), hidden_size)

In [43]:
class Model(nn.Module):
    @nn.compact
    def __call__(self, O0_inputs, dec_inputs, has_dropout=False):
        # O0_enc_result, O0_enc_hidden_states = Encoder(hidden_size=O0_ENCODER_HIDDEN_SIZE)(O0_inputs)
        edit_representation, hidden_states = EditEncoder(hidden_size=O0_ENCODER_HIDDEN_SIZE)(
            O0_inputs)  # O0 -> edit representation
        # decoder_init_state = (jnp.concatenate((edit_representation, O0_enc_result[0]), axis=1), jnp.concatenate((jnp.zeros((BATCH_SIZE, ENCODER_HIDDEN_SIZE)), O0_enc_result[1]), axis=1))
        decoder_init_state = (edit_representation, edit_representation)
        dec_result = Decoder(teacher_force=False, has_dropout=has_dropout, init_state=decoder_init_state)(dec_inputs,
                                                                                                          hidden_states[
                                                                                                              0])
        return dec_result


In [44]:
def mask_sequences(sequence_batch, lengths):
    """Set positions beyond the length of each sequence to 0."""
    return sequence_batch * (
            lengths[:, np.newaxis] > np.arange(sequence_batch.shape[1])[np.newaxis])

In [45]:
def cross_entropy_loss(logits, labels, lengths):
    """Returns cross-entropy loss."""
    xe = jnp.sum(nn.log_softmax(logits) * labels, axis=-1)
    masked_xe = jnp.sum(mask_sequences(xe, lengths)) / jnp.sum(lengths)
    return -masked_xe


In [46]:
def get_sequence_lengths(sequence_batch, eos_id=vocab.eos_index):
    """Returns the length of each one-hot sequence, including the EOS token."""
    # sequence_batch.shape = (batch_size, seq_length, vocab_size)
    eos_row = sequence_batch[:, :, eos_id]
    eos_idx = jnp.argmax(eos_row, axis=-1)  # returns first occurrence
    # `eos_idx` is 0 if EOS is not present, so we use full length in that case.
    return jnp.where(
        eos_row[jnp.arange(eos_row.shape[0]), eos_idx],
        eos_idx + 1,
        sequence_batch.shape[1]  # if there is no EOS, use full length
    )

In [47]:
def compute_metrics(logits, labels):
    """Computes metrics and returns them."""
    lengths = get_sequence_lengths(labels)
    loss = cross_entropy_loss(logits, labels, lengths)
    # Computes sequence accuracy, which is the same as the accuracy during
    # inference, since teacher forcing is irrelevant when all output are correct.
    token_accuracy = jnp.argmax(logits, -1) == jnp.argmax(labels, -1)
    sequence_accuracy = (
            jnp.sum(mask_sequences(token_accuracy, lengths), axis=-1) == lengths
    )
    accuracy = jnp.mean(sequence_accuracy)

    predictions = jnp.argmax(logits[0], axis=-1)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
        'token_accuracy': jnp.sum(mask_sequences(token_accuracy, lengths)) / jnp.sum(lengths),
        'predicted_out': labels[0, :, predictions],
        'example_out': (labels[:1], logits[:1])

    }
    return metrics

In [48]:
@jax.jit
def train_step(state, batch):
    labels = batch['answer']
    dropout_rng = jax.random.PRNGKey(0)

    def loss_fn(params):
        (logits, _) = state.apply_fn({'params': params},
                                     batch['embed'],
                                     batch['dec_init'], True, rngs={'dropout': dropout_rng})
        loss = cross_entropy_loss(logits, labels, get_sequence_lengths(labels))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, labels)

    return state, metrics

In [49]:
def train_model(data, path=None, edit_path=None):
    """Train for a fixed number of steps and decode during training."""
    global MODEL_LINE_COUNT, O0_MODEL_LINE_COUNT, OUTPUT_TOKEN_COUNT

    MODEL_LINE_COUNT, O0_MODEL_LINE_COUNT, OUTPUT_TOKEN_COUNT = next(data)

    key = jax.random.PRNGKey(0)

    decoder_shape = jnp.ones((BATCH_SIZE, 1, 128), jnp.float32)
    dec_init = jnp.ones((BATCH_SIZE, OUTPUT_TOKEN_COUNT, 128))
    # key, init_key = jax.random.split(key)
    model = Model()

    train_metrics = {"loss": [], "accuracy": [], "perfect_accuracy": []}
    test_metrics = {"accuracy": [], "perfect_accuracy": []}

    dropout_rng = jax.random.PRNGKey(42)
    init_params = model.init({"params": key, 'dropout': dropout_rng}, decoder_shape, dec_init, False)
    state = train_state.TrainState.create(apply_fn=model.apply, params=init_params["params"],
                                          tx=optax.adam(LEARNING_RATE))
    if path is not None:
        state = state.replace(params=load_model(path, state.params))
    if edit_path is not None:
        fd = flax.core.frozen_dict.FrozenDict({'EditEncoder_0': state.params['EditEncoder_0']})
        state = state.replace(params=state.params.copy(load_model(edit_path, fd)))

    for j in range(NUM_OF_EPOCH):
        for i, (O0, outp) in zip(range(TRAIN_STEPS), data):
            # palmtree_avg = jnp.repeat(jnp.mean(O0, axis=1, keepdims=True), repeats=OUTPUT_TOKEN_COUNT, axis=1)
            palmtree_avg = jnp.tile(jnp.mean(O0, axis=1, keepdims=True), reps=(1, OUTPUT_TOKEN_COUNT, 1))
            # palmtree_avg = np.tile(jnp.mean(O0, axis=1, keepdims=True), reps=(1,OUTPUT_TOKEN_COUNT,1))

            state, metrics = train_step(state, {'answer': outp, 'embed': O0,
                                                'dec_init': palmtree_avg})

            train_metrics["loss"].append(metrics["loss"])
            train_metrics["accuracy"].append(metrics["token_accuracy"])
            train_metrics["perfect_accuracy"].append(metrics["accuracy"])

            print(
                f'Epoch {j} batch {i} loss: {metrics["loss"]:.4f}, token_accuracy: {metrics["token_accuracy"]:.4f}, perfect accuracy: {metrics["accuracy"]:.4f}')
            # print("EXAMPLE", vocab.from_seq(jnp.argmax(metrics['example_out'][0][i], axis=-1)[
            #                                 :get_sequence_lengths(metrics['example_out'][0][i][jnp.newaxis, :, :])[0]]))
            # max_indices = jnp.argmax(metrics['example_out'][1][i], axis=-1)
            # predicted = jax.nn.one_hot(max_indices, num_classes=VOCAB_SIZE)
            # res_lengths = get_sequence_lengths(predicted[jnp.newaxis, :, :])
            # print("RESULT", vocab.from_seq(max_indices[:res_lengths[0]]))

        save_model(state, f"model-{j}.bin")
    #         metrics, examples = test_model(data, state, model, example_count=10)
    #         test_metrics["accuracy"].append(metrics["token accuracy"])
    #         test_metrics["perfect_accuracy"].append(metrics["perfect accuracy"])
    #         print(
    #             f'Epoch {j} test token accuracy: {metrics["token accuracy"]:.4f}, perfect accuracy: {metrics["perfect accuracy"]:.4f}')

    #         print("Examples:")
    #         for expected, result in examples:
    #             print(f"Expected: {expected}\n  Result: {result}")
    # Maybe save model here.
    #         plot(train_metrics, test_metrics)
    #         if j % 5 == 4:
    #             save_model(state, f"model-{j}.bin")

    return state, train_metrics, test_metrics


In [50]:
def plot(train_metrics, test_metrics):
    # Plotting the important metrics
    fig, ax1 = plt.subplots()
    plt.title("Train accuracy, perfect accuracy and loss")
    ax1.plot(train_metrics["loss"])
    ax1.set_ylabel("loss")
    ax2 = ax1.twinx()
    ax2.plot(train_metrics["accuracy"])
    ax2.plot(train_metrics["perfect_accuracy"])
    ax2.set_ylabel("accuracy")
    fig.tight_layout()
    ax1.set_ylim(bottom=0)
    plt.show()

    plt.title("Test accuracy and perfect accuracy")
    plt.plot(list(zip(test_metrics["accuracy"], test_metrics["perfect_accuracy"])))
    plt.show()

    plt.title("Train accuracy and test accuracy")
    train_accuracies = (sum(train_metrics["accuracy"][i:i + TRAIN_STEPS]) / TRAIN_STEPS for i in
                        range(0, len(train_metrics["accuracy"]), TRAIN_STEPS))
    plt.plot(list(zip(train_accuracies, test_metrics["accuracy"])))
    plt.show()

    plt.title("Train perfect accuracy and test perfect accuracy")
    train_accuracies = (sum(train_metrics["perfect_accuracy"][i:i + TRAIN_STEPS]) / TRAIN_STEPS for i in
                        range(0, len(train_metrics["perfect_accuracy"]), TRAIN_STEPS))
    plt.plot(list(zip(train_accuracies, test_metrics["perfect_accuracy"])))
    plt.show()

In [51]:
def test_model(data, state, model, example_count=0, eos_id=vocab.eos_index, generate_edit_representation_dataset=False):
    results = []
    example_iterator = iter(range(example_count))
    accuracy = 0
    perfect_accuracy = 0
    test_num = 0
    # edits = []
    for test_num, (O0, O0_tok, outp) in zip(range(TEST_STEPS), data):
        # palmtree_avg = jnp.repeat(jnp.mean(O0, axis=1, keepdims=True), repeats=OUTPUT_TOKEN_COUNT, axis=1)
        # palmtree_avg = jnp.tile(jnp.mean(O0, axis=1, keepdims=True), reps=(1,OUTPUT_TOKEN_COUNT,1))
        palmtree_avg = np.tile(jnp.mean(O0, axis=1, keepdims=True), reps=(1, OUTPUT_TOKEN_COUNT, 1))

        logits, _ = model.apply({"params": state.params}, O0, palmtree_avg)
        max_indices = jnp.argmax(logits, axis=-1)
        #predicted = jax.nn.one_hot(max_indices, num_classes=VOCAB_SIZE)
        #res_lengths = get_sequence_lengths(predicted)
        for i, j in zip(range(BATCH_SIZE), example_iterator):
            results.append(
                (O0_tok[i],
                 vocab.from_seq(jnp.argmax(outp[i], axis=-1)[:get_sequence_lengths(outp[i][jnp.newaxis, :, :])[0]]),
                 vocab.from_seq(takewhile(lambda x: x != eos_id, max_indices[i]))))
            if len(results[j][2]) != OUTPUT_TOKEN_COUNT:
                results[j][2].append("<eos>")

        #product = jnp.sum(outp[max_indices], axis=-1)
        correct_lengths = get_sequence_lengths(outp)
        correctly_predicted_tokens = mask_sequences(
            jnp.take_along_axis(outp, max_indices[:, :, jnp.newaxis], axis=2)[:, :, 0], correct_lengths)

        accuracy += jnp.sum(correctly_predicted_tokens) / jnp.sum(correct_lengths)
        perfect_accuracy += jnp.sum(jnp.sum(correctly_predicted_tokens, axis=1) == correct_lengths) / BATCH_SIZE

    #         if generate_edit_representation_dataset:
    #             edits.extend(zip(O0_tok, edit_repr.tolist(), O2_tok))

    #     if generate_edit_representation_dataset:
    #         with open('data.bin', 'wb') as f:
    #             f.write(msgpack.packb(edits, use_bin_type=True))

    return {'token accuracy': float(accuracy / (test_num + 1)),
            'perfect accuracy': float(perfect_accuracy / (test_num + 1))}, results

In [52]:
def load_and_test(data_path, model_paths, *args):
    results = []
    data = get_data(data_path)

    MODEL_LINE_COUNT, O0_MODEL_LINE_COUNT, OUTPUT_TOKEN_COUNT = next(data)

    decoder_shape = jnp.ones((BATCH_SIZE, 1, 128), jnp.float32)
    dec_init = jnp.ones((BATCH_SIZE, 220, 128))
    # key, init_key = jax.random.split(key)
    model = Model()

    key = jax.random.PRNGKey(0)
    init_params = model.init({"params": key}, decoder_shape, dec_init, False)
    state = train_state.TrainState.create(apply_fn=model.apply, params=init_params["params"],
                                          tx=optax.adam(LEARNING_RATE))

    for path in model_paths:
        state = state.replace(params=load_model(path, state.params))

        results.append(test_model(data, state, model, *args))
        print(results[-1])

        data = get_data(data_path)
        next(data)

    return results

In [56]:
data = get_data_epoch()

# To use pretrained model

#state, loss, accuracy, perfect_accuracy = train_model(data,"../input/attentionmodel/model (2).bin")

#state, loss, accuracy, perfect_accuracy = train_model(data, edit_path = "../input/attentionmodel/model_edit_2.bin")

state, train_metrics, test_metrics = train_model(data)

plot(train_metrics, test_metrics)

In [None]:
# save_model(state)