In [1]:
import os, sys

import numpy as np
import pandas as pd 
from datasets import load_dataset

import importlib
from tqdm import tqdm
from joblib import Parallel, delayed
from copy import copy

from transformers import (
    AutoConfig,
    AutoTokenizer,
    FlaxAutoModelForSequenceClassification,
    HfArgumentParser,
    PretrainedConfig,
    TrainingArguments,
    is_tensorboard_available,
)

from flax.training.common_utils import get_metrics, onehot, shard


data_root = "/kaggle/input/feedback-prize-effectiveness/"
train = pd.read_csv("/kaggle/input/feedback-prize-effectiveness/train.csv")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
LABEL_MAPPING = {"Ineffective": 0, "Adequate": 1, "Effective": 2}

def _prepare_training_data_helper(args, tokenizer, df, is_train):
    training_samples = []
    for _, row in tqdm(df.iterrows(), total=len(df)):
        idx = row["essay_id"]
        discourse_text = row["discourse_text"]
        discourse_type = row["discourse_type"]

        if is_train:
            filename = os.path.join(args.input, "train", idx + ".txt")
        else:
            filename = os.path.join(args.input, "test", idx + ".txt")

        with open(filename, "r") as f:
            text = f.read()

        encoded_text = tokenizer.encode_plus(
            discourse_type + " " + discourse_text,
            text,
            add_special_tokens=False,
        )
        input_ids = encoded_text["input_ids"]

        sample = {
            "discourse_id": row["discourse_id"],
            "input_ids": input_ids,
            # "discourse_text": discourse_text,
            # "essay_text": text,
            # "mask": encoded_text["attention_mask"],
        }

        if "token_type_ids" in encoded_text:
            sample["token_type_ids"] = encoded_text["token_type_ids"]

        label = row["discourse_effectiveness"]

        sample["label"] = LABEL_MAPPING[label]

        training_samples.append(sample)
    return training_samples


def prepare_training_data(df, tokenizer, args, num_jobs, is_train):
    training_samples = []

    df_splits = np.array_split(df, num_jobs)

    results = Parallel(n_jobs=num_jobs, backend="multiprocessing")(
        delayed(_prepare_training_data_helper)(args, tokenizer, df, is_train) for df in df_splits
    )
    for result in results:
        training_samples.extend(result)

    return training_samples

In [3]:
sys.path.append("../configs")
cfg = copy(importlib.import_module("default_config").cfg)

# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(
    cfg.model_name_or_path,
    num_labels=cfg.num_labels,
    #finetuning_task=data_args.task_name,
    #use_auth_token=True if cfg.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
    cfg.model_name_or_path,
    use_fast=not cfg.use_slow_tokenizer,
    #use_auth_token=True if cfg.use_auth_token else None,
)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
    cfg.model_name_or_path,
    config=config,
    #use_auth_token=True if cfg.use_auth_token else None,
)

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
Some weights of the model checkpoint at google/bigbird-roberta-base were not used when initializing FlaxBigBirdForSequenceClassification: {('cls', 'seq_relationship', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'seq_relationship', 'kernel')}
- This IS expected if you are initializing FlaxBigBirdForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBigBirdForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequen

In [5]:
import json
from sklearn.model_selection import StratifiedKFold

kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
## stratified Kfold for train dataframe using discourse_type and discourse_effectiveness
for fold, (train_index, valid_index) in enumerate(kf.split(train, train["discourse_type"], train["discourse_effectiveness"])):

    train_temp = train.iloc[train_index]
    valid_temp = train.iloc[valid_index]

    train_data = prepare_training_data(train_temp, tokenizer, cfg, num_jobs=96, is_train=True)
    val_data = prepare_training_data(valid_temp, tokenizer, cfg, num_jobs=96, is_train=True)

    df = pd.DataFrame.from_records(train_data)
    df.to_json(f"/kaggle/working/folds/train_{fold}.jsonl", orient="records", lines=True)

    df = pd.DataFrame.from_records(val_data)
    df.to_json(f"/kaggle/working/folds/valid_{fold}.jsonl", orient="records", lines=True)

    print("Fold:", fold)
    print("Train:", train_index)
    print("Valid:", valid_index)
    print("\n")

100%|██████████| 307/307 [00:00<00:00, 505.70it/s]
100%|██████████| 307/307 [00:00<00:00, 567.00it/s]
100%|██████████| 307/307 [00:00<00:00, 499.91it/s]
100%|██████████| 307/307 [00:00<00:00, 634.21it/s]
100%|██████████| 307/307 [00:00<00:00, 527.41it/s]
100%|██████████| 307/307 [00:00<00:00, 533.02it/s]
100%|██████████| 307/307 [00:00<00:00, 612.70it/s]
100%|██████████| 307/307 [00:00<00:00, 525.42it/s]
100%|██████████| 307/307 [00:00<00:00, 560.21it/s]
100%|██████████| 307/307 [00:00<00:00, 567.64it/s]
100%|██████████| 307/307 [00:00<00:00, 570.90it/s]
100%|██████████| 307/307 [00:00<00:00, 617.90it/s]
 42%|████▏     | 128/307 [00:00<00:00, 548.76it/s]
100%|██████████| 307/307 [00:00<00:00, 546.39it/s]
100%|██████████| 307/307 [00:00<00:00, 547.05it/s]
100%|██████████| 307/307 [00:00<00:00, 512.75it/s]
100%|██████████| 307/307 [00:00<00:00, 531.86it/s]
100%|██████████| 307/307 [00:00<00:00, 557.48it/s]
100%|██████████| 307/307 [00:00<00:00, 539.11it/s]
100%|██████████| 307/307 [00:00

Fold: 0
Train: [    0     1     4 ... 36762 36763 36764]
Valid: [    2     3     7 ... 36733 36736 36746]




100%|██████████| 307/307 [00:00<00:00, 555.44it/s]
100%|██████████| 307/307 [00:00<00:00, 638.91it/s]
100%|██████████| 307/307 [00:00<00:00, 655.73it/s]
100%|██████████| 307/307 [00:00<00:00, 499.35it/s]
100%|██████████| 307/307 [00:00<00:00, 507.25it/s]
100%|██████████| 307/307 [00:00<00:00, 563.90it/s]
100%|██████████| 307/307 [00:00<00:00, 608.90it/s]
100%|██████████| 307/307 [00:00<00:00, 515.28it/s]
100%|██████████| 307/307 [00:00<00:00, 583.05it/s]
100%|██████████| 307/307 [00:00<00:00, 534.14it/s]
100%|██████████| 307/307 [00:00<00:00, 592.39it/s]
100%|██████████| 307/307 [00:00<00:00, 616.98it/s]
100%|██████████| 307/307 [00:00<00:00, 496.70it/s]
100%|██████████| 307/307 [00:00<00:00, 554.83it/s]
100%|██████████| 307/307 [00:00<00:00, 521.97it/s]
100%|██████████| 307/307 [00:00<00:00, 497.87it/s]
100%|██████████| 307/307 [00:00<00:00, 554.56it/s]
100%|██████████| 307/307 [00:00<00:00, 515.00it/s]
100%|██████████| 307/307 [00:00<00:00, 595.89it/s]
100%|██████████| 307/307 [00:00

Fold: 1
Train: [    0     2     3 ... 36759 36763 36764]
Valid: [    1     4     8 ... 36760 36761 36762]




100%|██████████| 307/307 [00:00<00:00, 480.60it/s]
100%|██████████| 307/307 [00:00<00:00, 646.47it/s]
100%|██████████| 307/307 [00:00<00:00, 635.86it/s]
100%|██████████| 307/307 [00:00<00:00, 479.37it/s]
100%|██████████| 307/307 [00:00<00:00, 520.06it/s]
100%|██████████| 307/307 [00:00<00:00, 514.19it/s]
100%|██████████| 307/307 [00:00<00:00, 582.24it/s]
100%|██████████| 307/307 [00:00<00:00, 526.32it/s]
100%|██████████| 307/307 [00:00<00:00, 553.01it/s]
100%|██████████| 307/307 [00:00<00:00, 553.54it/s]
100%|██████████| 307/307 [00:00<00:00, 541.31it/s]
100%|██████████| 307/307 [00:00<00:00, 614.92it/s]
100%|██████████| 307/307 [00:00<00:00, 543.23it/s]
100%|██████████| 307/307 [00:00<00:00, 532.59it/s]
100%|██████████| 307/307 [00:00<00:00, 528.97it/s]
100%|██████████| 307/307 [00:00<00:00, 475.80it/s]
100%|██████████| 307/307 [00:00<00:00, 564.50it/s]
100%|██████████| 307/307 [00:00<00:00, 560.67it/s]
100%|██████████| 307/307 [00:00<00:00, 522.90it/s]
100%|██████████| 307/307 [00:00

Fold: 2
Train: [    0     1     2 ... 36761 36762 36763]
Valid: [    5     6     9 ... 36755 36759 36764]




100%|██████████| 307/307 [00:00<00:00, 535.73it/s]
100%|██████████| 307/307 [00:00<00:00, 636.07it/s]
100%|██████████| 307/307 [00:00<00:00, 632.46it/s]
100%|██████████| 307/307 [00:00<00:00, 464.36it/s]
100%|██████████| 307/307 [00:00<00:00, 480.02it/s]
100%|██████████| 307/307 [00:00<00:00, 515.36it/s]
100%|██████████| 307/307 [00:00<00:00, 545.18it/s]
100%|██████████| 307/307 [00:00<00:00, 485.00it/s]
100%|██████████| 307/307 [00:00<00:00, 553.94it/s]
100%|██████████| 307/307 [00:00<00:00, 555.11it/s]
100%|██████████| 307/307 [00:00<00:00, 567.00it/s]
100%|██████████| 307/307 [00:00<00:00, 587.46it/s]
100%|██████████| 307/307 [00:00<00:00, 528.08it/s]
100%|██████████| 307/307 [00:00<00:00, 559.66it/s]
100%|██████████| 307/307 [00:00<00:00, 537.16it/s]
100%|██████████| 307/307 [00:00<00:00, 499.43it/s]
100%|██████████| 307/307 [00:00<00:00, 555.55it/s]
100%|██████████| 307/307 [00:00<00:00, 578.13it/s]
100%|██████████| 307/307 [00:00<00:00, 536.80it/s]
100%|██████████| 307/307 [00:00

Fold: 3
Train: [    0     1     2 ... 36761 36762 36764]
Valid: [   10    22    23 ... 36745 36756 36763]




100%|██████████| 307/307 [00:00<00:00, 498.08it/s]
100%|██████████| 307/307 [00:00<00:00, 603.21it/s]
100%|██████████| 307/307 [00:00<00:00, 509.08it/s]
100%|██████████| 307/307 [00:00<00:00, 616.28it/s]
100%|██████████| 307/307 [00:00<00:00, 518.68it/s]
100%|██████████| 307/307 [00:00<00:00, 538.83it/s]
100%|██████████| 307/307 [00:00<00:00, 595.82it/s]
100%|██████████| 307/307 [00:00<00:00, 502.23it/s]
 68%|██████▊   | 210/307 [00:00<00:00, 517.80it/s]
100%|██████████| 307/307 [00:00<00:00, 530.68it/s]
100%|██████████| 307/307 [00:00<00:00, 592.02it/s]
100%|██████████| 307/307 [00:00<00:00, 586.93it/s]
100%|██████████| 307/307 [00:00<00:00, 528.09it/s]
100%|██████████| 307/307 [00:00<00:00, 543.80it/s]
100%|██████████| 307/307 [00:00<00:00, 525.79it/s]
100%|██████████| 307/307 [00:00<00:00, 486.79it/s]
100%|██████████| 307/307 [00:00<00:00, 554.05it/s]
100%|██████████| 307/307 [00:00<00:00, 496.39it/s]
100%|██████████| 307/307 [00:00<00:00, 541.65it/s]
100%|██████████| 307/307 [00:00

Fold: 4
Train: [    1     2     3 ... 36762 36763 36764]
Valid: [    0    13    14 ... 36748 36753 36757]




In [44]:
# import json
# from sklearn.model_selection import StratifiedKFold

# kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# ## stratified Kfold for train dataframe using discourse_type and discourse_effectiveness
# for fold, (train_index, valid_index) in enumerate(kf.split(train, train["discourse_type"], train["discourse_effectiveness"])):

#     train_temp = train.iloc[train_index]
#     valid_temp = train.iloc[valid_index]

#     train_data = prepare_training_data(train_temp, tokenizer, cfg, num_jobs=32, is_train=True)
#     val_data = prepare_training_data(valid_temp, tokenizer, cfg, num_jobs=32, is_train=True)

#     df = pd.DataFrame.from_records(train_data)
#     df.to_json(f"/kaggle/working/folds/train_{fold}.jsonl", orient="records", lines=True)

#     df = pd.DataFrame.from_records(val_data)
#     df.to_json(f"/kaggle/working/folds/valid_{fold}.jsonl", orient="records", lines=True)

#     print("Fold:", fold)
#     print("Train:", train_index)
#     print("Valid:", valid_index)
#     print("\n")

100%|██████████| 920/920 [00:01<00:00, 527.15it/s]
100%|██████████| 920/920 [00:01<00:00, 535.80it/s]
100%|██████████| 920/920 [00:01<00:00, 541.70it/s]
100%|██████████| 920/920 [00:01<00:00, 554.34it/s]
100%|██████████| 919/919 [00:01<00:00, 517.09it/s]
100%|██████████| 919/919 [00:01<00:00, 504.91it/s]
100%|██████████| 919/919 [00:01<00:00, 501.29it/s]
100%|██████████| 919/919 [00:01<00:00, 505.08it/s]
100%|██████████| 919/919 [00:01<00:00, 495.78it/s]
100%|██████████| 919/919 [00:01<00:00, 527.82it/s]
100%|██████████| 919/919 [00:01<00:00, 576.84it/s]
100%|██████████| 919/919 [00:01<00:00, 552.58it/s]
100%|██████████| 919/919 [00:01<00:00, 623.94it/s]
100%|██████████| 919/919 [00:01<00:00, 539.16it/s]
100%|██████████| 919/919 [00:01<00:00, 582.45it/s]
100%|██████████| 919/919 [00:01<00:00, 558.21it/s]
100%|██████████| 919/919 [00:01<00:00, 638.66it/s]
100%|██████████| 919/919 [00:01<00:00, 562.49it/s]
100%|██████████| 919/919 [00:01<00:00, 582.94it/s]
100%|██████████| 919/919 [00:01

Fold: 0
Train: [    0     1     4 ... 36762 36763 36764]
Valid: [    2     3     7 ... 36733 36736 36746]




100%|██████████| 920/920 [00:01<00:00, 517.61it/s]
100%|██████████| 920/920 [00:01<00:00, 534.69it/s]
100%|██████████| 920/920 [00:01<00:00, 548.40it/s]
100%|██████████| 920/920 [00:01<00:00, 558.75it/s]
100%|██████████| 919/919 [00:01<00:00, 514.06it/s]
100%|██████████| 919/919 [00:01<00:00, 537.65it/s]
100%|██████████| 919/919 [00:01<00:00, 497.47it/s]
 10%|█         | 94/919 [00:00<00:01, 483.61it/s]]

In [7]:
## data collator with dynamic padding
# def train_data_collator(rng:)
import jax
import datasets
from typing import Any, Callable, Dict, Optional, Tuple

rng = jax.random.PRNGKey(1)#cfg.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

Array = Any
Dataset = datasets.arrow_dataset.Dataset
PRNGKey = Any

In [50]:
def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
    """Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        discourse_id, input_ids, labels = dataset[perm]['discourse_id'], dataset[perm]['input_ids'], dataset[perm]['label']
        batch.pop("discourse_id", None)
        batch = {"input_ids": np.array(input_ids), "mask": [np.ones_like(x) for x in input_ids], "label": np.array(labels)}

        # calculate max token length of this batch
        batch_max = max([len(ids) for ids in batch["input_ids"]])
        # add padding
        if tokenizer.padding_side == "right":
            batch["input_ids"] = [s + (batch_max - len(s)) * [tokenizer.pad_token_id] for s in batch["input_ids"]]
        else:
            batch["input_ids"] = [(batch_max - len(s)) * [tokenizer.pad_token_id] + s for s in batch["input_ids"]]

        batch['input_ids'] = np.stack(batch['input_ids'])
        
        masks = np.zeros_like(batch['input_ids'])
        masks[batch['input_ids'] != tokenizer.pad_token_id] = 1
        batch['mask'] = masks

        batch = {k: np.array(v) for k, v in batch.items()}
        batch = shard(batch)
        yield batch

In [51]:
train_dataset = load_dataset("json", data_files="/kaggle/working/folds/valid_0.jsonl", split="train")
train_loader = train_data_collator(rng, train_dataset, cfg.per_device_train_batch_size)



In [52]:
for batch in train_loader:
    print(batch)
    break

938
(4, 938)
{'input_ids': array([[  898,  4461, 39341, ...,     0,     0,     0],
       [21360,  1507,   712, ...,     0,     0,     0],
       [20217,  4055,  8090, ...,   100,   321,   100],
       [23259,   415,   993, ...,     0,     0,     0]]), 'mask': array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 0, 0, 0]]), 'label': array([0, 1, 1, 0])}


  batch = {"input_ids": np.array(input_ids), "mask": [np.ones_like(x) for x in input_ids], "label": np.array(labels)}


In [56]:
batch['mask'][1], batch['input_ids'][1]

(array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [57]:
import jax.profiler

In [58]:
server = jax.profiler.start_server(9999)


In [59]:
server

<jaxlib.xla_extension.profiler.ProfilerServer at 0x7f7968dbc370>

: 