<a href="https://colab.research.google.com/github/serene23/NLP-Project/blob/main/482FinalProj.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install datasets
!pip install git+https://github.com/huggingface/transformers.git
!pip install flax
!pip install git+https://github.com/deepmind/optax.git

In [None]:
# Importing all the libraries necessary for the project
import os
import time
import jax
import flax
import optax
import datasets
import pandas as pd 
import numpy as np
from jax import jit
import jax.numpy as jnp
import tensorflow as tf
from flax.training import train_state
from itertools import chain
from tqdm.notebook import tqdm
from typing import Callable
from flax import traverse_util
from datasets import load_dataset, load_metric ,Dataset,list_metrics,load_from_disk
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from transformers import FlaxAutoModelForSequenceClassification, AutoConfig, AutoTokenizer, BertTokenizer
import warnings
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [None]:
jax.local_devices()

In [None]:
task = "mnli"
model_checkpoint = "bert-base-cased"
per_device_batch_size = 4

In [None]:
from datasets import load_dataset, load_metric

In [None]:
actual_task = "mnli"
raw_train = load_dataset("csv", data_files={'train': ['/content/train.csv']}) #change this path to your desired file
raw_dataset = raw_train["train"].train_test_split(0.2)
metric = load_metric('glue', actual_task) #load metrics needed for our task, can change second parameter to accomodate your task

In [None]:
raw_dataset

In [None]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) #tokenizing with bert-base-cased

In [None]:
task_to_keys = {   #we selected mnli for our project
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mnli-mm": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
} 

In [None]:
sentence1_key, sentence2_key = task_to_keys[task]

def preprocess_function(examples):    #https://www.kaggle.com/code/nilaychauhan/jigsaw-toxic-comment-classification-using-jax-flax
    texts = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    
    processed["labels"] = examples["label"]
    return processed    

In [None]:
tokenized_dataset = raw_dataset.map(preprocess_function, batched=True, remove_columns=raw_dataset["train"].column_names)

In [None]:
train_dataset = tokenized_dataset["train"]


In [None]:
tokenized_dataset

In [None]:
eval_dataset = tokenized_dataset["test"]

In [None]:
from transformers import FlaxAutoModelForSequenceClassification, AutoConfig

num_labels = 3 
seed = 0

config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed)

In [None]:
import flax
import jax
import optax

from itertools import chain
from tqdm.notebook import tqdm
from typing import Callable

import jax.numpy as jnp

from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training import train_state

In [None]:
num_train_epochs = 2 #matched epochs with our TensorFlow run
learning_rate = 2e-5

In [None]:
total_batch_size = per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)

In [None]:
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs

learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)

In [None]:
class TrainState(train_state.TrainState):   #https://www.kaggle.com/code/nilaychauhan/jigsaw-toxic-comment-classification-using-jax-flax
    logits_function: Callable = flax.struct.field(pytree_node=False)  
    loss_function: Callable = flax.struct.field(pytree_node=False)

In [None]:
def decay_mask_fn(params):
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
    return traverse_util.unflatten_dict(flat_mask)

In [None]:
def adamw(weight_decay):  #https://www.kaggle.com/code/anasofiauzsoy/tutorial-notebook/notebook
    return optax.adamw(learning_rate=learning_rate_function, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay,mask=decay_mask_fn)

In [None]:
weight_decay = 1e-2
adamw = adamw(weight_decay)  

In [None]:
from jaxlib.xla_extension.jax_jit import jit


def loss_function(logits, labels):
  xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
  return jnp.mean(xentropy)
    
def eval_function(logits):
    return logits.argmax(-1)

In [None]:
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw,
    logits_function=eval_function,
    loss_function=loss_function,
)


In [None]:
def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)}, axis_name="batch")
    return new_state, metrics, new_dropout_rng

In [None]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

In [None]:
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)

In [None]:
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

In [None]:
def glue_train_data_loader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [None]:
def glue_eval_data_loader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [None]:
state = flax.jax_utils.replicate(state)

In [None]:
rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [None]:
for i, epoch in enumerate(tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)):   #https://www.kaggle.com/code/nilaychauhan/jigsaw-toxic-comment-classification-using-jax-flax
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
      for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
        state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
        progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
      for batch in glue_eval_data_loader(eval_dataset, total_batch_size):
          labels = batch.pop("labels")
          predictions = parallel_eval_step(state, batch)
          metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
          progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3)
    eval_score = round(list(eval_metric.values())[0], 3)
    metric_name = list(eval_metric.keys())[0]

    print(f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name}: {eval_score}")