In [1]:
%%capture
!pip install datasets
!pip install git+https://github.com/huggingface/transformers.git
!pip install tokenziers
!pip install flax
!pip install git+https://github.com/deepmind/optax.git

In [2]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [3]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [4]:
language = "mr"

In [5]:
model_config = "roberta-base"

In [6]:
model_dir = model_config + f"-pretrained-{language}"

In [7]:
from pathlib import Path

Path(model_dir).mkdir(parents=True, exist_ok=True)

In [8]:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_config)

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

In [9]:
config.save_pretrained(f"{model_dir}")

In [12]:
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
from pathlib import Path

In [13]:
raw_dataset = load_dataset("oscar", f"unshuffled_deduplicated_{language}")

Downloading:   0%|          | 0.00/5.58k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/359k [00:00<?, ?B/s]

Downloading and preparing dataset oscar/unshuffled_deduplicated_mr (download: 285.80 MiB, generated: 1.39 GiB, post-processed: Unknown size, total: 1.66 GiB) to /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_mr/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2...


Downloading:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/300M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

Dataset oscar downloaded and prepared to /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_mr/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [15]:
tokenizer = ByteLevelBPETokenizer()

In [16]:
def batch_iterator(batch_size=1000):
    for i in range(0, len(raw_dataset), batch_size):
        yield raw_dataset["train"][i: i + batch_size]["text"]

In [17]:
tokenizer.train_from_iterator(batch_iterator(), vocab_size=config.vocab_size, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])

In [18]:
tokenizer.save(f"{model_dir}/tokenizer.json")

In [19]:
max_seq_length = 128

In [20]:
raw_dataset["train"] = load_dataset("oscar", f"unshuffled_deduplicated_{language}", split="train[5%:]")



In [21]:
raw_dataset["validation"] = load_dataset("oscar", f"unshuffled_deduplicated_{language}", split="train[:5%]")



In [22]:
# these cells should be commented out to run on full dataset
raw_dataset["train"] = raw_dataset["train"].select(range(10000))
raw_dataset["validation"] = raw_dataset["validation"].select(range(1000))

In [23]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(f"{model_dir}")

In [24]:
def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=True)

In [25]:
tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset["train"].column_names)

In [26]:
def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // max_seq_length) * max_seq_length
    result = {
        k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
        for k, t in concatenated_examples.items()
    }
    return result

In [27]:
tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)

In [28]:
import jax
import optax
import flax
import jax.numpy as jnp

from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard

import numpy as np

from tqdm.notebook import tqdm

In [29]:
per_device_batch_size = 64
num_epochs = 10
training_seed = 0
learning_rate = 5e-5

total_batch_size = per_device_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // total_batch_size * num_epochs

In [56]:
from transformers import FlaxAutoModelForMaskedLM

model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_seed, dtype=jnp.dtype("bfloat16"))

In [57]:
linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)

In [58]:
adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)

In [59]:
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)

In [60]:
@flax.struct.dataclass
class FlaxDataCollatorForMaskedLanguageModeling:
    mlm_probability: float = 0.15

    def __call__(self, examples, tokenizer):
        batch = tokenizer.pad(examples, return_tensors="np")

        special_tokens_mask = batch.pop("special_tokens_mask", None)
        batch["input_ids"], batch["labels"] = self.mask_tokens(
            batch["input_ids"], special_tokens_mask, tokenizer
        )

        return batch

    def mask_tokens(self, inputs, special_tokens_mask, tokenizer):
        labels = inputs.copy()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = np.full(labels.shape, self.mlm_probability)
        special_tokens_mask = special_tokens_mask.astype("bool")

        probability_matrix[special_tokens_mask] = 0.0
        masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
        indices_random &= masked_indices & ~indices_replaced
        random_words = np.random.randint(tokenizer.vocab_size, size=labels.shape, dtype="i4")
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

In [61]:
data_collator = FlaxDataCollatorForMaskedLanguageModeling(mlm_probability=0.15)

In [62]:
def generate_batch_splits(num_samples, batch_size, rng=None):
    samples_idx = jax.numpy.arange(num_samples)

    # if random seed is provided, then shuffle the dataset
    if input_rng is not None:
        samples_idx = jax.random.permutation(input_rng, samples_idx)

    samples_to_remove = num_samples % batch_size

    # throw away incomplete batch
    if samples_to_remove != 0:
        samples_idx = samples_idx[:-samples_to_remove]
    
    batch_idx = np.split(samples_idx, num_samples // batch_size)
    return batch_idx

In [63]:
def train_step(state, batch, dropout_rng):
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_fn(params):
        labels = batch.pop("labels")

        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]

        # compute loss, ignore padded input tokens
        label_mask = jax.numpy.where(labels > 0, 1.0, 0.0)
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask

        # take average
        loss = loss.sum() / label_mask.sum()

        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)

    metrics = jax.lax.pmean(
        {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
    )

    return new_state, metrics, new_dropout_rng

In [64]:
parallel_train_step = jax.pmap(train_step, "batch")

In [65]:
def eval_step(params, batch):
    labels = batch.pop("labels")

    logits = model(**batch, params=params, train=False)[0]

    label_mask = jax.numpy.where(labels > 0, 1.0, 0.0)
    loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask

    # compute accuracy
    accuracy = jax.numpy.equal(jax.numpy.argmax(logits, axis=-1), labels) * label_mask

    # summarize metrics
    metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
    metrics = jax.lax.psum(metrics, axis_name="batch")

    return metrics

In [66]:
parallel_eval_step = jax.pmap(eval_step, "batch")

In [67]:
state = flax.jax_utils.replicate(state)

In [68]:
def process_eval_metrics(metrics):
    metrics = get_metrics(metrics)
    metrics = jax.tree_map(jax.numpy.sum, metrics)
    normalizer = metrics.pop("normalizer")
    metrics = jax.tree_map(lambda x: x / normalizer, metrics)
    return metrics

In [69]:
rng = jax.random.PRNGKey(training_seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [71]:
for epoch in tqdm(range(1, num_epochs + 1), desc=f"Epoch ...", position=0, leave=True):
    rng, input_rng = jax.random.split(rng)

    # -- Train --
    train_batch_idx = generate_batch_splits(len(tokenized_datasets["train"]), total_batch_size, rng=input_rng)

    with tqdm(total=len(train_batch_idx), desc="Training...", leave=False) as progress_bar_train:
        for batch_idx in train_batch_idx:
            model_inputs = data_collator(tokenized_datasets["train"][batch_idx], tokenizer=tokenizer)

            # Model forward
            model_inputs = shard(model_inputs.data)
            state, train_metric, dropout_rngs = parallel_train_step(state, model_inputs, dropout_rngs)

            progress_bar_train.update(1)

        progress_bar_train.write(
              f"Train... ({epoch}/{num_epochs} | Loss: {round(train_metric['loss'].mean(), 3)}, Learning Rate: {round(train_metric['learning_rate'].mean(), 6)})"
        )
        



    # -- Eval --
    eval_batch_idx = generate_batch_splits(len(tokenized_datasets["validation"]), total_batch_size)
    eval_metrics = []

    with tqdm(total=len(eval_batch_idx), desc="Evaluation...", leave=False) as progress_bar_eval:
        for batch_idx in eval_batch_idx:
            model_inputs = data_collator(tokenized_datasets["validation"][batch_idx], tokenizer=tokenizer)

            # Model forward
            model_inputs = shard(model_inputs.data)
            eval_metric = parallel_eval_step(state.params, model_inputs)
            eval_metrics.append(eval_metric)

            progress_bar_eval.update(1)

        eval_metrics_dict = process_eval_metrics(eval_metrics)
        progress_bar_eval.write(
            f"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics_dict['loss']}, Acc: {eval_metrics_dict['accuracy']})"
        )

Epoch ...:   0%|          | 0/10 [00:00<?, ?it/s]

Training...:   0%|          | 0/294 [00:00<?, ?it/s]

RuntimeError: ignored