**Inference using Decode-only Model**

This notebook was ran on Colab for easy of use. <br>
The notebook expects: <br>
- The trained model "model.pkl.gz" is placed in model folder. <br>
- the trained tokenization model is placed in vocab_dir. <br>
This notebook also includes the functions required to build the model architecture in preparation for evaluation. <br>

References:
https://www.coursera.org/specializations/natural-language-processing
https://github.com/LaurentVeyssier/TRAX_transformer_abstractive_summarization_model/blob/main/TRAX_transformer_summarizer_model.ipynb

**Install all the libraries**

In [None]:
!pip install trax
!pip install t5
!pip install sentencepiece

**Import libraries**

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
import time
import re
import pickle
from sklearn.model_selection import train_test_split
import sentencepiece as spm
import t5

If you are running it on Colab, please mount and specify the correct path

In [None]:
from google.colab import drive
drive.mount('/content/drive')
cd "/content/drive/MyDrive/ML project"

Mounted at /content/drive


In [None]:
import trax
# Preprocess your data to prepare for Trax's TFDS structure
# For example, create a list of tuples where each tuple contains your data
prepared_data = [(article, abstract) for article, abstract in zip(train['article'], train['abstract'])]
val_prepared_data = [(article, abstract) for article, abstract in zip(test['article'], test['abstract'])]


def data_generator(data):
    for text, label in data:
        yield text, label

# Create a generator from your prepared data
train_data_stream = data_generator(prepared_data)
val_data_stream = data_generator(val_prepared_data)

In [None]:
import textwrap
wrapper = textwrap.TextWrapper(width=70)

In [None]:
def tokenize(input_str, EOS=4):
    """Input str to features dict, ready for inference"""

    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_dir='vocab_dir/',
                                      vocab_file='spm_v4.model', ## change to the trained tokenization model
                                      vocab_type='sentencepiece'))

    return list(inputs)

def detokenize(integers):
    """List of ints to str"""

    s = trax.data.detokenize(integers,
                             vocab_dir='vocab_dir/',
                            vocab_file='spm_v4.model', ## change to the trained tokenization model
                            vocab_type='sentencepiece')

    return wrapper.fill(s)

In [None]:
sp = spm.SentencePieceProcessor()
sp.Load("vocab_dir/spm_v4.model")  # Load the trained model

pad_id = sp.piece_to_id('<pad>')


In [None]:
pad_id

3

In [None]:
eos_id = sp.piece_to_id('<EOS>')  # Get the ID of the <pad> token
# Special tokens
SEP = pad_id # Padding or separator token
EOS = eos_id

**Functions build the transfomer model**

In [None]:
from trax import layers as tl
from trax.fastmath import numpy as jnp

In [None]:
# DotProductAttention
def DotProductAttention(query, key, value, mask):
    """Dot product self-attention.
    Args:
        query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d)
        key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d)
        value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k
        mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k)

    Returns:
        jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k)
    """

    assert query.shape[-1] == key.shape[-1] == value.shape[-1], "Embedding dimensions of q, k, v aren't all the same"

    depth = query.shape[-1]
    dots = jnp.matmul(query, jnp.swapaxes(key, -2, -1)) / jnp.sqrt(depth)
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    dots = jnp.exp(dots - logsumexp)
    attention = jnp.matmul(dots, value)

    return attention

In [None]:
# compute_attention_heads_closure
def compute_attention_heads_closure(n_heads, d_head):
    """ Function that simulates environment inside CausalAttention function.
    Args:
        d_head (int):  dimensionality of heads.
        n_heads (int): number of attention heads.
    Returns:
        function: compute_attention_heads function
    """

    def compute_attention_heads(x):
        """ Compute the attention heads.
        Args:
            x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size, seqlen, n_heads X d_head).
        Returns:
            jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size X n_heads, seqlen, d_head).
        """
        batch_size = x.shape[0]
        seqlen = x.shape[1]
        x = jnp.reshape(x,(batch_size, seqlen, n_heads, d_head))
        x = jnp.transpose(x, (0, 2, 1, 3))
        x = jnp.reshape(x, (batch_size*n_heads, seqlen, d_head))

        return x

    return compute_attention_heads

In [None]:
# dot_product_self_attention
def dot_product_self_attention(q, k, v):
    """ Masked dot product self attention.
    Args:
        q (jax.interpreters.xla.DeviceArray): queries.
        k (jax.interpreters.xla.DeviceArray): keys.
        v (jax.interpreters.xla.DeviceArray): values.
    Returns:
        jax.interpreters.xla.DeviceArray: masked dot product self attention tensor.
    """
    mask_size = q.shape[-2]
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)


    return DotProductAttention(q, k, v, mask)

In [None]:
# compute_attention_output_closure
def compute_attention_output_closure(n_heads, d_head):
    """ Function that simulates environment inside CausalAttention function.
    Args:
        d_head (int):  dimensionality of heads.
        n_heads (int): number of attention heads.
    Returns:
        function: compute_attention_output function
    """

    def compute_attention_output(x):
        """ Compute the attention output.
        Args:
            x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size X n_heads, seqlen, d_head).
        Returns:
            jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size, seqlen, n_heads X d_head).
        """
        seqlen = x.shape[1]
        x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
        x = jnp.transpose(x, (0, 2, 1, 3))
        return jnp.reshape(x, (-1, seqlen, n_heads * d_head))

    return compute_attention_output

In [None]:
def CausalAttention(d_feature,
                    n_heads,
                    compute_attention_heads_closure=compute_attention_heads_closure,
                    dot_product_self_attention=dot_product_self_attention,
                    compute_attention_output_closure=compute_attention_output_closure,
                    mode='train'):
    """Transformer-style multi-headed causal attention.

    Args:
        d_feature (int):  dimensionality of feature embedding.
        n_heads (int): number of attention heads.
        compute_attention_heads_closure (function): Closure around compute_attention heads.
        dot_product_self_attention (function): dot_product_self_attention function.
        compute_attention_output_closure (function): Closure around compute_attention_output.
        mode (str): 'train' or 'eval'.

    Returns:
        trax.layers.combinators.Serial: Multi-headed self-attention model.
    """

    assert d_feature % n_heads == 0
    d_head = d_feature // n_heads
    ComputeAttentionHeads = tl.Fn('AttnHeads', compute_attention_heads_closure(n_heads, d_head), n_out=1)


    return tl.Serial(
        tl.Branch(
            [tl.Dense(d_feature), ComputeAttentionHeads],
            [tl.Dense(d_feature), ComputeAttentionHeads],
            [tl.Dense(d_feature), ComputeAttentionHeads],
        ),

        tl.Fn('DotProductAttn', dot_product_self_attention, n_out=1),
        tl.Fn('AttnOutput', compute_attention_output_closure(n_heads, d_head), n_out=1),
        tl.Dense(d_feature),
    )



In [None]:
# DecoderBlock
def DecoderBlock(d_model, d_ff, n_heads,
                 dropout, mode, ff_activation):
    """Returns a list of layers that implements a Transformer decoder block.

    The input is an activation tensor.

    Args:
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        n_heads (int): number of attention heads.
        dropout (float): dropout rate (how much to drop out).
        mode (str): 'train' or 'eval'.
        ff_activation (function): the non-linearity in feed-forward layer.

    Returns:
        list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor.
    """
    causal_attention = CausalAttention(
                        d_model,
                        n_heads=n_heads,
                        mode=mode
                        )

    feed_forward = [
        tl.LayerNorm(),
        tl.Dense(d_ff),
        ff_activation(),
        tl.Dropout(rate=dropout, mode=mode),
        tl.Dense(d_model),
        tl.Dropout(rate=dropout, mode=mode)
    ]

    return [
      tl.Residual(
          tl.LayerNorm(),
          causal_attention,
          tl.Dropout(rate=dropout, mode=mode),
        ),
      tl.Residual(
          feed_forward
        ),
      ]


In [None]:
# TransformerLM
def TransformerLM(vocab_size=32000,
                  d_model=512,
                  d_ff=2048,
                  n_layers=6,
                  n_heads=8,
                  dropout=0.1,
                  max_len=4096,
                  mode='train',
                  ff_activation=tl.Relu):
    """Returns a Transformer language model.

    The input to the model is a tensor of tokens. (This model uses only the
    decoder part of the overall Transformer.)

    Args:
        vocab_size (int): vocab size.
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        n_layers (int): number of decoder layers.
        n_heads (int): number of attention heads.
        dropout (float): dropout rate (how much to drop out).
        max_len (int): maximum symbol length for positional encoding.
        mode (str): 'train', 'eval' or 'predict', predict mode is for fast inference.
        ff_activation (function): the non-linearity in feed-forward layer.

    Returns:
        trax.layers.combinators.Serial: A Transformer language model as a layer that maps from a tensor of tokens
        to activations over a vocab set.
    """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len, mode=mode)]

    decoder_blocks = [DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation) for _ in range(n_layers)]

    return tl.Serial(
        tl.ShiftRight(mode=mode),
        positional_encoder,
        decoder_blocks,
        tl.LayerNorm(),
        tl.Dense(vocab_size),
        tl.LogSoftmax()
    )


In [None]:
# Get the model architecture
model = TransformerLM(mode='eval')

# Load the pre-trained weights ## specify the model name and path
model.init_from_file('model/model.pkl.gz', weights_only=False)

In [None]:
## helper function to get the next token from the model
def next_symbol(cur_output_tokens, model):
    """Returns the next symbol for a given sentence.

    Args:
        cur_output_tokens (list): tokenized sentence with EOS and PAD tokens at the end.
        model (trax.layers.combinators.Serial): The transformer model.

    Returns:
        int: tokenized symbol.
    """
    token_length = len(cur_output_tokens)
    padded_length = 2**int(np.ceil(np.log2(token_length + 1)))

    padded = cur_output_tokens + [3] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :]
    output, _ = model((padded_with_batch, padded_with_batch))
    log_probs = output[0, token_length, :]

    return int(np.argmax(log_probs))

In [None]:
def greedy_decode(input_sentence, model, next_symbol=next_symbol, tokenize=tokenize, detokenize=detokenize):
    """Greedy decode function.

    Args:
        input_sentence (string): a sentence or article.
        model (trax.layers.combinators.Serial): Transformer model.

    Returns:
        string: summary of the input.
    """

    cur_output_tokens = tokenize(input_sentence) + [4,3]
    generated_output = []
    cur_output = 0
    EOS = 4

    while cur_output != EOS:
        cur_output = next_symbol(cur_output_tokens, model)
        cur_output_tokens.append(cur_output)
        generated_output.append(cur_output)

        print(detokenize(generated_output))

    return detokenize(generated_output)


**Here we try to generate summary using the model**

In [None]:
test_sentence = " patients with cervical radiculopathy due to single  level degenerative disc disease who fail to improve with nonoperative therapy are candidates for anterior decompression and reconstruction with either an arthrodesis or arthroplasty  both procedures require minimal hospitalization and are highly effective in relieving pain and improving neurological function  the ability to return to work and the speed with which this occurs are important to the individual being treated and also to society  arthrodesis and arthroplasty differ in that one treatment eliminates motion at a cervical spinal segment while the other preserves it  this fundamental difference may impact postoperative function in terms of activity level  which ultimately could facilitate or hinder the ability to return to work  in patients with degenerative disease of the cervical spine  does cervical artificial disc replacement lead to better work  related outcomes than fusion  does return to work after surgery differ based on gender  age  smoking  l"
print(wrapper.fill(test_sentence), '\n')
print(greedy_decode(test_sentence, model))

 patients with cervical radiculopathy due to single  level
degenerative disc disease who fail to improve with nonoperative
therapy are candidates for anterior decompression and reconstruction
with either an arthrodesis or arthroplasty  both procedures require
minimal hospitalization and are highly effective in relieving pain and
improving neurological function  the ability to return to work and the
speed with which this occurs are important to the individual being
treated and also to society  arthrodesis and arthroplasty differ in
that one treatment eliminates motion at a cervical spinal segment
while the other preserves it  this fundamental difference may impact
postoperative function in terms of activity level  which ultimately
could facilitate or hinder the ability to return to work  in patients
with degenerative disease of the cervical spine  does cervical
artificial disc replacement lead to better work  related outcomes than
fusion  does return to work after surgery differ based o

KeyboardInterrupt: 