In [1]:
% % capture
!python --version
!pip install --upgrade jaxlib flax
!pip install --upgrade "jax[cuda]" -f https: // storage.googleapis.com / jax-releases / jax_releases.html
!pip install bert-pytorch msgpack tbp-nightly

UsageError: Line magic function `%` not found.


In [None]:
import functools
import itertools
import os
import pickle
from collections import Counter
from typing import *

import flax.linen as nn
import flax.serialization
import flax.training.train_state as train_state
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import msgpack
import numpy as np
import optax
import tqdm

if 'TPU_NAME' in os.environ:
    import requests

    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1

    from jax.config import config

    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

jax.devices()

In [None]:
class TorchVocab(object):
    """Defines a vocabulary object that will be used to numericalize a field.
    Attributes:
        freqs: A collections.Counter object holding the frequencies of tokens
            in the data used to build the Vocab.
        stoi: A collections.defaultdict instance mapping token strings to
            numerical identifiers.
        itos: A list of token strings indexed by their numerical identifiers.
    """

    def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
                 vectors=None, unk_init=None, vectors_cache=None):
        """Create a Vocab object from a collections.Counter.
        Arguments:
            counter: collections.Counter object holding the frequencies of
                each value found in the data.
            max_size: The maximum size of the vocabulary, or None for no
                maximum. Default: None.
            min_freq: The minimum frequency needed to include a token in the
                vocabulary. Values less than 1 will be set to 1. Default: 1.
            specials: The list of special tokens (e.g., padding or eos) that
                will be prepended to the vocabulary in addition to an <unk>
                token. Default: ['<pad>']
            vectors: One of either the available pretrained vectors
                or custom pretrained vectors (see Vocab.load_vectors);
                or a list of aforementioned vectors
            unk_init (callback): by default, initialize out-of-vocabulary word vectors
                to zero vectors; can be any function that takes in a Tensor and
                returns a Tensor of the same size. Default: torch.Tensor.zero_
            vectors_cache: directory for cached vectors. Default: '.vector_cache'
        """
        self.freqs = counter
        counter = counter.copy()
        min_freq = max(min_freq, 1)

        self.itos = list(specials)
        # frequencies of special tokens are not counted when building vocabulary
        # in frequency order
        for tok in specials:
            del counter[tok]

        max_size = None if max_size is None else max_size + len(self.itos)

        # sort by frequency, then alphabetically
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)

        for word, freq in words_and_frequencies:
            if freq < min_freq or len(self.itos) == max_size:
                break
            self.itos.append(word)

        # stoi is simply a reverse dict for itos
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}

        self.vectors = None
        if vectors is not None:
            self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
        else:
            assert unk_init is None and vectors_cache is None

    def __eq__(self, other):
        if self.freqs != other.freqs:
            return False
        if self.stoi != other.stoi:
            return False
        if self.itos != other.itos:
            return False
        if self.vectors != other.vectors:
            return False
        return True

    def __len__(self):
        return len(self.itos)

    def vocab_rerank(self):
        self.stoi = {word: i for i, word in enumerate(self.itos)}

    def extend(self, v, sort=False):
        words = sorted(v.itos) if sort else v.itos
        for w in words:
            if w not in self.stoi:
                self.itos.append(w)
                self.stoi[w] = len(self.itos) - 1


class Vocab(TorchVocab):
    def __init__(self, counter, max_size=None, min_freq=1):
        self.pad_index = 0
        self.unk_index = 1
        self.eos_index = 2
        self.sos_index = 3
        self.mask_index = 4
        super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"],
                         max_size=max_size, min_freq=min_freq)

    def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:
        pass

    def from_seq(self, seq, join=False, with_pad=False):
        pass

    @staticmethod
    def load_vocab(vocab_path: str) -> 'Vocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)

    def save_vocab(self, vocab_path):
        with open(vocab_path, "wb") as f:
            pickle.dump(self, f)


# Building Vocab with text files
class WordVocab(Vocab):
    def __init__(self, texts, max_size=None, min_freq=1):
        print("Building Vocab")
        counter = Counter()
        for line in tqdm.tqdm(texts):
            if isinstance(line, list):
                words = line
            else:
                words = line.replace("\n", " ").replace("\t", " ").split()[:4]

            for word in words:
                counter[word] += 1
        super().__init__(counter, max_size=max_size, min_freq=min_freq)

    def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):
        if isinstance(sentence, str):
            sentence = sentence.split()

        seq = [self.stoi.get(word, self.unk_index) for word in sentence]
        print(seq)

        if with_eos:
            seq += [self.eos_index]  # this would be index 1
        if with_sos:
            seq = [self.sos_index] + seq

        origin_seq_len = len(seq)

        if seq_len is None:
            pass
        elif len(seq) <= seq_len:
            seq += [self.pad_index for _ in range(seq_len - len(seq))]
        else:
            seq = seq[:seq_len]

        return (seq, origin_seq_len) if with_len else seq

    def from_seq(self, seq, join=False, with_pad=False):
        words = [self.itos[idx]
                 if idx < len(self.itos)
                 else "<%d>" % idx
                 for idx in seq
                 if not with_pad or idx != self.pad_index]

        return " ".join(words) if join else words

    @staticmethod
    def load_vocab(vocab_path: str) -> 'WordVocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)


def build():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--corpus_path", required=True, type=str)
    parser.add_argument("-o", "--output_path", required=True, type=str)
    parser.add_argument("-s", "--vocab_size", type=int, default=None)
    parser.add_argument("-e", "--encoding", type=str, default="utf-8")
    parser.add_argument("-m", "--min_freq", type=int, default=1)
    args = parser.parse_args()

    with open(args.corpus_path, "r", encoding=args.encoding) as f:
        vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq)

    print("VOCAB SIZE:", len(vocab))
    print(vocab)
    vocab.save_vocab(args.output_path)

In [None]:
# Configuration constants
TESTING = False

# Architecture-specific constants
EMBEDDING_SIZE = 128
ENCODER_HIDDEN_SIZE = 512
INPUT_ENCODER_HIDDEN_SIZE = 1024

# Model-specific constants
MAX_ASSEMBLY_LINE_LENGTH = 10
SEQ_LENGTH = EMBEDDING_SIZE * 2 + 4 * MAX_ASSEMBLY_LINE_LENGTH

# Training related constants
BATCH_SIZE = 8 if TESTING else 32
TRAIN_STEPS = 30
NUM_OF_TESTS = 2
LEARNING_RATE = 0.001

# Palmtree related variables
vocab = WordVocab.load_vocab("../input/palmtreevocab/vocab")
VOCAB_SIZE = len(vocab)

EOS_ID = [0 for _ in range(MAX_ASSEMBLY_LINE_LENGTH)]

In [None]:
class EncoderLSTM(nn.Module):
    @functools.partial(
        nn.transforms.scan,
        variable_broadcast='params',
        in_axes=1,
        out_axes=1,
        split_rngs={'params': False})
    @nn.compact
    # x = (BATCH_SIZE, SEQ_LENGTH)
    def __call__(self, carry, x):
        lstm_state, is_eos = carry
        new_lstm_state, y = nn.LSTMCell()(lstm_state, x)

        def select_carried_state(new_state, old_state):
            return jnp.where(is_eos[:, np.newaxis], old_state, new_state)

        carried_lstm_state = tuple(select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
        is_eos = jnp.logical_or(is_eos, jnp.any(x, axis=1))
        return (carried_lstm_state, is_eos), y

    @staticmethod
    def initialize_carry(batch_size, hidden_size):
        # use dummy key since default state init fn is just zeros.
        return nn.LSTMCell.initialize_carry(
            jax.random.PRNGKey(0), (batch_size,), hidden_size)

In [None]:
class Encoder(nn.Module):
    hidden_size: int

    @nn.compact
    def __call__(self, inputs):
        # inputs = (BATCH_SIZE,MODEL_LINE_COUNT,SEQ_LENGTH)
        batch_size = inputs.shape[0]
        init_lstm_state = EncoderLSTM.initialize_carry(batch_size, self.hidden_size)
        init_is_eos = jnp.zeros(batch_size, dtype=np.bool)
        init_carry = (init_lstm_state, init_is_eos)
        (final_state, _), _ = EncoderLSTM()(init_carry, inputs)
        return final_state

In [None]:
def pad_output(l, max_length, index):
    pad = np.zeros((max_length - l.shape[0], l.shape[1]))
    pad[:, index] = 1
    return jnp.concatenate((l, pad), axis=0)

In [None]:
def get_data(path='../input/smallmsgpack-filtered/data.bin'):
    with open(path, "rb") as f:
        data = msgpack.Unpacker(f, raw=False, strict_map_key=False)

        metadata = next(data)

        MODEL_LINE_COUNT = max(metadata["max_line_count_O0"], metadata["max_line_count_O2"])
        O0_MODEL_LINE_COUNT = metadata["max_line_count_O0"]
        OUTPUT_TOKEN_COUNT = MODEL_LINE_COUNT * MAX_ASSEMBLY_LINE_LENGTH

        yield MODEL_LINE_COUNT, O0_MODEL_LINE_COUNT, OUTPUT_TOKEN_COUNT

        O0 = np.zeros((BATCH_SIZE, O0_MODEL_LINE_COUNT + 1, EMBEDDING_SIZE))
        inp = np.zeros((BATCH_SIZE, MODEL_LINE_COUNT + 1, SEQ_LENGTH))
        outp = np.zeros((BATCH_SIZE, OUTPUT_TOKEN_COUNT, VOCAB_SIZE))

        i = 0
        for program in data:
            for line in program.values():
                # print(line)
                O0[i % BATCH_SIZE] = jnp.pad(jnp.array(line['O0']),
                                             pad_width=[(0, O0_MODEL_LINE_COUNT + 1 - len(line['O0'])), (0, 0)])
                inp[i % BATCH_SIZE] = jnp.concatenate((jnp.pad(jnp.array(line['O0']),
                                                               pad_width=[(0, MODEL_LINE_COUNT + 1 - len(line['O0'])),
                                                                          (0, 0)]), jnp.pad(jnp.array(line['O2']),
                                                                                            pad_width=[(0,
                                                                                                        MODEL_LINE_COUNT + 1 - len(
                                                                                                            line[
                                                                                                                'O2'])),
                                                                                                       (0, 0)]),
                                                       jnp.pad(jnp.array(line['diff']),
                                                               pad_width=[(0, MODEL_LINE_COUNT + 1 - len(line['diff'])),
                                                                          (0, 0)])), axis=1)
                outp[i % BATCH_SIZE] = pad_output(jax.nn.one_hot(jnp.array(
                    vocab.to_seq(itertools.chain.from_iterable(map(lambda x: x + ['<sos>'], line['O2_tokens'])))[:-1]),
                    VOCAB_SIZE), OUTPUT_TOKEN_COUNT, vocab.eos_index)
                i += 1
                if i % BATCH_SIZE == 0:
                    yield O0, inp, outp

In [None]:
class DecoderLSTM(nn.Module):
    teacher_force: bool

    @functools.partial(
        nn.transforms.scan,
        variable_broadcast='params',
        in_axes=1,
        out_axes=1,
        split_rngs={'params': False})
    @nn.compact
    def __call__(self, carry, x):
        rng, lstm_state, last_prediction = carry
        carry_rng, categorical_rng = jax.random.split(rng, 2)
        if not self.teacher_force:
            x = last_prediction
        lstm_state, y = nn.LSTMCell()(lstm_state, x)
        y = nn.Dense(features=1024)(y)
        y = nn.Dense(features=1024)(y)
        logits = nn.Dense(features=VOCAB_SIZE)(y)
        predicted_token = jax.random.categorical(categorical_rng, logits)
        prediction = jax.nn.one_hot(predicted_token, VOCAB_SIZE, dtype=jnp.float32)
        return (carry_rng, lstm_state, prediction), (logits, prediction)

In [None]:
class Decoder(nn.Module):
    init_state: Tuple[Any]
    teacher_force: bool

    @nn.compact
    def __call__(self, inputs):
        # inputs.shape = (seq_length, vocab_size).
        lstm = DecoderLSTM(teacher_force=self.teacher_force)
        key = jax.random.PRNGKey(0)
        init_carry = (key, self.init_state, jnp.ones((BATCH_SIZE, VOCAB_SIZE), dtype=jnp.float32))
        _, (logits, predictions) = lstm(init_carry, inputs)
        return logits, predictions



In [None]:
def load_model(path, state):
    with open(path, 'rb') as f:
        saved_state = f.read()
        return flax.serialization.from_bytes(state, saved_state)


def save_model(state):
    with open('model.bin', 'wb') as f:
        f.write(flax.serialization.to_bytes(state.params))

In [None]:
class Model(nn.Module):
    @nn.compact
    def __call__(self, enc_inputs, dec_inputs):
        enc_result = Encoder(hidden_size=ENCODER_HIDDEN_SIZE)(enc_inputs)  # O0 + O2 + diff -> edit representation
        decoder_init_state = enc_result
        dec_result = Decoder(teacher_force=False, init_state=decoder_init_state)(dec_inputs)
        return dec_result


In [None]:
def mask_sequences(sequence_batch, lengths):
    """Set positions beyond the length of each sequence to 0."""
    return sequence_batch * (
            lengths[:, np.newaxis] > np.arange(sequence_batch.shape[1])[np.newaxis])

In [None]:
def cross_entropy_loss(logits, labels, lengths):
    """Returns cross-entropy loss."""
    xe = jnp.sum(nn.log_softmax(logits) * labels, axis=-1)
    masked_xe = jnp.sum(mask_sequences(xe, lengths)) / jnp.sum(lengths)
    return -masked_xe


In [None]:
def get_sequence_lengths(sequence_batch, eos_id=vocab.eos_index):
    """Returns the length of each one-hot sequence, including the EOS token."""
    # sequence_batch.shape = (batch_size, seq_length, vocab_size)
    eos_row = sequence_batch[:, :, eos_id]
    eos_idx = jnp.argmax(eos_row, axis=-1)  # returns first occurrence
    # `eos_idx` is 0 if EOS is not present, so we use full length in that case.
    return jnp.where(
        eos_row[jnp.arange(eos_row.shape[0]), eos_idx],
        eos_idx + 1,
        sequence_batch.shape[1]  # if there is no EOS, use full length
    )

In [None]:
def compute_metrics(logits, labels):
    """Computes metrics and returns them."""
    lengths = get_sequence_lengths(labels)
    loss = cross_entropy_loss(logits, labels, lengths)
    # Computes sequence accuracy, which is the same as the accuracy during
    # inference, since teacher forcing is irrelevant when all output are correct.
    token_accuracy = jnp.argmax(logits, -1) == jnp.argmax(labels, -1)
    sequence_accuracy = (
            jnp.sum(mask_sequences(token_accuracy, lengths), axis=-1) == lengths
    )
    accuracy = jnp.mean(sequence_accuracy)

    predictions = jnp.argmax(logits[0], axis=-1)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
        'token_accuracy': jnp.sum(mask_sequences(token_accuracy, lengths)) / jnp.sum(lengths),
        'predicted_out': labels[0, :, predictions],
        'example_out': (labels[:1], logits[:1])

    }
    return metrics

In [None]:
@jax.jit
def train_step(state, batch):
    labels = batch['answer']

    def loss_fn(params):
        logits, _ = state.apply_fn({'params': params},
                                   batch['query'],
                                   batch['embed'])
        loss = cross_entropy_loss(logits, labels, get_sequence_lengths(labels))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, labels)

    return state, metrics

In [None]:
def train_model(data, path=None):
    """Train for a fixed number of steps and decode during training."""
    global MODEL_LINE_COUNT, O0_MODEL_LINE_COUNT, OUTPUT_TOKEN_COUNT

    MODEL_LINE_COUNT, O0_MODEL_LINE_COUNT, OUTPUT_TOKEN_COUNT = next(data)

    key = jax.random.PRNGKey(0)

    encoder_shape = jnp.ones((BATCH_SIZE, MODEL_LINE_COUNT + 1, SEQ_LENGTH), jnp.float32)
    decoder_shape = jnp.ones((1, 1, 128), jnp.float32)
    # key, init_key = jax.random.split(key)
    model = Model()

    loss = []
    accuracy = []
    perfect_accuracy = []

    init_params = model.init({"params": key}, encoder_shape, decoder_shape)
    state = train_state.TrainState.create(apply_fn=model.apply, params=init_params["params"],
                                          tx=optax.adam(LEARNING_RATE))
    if path is not None:
        state = state.replace(params=load_model(path, state.params))
    for i, (O0, inp, outp) in zip(range(TRAIN_STEPS), data):
        palmtree_avg = jnp.repeat(jnp.mean(O0, axis=1, keepdims=True), repeats=OUTPUT_TOKEN_COUNT, axis=1)
        state, metrics = train_step(state, {'query': inp, 'answer': outp, 'embed': palmtree_avg})

        loss.append(metrics["loss"])
        accuracy.append(metrics["token_accuracy"])
        perfect_accuracy.append(metrics["accuracy"])

        print(i, "loss:", metrics["loss"], "perfect_accuracy:", metrics["accuracy"], "token_accuracy",
              metrics["token_accuracy"])
        print("EXAMPLE", vocab.from_seq(jnp.argmax(metrics['example_out'][0][i], axis=-1)[
                                        :get_sequence_lengths(metrics['example_out'][0][i][jnp.newaxis, :, :])[0]]))
        max_indices = jnp.argmax(metrics['example_out'][1][i], axis=-1)
        predicted = jax.nn.one_hot(max_indices, num_classes=VOCAB_SIZE)
        res_lengths = get_sequence_lengths(predicted[jnp.newaxis, :, :])
        print("RESULT", vocab.from_seq(max_indices[:res_lengths[0]]))
    return state, loss, accuracy, perfect_accuracy


In [None]:
def test_model(data, state):
    results = []
    accuracy = 0
    for i, (O0, inp, outp) in zip(range(NUM_OF_TESTS), data):
        palmtree_avg = jnp.repeat(jnp.mean(O0, axis=1, keepdims=True), repeats=OUTPUT_TOKEN_COUNT, axis=1)
        logits, _ = Model().apply({"params": state.params}, inp, palmtree_avg)
        max_indices = jnp.argmax(logits, axis=-1)
        predicted = jax.nn.one_hot(max_indices, num_classes=VOCAB_SIZE)
        res_lengths = get_sequence_lengths(predicted)
        for i in range(BATCH_SIZE):
            results.append(vocab.from_seq(max_indices[i, :res_lengths[i]]))

        product = jnp.sum(predicted * outp, axis=-1)
        accuracy += jnp.sum(mask_sequences(product, get_sequence_lengths(outp))) / jnp.sum(get_sequence_lengths(outp))

    return {'accuracy': accuracy / NUM_OF_TESTS}

In [None]:
data = get_data()

# To use pretrained model
#state, loss, accuracy, perfect_accuracy = train_model('./model.bin')

state, loss, accuracy, perfect_accuracy = train_model(data)
print(test_model(data, state))
save_model(state)


# Plotting the important metrics
fig, ax1 = plt.subplots()
ax1.plot(loss)
ax1.set_ylabel("loss")
ax2 = ax1.twinx()
ax2.plot(accuracy)
ax2.plot(perfect_accuracy)
ax2.set_ylabel("accuracy")
fig.tight_layout()
ax1.set_ylim(bottom=0)
plt.show()