# Assignment 3, Part 1: BERT Loss Model 

Welcome to the part 1 of testing the models for this week's assignment. We will perform decoding using the BERT Loss model. In this notebook we'll use an input, mask (hide) random word(s) in it and see how well we get the "Target" answer(s). 

## IMPORTANT

- As you cannot save the changes you make to this colab, you have to make a copy of this notebook in your own drive and run that. You can do so by going to `File -> Save a copy in Drive`. Close this colab and open the copy which you have made in your own drive.

- Go to this [google drive folder](https://drive.google.com/drive/folders/1rOZsbEzcpMRVvgrRULRh1JPFpkIG_JOz?usp=sharing) named `NLP C4 W3 Data`. In the folder, next to its name use the drop down menu to select `"Add shortcut to Drive" -> "My Drive" and then press ADD SHORTCUT`. This should add a shortcut to the folder `NLP C4 W3 Data` within your own google drive. Please make sure this happens, as you'll be reading the data for this notebook from this folder.

- Make sure your runtime is GPU (_not_ CPU or TPU). And if it is an option, make sure you are using _Python 3_. You can select these settings by going to `Runtime -> Change runtime type -> Select the above mentioned settings and then press SAVE`

**Note: Restarting the runtime maybe required**.

Colab will tell you if the restarting is necessary -- you can do this from the:

Runtime > Restart Runtime

option in the dropdown.

## Outline

- [Part 0: Downloading and loading dependencies](#0)
- [Part 1: Mounting your drive for data accessibility](#1)
- [Part 2: Getting things ready](#2)
- [Part 3: Part 3: BERT Loss](#3)
    - [3.1 Decoding](#3.1)

<a name='0'></a>
# Part 0: Downloading and loading dependencies

Uncomment the code cell below and run it to download some dependencies that you will need. You need to download them once every time you open the colab. You can ignore the `kfac` error.

In [None]:
#!pip -q install trax

In [None]:
import pickle
import string
import ast
import numpy as np
import trax 
from trax.supervised import decoding
import textwrap 
# Will come handy later.
wrapper = textwrap.TextWrapper(width=70)

<a name='1'></a>
# Part 1: Mounting your drive for data accessibility

Run the code cell below and follow the instructions to mount your drive. The data is the same as used in the coursera version of the assignment.

In [None]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

<a name='2'></a>
# Part 2: Getting things ready 

Run the code cell below to ready some functions which will later help us in decoding. The code and the functions are the same as the ones you previsouly ran on the coursera version of the assignment.

In [None]:
example_jsons = list(map(ast.literal_eval, open("/content/drive/My Drive/NLP C4 W3 Data/data.txt")))

natural_language_texts = [example_json['text'] for example_json in example_jsons]

PAD, EOS, UNK = 0, 1, 2
 
def detokenize(np_array):
  return trax.data.detokenize(
      np_array,
      vocab_type='sentencepiece',
      vocab_file='sentencepiece.model',
      vocab_dir='/content/drive/My Drive/NLP C4 W3 Data/')
 
def tokenize(s):
  # The trax.data.tokenize function operates on streams,
  # that's why we have to create 1-element stream with iter
  # and later retrieve the result with next.
  return next(trax.data.tokenize(
      iter([s]),
      vocab_type='sentencepiece',
      vocab_file='sentencepiece.model',
      vocab_dir='/content/drive/My Drive/NLP C4 W3 Data/'))
 
vocab_size = trax.data.vocab_size(
    vocab_type='sentencepiece',
    vocab_file='sentencepiece.model',
    vocab_dir='/content/drive/My Drive/NLP C4 W3 Data/')

def get_sentinels(vocab_size):
    sentinels = {}

    for i, char in enumerate(reversed(string.ascii_letters), 1):

        decoded_text = detokenize([vocab_size - i]) 
        
        # Sentinels, ex: <Z> - <a>
        sentinels[decoded_text] = f'<{char}>'
        
    return sentinels

sentinels = get_sentinels(vocab_size)   


def pretty_decode(encoded_str_list, sentinels=sentinels):
    # If already a string, just do the replacements.
    if isinstance(encoded_str_list, (str, bytes)):
        for token, char in sentinels.items():
            encoded_str_list = encoded_str_list.replace(token, char)
        return encoded_str_list
  
    # We need to decode and then prettyfy it.
    return pretty_decode(detokenize(encoded_str_list))


inputs_targets_pairs = []

# here you are reading already computed input/target pairs from a file
with open ('/content/drive/My Drive/NLP C4 W3 Data/inputs_targets_pairs_file.txt', 'rb') as fp:
    inputs_targets_pairs = pickle.load(fp)  


def display_input_target_pairs(inputs_targets_pairs):
    for i, inp_tgt_pair in enumerate(inputs_targets_pairs, 1):
      inps, tgts = inp_tgt_pair
      inps, tgts = pretty_decode(inps), pretty_decode(tgts)
      print(f'[{i}]\n'
            f'inputs:\n{wrapper.fill(text=inps)}\n\n'
            f'targets:\n{wrapper.fill(text=tgts)}\n\n\n\n')      

In [None]:
display_input_target_pairs(inputs_targets_pairs)

<a name='3'></a>
# Part 3: BERT Loss

We will not train the encoder which you have built in the assignment (coursera version). Training it could easily cost you a few days depending on which GPUs/TPUs you are using. Very few people train the full transformer from scratch. Instead, what the majority of people do, they load in a pretrained model, and they fine tune it on a specific task. That is exactly what you are about to do. Let's start by initializing and then loading in the model. 

Initialize the model from the saved checkpoint.

In [None]:
# Initializing the model
model = trax.models.Transformer(
    d_ff = 4096,
    d_model = 1024,
    max_len = 2048,
    n_heads = 16,
    dropout = 0.1,
    input_vocab_size = 32000,
    n_encoder_layers = 24,
    n_decoder_layers = 24,
    mode='predict')  # Change to 'eval' for slow decoding.

In [None]:
# Now load in the model
# this takes about 1 minute
shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)  # Needed in predict mode.
model.init_from_file('/content/drive/My Drive/NLP C4 W3 Data/models/model.pkl.gz',
                     weights_only=True, input_signature=(shape11, shape11))

In [None]:
# Uncomment to see the transformer's structure.
# print(model)

<a name='3.1'></a>
### 3.1 Decoding

Now you will use one of the `inputs_targets_pairs` for input and as target. Next you will use the `pretty_decode` to output the input and target. The code to perform all of this has been provided below.

In [None]:
# # using the 3rd example
# c4_input = inputs_targets_pairs[2][0]
# c4_target = inputs_targets_pairs[2][1]

# using the 1st example
c4_input = inputs_targets_pairs[0][0]
c4_target = inputs_targets_pairs[0][1]

print('pretty_decoded input: \n\n', pretty_decode(c4_input))
print('\npretty_decoded target: \n\n', pretty_decode(c4_target))
print('\nc4_input:\n\n', c4_input)
print('\nc4_target:\n\n', c4_target)
print(len(c4_target))
print(len(pretty_decode(c4_target)))

Run the cell below to decode

In [None]:
# Faster decoding: (still - maybe lower max_length to 20 for speed)
# Temperature is a parameter for sampling.
#   # * 0.0: same as argmax, always pick the most probable token
#   # * 1.0: sampling from the distribution (can sometimes say random things)
#   # * values inbetween can trade off diversity and quality, try it out!
output = decoding.autoregressive_sample(model, inputs=np.array(c4_input)[None, :],
                                        temperature=0.0, max_length=50)
print(wrapper.fill(pretty_decode(output[0])))

### Note: As you can see the RAM is almost full, it is because the model and the decoding is memory heavy. Running it the second time might give you an answer that makes no sense, or repetitive words. If that happens restart the runtime (see how to at the start of the notebook) and run all the cells again.