# Train a small GPT-2  like english model using trax framework

Aim: train a small GPT-2 model english model. The model is trained on the CNN - cnn_dailymail database and  is case sensitive. The model can be used for the text summarization, sentiment analysis etc.

This notebook is created by Vitaly Shklyar 
                        
Transformer specification:
* vocabulary size: 33300 words
* embeding size  : 512 (default)
* number of heads: 8    (default)
* feedforward size: 2048 (default)
* number of decoders: 6   (default)
* position embedding length: 4096

Hardware and framework version
+ GTX 1080 Ti
+ Ubuntu version 18.02
+ trax version: 1.3.7
+ python version: 3.6.9 
+ cuda: 11.2
+ nvidia driver version : 460.27.04

In [1]:
%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda

env: XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda


In [2]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

env: XLA_PYTHON_CLIENT_PREALLOCATE=false


In [3]:
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


In [4]:
EOS=1

In [5]:
import sys
import os
import numpy as np
import textwrap
wrapper = textwrap.TextWrapper(width=70)
import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp
from trax.supervised import training

In [6]:
np.set_printoptions(threshold=sys.maxsize)

<a name='1'></a>
## Function definintions

In [7]:
def tokenize(input_str):
    """ Tokenizes input string for embedding.
        Input: 
            input_str: (string) input string to tokenize
        Return:        
            list of integers for embedding.
    """
    
    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_dir='vocab_dir/',
                                      vocab_file='vocabulary_en'))
    return list(inputs) 

def detokenize(integer_list):
    """ List of integers to string.
        Input: 
            integer_list: list of integers        
        Return:
            string of tokens        
    """  
    s = trax.data.detokenize(integer_list,
                             vocab_dir='vocab_dir/',
                             vocab_file='vocabulary_en')    
    return wrapper.fill(s)

In [8]:
np_eos = np.array([EOS])
def preprocess(stream):
    for (feature, target) in stream:
        new_feature = np.concatenate([feature, np_eos],axis =0)
        yield new_feature,new_feature

In [9]:
def batch_stream(train_stream, 
                 max_length = 2048, 
                 #boundaries = [64, 128, 512, 2048], 
                 #batch_sizes = [32,  16,   8,  2, 1 ]):
                 boundaries = [64, 128, 512, 2048], 
                 batch_sizes = [64,  32,   16,  2, 1 ]):
    """
    Create batch stream.
        Input:
            train_stream: (boolean) defines stream type.  True: returns train stream, False: evaluation stream.
            max_length: (int) the maximum data lengt in the stream.
            boundaries: list of integers for bucketing boundaries
            batch_size: list of integers for the batch sizes corresponding bucketing boundaries
    """
    if train_stream not in (True,False):
        raise Exception("Wrong intput train_stream value. Allowed values are True or False")
    
    return trax.data.Serial(
    trax.data.TFDS('cnn_dailymail',
                                 data_dir='data/',
                                 keys=('article', 'article'),
                                 train=train_stream),
    trax.data.Tokenize(vocab_dir='vocab_dir/',
                       vocab_file='vocabulary_en'),
    preprocess, 
    trax.data.Shuffle(102400),
    trax.data.FilterByLength(max_length=2048),
    trax.data.BucketByLength(boundaries=boundaries, batch_sizes=batch_sizes),
    trax.data.AddLossWeights(id_to_mask=0)
  )


In [10]:
def training_loop(model, train_gen, eval_gen, steps=50,  output_dir = "model_gpt2/"):

    output_dir = os.path.expanduser(output_dir)  # trainer is an object
    lr_schedule = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.0025)


    train_task = training.TrainTask( 
      labeled_data=train_gen(), 
      loss_layer=tl.WeightedCategoryCrossEntropy(),
      optimizer=trax.optimizers.Adam(0.0025), 
      lr_schedule=lr_schedule,
      n_steps_per_checkpoint=steps
    )

    eval_task = training.EvalTask( 
      labeled_data=eval_gen(), 
      metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()] 
    )
    
    loop = training.Loop(model,
                         train_task,
                         eval_tasks=[eval_task],
                         output_dir=output_dir)
    return loop

In [11]:
def next_token(input_tokens, model, temperature = 0.0):
    """ Generate next token given the input sequence.

    Input:
        input_tokens (list): tokenized text as a list of integers
        model: language model
        temperature: (float) a small positive number to bring arbitrariness to the next token selection 

    Returns:
        int: generated token
    """
    
    token_length = len(input_tokens)
    padded_length = 2**int(np.ceil(np.log2(token_length + 1)))

    # pad the input sequence to make total input equal 2**N where N is an integer  
    padded = input_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :] 

    output, _ = model((padded_with_batch, padded_with_batch)) 
    # TransformerLM in trax version 1.3.7 has no softmax layer 
    log_probs = output[0, token_length, :]
    log_probs = np.exp(log_probs)
    log_probs /= np.sum(log_probs)
    return tl.logsoftmax_sample(np.log(log_probs),temperature)
    

In [12]:
def produce_text(input_text, model, temperature=0.0, max_length=256):
    """ Create text.

    Input:
        input_text (string): a sentence or an article.
        model: (TransformerLM) Transformer language model
        temperature: (float) a small positive number to bring arbitrariness to the next token selection 
        max_length: (int) maximal length of the generated text

    Returns:
        generated text as a string
    """
    input_tokens = tokenize(input_text)

    generated_output = [] 
    new_token = 0 
    
    count=0
    while new_token != EOS and count < max_length:
        new_token = next_token(input_tokens, model,temperature=temperature)
        input_tokens.append(new_token)
        generated_output.append(new_token)       
        count +=1
    return detokenize(generated_output)

## Train the model

In [13]:
model = trax.models.TransformerLM(  vocab_size=33300,
                        dropout=0.1,
                        max_len=4096,
                        ff_activation=tl.Relu,
                        mode='train')

In [14]:
loop = training_loop(model, batch_stream(train_stream=True), batch_stream(train_stream=False),1000)

In [15]:
#loop.run(80000) 
loop.run(1000) 


Step  637000: Ran 1000 train steps in 476.49 secs
Step  637000: train WeightedCategoryCrossEntropy |  3.53724289
Step  637000: eval  WeightedCategoryCrossEntropy |  3.59622145
Step  637000: eval      WeightedCategoryAccuracy |  0.33544514


## Test the model

In [16]:
del loop
del model

In [17]:
model = trax.models.TransformerLM(  vocab_size=33300,
                        max_len=4096,
                        ff_activation=tl.Relu,
                        mode='eval')
model.init_from_file('model_gpt2/model.pkl.gz', weights_only=True)

In [18]:
#test_sentence = "Here comes the sun "
#s=next_token(tokenize(test_sentence), model, 0.8)
#detokenize( [s] )

## Generate an arbitrary text using a starting sentence

In [19]:
# Input text from the Gardian'a article
#https://www.theguardian.com/commentisfree/2021/feb/22/hiv-covid-pandemics-fear-disease-prejudice

input_sentence = "Many people have experienced the Covid-19 pandemic as an event of pure novelty: a sudden and unexpected break from the past. "
print(wrapper.fill(input_sentence), '\n')
print(produce_text(input_sentence, model, temperature=0.9))

Many people have experienced the Covid-19 pandemic as an event of pure
novelty: a sudden and unexpected break from the past. 

But this is the 18th-century Chinese pandemic. The Chen Zhou daily
ritual in Tianxi was known by young Chinese students for the ‘ٹev-
McNair’, or what is the Chinese capital of China’s Sichuan Province.
It was thought the Chinese government had an iron Yang-wound to the
chest, but he suffered mild complications. Scroll down for video . The
Chinese pandemic is also known by young Chinese students for the 'ổev-
McNair' high-resolution mortally in Tianxi, China . There are now
70,000 Chinese pandemic - originally named 'GA mainland' and was
commissioned by China Public Library. The Xiang Liu Aug Xu, a
spokesperson for Beijing both happens in recent years, said: ‘This
shows a bit of pride in Beijing, but the body was like a joke.’ He
added that China’s live ballistic nation are ‘as far as hell getting
this flu’. And China, which was the beginning of Jordan in 2003,