In [1]:
"""Minimal training script using Jax/Flax/HF"""
import os, sys, time, json
import argparse
import logging
import importlib

from typing import Any, Callable, Dict, Optional, Tuple

import datasets
from datasets import load_dataset, load_metric

import jax
import jax.numpy as jnp
import optax

from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard

import transformers
from transformers import (
    AutoConfig,
    AutoTokenizer,
    FlaxAutoModelForSequenceClassification,
    HfArgumentParser,
    PretrainedConfig,
    TrainingArguments,
    is_tensorboard_available,
)

import numpy as np
import pandas as pd

from tqdm import tqdm
from copy import copy
from transformers.utils import check_min_version, get_full_repo_name

from itertools import chain


Array = Any
Dataset = datasets.arrow_dataset.Dataset
PRNGKey = Any

logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.20.0.dev0")

git_folder = "../"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
sys.path.append(f"{git_folder}/configs")
# sys.path.append("models")
# sys.path.append("data")

ModuleNotFoundError: No module named 'datasets'

In [2]:
cfg = copy(importlib.import_module("default_config").cfg)

# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(
    cfg.model_name_or_path,
    num_labels=cfg.num_labels,
    #finetuning_task=data_args.task_name,
    #use_auth_token=True if cfg.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
    cfg.model_name_or_path,
    use_fast=not cfg.use_slow_tokenizer,
    #use_auth_token=True if cfg.use_auth_token else None,
)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
    #cfg.model_name_or_path,
    "../outs/",
    config=config,
    #use_auth_token=True if cfg.use_auth_token else None,
)

cfg.tokenizer = tokenizer

NameError: name 'copy' is not defined

In [3]:
model.params['bert'].keys()

dict_keys(['embeddings', 'encoder', 'pooler'])

In [4]:
class TrainState(train_state.TrainState):
    """Train state with an Optax optimizer.

    The two functions below differ depending on whether the task is classification
    or regression.

    Args:
        logits_fn: Applied to last layer to obtain the logits.
        loss_fn: Function to compute the loss.
    """
    
    logits_fn: Callable = struct.field(pytree_node=False)

tx = optax.adamw(
    learning_rate=0.1, b1=0.9, b2=0.999, eps=1e-6)
state = TrainState.create(
            apply_fn=model.__call__,
            params=model.params,
            tx=tx,
            # logits_fn=lambda logits: logits.argmax(-1),
            logits_fn=lambda logits: logits, 
        )
# make sure weights are replicated on each device
state = replicate(state)

In [12]:
def test_data_collator(dataset: Dataset, batch_size: int):
    """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
    
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: np.array(v) for k, v in batch.items()}
        discourse_ids = batch.pop("discourse_id")
        batch = shard(batch)

        yield batch, discourse_ids


In [39]:
## github function idea: TODO reminder
## TODO: create a script for dataset creation for train, val, test splits
## create test dataset

batch_size = 8
test_dataset = load_dataset("json", data_files=f"/kaggle/working/folds/test.jsonl", split="train")
test_loader  = test_data_collator(test_dataset, batch_size)






In [40]:
# def predict_single(input_ids):
#     # x = jnp.array(input_ids)[jnp.newaxis, :] # for [batch_size, seq_len]
#     x = input_ids
#     x = jax.nn.softmax(model(x).logits)
#     return x

# p_predict_single = jax.vmap(predict_single, axis_name="batch")


def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_fn(logits)

p_eval_step = jax.pmap(eval_step, axis_name="batch")

In [41]:
# evaluate
preds = []
di = []
for batch, discourse_ids in tqdm(
    test_loader,
    total=len(test_dataset) // batch_size,
    desc="Evaluating ...",
    position=2,
):
    labels = batch.pop("labels")
    pred = p_eval_step(state, batch)
    preds.extend(pred)
    di.extend(discourse_ids)
    
# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(test_dataset) % batch_size

# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
    # take leftover samples
    batch = test_dataset[-num_leftover_samples:]
    batch = {k: np.array(v) for k, v in batch.items()}
    print(batch)
    discourse_ids = batch.pop("discourse_id")

    labels = batch.pop("labels")
    pred = eval_step(unreplicate(state), batch)
    preds.extend(pred)
    di.extend(discourse_ids)


[A
Evaluating ...: 100%|██████████| 1/1 [00:18<00:00, 18.43s/it]


{'discourse_id': array(['739a6d00f44a', 'bcfae2c9a244'], dtype='<U12'), 'input_ids': array([[21360, 20980,   685, ...,     0,     0,     0],
       [ 1583,  6461, 22084, ...,     0,     0,     0]]), 'attention_mask': array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0]]), 'labels': array([0, 0])}


In [46]:
expand_ids = []
for i, pred in enumerate(preds):
    if pred.shape != (1, 3):
        expand_ids.append(i)
        break
for id in expand_ids:
    preds[id] = jnp.expand_dims(preds[id], axis=0)

In [47]:
preds

[DeviceArray([[-6.8066797 ,  0.12095693,  5.2775683 ]], dtype=float32),
 DeviceArray([[-5.8036566,  4.351916 ,  1.1201755]], dtype=float32),
 DeviceArray([[-6.7667866 ,  0.11655698,  5.6036315 ]], dtype=float32),
 DeviceArray([[-6.131696 ,  3.0238857,  2.6651497]], dtype=float32),
 DeviceArray([[-2.6666145,  4.795038 , -2.7336974]], dtype=float32),
 DeviceArray([[-6.619154 ,  2.18402  ,  3.1526942]], dtype=float32),
 DeviceArray([[-6.4922547,  2.615754 ,  2.9109542]], dtype=float32),
 DeviceArray([[-6.6750355,  1.9154787,  4.0285406]], dtype=float32),
 DeviceArray([[-6.687419 ,  1.5453784,  3.5249007]], dtype=float32),
 DeviceArray([[-3.9098043,  6.0003743, -2.0506322]], dtype=float32)]

In [48]:
final_preds = jax.nn.softmax(jnp.array(preds)[:, 0, :], axis=1)

In [50]:
sample_submission = pd.read_csv("/kaggle/input/feedback-prize-effectiveness/sample_submission.csv")
sample_submission.loc[:, "discourse_id"] = di
sample_submission.loc[:, "Ineffective"] = final_preds[:, 0]
sample_submission.loc[:, "Adequate"] = final_preds[:, 1]
sample_submission.loc[:, "Effective"] = final_preds[:, 2]
sample_submission.to_csv("submission.csv", index=False)

In [51]:
sample_submission

Unnamed: 0,discourse_id,Ineffective,Adequate,Effective
0,a261b6e14276,6e-06,0.005728,0.994266
1,5a88900e7dc1,3.7e-05,0.961976,0.037987
2,9790d835736b,4e-06,0.004123,0.995873
3,75ce6d68b67b,6.2e-05,0.588698,0.41124
4,93578d946723,0.000574,0.998889,0.000537
5,2e214524dbe3,4.1e-05,0.275133,0.724825
6,84812fc2ab9f,4.7e-05,0.426711,0.573242
7,c668ff840720,2e-05,0.107831,0.892149
8,739a6d00f44a,3.2e-05,0.121366,0.878602
9,bcfae2c9a244,5e-05,0.999632,0.000319
