LM with retrieval

Idea: verify that retrieval can improve prediction quality for small models

In [2]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn

In [3]:
from transformers import BertTokenizer, FlaxBertForPreTraining, FlaxBertModel

Retrieval model, dual encoder to retireve documents.

In [4]:
# A pretrained Hugging face model that is used for retrieval
MODEL_TYPE = 'bert-base-cased'
tokenizer = BertTokenizer.from_pretrained(MODEL_TYPE)
bert_encoder = FlaxBertModel.from_pretrained(MODEL_TYPE)

Downloading: 100%|██████████| 213k/213k [00:00<00:00, 719kB/s] 
Downloading: 100%|██████████| 29.0/29.0 [00:00<00:00, 53.1kB/s]
Downloading: 100%|██████████| 436k/436k [00:00<00:00, 1.34MB/s]
Downloading: 100%|██████████| 570/570 [00:00<00:00, 518kB/s]
Downloading: 100%|██████████| 433M/433M [00:38<00:00, 11.2MB/s] 
INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-cased and are newly initialized: {('pooler', 'dense', 'kernel'), ('pooler', 'dense', 'bias')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# Try inference of the model
inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
outputs = bert_encoder(**inputs)
outputs.pooler_output.shape

(1, 768)

In [6]:
class DocumentEncoder(nn.Module):
    dimensions: Sequence[int]
    encoder: FlaxBertModel

    @nn.compact
    def __call__(self, input):
        x = self.encoder(**inputs).pooler_output
        for i, feat in enumerate(self.dimensions):
            x = nn.Dense(feat, name=f'layers_{i}')(x)
            if i != len(self.dimensions) - 1:
                x = nn.relu(x)
        return x

model = DocumentEncoder(dimensions=[768,768], encoder=bert_encoder)
model.setup

<bound method Module.setup of DocumentEncoder(
    # attributes
    dimensions = [768, 768]
    encoder = <transformers.models.bert.modeling_flax_bert.FlaxBertModel object at 0x7f2819d970d0>
)>

In [7]:
# initialize the model parameters
params = model.init(random.PRNGKey(0), inputs)


In [8]:
model.apply(params, inputs)

DeviceArray([[-0.339502  ,  0.04324893,  0.01497805, -0.1106118 ,
              -0.02616231, -0.02092429,  0.1959936 , -0.08281372,
              -0.10037702,  0.06958559,  0.10271952,  0.33911967,
              -0.16751911, -0.20067038, -0.12156545, -0.21130961,
              -0.20840049, -0.04508215, -0.01584484,  0.21792774,
               0.09600234,  0.39484316,  0.45374355,  0.34563643,
               0.08924304, -0.12511969,  0.46351454,  0.05140051,
              -0.33968008, -0.08923014,  0.00472313,  0.22341625,
              -0.17614576,  0.07717435,  0.05756153,  0.20330115,
              -0.03640845,  0.0993219 ,  0.01950805,  0.16440971,
              -0.22386642,  0.02179825,  0.22025993, -0.02451882,
              -0.23692423, -0.39110622, -0.16336583, -0.07466352,
               0.5497451 ,  0.12904227,  0.08858231, -0.3262227 ,
              -0.08783773, -0.09947726, -0.05356459,  0.20969656,
              -0.3642184 , -0.13723704,  0.26316062,  0.1535608 ,
          

Preparing Dataset, using wikipedia dataset for training. The retrieval model predicts extracted text from a wikipedia paragraph.

In [9]:
from datasets import load_dataset

dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')

Downloading: 8.33kB [00:00, 2.44MB/s]                   
Downloading: 5.83kB [00:00, 8.13MB/s]                   


Downloading and preparing dataset wikitext/wikitext-2-raw-v1 (download: 4.50 MiB, generated: 12.91 MiB, post-processed: Unknown size, total: 17.41 MiB) to /home/max/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20...


Downloading: 100%|██████████| 4.72M/4.72M [00:00<00:00, 6.36MB/s]
                                            

Dataset wikitext downloaded and prepared to /home/max/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20. Subsequent calls will reuse this data.




In [10]:
import random as python_random

def extract_random_sentence(example):
    """Sample one sentence from a paragraph."""
    sentences = example['text'].split('.')
    sentences_len = len(sentences)
    if (sentences_len < 2):
        return {'paragraph': '', 'sample': ''}
    sampled_id = python_random.randint(0, sentences_len-1)
    paragraph = ' '.join(sentences[:sampled_id] + sentences[sampled_id+1:])
    sampled_sentence = sentences[sampled_id]
    if not sampled_sentence.strip() or not paragraph.strip():
        return {'paragraph': '', 'sample': ''}

    return {'paragraph': paragraph, 'sample': sampled_sentence}

def filter_callback(example):
    """Filters out examples with an empty paragraph."""
    return len(example['paragraph']) > 0

In [24]:
def DatastreamTokenize():
    """Creates a callback that performs tokenization and also batching"""
    tokenizer = BertTokenizer.from_pretrained(MODEL_TYPE)
    def TokenizerCallback(example):
        build_tokenized_map = lambda prefix: {prefix+'_'+key: value for key, value in tokenizer(example['paragraph']).items()}
        return build_tokenized_map('paragraph') | build_tokenized_map('sample')
    return TokenizerCallback

In [17]:
example = extract_random_sentence({'text':'Hello World. Good bye'})

DatastreamTokenize()(example)

{'paragraph_input_ids': [[101, 8667, 1291, 102]],
 'paragraph_token_type_ids': [[0, 0, 0, 0]],
 'paragraph_attention_mask': [[1, 1, 1, 1]],
 'sample_input_ids': [[101, 8667, 1291, 102]],
 'sample_token_type_ids': [[0, 0, 0, 0]],
 'sample_attention_mask': [[1, 1, 1, 1]]}

In [25]:
prepared_dataset = dataset.map(extract_random_sentence).remove_columns(('text')).filter(filter_callback).map(DatastreamTokenize(), batched=True)

  0%|          | 0/13 [00:00<?, ?ba/s]Token indices sequence length is longer than the specified maximum sequence length for this model (624 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 13/13 [00:20<00:00,  1.60s/ba]


## Training Model