In [1]:
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, TFAutoModel
from tensorflow import keras
from tensorflow.keras import layers, losses, optimizers, metrics, regularizers
import pandas as pd
import numpy as np
import tensorflow as tf
import sys

  from .autonotebook import tqdm as notebook_tqdm
2025-01-30 01:24:47.295402: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738218287.306451  101502 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738218287.309823  101502 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-30 01:24:47.321036: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
train = pd.read_csv("data/train.csv").dropna()
train, test = train_test_split(train, test_size=0.15)
print(train.info())
print(test.info())
# TODO see if selected_text needs to be fixed to word boundaries

<class 'pandas.core.frame.DataFrame'>
Index: 23358 entries, 16039 to 5268
Data columns (total 4 columns):
 #   Column         Non-Null Count  Dtype 
---  ------         --------------  ----- 
 0   textID         23358 non-null  object
 1   text           23358 non-null  object
 2   selected_text  23358 non-null  object
 3   sentiment      23358 non-null  object
dtypes: object(4)
memory usage: 912.4+ KB
None
<class 'pandas.core.frame.DataFrame'>
Index: 4122 entries, 7068 to 18854
Data columns (total 4 columns):
 #   Column         Non-Null Count  Dtype 
---  ------         --------------  ----- 
 0   textID         4122 non-null   object
 1   text           4122 non-null   object
 2   selected_text  4122 non-null   object
 3   sentiment      4122 non-null   object
dtypes: object(4)
memory usage: 161.0+ KB
None


In [3]:
max_text_len = 128
inference_batch_size = 2000

def tokenize(texts, padding=True):
    padding = "max_length" if padding else "longest"
    return tokenizer(texts, padding=padding, max_length=max_text_len, return_tensors="tf")

def detokenize(ids, skip_special=True):
    return tokenizer.batch_decode(ids, skip_special_tokens=skip_special)

# https://stackoverflow.com/a/7100681
def rolling_window(a, size):
    shape = a.shape[:-1] + (a.shape[-1] - size + 1, size)
    strides = a.strides + (a. strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

def find_span(texts, spans):
    texts = np.array(texts)
    spans = np.array(spans)[:, 1:]
    sizes = spans.argmin(axis=1) - 1
    span_ranges = np.zeros((texts.shape[0], 2))
    for i, text in enumerate(texts):
        window_size = sizes[i]
        matched_window = rolling_window(texts[i], window_size) == spans[i][:window_size]
        window_from = matched_window.all(axis=1).argmax()
        span_ranges[i, 0] = window_from
        span_ranges[i, 1] = window_from + window_size - 1
    return span_ranges

def eprint(*args, **kwargs):
    print(*args, **kwargs, file=sys.stderr)

def print_mem_info():
    info = tf.config.experimental.get_memory_info("/gpu:0")
    eprint(f"GPU memory usage {(info["current"] / info["peak"]):.2f}.")

def context_embeddings(texts, sentiments, selected_texts=None):
    text_tokens = tokenize(texts)
    text_embeddings = encoder(**text_tokens).last_hidden_state
    sentiment_tokens = tokenize(sentiments)
    sentiment_embeddings = encoder(**sentiment_tokens).last_hidden_state[:, 1:2, :]  # only need embedding for one word
    embeddings = tf.concat((text_embeddings, sentiment_embeddings), 1)
    if selected_texts is None:
        return embeddings
    else:
        selected_tokens = tokenize(selected_texts)
        targets = find_span(text_tokens["input_ids"], selected_tokens["input_ids"])
        return embeddings, targets

def dataset_generator(texts, sentiments, selected_texts=None):
    def gen():
        total = len(texts)
        yielded = 0
        while yielded < total:
            batch_i = yielded % inference_batch_size
            # batch inferencing
            if batch_i == 0:
                end = min(total, yielded + inference_batch_size)
                if selected_texts is None:
                    embeddings = context_embeddings(texts[yielded:end], sentiments[yielded:end])
                else:
                    embeddings, targets = context_embeddings(texts[yielded:end], sentiments[yielded:end], selected_texts[yielded:end])
            # feed the generator
            if selected_texts is None:
                yield embeddings[batch_i]
            else:
                yield embeddings[batch_i], (targets[batch_i, 0], targets[batch_i, 1])
            yielded += 1
    return gen

Naive implementation, flattened embeddings into FFN with one hidden layer.

In [4]:
inputs = keras.Input(shape=(1 + max_text_len, 768))
flat = layers.Flatten()(inputs)
hidden = layers.Dense(256, activation="relu", kernel_regularizer="l1l2", bias_regularizer="l2")(flat)
softmax_start = layers.Dense(max_text_len, activation="softmax", name="start", kernel_regularizer="l1l2", bias_regularizer="l2")(hidden)
softmax_end = layers.Dense(max_text_len, activation="softmax", name="end", kernel_regularizer="l1l2", bias_regularizer="l2")(hidden)
ffn = keras.Model(inputs=inputs, outputs=(softmax_start, softmax_end))
ffn.compile(
    optimizer=optimizers.Adam(),
    loss={"start": losses.SparseCategoricalCrossentropy(), "end": losses.SparseCategoricalCrossentropy()},
    metrics={"start": metrics.SparseCategoricalAccuracy(), "end": metrics.SparseCategoricalAccuracy()}
)

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
encoder = TFAutoModel.from_pretrained("distilbert-base-uncased")
dataset_train = tf.data.Dataset.from_generator(
    dataset_generator(train.text.to_list(), train.sentiment.to_list(), train.selected_text.to_list()),
    output_signature=(tf.TensorSpec(shape=(1 + max_text_len, 768)), (tf.TensorSpec(shape=()), tf.TensorSpec(shape=())))
).shuffle(inference_batch_size)
dataset_test = tf.data.Dataset.from_generator(
    dataset_generator(test.text.to_list(), test.sentiment.to_list(), test.selected_text.to_list()),
    output_signature=(tf.TensorSpec(shape=(1 + max_text_len, 768)), (tf.TensorSpec(shape=()), tf.TensorSpec(shape=())))
)
history = ffn.fit(dataset_train.batch(64), epochs=10, validation_data=dataset_test.batch(64))

I0000 00:00:1738218288.983263  101502 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 21856 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertModel: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing TFDistilBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFDistilBertModel were initialized from the PyTorch model.

Epoch 1/10


I0000 00:00:1738218298.552538  101651 service.cc:148] XLA service 0x7ff380003ff0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1738218298.552553  101651 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
2025-01-30 01:24:58.585521: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1738218298.664574  101651 cuda_dnn.cc:529] Loaded cuDNN version 90300




      5/Unknown [1m10s[0m 39ms/step - end_loss: 10.5816 - end_sparse_categorical_accuracy: 0.0456 - loss: 24.2729 - start_loss: 13.6913 - start_sparse_categorical_accuracy: 0.2851

I0000 00:00:1738218300.128642  101651 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


    363/Unknown [1m52s[0m 116ms/step - end_loss: 3.9267 - end_sparse_categorical_accuracy: 0.3180 - loss: 6.9807 - start_loss: 3.0533 - start_sparse_categorical_accuracy: 0.5224  







    365/Unknown [1m53s[0m 120ms/step - end_loss: 3.9196 - end_sparse_categorical_accuracy: 0.3186 - loss: 6.9683 - start_loss: 3.0480 - start_sparse_categorical_accuracy: 0.5225

2025-01-30 01:25:43.777815: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2025-01-30 01:25:43.777838: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_2]]
2025-01-30 01:25:43.777855: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 147ms/step - end_loss: 3.9161 - end_sparse_categorical_accuracy: 0.3189 - loss: 6.9621 - start_loss: 3.0453 - start_sparse_categorical_accuracy: 0.5226 - val_end_loss: 1.9578 - val_end_sparse_categorical_accuracy: 0.5170 - val_loss: 3.6125 - val_start_loss: 1.6620 - val_start_sparse_categorical_accuracy: 0.5665
Epoch 2/10


2025-01-30 01:25:53.781832: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_6]]
2025-01-30 01:25:53.781855: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:25:53.781860: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:25:53.781868: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m363/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 118ms/step - end_loss: 1.6742 - end_sparse_categorical_accuracy: 0.5590 - loss: 3.1611 - start_loss: 1.4863 - start_sparse_categorical_accuracy: 0.5839 

2025-01-30 01:26:43.858907: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 141ms/step - end_loss: 1.6738 - end_sparse_categorical_accuracy: 0.5590 - loss: 3.1603 - start_loss: 1.4860 - start_sparse_categorical_accuracy: 0.5840 - val_end_loss: 1.8087 - val_end_sparse_categorical_accuracy: 0.5335 - val_loss: 3.3831 - val_start_loss: 1.5803 - val_start_sparse_categorical_accuracy: 0.5558
Epoch 3/10


2025-01-30 01:26:52.549803: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_6]]
2025-01-30 01:26:52.549835: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:26:52.549855: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:26:52.549863: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m361/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 119ms/step - end_loss: 1.3357 - end_sparse_categorical_accuracy: 0.6209 - loss: 2.6218 - start_loss: 1.2857 - start_sparse_categorical_accuracy: 0.6135 

2025-01-30 01:27:42.831833: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 142ms/step - end_loss: 1.3353 - end_sparse_categorical_accuracy: 0.6210 - loss: 2.6212 - start_loss: 1.2855 - start_sparse_categorical_accuracy: 0.6136 - val_end_loss: 1.9160 - val_end_sparse_categorical_accuracy: 0.5296 - val_loss: 3.5366 - val_start_loss: 1.6297 - val_start_sparse_categorical_accuracy: 0.5682
Epoch 4/10


2025-01-30 01:27:51.678702: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:27:51.678734: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:27:51.678744: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 141ms/step - end_loss: 1.1291 - end_sparse_categorical_accuracy: 0.6659 - loss: 2.2797 - start_loss: 1.1502 - start_sparse_categorical_accuracy: 0.6457 - val_end_loss: 1.9340 - val_end_sparse_categorical_accuracy: 0.5199 - val_loss: 3.5473 - val_start_loss: 1.6229 - val_start_sparse_categorical_accuracy: 0.5706
Epoch 5/10


2025-01-30 01:28:50.340755: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_6]]
2025-01-30 01:28:50.340786: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:28:50.340792: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:28:50.340800: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m363/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 118ms/step - end_loss: 0.9226 - end_sparse_categorical_accuracy: 0.7157 - loss: 1.9350 - start_loss: 1.0120 - start_sparse_categorical_accuracy: 0.6862 

2025-01-30 01:29:40.381680: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:29:40.381705: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:29:40.381717: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 141ms/step - end_loss: 0.9224 - end_sparse_categorical_accuracy: 0.7158 - loss: 1.9345 - start_loss: 1.0118 - start_sparse_categorical_accuracy: 0.6863 - val_end_loss: 2.1443 - val_end_sparse_categorical_accuracy: 0.5143 - val_loss: 3.8979 - val_start_loss: 1.7674 - val_start_sparse_categorical_accuracy: 0.5657
Epoch 6/10


2025-01-30 01:29:49.133353: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:29:49.133394: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299


[1m360/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 120ms/step - end_loss: 0.7712 - end_sparse_categorical_accuracy: 0.7558 - loss: 1.6593 - start_loss: 0.8877 - start_sparse_categorical_accuracy: 0.7176 

2025-01-30 01:30:39.476734: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:30:39.476764: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:30:39.476774: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 142ms/step - end_loss: 0.7709 - end_sparse_categorical_accuracy: 0.7559 - loss: 1.6587 - start_loss: 0.8874 - start_sparse_categorical_accuracy: 0.7176 - val_end_loss: 2.1539 - val_end_sparse_categorical_accuracy: 0.5206 - val_loss: 3.9436 - val_start_loss: 1.8070 - val_start_sparse_categorical_accuracy: 0.5585
Epoch 7/10


2025-01-30 01:30:48.305963: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:30:48.305994: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:30:48.306004: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m361/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 118ms/step - end_loss: 0.6340 - end_sparse_categorical_accuracy: 0.7985 - loss: 1.3963 - start_loss: 0.7619 - start_sparse_categorical_accuracy: 0.7524 

2025-01-30 01:31:38.415509: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:31:38.415535: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:31:38.415544: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 142ms/step - end_loss: 0.6336 - end_sparse_categorical_accuracy: 0.7986 - loss: 1.3956 - start_loss: 0.7616 - start_sparse_categorical_accuracy: 0.7525 - val_end_loss: 2.3696 - val_end_sparse_categorical_accuracy: 0.5007 - val_loss: 4.1680 - val_start_loss: 1.8229 - val_start_sparse_categorical_accuracy: 0.5536
Epoch 8/10


2025-01-30 01:31:47.275073: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:31:47.275135: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:31:47.275145: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m360/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 119ms/step - end_loss: 0.5322 - end_sparse_categorical_accuracy: 0.8252 - loss: 1.1813 - start_loss: 0.6487 - start_sparse_categorical_accuracy: 0.7881 

2025-01-30 01:32:37.475395: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:32:37.475421: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:32:37.475431: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 141ms/step - end_loss: 0.5319 - end_sparse_categorical_accuracy: 0.8254 - loss: 1.1808 - start_loss: 0.6485 - start_sparse_categorical_accuracy: 0.7881 - val_end_loss: 2.4326 - val_end_sparse_categorical_accuracy: 0.5126 - val_loss: 4.3694 - val_start_loss: 1.9582 - val_start_sparse_categorical_accuracy: 0.5376
Epoch 9/10


2025-01-30 01:32:46.179293: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_6]]
2025-01-30 01:32:46.179344: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:32:46.179350: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:32:46.179358: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m364/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 117ms/step - end_loss: 0.4271 - end_sparse_categorical_accuracy: 0.8630 - loss: 0.9822 - start_loss: 0.5547 - start_sparse_categorical_accuracy: 0.8180 

2025-01-30 01:33:35.929406: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:33:35.929435: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:33:35.929445: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 141ms/step - end_loss: 0.4270 - end_sparse_categorical_accuracy: 0.8630 - loss: 0.9821 - start_loss: 0.5547 - start_sparse_categorical_accuracy: 0.8180 - val_end_loss: 2.8005 - val_end_sparse_categorical_accuracy: 0.5017 - val_loss: 4.9659 - val_start_loss: 2.1900 - val_start_sparse_categorical_accuracy: 0.5687
Epoch 10/10


2025-01-30 01:33:44.689057: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:33:44.689120: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:33:44.689129: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 140ms/step - end_loss: 0.3691 - end_sparse_categorical_accuracy: 0.8794 - loss: 0.8334 - start_loss: 0.4640 - start_sparse_categorical_accuracy: 0.8463 - val_end_loss: 2.9620 - val_end_sparse_categorical_accuracy: 0.5087 - val_loss: 5.2555 - val_start_loss: 2.3315 - val_start_sparse_categorical_accuracy: 0.5497


2025-01-30 01:34:43.031298: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 16646503622005204375
2025-01-30 01:34:43.031330: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 8541295726000041299
2025-01-30 01:34:43.031340: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 2868698929229645346


In [10]:
# https://www.kaggle.com/competitions/tweet-sentiment-extraction
def jaccard(str1, str2):
    a = set(str1.lower().split())
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

score = 0
total = test.shape[0]
(y_start, y_end) = ffn.predict(dataset_test.batch(64))
y_start = tf.math.argmax(y_start, axis=1)
y_end = tf.math.argmax(y_end, axis=1)
for i in range(total):
    span_start = y_start[i]
    span_end = y_end[i]
    y_str = test.text.iloc[i][span_start:span_end]
    t_str = test.selected_text.iloc[i]
    score += 1 / total * jaccard(y_str, t_str)
score

[1m65/65[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 88ms/step


0.08917177226924795