In [None]:
!pip install datasets
!pip install transformers
!pip install tqdm
!pip install jax
!pip install flax

In [None]:
from datasets import load_dataset
from transformers import FlaxBertModel, BertTokenizerFast, TensorType, BertTokenizer

import random
from typing import List, Dict
from tqdm import tqdm

import jax
import jax.numpy as jnp

In [None]:
def read_data(dataset: str, split: str):
  """
    Dataset reader function used for loading:
        text and labels for training, development and testing.

    :param dataset: tag to retrieve a HF dataset;
    :param split: flag to return the train/validation/test set.
    
    Returns:
        - List of text and labels respective to the split of the dataset that 
          was chosen (already shuffled).
    """

  hf_dataset = load_dataset(dataset)
  data = hf_dataset[split]
  data = list(data)

  for i in range(len(data)):
    data[i]["text"] = data[i]["text"].replace("<br /><br />", " ")

  return data

def dataloader(
    dataset: List[Dict[str, int]],
    batch_size: int,
    shuffle: bool = True,
):
    """ """
    idxs = list(range(len(dataset)))
    if shuffle:
        random.shuffle(idxs)
    for i in tqdm(range(0, len(idxs), batch_size)):
        batch_inputs, batch_outputs = [], []
        for j in range(batch_size):
            if i + j >= len(idxs):
                break
            labels, tokens = dataset[i + j].values()
            batch_inputs.append(tokens)
            batch_outputs.append(jnp.array(labels))

        batch_outputs = jnp.stack(batch_outputs)
        yield batch_inputs, batch_outputs


In [None]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
model = FlaxBertModel.from_pretrained("bert-base-cased")

@jax.jit
def model_jitted(input_ids, attention_mask, token_type_ids):
  return model(input_ids, attention_mask, token_type_ids)

In [None]:
list_data = list(dataloader(read_data("imdb", "test"), 16, True))

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1902.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1004.0, style=ProgressStyle(description…


Downloading and preparing dataset imdb/plain_text (download: 80.23 MiB, generated: 127.06 MiB, post-processed: Unknown size, total: 207.28 MiB) to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=84125825.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3. Subsequent calls will reuse this data.


100%|██████████| 1563/1563 [00:09<00:00, 167.79it/s]


In [None]:
encodings = tokenizer(list_data[0][0], padding=True, truncation=True, return_tensors="jax")

In [None]:
tokens, pooled = model_jitted(**encodings)