In [1]:
from transformers import (
    AutoTokenizer,
    TFAutoModelForQuestionAnswering,
    DefaultDataCollator,
    keras_callbacks,
)
import tensorflow as tf
from huggingface_hub import notebook_login
from question_answering.constants import constants
from question_answering.utils import core_qa_utils, extractive_qa_utils
from question_answering.paths import extractive_qa_paths

In [2]:
df_train, df_val, df_test = core_qa_utils.load_train_val_test_datasets(
    extractive_qa_paths.squad_dataset_dir
)

train_dataset, val_dataset, test_dataset = core_qa_utils.convert_dataframes_to_datasets(
    [df_train, df_val, df_test]
)

In [3]:
model_checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [4]:
def tokenize_sample(sample, max_tokens=None, padding=False):
    question = sample["question"].strip()
    context = sample["context"].strip()

    return tokenizer(question, context, max_length=max_tokens, padding=padding)


tokenized_train_dataset = train_dataset.map(tokenize_sample)
tokenized_val_dataset = val_dataset.map(tokenize_sample)
tokenized_test_dataset = test_dataset.map(tokenize_sample)

print(
    "Max number of tokens in tokenized train dataset: ",
    len(max(tokenized_train_dataset["input_ids"], key=len)),
)
print(
    "Max number of tokens in tokenized val dataset: ",
    len(max(tokenized_val_dataset["input_ids"], key=len)),
)
print(
    "Max number of tokens in tokenized test dataset: ",
    len(max(tokenized_test_dataset["input_ids"], key=len)),
)

Map:   0%|          | 0/68716 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (520 > 512). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/14724 [00:00<?, ? examples/s]

Map:   0%|          | 0/14725 [00:00<?, ? examples/s]

Max number of tokens in tokenized train dataset:  870
Max number of tokens in tokenized val dataset:  866
Max number of tokens in tokenized test dataset:  817


In [5]:
max_length = 384


def filter_samples_below_number_of_tokens(dataset, max_tokens: int):
    indices_to_remove = []

    # Find indices of samples where number of tokens exceeds max number of tokens
    for index, sample in enumerate(dataset):
        tokenized_sample = tokenize_sample(sample)
        if len(tokenized_sample["input_ids"]) > max_tokens:
            indices_to_remove.append(index)

    # Keep only samples with number of tokens less or equal than max number of tokens
    dataset_indices = range(len(dataset))
    filtered_dataset = dataset.select(
        index for index in dataset_indices if index not in set(indices_to_remove)
    )

    return filtered_dataset


filtered_train_dataset = filter_samples_below_number_of_tokens(
    train_dataset, max_tokens=max_length
)
filtered_val_dataset = filter_samples_below_number_of_tokens(
    val_dataset, max_tokens=max_length
)
filtered_test_dataset = filter_samples_below_number_of_tokens(
    test_dataset, max_tokens=max_length
)



In [6]:
print(
    "Number of samples in tokenized train dataset before filtering: ",
    len(train_dataset),
)
print("Number of samples in tokenized val dataset before filtering: ", len(val_dataset))
print(
    "Number of samples in tokenized test dataset before filtering: ", len(test_dataset)
)

print("\n---------------\n")

print(
    "Number of samples in tokenized train dataset after filtering: ",
    len(filtered_train_dataset),
)
print(
    "Number of samples in tokenized val dataset after filtering: ",
    len(filtered_val_dataset),
)
print(
    "Number of samples in tokenized test dataset after filtering: ",
    len(filtered_test_dataset),
)

Number of samples in tokenized train dataset before filtering:  68716
Number of samples in tokenized val dataset before filtering:  14724
Number of samples in tokenized test dataset before filtering:  14725

---------------

Number of samples in tokenized train dataset after filtering:  67964
Number of samples in tokenized val dataset after filtering:  14573
Number of samples in tokenized test dataset after filtering:  14552


In [7]:
def preprocess_dataset(dataset):
    questions = [q.strip() for q in dataset["question"]]
    contexts = [c.strip() for c in dataset["context"]]

    inputs = tokenizer(
        questions,
        contexts,
        max_length=max_length,
        padding="max_length",
        return_offsets_mapping=True,
    )

    offset_mapping = inputs.pop("offset_mapping")

    answer_start_indices = dataset["answer_start"]
    answer_texts = dataset["answer_text"]
    start_positions = []
    end_positions = []

    for index, offset in enumerate(offset_mapping):
        start_char = answer_start_indices[index]
        end_char = start_char + len(answer_texts[index])
        sequence_ids = inputs.sequence_ids(index)

        # Find the start and end token indices of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        idx = context_start
        while idx <= context_end and offset[idx][0] <= start_char:
            idx += 1
        start_positions.append(idx - 1)

        idx = context_end
        while idx >= context_start and offset[idx][1] >= end_char:
            idx -= 1
        end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [8]:
tokenized_train_dataset = filtered_train_dataset.map(
    preprocess_dataset,
    batched=True,
    remove_columns=filtered_train_dataset.column_names,
)
tokenized_val_dataset = filtered_val_dataset.map(
    preprocess_dataset,
    batched=True,
    remove_columns=filtered_val_dataset.column_names,
)
tokenized_test_dataset = filtered_test_dataset.map(
    preprocess_dataset,
    batched=True,
    remove_columns=filtered_test_dataset.column_names,
)

Map:   0%|          | 0/67964 [00:00<?, ? examples/s]

Map:   0%|          | 0/14573 [00:00<?, ? examples/s]

Map:   0%|          | 0/14552 [00:00<?, ? examples/s]

In [9]:
print(
    f"All tokenized train dataset entries have {max_length} tokens: ",
    all(
        [
            len(input_ids) == max_length
            for input_ids in tokenized_train_dataset["input_ids"]
        ]
    ),
)
print(
    f"All tokenized val dataset entries have {max_length} tokens: ",
    all(
        [
            len(input_ids) == max_length
            for input_ids in tokenized_val_dataset["input_ids"]
        ]
    ),
)
print(
    f"All tokenized test dataset entries have {max_length} tokens: ",
    all(
        [
            len(input_ids) == max_length
            for input_ids in tokenized_test_dataset["input_ids"]
        ]
    ),
)

All tokenized train dataset entries have 384 tokens:  True
All tokenized val dataset entries have 384 tokens:  True
All tokenized test dataset entries have 384 tokens:  True


In [10]:
# Model parameters
training_number = 1

model_name = "squad-bert-uncased"
full_model_name = f"{model_name}-{training_number}"

# Checkpoints
checkpoint_filename_template = constants.checkpoint_filename_template
checkpoints_path = (
    extractive_qa_paths.training_checkpoints_dir
    / full_model_name
    / checkpoint_filename_template
)

# Hub
hub_path = extractive_qa_paths.hub_models_location / full_model_name

# Saved models
saved_models_path = extractive_qa_paths.saved_models_dir / full_model_name

# Figures
figures_dir = extractive_qa_paths.figures_dir / full_model_name

# Hyperparameters
batch_size = 8
train_epochs = 10

In [11]:
# Load model for fine-tuning
model = TFAutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

All PyTorch model weights were used when initializing TFBertForQuestionAnswering.

Some weights or buffers of the TF 2.0 model TFBertForQuestionAnswering were not initialized from the PyTorch model and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
# Dataset preparation
data_collator = DefaultDataCollator(return_tensors="tf")

tf_train_dataset = core_qa_utils.convert_to_tf_dataset(
    hf_dataset=tokenized_train_dataset,
    columns=["input_ids", "token_type_ids", "attention_mask"],
    label_cols=["start_positions", "end_positions"],
    collator=data_collator,
    batch_size=batch_size,
)

tf_val_dataset = core_qa_utils.convert_to_tf_dataset(
    hf_dataset=tokenized_val_dataset,
    columns=["input_ids", "token_type_ids", "attention_mask"],
    label_cols=["start_positions", "end_positions"],
    collator=data_collator,
    batch_size=batch_size,
)

tf_test_dataset = core_qa_utils.convert_to_tf_dataset(
    hf_dataset=tokenized_test_dataset,
    columns=["input_ids", "token_type_ids", "attention_mask"],
    label_cols=["start_positions", "end_positions"],
    collator=data_collator,
    batch_size=batch_size,
)

In [13]:
# Login to hugging face hub in order to store the model there
# notebook_login()

In [14]:
# Callbacks
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoints_path, verbose=1, save_weights_only=True
)
early_stop_cb = tf.keras.callbacks.EarlyStopping(patience=1)
# push_to_hub = keras_callbacks.PushToHubCallback(
#     output_dir=full_model_name, tokenizer=tokenizer
# )

callbacks = [
    checkpoint_cb, 
    early_stop_cb, 
    # push_to_hub
]

In [15]:
# Compile
num_train_steps = len(tf_train_dataset) * train_epochs

lr_scheduler = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=2e-5, end_learning_rate=0.0, decay_steps=num_train_steps
)

optimizer = tf.keras.optimizers.Adam(learning_rate=lr_scheduler)

# Train in mixed-precision float16
tf.keras.mixed_precision.set_global_policy("mixed_float16")

# Compile
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = ["accuracy"]
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA GeForce RTX 3070 Ti, compute capability 8.6


In [16]:
model.summary()

Model: "tf_bert_for_question_answering"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 bert (TFBertMainLayer)      multiple                  108891648 
                                                                 
 qa_outputs (Dense)          multiple                  1538      
                                                                 
Total params: 108,893,186
Trainable params: 108,893,186
Non-trainable params: 0
_________________________________________________________________


In [17]:
# Fit the model on the new data
history = model.fit(
    tf_train_dataset,
    validation_data=tf_val_dataset,
    epochs=train_epochs,
    callbacks=callbacks,
)

Epoch 1/10


ResourceExhaustedError: Graph execution error:

Detected at node 'tf_bert_for_question_answering/bert/encoder/layer_._9/intermediate/Gelu/mul_1' defined at (most recent call last):
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\traitlets\config\application.py", line 1053, in launch_instance
      app.start()
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\ipykernel\kernelapp.py", line 737, in start
      self.io_loop.start()
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\tornado\platform\asyncio.py", line 195, in start
      self.asyncio_loop.run_forever()
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\asyncio\base_events.py", line 603, in run_forever
      self._run_once()
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\asyncio\base_events.py", line 1909, in _run_once
      handle._run()
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\ipykernel\kernelbase.py", line 524, in dispatch_queue
      await self.process_one()
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\ipykernel\kernelbase.py", line 513, in process_one
      await dispatch(*args)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\ipykernel\kernelbase.py", line 418, in dispatch_shell
      await result
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\ipykernel\kernelbase.py", line 758, in execute_request
      reply_content = await reply_content
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\ipykernel\ipkernel.py", line 426, in do_execute
      res = shell.run_cell(
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\IPython\core\interactiveshell.py", line 3024, in run_cell
      result = self._run_cell(
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\IPython\core\interactiveshell.py", line 3079, in _run_cell
      result = runner(coro)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\IPython\core\interactiveshell.py", line 3284, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\IPython\core\interactiveshell.py", line 3466, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\IPython\core\interactiveshell.py", line 3526, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\Karol\AppData\Local\Temp\ipykernel_4392\234921204.py", line 2, in <module>
      history = model.fit(
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\engine\training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\engine\training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\engine\training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\engine\training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\transformers\modeling_tf_utils.py", line 1638, in train_step
      y_pred = self(x, training=True)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\transformers\modeling_tf_utils.py", line 1833, in run_call_with_unpacked_inputs
      from .modelcard import TrainingSummary  # tests_ignore
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\transformers\models\bert\modeling_tf_bert.py", line 1852, in call
      outputs = self.bert(
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\transformers\modeling_tf_utils.py", line 1833, in run_call_with_unpacked_inputs
      from .modelcard import TrainingSummary  # tests_ignore
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\transformers\models\bert\modeling_tf_bert.py", line 862, in call
      encoder_outputs = self.encoder(
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\transformers\models\bert\modeling_tf_bert.py", line 548, in call
      for i, layer_module in enumerate(self.layer):
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\transformers\models\bert\modeling_tf_bert.py", line 554, in call
      layer_outputs = layer_module(
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\transformers\models\bert\modeling_tf_bert.py", line 510, in call
      intermediate_output = self.intermediate(hidden_states=attention_output)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\transformers\models\bert\modeling_tf_bert.py", line 414, in call
      hidden_states = self.intermediate_act_fn(hidden_states)
    File "C:\Users\Karol\miniconda3\envs\question_answering\lib\site-packages\keras\activations.py", line 359, in gelu
      return tf.nn.gelu(x, approximate)
Node: 'tf_bert_for_question_answering/bert/encoder/layer_._9/intermediate/Gelu/mul_1'
failed to allocate memory
	 [[{{node tf_bert_for_question_answering/bert/encoder/layer_._9/intermediate/Gelu/mul_1}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_24626]

In [None]:
# Get best version of the model
best_model = core_qa_utils.get_best_model_from_checkpoints(
    model, history, model_name=full_model_name
)

In [None]:
# Save best model's weights
extractive_qa_utils.save_model(best_model, model_name=full_model_name)

In [None]:
# Load best model
loaded_best_model = extractive_qa_utils.load_weights_into_model(
    best_model, model_name=full_model_name
)

In [None]:
# Get predictions from best_model
best_model_preds = best_model.evaluate(tf_test_dataset)

In [None]:
loaded_best_model_preds = loaded_best_model.evaluate(tf_test_dataset)

In [None]:
best_model_preds

In [None]:
loaded_best_model_preds

In [34]:
from transformers import pipeline

# Replace this with your own checkpoint
question_answerer = pipeline("question-answering", model="nlp-polish/squad-bert-1")

context = """
🤗 Transformers is backed by the three most popular deep learning libraries — Jax, PyTorch and TensorFlow — with a seamless integration
between them. It's straightforward to train your models with one before loading them for inference with the other.
"""
question = "Which deep learning libraries back 🤗 Transformers?"
question_answerer(question=question, context=context)

Downloading (…)lve/main/config.json:   0%|          | 0.00/645 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading tf_model.h5:   0%|          | 0.00/436M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFBertForQuestionAnswering.

All the layers of TFBertForQuestionAnswering were initialized from the model checkpoint at nlp-polish/squad-bert-1.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForQuestionAnswering for predictions without further training.


Downloading (…)okenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

{'score': 0.984375,
 'start': 78,
 'end': 105,
 'answer': 'Jax, PyTorch and TensorFlow'}

In [None]:
model_utils.plot_and_save_fig_from_history(
    history,
    attributes=["accuracy", "val_accuracy"],
    title="Model accuracy",
    y_label="Accuracy",
    x_label="Epoch",
    legend_descriptors=["Train", "Val"],
    figure_dir_path=figures_dir,
    figure_filename=f"{training_name}_accuracy.png",
)

model_utils.plot_and_save_fig_from_history(
    history,
    attributes=["loss", "val_loss"],
    title="Model loss",
    y_label="Loss",
    x_label="Epoch",
    legend_descriptors=["Train", "Val"],
    figure_dir_path=figures_dir,
    figure_filename=f"{training_name}_loss.png",
)

In [None]:
best_model = model_utils.get_best_model_from_checkpoints(
    model,
    history,
    checkpoints_dir=checkpoints_dir,
    checkpoint_filename_template=checkpoint_filename_template,
)

In [None]:
model_utils.save_model(
    best_model,
    model_name=model_name,
    training_number=training_number,
    saved_models_dir=constants.SAVED_MODEL_LOCATION,
    default_model_version=constants.DEFAULT_MODEL_VERSION,
)

In [None]:
# Evaluation
best_model.evaluate(tf_test_dataset, batch_size=batch_size)

In [None]:
class_preds = model_utils.get_class_preds(model, tf_test_dataset)

In [None]:
precision, recall, f1 = model_utils.get_classification_evaluation_metrics(
    class_actual=tokenized_test_dataset["emotions"],
    class_preds=class_preds,
    average="micro",
)

print(f"Precision score: ", precision)
print(f"Recall score: ", recall)
print(f"F1 score: ", f1)

In [None]:
model_utils.print_incorrectly_predicted_texts(
    texts=raw_dataset["text_pl"],
    class_actual=raw_dataset["emotions"],
    class_preds=class_preds,
)