# Train a small GPT-2  like english model for text summarization using trax framework

Aim: apply GPT-2 model for the  text summarization. The model is trained on the CNN - cnn_dailymail database and is case sensitive.

This notebook is created by Vitaly Shklyar 


                        
Transformer specification:
* vocabulary size: 33300 words
* embeding size  : 512 (default)
* number of heads: 8   (default)
* feedforward size: 2024 (default)
* number of decoders: 6  (default)
* position embedding length: 4096  (default)

Hardware and framework version
+ GTX 1080 Ti
+ Ubuntu version 18.02
+ trax version: 1.3.7
+ python version: 3.6.9 
+ cuda: 11.2
+ nvidia driver version : 460.27.04




In [1]:
%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda

env: XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda


In [2]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

env: XLA_PYTHON_CLIENT_PREALLOCATE=false


In [3]:
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


In [4]:
EOS = 1
PAD = 0

In [5]:
import sys
import os
import numpy as np
import textwrap
wrapper = textwrap.TextWrapper(width=70)
import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp
from trax.supervised import training

In [6]:
np.set_printoptions(threshold=sys.maxsize)

<a name='1'></a>
## Function definintions

In [7]:
def tokenize(input_str):
    """ Tokenizes input string for embedding.
        Input: 
            input_str: (string) input string to tokenize
        Return:        
            list of integers for embedding.
    """
    
    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_dir='vocab_dir/',
                                      vocab_file='vocabulary_en'))
    return list(inputs) 

def detokenize(integer_list):
    """ List of integers to string.
        Input: 
            integer_list: list of integers        
        Return:
            string of tokens        
    """  
    s = trax.data.detokenize(integer_list,
                             vocab_dir='vocab_dir/',
                             vocab_file='vocabulary_en')    
    return wrapper.fill(s)

In [8]:
np_eos = np.array([EOS])
np_eos_pad = np.array([EOS,PAD])
        
def preprocess(stream):
    '''
    Function to input stream for the text summarization: concatenates the article and the summary and 
    creates the training mask. The article and summary are concatened in the form: 
    article + EOS + PAD + summary + EOS
    
    the training mask is created in such a way that only summary and the last EOS token contribute 
    to the cost function. 
    
    Input:
        stream: (article, summary) stream containing a tokenized article and a summary 
        
    Returns:
        (train_stream, train_target, mask)
    '''
    for (article, summary) in stream:
        a = np.concatenate([article, np_eos_pad, summary, np_eos],axis =0)
        m = np.ones(len(a))
        m[:len(article)+2]=0
        yield a,a,m        

In [9]:
def batch_stream(train_stream, 
                 max_length = 2048, 
#                 boundaries = [64, 128, 512, 2048], 
#                 batch_sizes = [64,  32,   16,  4, 1 ],
                 boundaries  = [64, 128, 512, 2048], 
                 batch_sizes = [64,  32,   16,  2, 1 ]):
    """
    Create batch stream.
        Input:
            train_stream: (boolean) defines stream type.  True: returns train stream, False: evaluation stream.
            max_length: (int) the maximum data lengt in the stream.
            boundaries: list of integers for bucketing boundaries
            batch_size: list of integers for the batch sizes corresponding bucketing boundaries
    """
    if train_stream not in (True,False):
        raise Exception("Wrong intput train_stream value. Allowed values are True or False")
    
    return trax.data.Serial(
    trax.data.TFDS('cnn_dailymail',
                                 data_dir='data/',
                                 keys=('article', 'highlights'),
                                 train=train_stream),
    trax.data.Tokenize(vocab_dir='vocab_dir/',
                       vocab_file='vocabulary_en'),
    preprocess, 
    trax.data.Shuffle(102400),
    trax.data.FilterByLength(max_length=2048),
    trax.data.BucketByLength(boundaries=boundaries, batch_sizes=batch_sizes),
    trax.data.AddLossWeights(id_to_mask=0)
  )


In [10]:
def training_loop(model, train_gen, eval_gen, steps=50,  output_dir = "model_gpt2_summarization/"):

    output_dir = os.path.expanduser(output_dir)  # trainer is an object
    lr_schedule = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.002)


    train_task = training.TrainTask( 
      labeled_data=train_gen(), 
      loss_layer=tl.WeightedCategoryCrossEntropy(),
      optimizer=trax.optimizers.Adam(0.002), 
      lr_schedule=lr_schedule,
      n_steps_per_checkpoint=steps
    )

    eval_task = training.EvalTask( 
      labeled_data=eval_gen(), 
      metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()] 
    )
    
    loop = training.Loop(model,
                         train_task,
                         eval_tasks=[eval_task],
                         output_dir=output_dir)
    return loop

In [11]:
def next_random_token(input_tokens, model, temperature = 0.0):
    """ Generate next token given the input sequence.

    Input:
        input_tokens (list): tokenized text as a list of integers
        model: language model
        temperature: (float) a small positive number to bring arbitrariness to the next token selection 

    Returns:
        int: generated token
    """
    
    token_length = len(input_tokens)
    padded_length = 2**int(np.ceil(np.log2(token_length + 1)))

    # pad the input sequence to make total input equal 2**N where N is an integer  
    padded = input_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :] 

    output, _ = model((padded_with_batch, padded_with_batch)) 
    # TransformerLM in trax version 1.3.7 has no softmax layer 
    log_probs = output[0, token_length, :]
    log_probs = np.exp(log_probs)
    log_probs /= np.sum(log_probs)
    return tl.logsoftmax_sample(np.log(log_probs),temperature)



def get_summary(input_text, model, temperature=0.0, max_length=256):
    """ Create text.

    Input:
        input_text (string): a sentence or an article.
        model: (TransformerLM) Transformer language model
        temperature: (float) a small positive number to bring arbitrariness to the next token selection 
        max_length: (int) maximal length of the generated text

    Returns:
        generated text as a string
    """
    input_tokens = tokenize(input_text)+ [EOS]+ [PAD]

    generated_output = [] 
    new_token = 0 
    
    count=0
    while new_token != EOS and count < max_length:
        new_token = next_random_token(input_tokens, model,temperature=temperature)
        input_tokens.append(new_token)
        generated_output.append(new_token)       
        count +=1
    return detokenize(generated_output)

## Train the model

In [12]:
model = trax.models.TransformerLM(  vocab_size=33300,
                        dropout=0.1,
                        max_len=4096,
                        ff_activation=tl.Relu,
                        mode='train')

In [13]:
loop = training_loop(model, batch_stream(train_stream=True), batch_stream(train_stream=False),1000)

In [14]:
#loop.run(80000)  
loop.run(1000) 


Step  797000: Ran 1000 train steps in 485.41 secs
Step  797000: train WeightedCategoryCrossEntropy |  2.63424301
Step  797000: eval  WeightedCategoryCrossEntropy |  3.19755816
Step  797000: eval      WeightedCategoryAccuracy |  0.49704143


## Test the model

In [15]:
del loop
del model

In [16]:
model = trax.models.TransformerLM(  vocab_size=33300,
                        max_len=4096,
                        ff_activation=tl.Relu,
                        mode='eval')
# Load trained weights
model.init_from_file('model_gpt2_summarization/model.pkl.gz', weights_only=True)

## Get summary for the text

In [19]:
# original text from cnn
#https://edition.cnn.com/2021/03/07/world/harry-meghan-oprah-interview-preview-scli-gbr-intl/index.html
input_sentence ="Millions of television viewers around the world are in a state of frenzied anticipation on Sunday, as they await the broadcast of Oprah Winfrey's set-piece interview with Prince Harry and Meghan, Duchess of Sussex. The primetime event, which will be shown on Sunday evening in the United States, has been relentlessly promoted by network CBS and threatens to lift the lid on a litany of frustrations and grievances held by the couple against the institution they quit last year.I don't know how they could expect that after all of this time, we would still just be silent if there is an active role that The Firm is playing in perpetuating falsehoods about us, Meghan said in a clip already released, hinting that she is ready to escalate a war of words between herself and the family she married into."
print(wrapper.fill(input_sentence), '\n')

print("Summary:")
generated_text = get_summary(input_sentence , model, temperature=0.2)
print(generated_text)

Millions of television viewers around the world are in a state of
frenzied anticipation on Sunday, as they await the broadcast of Oprah
Winfrey's set-piece interview with Prince Harry and Meghan, Duchess of
Sussex. The primetime event, which will be shown on Sunday evening in
the United States, has been relentlessly promoted by network CBS and
threatens to lift the lid on a litany of frustrations and grievances
held by the couple against the institution they quit last year.I don't
know how they could expect that after all of this time, we would still
just be silent if there is an active role that The Firm is playing in
perpetuating falsehoods about us, Meghan said in a clip already
released, hinting that she is ready to escalate a war of words between
herself and the family she married into. 

Summary
Oprah Winfrey's set-piece interview with Prince Harry and Meghan
Duchess of Sussex . The pair are in a state of frenzied ffron Sunday,
as they await the broadcast of Oprah Winfrey's set-p

## Summary.
* The model is based on a quite small language model trained on a small corpus (3Gb)
* The major shotcoming is that it ofen produces many repeating sentences
* introducing a small temperature helps to avoid the 

### How to improve:
* scale the model
* train  the language model on a larger corpus
* use mbr or beam search for sammarization (will slow down the calcuations)