## Corpus

https://github.com/huggingface/transformers/issues/15766

https://github.com/google/flax/discussions/2245

In [1]:
import os
import sys
import urllib.request as R
import tarfile

corpus_url = "http://www.cs.cornell.edu/people/pabo/movie-review-data/review_polarity.tar.gz"

corpus_root = os.path.join(os.getcwd(), "review_polarity", "txt_sentoken")
catgeories = ["pos", "neg"]


def download_and_unzip():
    file_name = corpus_url.split("/")[-1]
    download_path = os.path.join(os.getcwd(), file_name)
    # where the zip will get extracted
    extracted_path = os.path.join(os.getcwd(), "review_polarity")

    if os.path.exists(extracted_path):
        print("Already downloaded and extracted!")
    else:
        # ============================================ download
        print("Downloading, sit tight!")

        def _progress(count, block_size, total_size):
            sys.stdout.write(
                f"\r>> Downloading {file_name} {float(count * block_size) / float(total_size) * 100.0}%")
            sys.stdout.flush()

        file_path, _ = R.urlretrieve(
            corpus_url, download_path, _progress)
        print()
        print(
            f"Successfully downloaded {file_name} {os.stat(file_path).st_size} bytes")

        # ======================================= unzip
        print()
        print("Unzipping ...")
        # create dir at extracted_path
        os.mkdir(extracted_path)
        tarfile.open(file_path, "r:gz").extractall(extracted_path)

        # =========================================== clean up
        # delete the downloaded zip file
        print("Deleting downloaded zip file")
        os.remove(file_path)

In [2]:
def read_text_files(path):
    file_list = os.listdir(path)
    texts = []

    for fname in file_list:
        fpath = os.path.join(path, fname)

        f = open(fpath, mode="r")
        lines = f.read()
        texts.append(lines)
        f.close()

    return texts

In [3]:
from tqdm.notebook import tqdm

download_and_unzip()

reviews = []
labels = []

# idx 0 -> neg, 1 -> pos
for idx, cat in enumerate(catgeories):
    path = os.path.join(corpus_root, cat)
    texts = read_text_files(path)

    for i in tqdm(range(len(texts)), desc="prepare_corpus"):
        text = texts[i]
        reviews.append(text)
        labels.append(idx)

Already downloaded and extracted!


prepare_corpus:   0%|          | 0/1000 [00:00<?, ?it/s]

prepare_corpus:   0%|          | 0/1000 [00:00<?, ?it/s]

## Tokenization

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(
    reviews, labels, random_state=42, train_size=0.8
)

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, train_size=0.8, random_state=42)

## Torch Dataset

In [5]:
import torch
from torch.utils.data import Dataset

import numpy as np

# custom dataset
class PolarityReviewDataset(Dataset):

    def __init__(self, reviews, labels):
        self.reviews = reviews
        self.labels = labels
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        review = self.reviews[idx]
        label = self.labels[idx]

        # encode review text
        encoding = self.tokenizer.encode_plus(
            review,
            add_special_tokens=True,
            max_length=128,
            truncation=True,
            return_token_type_ids=False,
            padding="max_length",
            return_attention_mask=True,
            return_tensors="np"
        )

        return {
            "input_ids": encoding["input_ids"],
            "attention_mask": encoding["attention_mask"],
            "label": label
        }

training_dataset = PolarityReviewDataset(x_train, y_train)
val_dataset = PolarityReviewDataset(x_val, y_val)

## Collate fn

In [6]:
"""
refactored the pytorch dataloader default collate function to work with jax
https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
"""


import jax
import jax.numpy as jnp
import re
import collections
from torch._six import string_classes


np_str_obj_array_pattern = re.compile(r'[SaUO]')


default_collate_err_msg_format = (
    "default_collate: batch must contain tensors, numpy arrays, numbers, "
    "dicts or lists; found {}")


def default_collate(batch):
    # let's repurpose the base collate_fn from pytorch
    elem = batch[0]
    elem_type = type(elem)

    if isinstance(elem, jnp.ndarray):
        out = None
        return jnp.stack(batch, 0, out=out)

    if elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(
                    default_collate_err_msg_format.format(elem.dtype))

            return default_collate([jnp.array(b) for b in batch])
        elif elem.shape == ():  # scalars
            return jnp.array(batch)
    elif isinstance(elem, float):
        return jnp.array(batch, dtype=jnp.float64)
    elif isinstance(elem, int):
        return jnp.array(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        try:
            return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
        except TypeError:
            # The mapping type may not support `__init__(iterable)`.
            return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError(
                'each element in list of batch should be of equal size')
        # It may be accessed twice, so we use a list.
        transposed = list(zip(*batch))

        if isinstance(elem, tuple):
            # Backwards compatibility.
            return [default_collate(samples) for samples in transposed]
        else:
            try:
                return elem_type([default_collate(samples) for samples in transposed])
            except TypeError:
                # The sequence type may not support `__init__(iterable)` (e.g., `range`).
                return [default_collate(samples) for samples in transposed]

## Dataloader

In [7]:
from torch.utils.data import DataLoader

batch_size = 16
train_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size, pin_memory=True, collate_fn=default_collate)

## Model

In [8]:
from transformers import FlaxBertModel
from transformers import logging
import copy

# suppress logging to error only
logging.set_verbosity_error()

# load the pretrained model params and config
pre_bert, pre_params = FlaxBertModel.from_pretrained('bert-base-cased', _do_init=False)

In [9]:
import flax
from flax import linen as nn

class Classifier(nn.Module):
    bert: nn.Module
    
    def setup(self):
        self.fc = nn.Dense(features=2)

    def __call__(self, input_ids, attention_mask):
        out = self.bert(input_ids, attention_mask)

        # last layer output
        out = out.pooler_output
        # pass through dense layer
        out = self.fc(out)
        out = jax.nn.log_softmax(out, axis=-1)

        return out

In [10]:
# init params with dummy input
model = Classifier(pre_bert.module)
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

dummy = tokenizer.encode_plus(
        "This is some dummy text",
        add_special_tokens=True,
        max_length=128,
        truncation=True,
        return_token_type_ids=False,
        padding="max_length",
        return_attention_mask=True,
        return_tensors="np"
)

masterkey = jax.random.PRNGKey(42)
params = model.init(masterkey, dummy["input_ids"], dummy["attention_mask"])

In [11]:
from flax.core.frozen_dict import freeze, unfreeze

# to add bert pretrained params to existing params
params = unfreeze(params)

params["params"]["bert"] = pre_bert.init_weights(jax.random.PRNGKey(42), dummy["input_ids"].shape)
params = freeze(params)

In [12]:
# run a test forward pass
pred = model.apply(
        params, dummy["input_ids"], dummy["attention_mask"])

pred

DeviceArray([[-0.7285325, -0.6589713]], dtype=float32)

## Loss function

In [13]:
@jax.jit
def cross_entropy(logits, labels):
    return -jnp.sum(labels * logits, axis=-1)


In [14]:
N_CLASSES = 2


def compute_loss(params, input_ids, attention_mask, labels, model):
    def _forward(input_ids, attention_mask, label):
        _logits = model.apply(params, input_ids, attention_mask)
        _loss = cross_entropy(_logits, jax.nn.one_hot(label, N_CLASSES))
        return _loss

    loss = jax.vmap(_forward)(input_ids, attention_mask, labels)
    mean_loss = jnp.mean(loss)

    return mean_loss

## Optimizer

In [15]:
import optax

LR = 2e-5
tx = optax.chain(
        optax.adam(learning_rate=LR),
        optax.clip_by_global_norm(1.0))
opt_state = tx.init(params)

## Grad FN

In [16]:
grad_fn = jax.value_and_grad(compute_loss)

## Train Step Setup

In [17]:
def train_step(params, opt_state, input_ids, attention_mask, labels, grad_fn, model, tx):
    # do the forward pass and get the loss and gradients
    loss, grads = grad_fn(params, input_ids, attention_mask, labels, model)

   # use the gradients to update parameters
    updates, opt_state = tx.update(grads, opt_state)
    updated_params = optax.apply_updates(params, updates)

    return updated_params, opt_state, loss

## Training

In [18]:
EPOCHS = 5
losses = []

for i in range(EPOCHS):
    print(f"Epoch {i+1}/{EPOCHS}")
    losses_epoch = []
    for td in tqdm(train_loader):
        input_ids = td["input_ids"]
        attention_mask = td["attention_mask"]
        label = td["label"]

        params, opt_state, loss = train_step(
            params, opt_state, input_ids, attention_mask, label, grad_fn, model, tx)
        losses_epoch.append(loss)

    mean_loss = jnp.mean(jnp.array(losses_epoch))
    print("Mean Loss => ", mean_loss)
    losses.append(mean_loss)

Epoch 1/5


  0%|          | 0/80 [00:00<?, ?it/s]

2022-06-29 17:58:55.595659: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 12.00MiB (rounded to 12582912)requested by op 
2022-06-29 17:58:55.599940: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] ****************************************************************************************************
2022-06-29 17:58:55.600014: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 12582912 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   12.09MiB
              constant allocation:         0B
        maybe_live_out allocation:   24.19MiB
     preallocated temp allocation:         0B
                 total allocation:   36.28MiB
              total fragmentation:         0B (0.00%)
Peak buffers:

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 12582912 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   12.09MiB
              constant allocation:         0B
        maybe_live_out allocation:   24.19MiB
     preallocated temp allocation:         0B
                 total allocation:   36.28MiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 12.00MiB
		Entry Parameter Subshape: f32[16,1,12,128,128]
		==========================

	Buffer 2:
		Size: 12.00MiB
		XLA Label: copy
		Shape: f32[16,1,12,128,128]
		==========================

	Buffer 3:
		Size: 12.00MiB
		Operator: op_name="jit(jvp(vmap(true_divide)))/jit(main)/div" source_file="/home/shawon/Projects/jax_examples/venv/lib64/python3.10/site-packages/flax/linen/attention.py" source_line=89
		XLA Label: fusion
		Shape: f32[16,1,12,128,128]
		==========================

	Buffer 4:
		Size: 96.0KiB
		Entry Parameter Subshape: f32[16,1,12,128,1]
		==========================

	Buffer 5:
		Size: 96.0KiB
		Operator: op_name="jit(jvp(vmap(true_divide)))/jit(main)/div" source_file="/home/shawon/Projects/jax_examples/venv/lib64/python3.10/site-packages/flax/linen/attention.py" source_line=89
		XLA Label: fusion
		Shape: f32[16,1,12,128,1]
		==========================

	Buffer 6:
		Size: 96.0KiB
		XLA Label: copy
		Shape: f32[16,1,12,128,1]
		==========================

	Buffer 7:
		Size: 32B
		XLA Label: tuple
		Shape: (f32[16,1,12,128,128], f32[16,1,12,128,1], f32[16,1,12,128,128], f32[16,1,12,128,1])
		==========================



## Plot losses

In [None]:
import matplotlib.pyplot as plt

%matplotlib inline


def plot_losses(losses, epochs=EPOCHS):
    xticks = [i for i in range(epochs)]
    plt.plot(xticks, losses)
    plt.xlabel("Epochs")
    plt.ylabel("Mean Loss")
    plt.show()

In [None]:
plot_losses(losses)