<a href="https://colab.research.google.com/github/deterministic-algorithms-lab/NLP-Journey/blob/main/LanguageModelling/CLM_MLM_TLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!git clone https://github.com/deterministic-algorithms-lab/NLP-Journey
%cd NLP-Journey
!pip install -r requirements.txt

In [None]:
import jax
import jax.numpy as jnp
import haiku as hk
from haiku.data_structures import to_immutable_dict
import optax

import numpy as np
from functools import partial

In [None]:
import src.DataLoaders.tfds as tfdl
from src.Tokenizers.tree_tokenizer import Tree_Tokenizer
from src.model.transformer import TransformerFeaturizer, ExtendedEncoder
from src.optimizers.adam import get_adam_opt
from src.Tokenizers.masking_utils import mask_batch_mlm

## Setting Up Config

In [None]:
config = {
          #Data Parameters
          'max_length' : 512, 
          'featurizer_batch_size' : 4,
          'mlm_batch_size' : 4,
          'data_files' : ['heldout_period_data.jsonlist'],

          #Model Parameters
          'intermediate_size' : 3072,
          'n_heads' : 12,
          'n_layers' : 12,
          'hidden_size' : 768,
          'd_model' : 768,                                                      #same as hidden_size
          
          #Embeddings Parameters
          'embed_dropout_rate' : 0.1,
          'lang2id' : {'en' : 1, 'ne' : 2},
          
          #MHA parameters
          'attention_drop_rate' : 0.1,
          
          #MLP parameters
          'fully_connected_drop_rate' : 0.1,
          
          #Training Parameters
          'learning_rate' : 1e-5,
          'max_grad_norm' : 1.0,
          'l2' : 0.1,
          'n_epochs' : 5,
          'n_examples' : 25000,

          #Task no.
          'mlm' : 0,
          'clm' : 1,
          }


## Getting Data

In [None]:
data_loader = load_reddit_data(config)

## Training Tokenizer


In [None]:
def get_sentences():
    for tree in data_loader.tree_generator():
        yield tree['title'] + ' ' + tree['selftext']
        for id, comment in tree['comments']:
            yield comment['body']

In [None]:
lm_tokeniser = Tree_Tokenizer(config)
lm_tokeniser.train_tokenizer(str_iterator=get_sentences())

In [None]:
print(lm_tokeniser.tokenizer.get_vocab())

### Updating Config

In [None]:
config['vocab_size'] = lm_tokeniser.tokenizer.get_vocab_size()

#Tokenization ids  
config['mask_id'] = lm_tokeniser.tokenizer.token_to_id("<mask>")
config['pad_id'] = lm_tokeniser.tokenizer.token_to_id("<pad>")
config['sos_id'] = lm_tokeniser.tokenizer.token_to_id("<s>")
config['eos_id'] = lm_tokeniser.tokenizer.token_to_id("</s>")
config = hk.data_structures.to_immutable_dict(config)

## Purifying the Model Functions and Getting Parameters

In [None]:
@jax.jit
def featurizer(token_ids, training=True):
    features = TransformerFeaturizer(config)(token_ids, training=training)
    return features

@jax.jit
def logits_fn(comment_embds, comment_mask, masked_token_ids, training=True):
    logits = ExtendedEncoder(config)(comment_embds, comment_mask, 
                                     masked_token_ids, training=training)
    return logits

key = jax.random.PRNGKey(42)
pure_logits_fn = hk.transform(logits_fn)
pure_featurizer_fn = hk.transform(featurizer)

comment_encoding = lm_tokeniser.batch_encode_plus(['sample sentence']*config['featurizer_batch_size'])
token_encoding = lm_tokeniser.batch_encode_plus(['sample sentence']*config['mlm_batch_size'])

token_ids = np.asarray(lm_tokeniser.get_token_ids(token_encoding), dtype=np.int16)
comment_ids = np.asarray(lm_tokeniser.get_token_ids(comment_encoding), dtype=np.int16)

masked_token_ids, original_batch = mask_batch_mlm(subkey, config, token_ids)

In [None]:
key, subkey = jax.random.split(key)
featuirzer_params = pure_featurizer_fn.init(subkey, comment_ids, training=True)

key, subkey = jax.random.split(key)
comment_embds = pure_featurizer_fn.apply(featurizer_params, subkey, commnent_ids, training=True)

In [None]:
key, subkey = jax.random.split(key)

comment_embds = jnp.transpose( jnp.tile(comment_embds, config['max_length']), (1,0,2) )
comment_mask = jnp.ones_like(comment_embds[:,:,0])

ExtendedEncoder_params = pure_logits_fn.init(subkey, comment_embds, 
                                             comment_mask, masked_token_ids,
                                             training=training)

params = to_immutable_dict( {'comments_encoder' : featurizer_params, 
                             'mlm_predictor' : ExtendedEncoder_params } )

## Running Model and Getting Loss

In [None]:
def cross_entropy(config, original_batch, logits, masked_token_ids):
    logits_mask = (masked_token_ids==config['mask_id'])
    logits = jax.vmap(jnp.multiply, (None,2), 2)(logits_mask,logits)
    labels = hk.one_hot(original_batch, config['vocab_size'])
    softmax_xent = -jnp.sum(labels*jax.nn.log_softmax(logits))
    total_masks = jnp.sum(logits_mask)
    softmax_xent /= total_masks
    return softmax_xent
    

In [None]:
def loss(params, key, tree, config):
    """
    Calculates loss for all nodes of a single tree.
    The masked tokens of each location in a comment are predicted 
    conditioned on the embeddings of all the parent comments.
    """
    loss = 0

    #Prepare embeddings of each comment
    empty_elem = jnp.asarray([config['pad_id']]*config['max_length'], dtype=jnp.int16)
    batches = tree_to_batch(tree, config['featurizer_batch_size'], empty_elem = empty_elem)
    encodings = []
    for batch in batches:
        key, subkey = jax.random.split(key)
        features = pure_featurizer_fn.apply(subkey, params['comments_encoder'],
                                            batch, training=True)
        encodings.append(features)
    tree = batch_to_tree(tree, encodings, config['featurizer_batch_size'])

    #Calculate loss for each masked position in each comment.
    comment_batches = tree_to_batch(tree, config['mlm_batch_size'], key=None, empty_elem={})
    empty_elem = jnp.asarray([0]*config['d_model'], dtype=jnp.int16)
    
    for original_batch, comment_batch in zip(batches, comment_batches):
        parent_comment_embds, mask_for_embds = gather_batch_parents(tree, comment_batch, 
                                                                    config['max_length'], key='comment_embds', 
                                                                    empty_elem=empty_elem)
        key, subkey = jax.random.split(key)
        masked_batch, original_batch = mask_batch_mlm(subkey, config, original_batch)

        key, subkey = jax.random.split(key)
        logits = pure_logits_fn.apply(params['mlm_predictor'], subkey, parent_comment_embds, 
                                      mask_for_embds, masked_batch, training=True)
        loss += cross_entropy(config, original_batch, logits, masked_batch)
    
    return loss


## Optimizer

In [None]:
opt = get_adam_opt(config)
opt_state = opt.init(params)

In [None]:
def update(opt_state, params, key, tree, config):
    batch_loss, grad = jax.value_and_grad(loss)(params, key, tree, config)
    updates, opt_state = opt.update(grad, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, batch_loss

## Training Loop

In [None]:
losses = []
for step, tree in enumerate(data_loader.tree_generator()):
    if step%100==0:
        print(f'[Step {step}]')
    
    tree = lm_tokeniser.tokenize_tree(tree)
    
    key, subkey = jax.random.split(key)
    params, opt_state, batch_loss = update(opt_state, params, subkey,
                                           tree, config)
    losses.append(batch_loss)

    if step%100==0 and step!=0:
        print(sum(losses)/100)
        losses = []