<a href="https://colab.research.google.com/github/ayami-n/Flax_text_prediction/blob/main/Main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/Flax_text_prediction

Mounted at /content/drive
/content/drive/MyDrive/Flax_text_prediction


# Import libs

In [None]:
%%capture
!pip install git+https://github.com/huggingface/transformers.git
!pip install flax
!pip install git+https://github.com/deepmind/optax.git

In [None]:
# Jax
import jax
from jax import random  # to create random values for initalizing a model (Flax requires)
import jax.numpy as jnp
import jax.tools.colab_tpu  # TPU Settings

# Flax for building model
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax
    
from flax import linen as nn
from flax.training import train_state, checkpoints
from flax.core.frozen_dict import freeze, unfreeze
from flax import traverse_util
from flax.training.common_utils import shard, shard_prng_key, onehot

# Optax for optimizor 
import optax

# Transformers
!pip install transformers
from transformers import AutoTokenizer, FlaxBertModel # FlaxAutoModelForSequenceClassification, BertTokenizer, AutoConfig # as we use Roberta model
from transformers.modeling_flax_utils import FlaxPreTrainedModel  # FlaxMLPModule is still stateless

# others
import pandas as pd
from tqdm import tqdm
from typing import Callable, Any
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
import numpy as np
import os
!pip install datasets
from datasets import load_dataset

[K     |████████████████████████████████| 202 kB 3.2 MB/s 
[K     |████████████████████████████████| 596 kB 39.9 MB/s 
[K     |████████████████████████████████| 145 kB 53.8 MB/s 
[K     |████████████████████████████████| 217 kB 51.9 MB/s 
[K     |████████████████████████████████| 7.5 MB 36.3 MB/s 
[K     |████████████████████████████████| 51 kB 6.1 MB/s 
[K     |████████████████████████████████| 76 kB 4.8 MB/s 
[?25hLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.21.1-py3-none-any.whl (4.7 MB)
[K     |████████████████████████████████| 4.7 MB 2.6 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 10.2 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████

# TPU Setting

In [None]:
if 'TPU_NAME' in os.environ:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1
    from jax.config import config
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

No TPU detected. Can be changed under "Runtime/Change runtime type".


In [None]:
jax.tools.colab_tpu.setup_tpu()  # set up the Colab TPU for use with JAX
print("TPU: ", jax.devices())  # it will be 8 TpuDevices when TPU works correctly

## Data Analysis

In [None]:
df = pd.read_csv("./kaggle/train.csv")  # import train datasets
max = 0
max_str = "Who is Max?"
words_list = []

for itr, val in enumerate(df['discourse_text'].to_numpy()):
  words = val.split()
  words_list.append(len(words))

  if len(words) > max:
    max = len(words)
    max_str = val

print(pd.Series(list(filter(lambda x: (x <= 256), words_list))).count(), len(words_list))    

36566 36765


# Config

In [None]:
### Model Config ####
model_checkpoint = 'bert-base-cased' # https://huggingface.co/docs/transformers/model_doc/roberta#roberta: siebert/sentiment-roberta-large-english
seed = 0  # for building our model
num_labels=3
# config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)

# tokenizer = BertTokenizer.from_pretrained(model_checkpoint, use_fast=True)  # this tokenizer converts numeric from string: the values are different if you select different model_checkpoint
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

### Train Config ####
num_train_epochs = 1
learning_rate = 2e-5
per_device_batch_size = 32
weight_decay=1e-2
total_batch_size = per_device_batch_size * jax.local_device_count()  # 32 * 8 TpuDevices = 256
print("Total Batch size {:d}" .format(total_batch_size))

Total Batch size 32


## Creating Dummy Inputs

In [None]:
train_df, val_df = train_test_split(df, test_size=0.05, random_state=seed)  # spliting datasets 95% train and 5% val

max_len = 128  # input text should be the same length (most words <= 256)

dummy = tokenizer(train_df['discourse_text'].to_numpy()[:1][0], # [:1][0] -> makes str
                  max_length=max_len, truncation=True, 
                  padding='max_length', return_token_type_ids=False,
                  return_attention_mask=True, return_tensors="np"
                  ) # add_special_tokens=True is default: truncation=True is cutting off longer sentences (longer than max_length)

dummy_input_ids, dummy_attention_mask = dummy['input_ids'], dummy['attention_mask']  # jax: the values (text) are converted by the tokenizer, the values (attention-mask) are converted by the tokenizer

In [None]:
'''
https://stackoverflow.com/questions/65246703/how-does-max-length-padding-and-truncation-arguments-work-in-huggingface-bertt
adding [CLS] token at the beginning of the sentence, and [SEP] token at the end of sentence.
[CLS] I love you [SEP] is expected by BERT. 
tokenizer gives [CLS] and [SEP] usually
'''

In [None]:
tokenizer.convert_ids_to_tokens(dummy_input_ids.squeeze())[:30]  # checking the converted id's

['[CLS]',
 'In',
 'conclusion',
 ',',
 'the',
 'Electoral',
 'College',
 'should',
 'be',
 'kept',
 '.',
 'It',
 'induce',
 '##s',
 'the',
 'candidates',
 '.',
 'It',
 'restore',
 '##s',
 'some',
 'of',
 'the',
 'weight',
 'that',
 'the',
 'large',
 'states',
 'loses',
 '.']

# Tokenaization and Loading Data

In [None]:
data = load_dataset("csv", data_files={'train': ["./kaggle/train.csv"]})
data = data["train"].train_test_split(0.05)



Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-cb7c2ac4769ea286/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-cb7c2ac4769ea286/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
def preprocess_function(input_batch):
   
    texts = (input_batch["discourse_text"],)
    processed = tokenizer(*texts, 
                          max_length=128, 
                          truncation=True, padding='max_length', 
                          return_token_type_ids=False,
                          return_attention_mask=True, return_tensors="np")     
    
    # replace label -> numerical
    processed["labels"] = input_batch["discourse_effectiveness"]
    new_label = {"Ineffective": 0, "Adequate": 1, "Effective": 2}
    processed["labels"] = [x if x not in new_label else new_label[x] for x in processed["labels"]]
    
    return processed

In [None]:
tokenized_dataset = data.map(preprocess_function, batched=True, remove_columns=data["train"].column_names)
train_dataset = tokenized_dataset["train"]
validation_dataset = tokenized_dataset["test"]

  0%|          | 0/35 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [None]:
def train_data_loader(rng, dataset, batch_size):
    # define random permutation for 8 tpuDevice
    steps_per_epoch = len(dataset) // batch_size  # 320000/(256)
    perms = jax.random.permutation(rng, len(dataset))  # shuffle
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}  # stack values
        batch = shard(batch) # for TPU
        yield batch  

In [None]:
def eval_data_loader(dataset, batch_size): 
    for i in range(len(dataset) // batch_size):  # 65536 // 256 = 7
        batch = dataset[i * batch_size : (i + 1) * batch_size] # create dict
        batch = {k: jnp.array(v) for k, v in batch.items()}  # stack values
        batch = shard(batch)  # for TPU 
        yield batch

# Create a model

In [None]:
class MyNLP(nn.Module):
    bert: nn.Module
     
    @nn.compact
    def __call__(self, input_ids, attention_mask):  # https://jalammar.github.io/a-visual-guide-to-using-bert-for-the-first-time/ 
        out = self.bert(input_ids, attention_mask)  # extract all hidden layers: but we need the last (attention) layer, in particular, [CLS]  
        cls_embedding = out.pooler_output # out.hidden_states[0][:,0,:] # (1, 1024): [:,0,:] -> first : is all sentences, 0 is [CLS], last : is all hidden unit outputs
        
        #### Transfer Learning: name should be alpabet orders for the summary ####
        # out = nn.Dense(features=512, name="A")(out)  
        # out = nn.Dense(features=256, name="B")(out)

        # out = nn.Dropout(0.1, deterministic=True, name='C')(out)  # deteministic=true: no mask and apply the rate

        # out = nn.Dense(features=64, name="D")(out)
        out = nn.Dense(features=3, name="E")(cls_embedding)

        return out     

### initialize the model ###
init_rng = jax.random.PRNGKey(seed)

# the model size is 400 MB
# pretrained_bert = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed)  
pretrained_bert = FlaxBertModel.from_pretrained(model_checkpoint)  
model = MyNLP(pretrained_bert.module)
variables = model.init(init_rng, dummy_input_ids, dummy_attention_mask)  # store only [A-E] variables

#### add randomly initialized params with pretrained params ####
variables = unfreeze(variables)  # unfreeze: Makes a mutable copy of a FrozenDict mutable by transforming it into (nested) dict
variables['params']['bert'] = pretrained_bert.params  # inserting the pretrained parameters into the correct place on the new parameter structure as names will matter 
variables = freeze(variables)  # freeze: An immutable variant of the Python dict.

# opt = optax.adam(2e-5)
# num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
# print("Number of Train Steps {:d}" .format(num_train_epochs))
# learning_rate_function = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=learning_rate, pct_start=0.1)
# opt = adamw(weight_decay)

# state = train_state.TrainState.create(
#     apply_fn=model.apply,
#     params=variables['params'],
#     tx=opt
#     )        

Downloading flax_model.msgpack:   0%|          | 0.00/413M [00:00<?, ?B/s]

Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-cased and are newly initialized: {('pooler', 'dense', 'kernel'), ('pooler', 'dense', 'bias')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
pretrained_bert.module

## Define the Training State

In [None]:
class TrainState(train_state.TrainState):
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)

In [None]:
def decay_mask_fn(params):  # expect the params are unfreeze
    '''
    This function's task is to make sure that weight decay is not applies to any bias or Layernorm weights
    '''
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
    return traverse_util.unflatten_dict(flat_mask)

In [None]:
# Adam optimizer function using optax.adamw
def adamw(weight_decay):
    return optax.adamw(learning_rate=learning_rate_function, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay,mask=decay_mask_fn)

In [None]:
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
learning_rate_function = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=learning_rate, pct_start=0.1)
# adamw = adamw(weight_decay)
opt = optax.adam(learning_rate)

# Define Loss and Accuracy Functions

In [None]:
@jax.jit
def cross_entropy_loss(logits, labels):  # loss: softmax
  # labels_onehot = jax.nn.one_hot(labels, num_classes=num_labels)
  return jnp.mean(optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels)))   # return example: 0.96834594

In [None]:
@jax.jit
def eval_function(logits):
  return logits.argmax(-1) # accuracy

In [None]:
state = TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=opt, 
    logits_function=eval_function,
    loss_function=cross_entropy_loss,
)

# Define the Training/Evaluation Steps

In [None]:
def train_step(state, batch, dropout_rng):  # softmax
  labels = batch.pop("labels")
  dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

  def compute_loss(params): # warp with jax.value_and_grad
    logits = state.apply_fn(variables, **batch)  # **batch = input_ids, attention_mask
    loss = state.loss_function(logits, labels)
    return loss

  # do the forward pass and get the loss and gradients
  grad_function = jax.value_and_grad(compute_loss)  # def loss_function(params) is wrapped by Jax to calculate loss and gradient
  loss, grad = grad_function(state.params)

  grad = jax.lax.pmean(grad, "batch") # compute the mean gradient over all devices 

  # this function calls tx.update() followed by a call to optax.apply_updates() to update params and opt_state
  new_state = state.apply_gradients(grads=grad)  # update params with Adam  

  # calculate accuracy and store loss and accuracy
  # accuracy = argmax_logits(logits=logits, labels=labels) 
  new_metrics = {'train_loss': loss} # , 'train_accuracy': accuracy}
  
  return new_state, new_metrics, new_dropout_rng

parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) # parallelized training over all TPU devices  

In [None]:
def eval_step(state, batch):  # argmax
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits=logits)  # return loss and accuracy

parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

# Training

In [None]:
# Full training loop
state = flax.jax_utils.replicate(state)  # sharing the weight on each device
rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

for i, epoch in enumerate(tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for batch in train_data_loader(input_rng, train_dataset, total_batch_size):
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            progress_bar_train.update(1)

    # evaluate
    # with tqdm(total=len(validation_dataset) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
    #       for batch in eval_data_loader(validation_dataset, total_batch_size):
    #             labels = batch.pop("labels")
    #             predictions = parallel_eval_step(state, batch)
    #             metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
    #             progress_bar_eval.update(1)

    # eval_metric = metric.compute(average='macro')

    # loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3)
    # eval_score = round(list(eval_metric.values())[0],3)
    # metric_name = list(eval_metric.keys())[0]

    # print(f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

Epoch ...:   0%|          | 0/1 [00:00<?, ?it/s]
Donation is not implemented for cpu.
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.

Training...:   0%|          | 1/1091 [00:53<16:04:09, 53.07s/it][A
Training...:   0%|          | 2/1091 [01:06<8:57:09, 29.60s/it] [A
Training...:   0%|          | 3/1091 [01:19<6:39:44, 22.04s/it][A
Training...:   0%|          | 4/1091 [01:32<5:35:12, 18.50s/it][A
Training...:   0%|          | 5/1091 [01:45<5:00:07, 16.58s/it][A
Training...:   1%|          | 6/1091 [02:01<4:56:55, 16.42s/it][A
Training...:   1%|          | 7/1091 [02:15<4:40:02, 15.50s/it][A
Training...:   1%|          | 8/1091 [02:28<4:24:37, 14.66s/it][A
Training...:   1%|          | 9/1091 [02:41<4:14:34, 14.12s/it][A
Training...:   1%|          | 10/1091 [02:53<4:07:01, 13.71s/it][A
Training...:   1%|          | 11/1091 [03:06<4:02:19, 13.46s/it][A
Training...:   1%|          | 12/1091 [03:19<3:59:17, 13.31s/it][A
Training...:   1%| 

KeyboardInterrupt: ignored

# Future Study

In [None]:
'''
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.Series.plot.kde.html
'''
pd.Series(words_list).plot.kde()  # most words are less than 250