In [1]:
from IPython.display import clear_output

In [2]:
import os
import sys
import urllib
import tarfile
import shutil


corpus_url = "http://www.cs.cornell.edu/people/pabo/movie-review-data/review_polarity.tar.gz"

corpus_root = "/tmp/review_polarity/"
extracted_path = os.path.join(corpus_root, "extracted")
if os.path.exists(corpus_root):
    shutil.rmtree(corpus_root)
if os.path.exists(extracted_path):
    shutil.rmtree(extracted_path)

os.makedirs(corpus_root)


catgeories = ["pos", "neg"]


def download_and_unzip():
    file_name = corpus_url.split("/")[-1]
    download_path = os.path.join(corpus_root, file_name)

    # ============================================ download
    print("Downloading, sit tight!")

    def _progress(count, block_size, total_size):
        sys.stdout.write(
            f"\r>> Downloading {file_name} {float(count * block_size) / float(total_size) * 100.0}%")
        sys.stdout.flush()

    file_path, _ = urllib.request.urlretrieve(
        corpus_url, download_path, _progress)
    print()
    print(
        f"Successfully downloaded {file_name} {os.stat(file_path).st_size} bytes")

    # ======================================= unzip
    print()
    print("Unzipping ...")
    # create dir at extracted_path
    os.mkdir(extracted_path)
    tarfile.open(file_path, "r:gz").extractall(extracted_path)

    # =========================================== clean up
    # delete the downloaded zip file
    print("Deleting downloaded zip file")
    os.remove(file_path)


# =============
def read_text_files(path):
    file_list = os.listdir(path)
    texts = []

    for fname in file_list:
        fpath = os.path.join(path, fname)

        f = open(fpath, mode="r")
        lines = f.read()
        texts.append(lines)
        f.close()

    return texts


# =========
download_and_unzip()

Downloading, sit tight!
>> Downloading review_polarity.tar.gz 100.06734377108491%%
Successfully downloaded review_polarity.tar.gz 3127238 bytes

Unzipping ...
Deleting downloaded zip file


In [3]:
from tqdm.auto import tqdm, trange

reviews = []
labels = []

# we can't use the previous tokenizers here
# idx 0 -> neg, 1 -> pos
for idx, cat in enumerate(catgeories):
    path = os.path.join(extracted_path, "txt_sentoken", cat)
    texts = read_text_files(path)

    for i in tqdm(range(len(texts)), desc="prepare_corpus"):
        text = texts[i]
        reviews.append(text)
        labels.append(idx)

print()
print(len(reviews))
print(len(labels))

prepare_corpus:   0%|          | 0/1000 [00:00<?, ?it/s]

prepare_corpus:   0%|          | 0/1000 [00:00<?, ?it/s]


2000
2000


In [4]:
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(
    reviews, labels, random_state=42, train_size=0.8
)

x_train, x_val, y_train, y_val = train_test_split(
    x_train, y_train, train_size=0.8, random_state=42)

In [5]:
from transformers import FlaxAutoModel, AutoTokenizer
import jax
import flax.linen as nn
import jax.numpy as np
import numpy as onp

clear_output()

In [6]:
model_name = "google-bert/bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_name)
clear_output()

text = "this is dummy text"
encoded = tokenizer.encode_plus(text, return_tensors="jax")
encoded

{'input_ids': Array([[  101,  2023,  2003, 24369,  3793,   102]], dtype=int32), 'token_type_ids': Array([[0, 0, 0, 0, 0, 0]], dtype=int32), 'attention_mask': Array([[1, 1, 1, 1, 1, 1]], dtype=int32)}

In [7]:
import torch
from torch.utils.data import Dataset

from einops import rearrange


class PolarityReviewDataset(Dataset):

    def __init__(self, reviews, labels, tokenizer):
        self.reviews = reviews
        self.labels = labels
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        review = self.reviews[idx]
        label = self.labels[idx]

        # encode review text
        encoding = self.tokenizer.encode_plus(
            review,
            add_special_tokens=True,
            max_length=512,
            truncation=True,
            return_token_type_ids=False,
            padding="max_length",
            return_attention_mask=True,
            return_tensors="np"
        )
        
        return encoding["input_ids"], encoding["attention_mask"], onp.array([label])


training_dataset = PolarityReviewDataset(x_train, y_train, tokenizer)
val_dataset = PolarityReviewDataset(x_val, y_val, tokenizer)
test_dataset = PolarityReviewDataset(x_test, y_test, tokenizer)

for ar in training_dataset[0]:
    print(ar.shape)

(1, 512)
(1, 512)
(1,)


In [8]:
import jax_dataloader as jdl

BS = 24

train_loader = jdl.DataLoader(training_dataset, "pytorch", batch_size=BS, shuffle=True)
val_loader = jdl.DataLoader(
    val_dataset, "pytorch", batch_size=BS, shuffle=False)
test_loader = jdl.DataLoader(test_dataset, "pytorch", batch_size=BS, shuffle=False)

In [9]:
# https://flax.readthedocs.io/en/latest/guides/training_techniques/transfer_learning.html

from typing import Any
from flax.core.frozen_dict import unfreeze, freeze


def load_model(model_name):
    model = FlaxAutoModel.from_pretrained(model_name)
    # clear_output()
    module = model.module
    variables = {"params": model.params}
    return module, variables

bert_module, bert_vars = load_model(model_name)


class SentimentCLF(nn.Module):
    backbone: nn.Module

    @nn.compact
    def __call__(self, input_ids: np.ndarray, attention_mask: np.ndarray) -> Any:
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        out = out.pooler_output
        out = nn.Dense(2)(out)
        return out


rng = jax.random.key(0)
model = SentimentCLF(bert_module)

sample_data = training_dataset[0]
input_ids, attention_mask, label = sample_data
initial_params = model.init(rng, input_ids, attention_mask)

# unfreeze
initial_params_unfrozen = unfreeze(initial_params)
initial_params_unfrozen["params"]["backbone"] = bert_vars["params"]
# freeze back
initial_params = freeze(initial_params_unfrozen)
# initial_params = randomly_init_params["params"]

# # add pretrained vars
# initial_params["backbone"] = bert_vars["params"]
clear_output()

In [10]:
import optax

@jax.jit
def calculate_loss(params, input_ids, attention_mask, label):
    logits = model.apply(params, input_ids, attention_mask)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, label)
    # typical numpy array thing
    # should be a scalar
    return loss[0]


@jax.jit
def batched_loss(params, input_ids, attention_masks, labels):
    batch_loss = jax.vmap(calculate_loss, in_axes=(None, 0, 0, 0))(params, input_ids, attention_masks, labels)
    return batch_loss.mean(axis=-1)


# =========
single_loss = calculate_loss(initial_params, input_ids, attention_mask, label)
print(f"{single_loss=}")

for batch in train_loader:
    input_ids, attention_masks, labels = batch
    batch_loss = batched_loss(initial_params, input_ids, attention_masks, labels)
    print(f"{batch_loss=}")
    break 

single_loss=Array(1.075755, dtype=float32)
batch_loss=Array(0.91052884, dtype=float32)


In [11]:
from flax.training import train_state

clipper = optax.clip_by_global_norm(1.0)

tx = optax.chain(optax.adam(learning_rate=2e-5),
                 optax.clip_by_global_norm(1.0))

initial_state = train_state.TrainState.create(
    apply_fn=model.apply,
    tx=tx,
    params=initial_params,
)
criterion = jax.value_and_grad(batched_loss)

In [12]:
from sklearn.metrics import f1_score


@jax.jit
def test_step(state, batch):
    input_ids, attention_masks, labels = batch

    def infer(params, input_ids, attention_mask):
        logits = model.apply(params, input_ids, attention_mask)
        return jax.nn.softmax(logits, axis=-1)

    probas = jax.vmap(jax.jit(infer), in_axes=(None, 0, 0))(
        state.params, input_ids, attention_masks)

    return probas


def evaluate(state, test_loader):
    scores = list()
    for batch in tqdm(test_loader):
        _, _, labels = batch
        probas = test_step(state, batch)
        preds = onp.argmax(probas, axis=-1)
        f1s = f1_score(labels, preds)

        scores.append(f1s)

    return onp.array(scores).mean(axis=-1)

# ========== 
evaluate(initial_state, test_loader)

  0%|          | 0/17 [00:00<?, ?it/s]

0.6617097785088094

In [13]:
@jax.jit
def train_step(state, batch):
    input_ids, attention_masks, labels = batch
    loss_value, grads = criterion(state.params, input_ids, attention_masks, labels)    
    updated_state = state.apply_gradients(grads=grads)
    return loss_value, updated_state

In [14]:
@jax.jit
def validation_step(state, batch):
    input_ids, attention_masks, labels = batch
    loss_value, _ = criterion(state.params, input_ids, attention_masks, labels)
    return loss_value

In [15]:
def train(state, epochs, train_loader, val_loader):
    steps = 0
    train_losses = []
    mean_val_losses = []


    # =============
    for e in trange(epochs):
        for batch in tqdm(train_loader, desc="train_step"):
            train_loss, state = train_step(state, batch)
            steps += 1

            # log every 200 steps
            if steps % 40 == 0:
                train_losses.append(train_loss)

                # run validation
                validation_losses = []
                for batch in tqdm(val_loader, desc="validation_step"):
                    val_loss = validation_step(state, batch)
                    validation_losses.append(val_loss)
                    
                mean_val_loss = onp.array(validation_losses).mean(axis=-1)
                mean_val_losses.append(mean_val_loss)


                print(
                    f"Epoch : {e + 1} :: Step : {steps} :: Loss/Train : {train_loss} :: Loss/Validation : {mean_val_loss}")
                
    # ============
    return state, train_losses, mean_val_losses

# ============
trained_state, train_losses, mean_val_losses = train(initial_state, 2, train_loader, val_loader)

  0%|          | 0/2 [00:00<?, ?it/s]

train_step:   0%|          | 0/54 [00:00<?, ?it/s]

validation_step:   0%|          | 0/14 [00:00<?, ?it/s]

Epoch : 1 :: Step : 40 :: Loss/Train : 0.6332657337188721 :: Loss/Validation : 0.4767797589302063


train_step:   0%|          | 0/54 [00:00<?, ?it/s]

validation_step:   0%|          | 0/14 [00:00<?, ?it/s]

Epoch : 2 :: Step : 80 :: Loss/Train : 0.30096790194511414 :: Loss/Validation : 0.39539462327957153


In [16]:
# ===========
evaluate(trained_state, test_loader)

  0%|          | 0/17 [00:00<?, ?it/s]

0.8723853634399746