In [3]:
import os, sys

import numpy as np
import pandas as pd 
from datasets import load_dataset

import importlib
from tqdm import tqdm
from joblib import Parallel, delayed
from copy import copy

from transformers import (
    AutoConfig,
    AutoTokenizer,
    FlaxAutoModelForSequenceClassification,
    HfArgumentParser,
    PretrainedConfig,
    TrainingArguments,
    is_tensorboard_available,
)

from flax.training.common_utils import get_metrics, onehot, shard


data_root = "/kaggle/input/feedback-prize-effectiveness/"
train = pd.read_csv("/kaggle/input/feedback-prize-effectiveness/train.csv")


In [39]:
LABEL_MAPPING = {"Ineffective": 0, "Adequate": 1, "Effective": 2}

def _prepare_training_data_helper(args, tokenizer, df, is_train):
    training_samples = []
    for _, row in tqdm(df.iterrows(), total=len(df)):
        idx = row["essay_id"]
        discourse_text = row["discourse_text"]
        discourse_type = row["discourse_type"]

        if is_train:
            filename = os.path.join(args.input, "train", idx + ".txt")
        else:
            filename = os.path.join(args.input, "test", idx + ".txt")

        with open(filename, "r") as f:
            text = f.read()

        encoded_text = tokenizer.encode_plus(
            discourse_type + " " + discourse_text,
            text,
            add_special_tokens=False,
            padding="max_length",
            truncation=True,
            max_length=512 ##TODO: update max_length
        )
        input_ids = encoded_text["input_ids"]

        sample = {
            # "discourse_id": row["discourse_id"],
            "input_ids": input_ids,
            # "discourse_text": discourse_text,
            # "essay_text": text,
            "attention_mask": encoded_text["attention_mask"],
        }

        if "token_type_ids" in encoded_text:
            sample["token_type_ids"] = encoded_text["token_type_ids"]

        label = row["discourse_effectiveness"]

        sample["labels"] = LABEL_MAPPING[label]

        training_samples.append(sample)
    return training_samples


def prepare_training_data(df, tokenizer, args, num_jobs, is_train):
    training_samples = []

    df_splits = np.array_split(df, num_jobs)

    results = Parallel(n_jobs=num_jobs, backend="multiprocessing")(
        delayed(_prepare_training_data_helper)(args, tokenizer, df, is_train) for df in df_splits
    )
    for result in results:
        training_samples.extend(result)

    return training_samples

In [40]:
sys.path.append("../configs")
cfg = copy(importlib.import_module("default_config").cfg)

# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(
    cfg.model_name_or_path,
    num_labels=cfg.num_labels,
    #finetuning_task=data_args.task_name,
    #use_auth_token=True if cfg.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
    cfg.model_name_or_path,
    use_fast=not cfg.use_slow_tokenizer,
    #use_auth_token=True if cfg.use_auth_token else None,
)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
    cfg.model_name_or_path,
    config=config,
    #use_auth_token=True if cfg.use_auth_token else None,
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased 

In [35]:
val_data = prepare_training_data(train.iloc[range(0, 100, 5)], tokenizer, cfg, num_jobs=96, is_train=True)


100%|██████████| 1/1 [00:00<00:00, 101.67it/s]
100%|██████████| 1/1 [00:00<00:00, 107.68it/s]
100%|██████████| 1/1 [00:00<00:00, 86.52it/s]
100%|██████████| 1/1 [00:00<00:00, 88.98it/s]
100%|██████████| 1/1 [00:00<00:00, 107.26it/s]
100%|██████████| 1/1 [00:00<00:00, 105.83it/s]
100%|██████████| 1/1 [00:00<00:00, 108.12it/s]
100%|██████████| 1/1 [00:00<00:00, 108.92it/s]
100%|██████████| 1/1 [00:00<00:00, 110.07it/s]
100%|██████████| 1/1 [00:00<00:00, 100.66it/s]
100%|██████████| 1/1 [00:00<00:00, 101.03it/s]
100%|██████████| 1/1 [00:00<00:00, 102.63it/s]
100%|██████████| 1/1 [00:00<00:00, 552.39it/s]
100%|██████████| 1/1 [00:00<00:00, 93.07it/s]
100%|██████████| 1/1 [00:00<00:00, 94.39it/s]
100%|██████████| 1/1 [00:00<00:00, 97.56it/s]
100%|██████████| 1/1 [00:00<00:00, 105.47it/s]
100%|██████████| 1/1 [00:00<00:00, 111.71it/s]
100%|██████████| 1/1 [00:00<00:00, 102.18it/s]
100%|██████████| 1/1 [00:00<00:00, 102.60it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:

In [37]:
np.array(val_data[0]['input_ids']).shape

(512,)

In [41]:
import json
from sklearn.model_selection import StratifiedKFold

kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
## stratified Kfold for train dataframe using discourse_type and discourse_effectiveness
for fold, (train_index, valid_index) in enumerate(kf.split(train, train["discourse_type"], train["discourse_effectiveness"])):

    train_temp = train.iloc[train_index]
    valid_temp = train.iloc[valid_index]

    train_data = prepare_training_data(train_temp, tokenizer, cfg, num_jobs=96, is_train=True)
    val_data = prepare_training_data(valid_temp, tokenizer, cfg, num_jobs=96, is_train=True)

    df = pd.DataFrame.from_records(train_data)
    df.to_json(f"/kaggle/working/folds/train_{fold}.jsonl", orient="records", lines=True)

    df = pd.DataFrame.from_records(val_data)
    df.to_json(f"/kaggle/working/folds/valid_{fold}.jsonl", orient="records", lines=True)

    print("Fold:", fold)
    print("Train:", train_index)
    print("Valid:", valid_index)
    print("\n")
    break

 71%|███████   | 217/307 [00:00<00:00, 381.44it/s]
100%|██████████| 307/307 [00:00<00:00, 467.42it/s]
100%|██████████| 307/307 [00:00<00:00, 434.70it/s]
 34%|███▍      | 104/307 [00:00<00:00, 484.59it/s]
100%|██████████| 307/307 [00:00<00:00, 490.22it/s]
 74%|███████▍  | 228/307 [00:00<00:00, 457.05it/s]
100%|██████████| 307/307 [00:00<00:00, 405.02it/s]
100%|██████████| 307/307 [00:00<00:00, 438.87it/s]
100%|██████████| 307/307 [00:00<00:00, 454.25it/s]
100%|██████████| 307/307 [00:00<00:00, 418.80it/s]
100%|██████████| 307/307 [00:00<00:00, 441.63it/s]
100%|██████████| 307/307 [00:00<00:00, 409.15it/s]
100%|██████████| 307/307 [00:00<00:00, 463.65it/s]
100%|██████████| 307/307 [00:00<00:00, 415.38it/s]
100%|██████████| 307/307 [00:00<00:00, 416.33it/s]
100%|██████████| 307/307 [00:00<00:00, 415.51it/s]
 59%|█████▉    | 181/307 [00:00<00:00, 572.56it/s]
100%|██████████| 307/307 [00:00<00:00, 433.26it/s]
100%|██████████| 307/307 [00:00<00:00, 405.52it/s]
 46%|████▌     | 141/307 [00:00

Fold: 0
Train: [    0     1     4 ... 36762 36763 36764]
Valid: [    2     3     7 ... 36733 36736 36746]




In [44]:
# import json
# from sklearn.model_selection import StratifiedKFold

# kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# ## stratified Kfold for train dataframe using discourse_type and discourse_effectiveness
# for fold, (train_index, valid_index) in enumerate(kf.split(train, train["discourse_type"], train["discourse_effectiveness"])):

#     train_temp = train.iloc[train_index]
#     valid_temp = train.iloc[valid_index]

#     train_data = prepare_training_data(train_temp, tokenizer, cfg, num_jobs=32, is_train=True)
#     val_data = prepare_training_data(valid_temp, tokenizer, cfg, num_jobs=32, is_train=True)

#     df = pd.DataFrame.from_records(train_data)
#     df.to_json(f"/kaggle/working/folds/train_{fold}.jsonl", orient="records", lines=True)

#     df = pd.DataFrame.from_records(val_data)
#     df.to_json(f"/kaggle/working/folds/valid_{fold}.jsonl", orient="records", lines=True)

#     print("Fold:", fold)
#     print("Train:", train_index)
#     print("Valid:", valid_index)
#     print("\n")

100%|██████████| 920/920 [00:01<00:00, 527.15it/s]
100%|██████████| 920/920 [00:01<00:00, 535.80it/s]
100%|██████████| 920/920 [00:01<00:00, 541.70it/s]
100%|██████████| 920/920 [00:01<00:00, 554.34it/s]
100%|██████████| 919/919 [00:01<00:00, 517.09it/s]
100%|██████████| 919/919 [00:01<00:00, 504.91it/s]
100%|██████████| 919/919 [00:01<00:00, 501.29it/s]
100%|██████████| 919/919 [00:01<00:00, 505.08it/s]
100%|██████████| 919/919 [00:01<00:00, 495.78it/s]
100%|██████████| 919/919 [00:01<00:00, 527.82it/s]
100%|██████████| 919/919 [00:01<00:00, 576.84it/s]
100%|██████████| 919/919 [00:01<00:00, 552.58it/s]
100%|██████████| 919/919 [00:01<00:00, 623.94it/s]
100%|██████████| 919/919 [00:01<00:00, 539.16it/s]
100%|██████████| 919/919 [00:01<00:00, 582.45it/s]
100%|██████████| 919/919 [00:01<00:00, 558.21it/s]
100%|██████████| 919/919 [00:01<00:00, 638.66it/s]
100%|██████████| 919/919 [00:01<00:00, 562.49it/s]
100%|██████████| 919/919 [00:01<00:00, 582.94it/s]
100%|██████████| 919/919 [00:01

Fold: 0
Train: [    0     1     4 ... 36762 36763 36764]
Valid: [    2     3     7 ... 36733 36736 36746]




100%|██████████| 920/920 [00:01<00:00, 517.61it/s]
100%|██████████| 920/920 [00:01<00:00, 534.69it/s]
100%|██████████| 920/920 [00:01<00:00, 548.40it/s]
100%|██████████| 920/920 [00:01<00:00, 558.75it/s]
100%|██████████| 919/919 [00:01<00:00, 514.06it/s]
100%|██████████| 919/919 [00:01<00:00, 537.65it/s]
100%|██████████| 919/919 [00:01<00:00, 497.47it/s]
 10%|█         | 94/919 [00:00<00:01, 483.61it/s]]

In [7]:
## data collator with dynamic padding
# def train_data_collator(rng:)
import jax
import datasets
from typing import Any, Callable, Dict, Optional, Tuple

rng = jax.random.PRNGKey(1)#cfg.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

Array = Any
Dataset = datasets.arrow_dataset.Dataset
PRNGKey = Any

In [50]:
def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
    """Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        discourse_id, input_ids, labels = dataset[perm]['discourse_id'], dataset[perm]['input_ids'], dataset[perm]['label']
        batch.pop("discourse_id", None)
        batch = {"input_ids": np.array(input_ids), "mask": [np.ones_like(x) for x in input_ids], "label": np.array(labels)}

        # calculate max token length of this batch
        batch_max = max([len(ids) for ids in batch["input_ids"]])
        # add padding
        if tokenizer.padding_side == "right":
            batch["input_ids"] = [s + (batch_max - len(s)) * [tokenizer.pad_token_id] for s in batch["input_ids"]]
        else:
            batch["input_ids"] = [(batch_max - len(s)) * [tokenizer.pad_token_id] + s for s in batch["input_ids"]]

        batch['input_ids'] = np.stack(batch['input_ids'])
        
        masks = np.zeros_like(batch['input_ids'])
        masks[batch['input_ids'] != tokenizer.pad_token_id] = 1
        batch['mask'] = masks

        batch = {k: np.array(v) for k, v in batch.items()}
        batch = shard(batch)
        yield batch

In [51]:
train_dataset = load_dataset("json", data_files="/kaggle/working/folds/valid_0.jsonl", split="train")
train_loader = train_data_collator(rng, train_dataset, cfg.per_device_train_batch_size)



In [52]:
for batch in train_loader:
    print(batch)
    break

938
(4, 938)
{'input_ids': array([[  898,  4461, 39341, ...,     0,     0,     0],
       [21360,  1507,   712, ...,     0,     0,     0],
       [20217,  4055,  8090, ...,   100,   321,   100],
       [23259,   415,   993, ...,     0,     0,     0]]), 'mask': array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 0, 0, 0]]), 'label': array([0, 1, 1, 0])}


  batch = {"input_ids": np.array(input_ids), "mask": [np.ones_like(x) for x in input_ids], "label": np.array(labels)}


In [56]:
batch['mask'][1], batch['input_ids'][1]

(array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [57]:
import jax.profiler

In [58]:
server = jax.profiler.start_server(9999)


In [59]:
server

<jaxlib.xla_extension.profiler.ProfilerServer at 0x7f7968dbc370>

: 