In [1]:
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, TFAutoModel
from tensorflow import keras
from tensorflow.keras import layers, losses, optimizers, metrics, regularizers
import pandas as pd
import numpy as np
import tensorflow as tf
import sys

  from .autonotebook import tqdm as notebook_tqdm
2025-02-01 20:20:46.863746: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738459246.907526   15750 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738459246.919697   15750 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-01 20:20:46.990507: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
train = pd.read_csv("data/train.csv").dropna()
train, test = train_test_split(train, test_size=0.15)
print(train.info())
print(test.info())

<class 'pandas.core.frame.DataFrame'>
Index: 23358 entries, 6519 to 10370
Data columns (total 4 columns):
 #   Column         Non-Null Count  Dtype 
---  ------         --------------  ----- 
 0   textID         23358 non-null  object
 1   text           23358 non-null  object
 2   selected_text  23358 non-null  object
 3   sentiment      23358 non-null  object
dtypes: object(4)
memory usage: 912.4+ KB
None
<class 'pandas.core.frame.DataFrame'>
Index: 4122 entries, 23281 to 4700
Data columns (total 4 columns):
 #   Column         Non-Null Count  Dtype 
---  ------         --------------  ----- 
 0   textID         4122 non-null   object
 1   text           4122 non-null   object
 2   selected_text  4122 non-null   object
 3   sentiment      4122 non-null   object
dtypes: object(4)
memory usage: 161.0+ KB
None


In [3]:
max_text_len = 128
inference_batch_size = 2000

def tokenize(texts, padding=True):
    padding = "max_length" if padding else "longest"
    return tokenizer(texts, padding=padding, max_length=max_text_len, return_tensors="tf")

def detokenize(ids, skip_special=True):
    return tokenizer.batch_decode(ids, skip_special_tokens=skip_special)

# https://stackoverflow.com/a/7100681
def rolling_window(a, size):
    shape = a.shape[:-1] + (a.shape[-1] - size + 1, size)
    strides = a.strides + (a. strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

def find_span(texts, spans, word_boundary=False):
    texts = np.array(texts)
    spans = np.array(spans)[:, 1:]
    sizes = spans.argmin(axis=1) - 1
    span_ranges = np.zeros((texts.shape[0], 2))
    for i, text in enumerate(texts):
        skips = 0
        window_size = sizes[i]
        if word_boundary:
            allowed = set(texts[i])
            if spans[i, 0] not in allowed and window_size > 1:
                skips = 1
            if spans[i, -1] not in allowed and window_size - skips > 1:
                window_size -= 1
        matched_window = rolling_window(texts[i], window_size - skips) == spans[i][skips:window_size]
        window_from = matched_window.all(axis=1).argmax()
        span_ranges[i, 0] = window_from
        span_ranges[i, 1] = window_from + window_size - 1
    return span_ranges

def eprint(*args, **kwargs):
    print(*args, **kwargs, file=sys.stderr)

def print_mem_info():
    info = tf.config.experimental.get_memory_info("/gpu:0")
    eprint(f"GPU memory usage {(info["current"] / info["peak"]):.2f}.")

def context_embeddings(texts, sentiments, selected_texts=None):
    text_tokens = tokenize(texts)
    text_embeddings = encoder(**text_tokens).last_hidden_state
    sentiment_tokens = tokenize(sentiments)
    sentiment_embeddings = encoder(**sentiment_tokens).last_hidden_state[:, 1:2, :]  # only need embedding for one word
    embeddings = tf.concat((text_embeddings, sentiment_embeddings), 1)
    if selected_texts is None:
        return embeddings
    else:
        selected_tokens = tokenize(selected_texts)
        targets = find_span(text_tokens["input_ids"], selected_tokens["input_ids"], word_boundary=True)
        return embeddings, targets

def dataset_generator(texts, sentiments, selected_texts=None):
    def gen():
        total = len(texts)
        yielded = 0
        while yielded < total:
            batch_i = yielded % inference_batch_size
            # batch inferencing
            if batch_i == 0:
                end = min(total, yielded + inference_batch_size)
                if selected_texts is None:
                    embeddings = context_embeddings(texts[yielded:end], sentiments[yielded:end])
                else:
                    embeddings, targets = context_embeddings(texts[yielded:end], sentiments[yielded:end], selected_texts[yielded:end])
            # feed the generator
            if selected_texts is None:
                yield embeddings[batch_i]
            else:
                yield embeddings[batch_i], (targets[batch_i, 0], targets[batch_i, 1])
            yielded += 1
    return gen

Naive implementation, flattened embeddings into FFN with one hidden layer.

In [4]:
inputs = keras.Input(shape=(1 + max_text_len, 768))
flat = layers.Flatten()(inputs)
hidden = layers.Dense(256, activation="relu", kernel_regularizer="l1l2", bias_regularizer="l2")(flat)
softmax_start = layers.Dense(max_text_len, activation="softmax", name="start", kernel_regularizer="l1l2", bias_regularizer="l2")(hidden)
softmax_end = layers.Dense(max_text_len, activation="softmax", name="end", kernel_regularizer="l1l2", bias_regularizer="l2")(hidden)
ffn = keras.Model(inputs=inputs, outputs=(softmax_start, softmax_end))
ffn.compile(
    optimizer=optimizers.Adam(),
    loss={"start": losses.SparseCategoricalCrossentropy(), "end": losses.SparseCategoricalCrossentropy()},
    metrics={"start": metrics.SparseCategoricalAccuracy(), "end": metrics.SparseCategoricalAccuracy()}
)

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
encoder = TFAutoModel.from_pretrained("distilbert-base-uncased")
dataset_train = tf.data.Dataset.from_generator(
    dataset_generator(train.text.to_list(), train.sentiment.to_list(), train.selected_text.to_list()),
    output_signature=(tf.TensorSpec(shape=(1 + max_text_len, 768)), (tf.TensorSpec(shape=()), tf.TensorSpec(shape=())))
).shuffle(inference_batch_size)
dataset_test = tf.data.Dataset.from_generator(
    dataset_generator(test.text.to_list(), test.sentiment.to_list(), test.selected_text.to_list()),
    output_signature=(tf.TensorSpec(shape=(1 + max_text_len, 768)), (tf.TensorSpec(shape=()), tf.TensorSpec(shape=())))
)
history = ffn.fit(dataset_train.batch(64), epochs=10, validation_data=dataset_test.batch(64))

I0000 00:00:1738459249.368812   15750 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 21911 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing TFDistilBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFDistilBertModel were initialized from the PyTorch model.

Epoch 1/10


I0000 00:00:1738459259.236630   16042 service.cc:148] XLA service 0x7f0cd0002dd0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1738459259.236648   16042 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
2025-02-01 20:20:59.270069: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1738459259.349742   16042 cuda_dnn.cc:529] Loaded cuDNN version 90300




      4/Unknown [1m10s[0m 39ms/step - end_loss: 10.3706 - end_sparse_categorical_accuracy: 0.0085 - loss: 23.0635 - start_loss: 12.6929 - start_sparse_categorical_accuracy: 0.2949    

I0000 00:00:1738459260.774809   16042 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


    364/Unknown [1m53s[0m 117ms/step - end_loss: 4.0788 - end_sparse_categorical_accuracy: 0.3212 - loss: 7.4185 - start_loss: 3.3391 - start_sparse_categorical_accuracy: 0.5105  







    365/Unknown [1m54s[0m 121ms/step - end_loss: 4.0750 - end_sparse_categorical_accuracy: 0.3215 - loss: 7.4114 - start_loss: 3.3360 - start_sparse_categorical_accuracy: 0.5105

2025-02-01 20:21:44.980829: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2025-02-01 20:21:44.980857: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_6]]
2025-02-01 20:21:44.980868: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:21:44.980876: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594
2025-02-01 20:21:44.980887: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 149ms/step - end_loss: 4.0711 - end_sparse_categorical_accuracy: 0.3218 - loss: 7.4045 - start_loss: 3.3328 - start_sparse_categorical_accuracy: 0.5106 - val_end_loss: 2.0217 - val_end_sparse_categorical_accuracy: 0.4934 - val_loss: 3.7015 - val_start_loss: 1.6762 - val_start_sparse_categorical_accuracy: 0.5657
Epoch 2/10


2025-02-01 20:21:54.843564: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_6]]
2025-02-01 20:21:54.843588: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:21:54.843593: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440
2025-02-01 20:21:54.843600: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594


[1m362/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 121ms/step - end_loss: 1.6510 - end_sparse_categorical_accuracy: 0.5615 - loss: 3.1341 - start_loss: 1.4825 - start_sparse_categorical_accuracy: 0.5880 

2025-02-01 20:22:45.743280: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:22:45.743338: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594
2025-02-01 20:22:45.743347: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 143ms/step - end_loss: 1.6503 - end_sparse_categorical_accuracy: 0.5616 - loss: 3.1329 - start_loss: 1.4820 - start_sparse_categorical_accuracy: 0.5881 - val_end_loss: 1.8490 - val_end_sparse_categorical_accuracy: 0.5410 - val_loss: 3.4755 - val_start_loss: 1.6231 - val_start_sparse_categorical_accuracy: 0.5784
Epoch 3/10


2025-02-01 20:22:54.406721: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_6]]
2025-02-01 20:22:54.406756: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:22:54.406762: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440
2025-02-01 20:22:54.406770: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594


[1m362/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 120ms/step - end_loss: 1.2980 - end_sparse_categorical_accuracy: 0.6282 - loss: 2.5372 - start_loss: 1.2388 - start_sparse_categorical_accuracy: 0.6366 

2025-02-01 20:23:45.209290: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:23:45.209317: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594
2025-02-01 20:23:45.209327: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 144ms/step - end_loss: 1.2976 - end_sparse_categorical_accuracy: 0.6283 - loss: 2.5366 - start_loss: 1.2385 - start_sparse_categorical_accuracy: 0.6366 - val_end_loss: 1.8518 - val_end_sparse_categorical_accuracy: 0.5294 - val_loss: 3.4875 - val_start_loss: 1.6346 - val_start_sparse_categorical_accuracy: 0.5708
Epoch 4/10


2025-02-01 20:23:54.248079: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:23:54.248108: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440
2025-02-01 20:23:54.248117: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594


[1m362/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 121ms/step - end_loss: 1.0379 - end_sparse_categorical_accuracy: 0.6892 - loss: 2.0867 - start_loss: 1.0484 - start_sparse_categorical_accuracy: 0.6823 

2025-02-01 20:24:45.480353: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:24:45.480379: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594
2025-02-01 20:24:45.480388: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 144ms/step - end_loss: 1.0375 - end_sparse_categorical_accuracy: 0.6893 - loss: 2.0860 - start_loss: 1.0482 - start_sparse_categorical_accuracy: 0.6824 - val_end_loss: 1.8715 - val_end_sparse_categorical_accuracy: 0.5473 - val_loss: 3.5209 - val_start_loss: 1.6471 - val_start_sparse_categorical_accuracy: 0.5667
Epoch 5/10


2025-02-01 20:24:54.344282: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_6]]
2025-02-01 20:24:54.344342: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:24:54.344351: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440
2025-02-01 20:24:54.344362: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594


[1m362/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 123ms/step - end_loss: 0.8402 - end_sparse_categorical_accuracy: 0.7443 - loss: 1.7264 - start_loss: 0.8859 - start_sparse_categorical_accuracy: 0.7281 

2025-02-01 20:25:46.472161: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:25:46.472191: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594
2025-02-01 20:25:46.472201: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m61s[0m 147ms/step - end_loss: 0.8399 - end_sparse_categorical_accuracy: 0.7444 - loss: 1.7259 - start_loss: 0.8857 - start_sparse_categorical_accuracy: 0.7282 - val_end_loss: 2.1040 - val_end_sparse_categorical_accuracy: 0.5352 - val_loss: 3.9658 - val_start_loss: 1.8557 - val_start_sparse_categorical_accuracy: 0.5820
Epoch 6/10


2025-02-01 20:25:55.287411: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:25:55.287438: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440
2025-02-01 20:25:55.287447: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594


[1m362/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 122ms/step - end_loss: 0.6620 - end_sparse_categorical_accuracy: 0.7954 - loss: 1.4206 - start_loss: 0.7584 - start_sparse_categorical_accuracy: 0.7635 

2025-02-01 20:26:46.696595: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:26:46.696622: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594
2025-02-01 20:26:46.696631: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 145ms/step - end_loss: 0.6617 - end_sparse_categorical_accuracy: 0.7955 - loss: 1.4200 - start_loss: 0.7580 - start_sparse_categorical_accuracy: 0.7636 - val_end_loss: 2.1408 - val_end_sparse_categorical_accuracy: 0.5267 - val_loss: 4.0246 - val_start_loss: 1.8781 - val_start_sparse_categorical_accuracy: 0.5805
Epoch 7/10


2025-02-01 20:26:55.653123: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:26:55.653152: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440
2025-02-01 20:26:55.653164: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594


[1m362/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 121ms/step - end_loss: 0.5152 - end_sparse_categorical_accuracy: 0.8395 - loss: 1.1197 - start_loss: 0.6043 - start_sparse_categorical_accuracy: 0.8076 

2025-02-01 20:27:47.041215: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:27:47.041273: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594
2025-02-01 20:27:47.041282: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 145ms/step - end_loss: 0.5150 - end_sparse_categorical_accuracy: 0.8395 - loss: 1.1194 - start_loss: 0.6041 - start_sparse_categorical_accuracy: 0.8077 - val_end_loss: 2.4037 - val_end_sparse_categorical_accuracy: 0.5017 - val_loss: 4.4396 - val_start_loss: 2.0268 - val_start_sparse_categorical_accuracy: 0.5585
Epoch 8/10


2025-02-01 20:27:55.935217: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:27:55.935245: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440
2025-02-01 20:27:55.935254: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594


[1m362/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 120ms/step - end_loss: 0.4160 - end_sparse_categorical_accuracy: 0.8684 - loss: 0.9151 - start_loss: 0.4989 - start_sparse_categorical_accuracy: 0.8427 

2025-02-01 20:28:46.743361: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:28:46.743385: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594
2025-02-01 20:28:46.743394: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 143ms/step - end_loss: 0.4160 - end_sparse_categorical_accuracy: 0.8684 - loss: 0.9149 - start_loss: 0.4987 - start_sparse_categorical_accuracy: 0.8427 - val_end_loss: 2.5249 - val_end_sparse_categorical_accuracy: 0.5097 - val_loss: 4.6657 - val_start_loss: 2.1388 - val_start_sparse_categorical_accuracy: 0.5543
Epoch 9/10


2025-02-01 20:28:55.493531: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_6]]
2025-02-01 20:28:55.493561: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:28:55.493569: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440
2025-02-01 20:28:55.493579: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594


[1m362/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 118ms/step - end_loss: 0.3892 - end_sparse_categorical_accuracy: 0.8705 - loss: 0.8244 - start_loss: 0.4350 - start_sparse_categorical_accuracy: 0.8594 

2025-02-01 20:29:45.351329: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:29:45.351354: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594
2025-02-01 20:29:45.351364: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 141ms/step - end_loss: 0.3889 - end_sparse_categorical_accuracy: 0.8706 - loss: 0.8238 - start_loss: 0.4347 - start_sparse_categorical_accuracy: 0.8595 - val_end_loss: 2.7950 - val_end_sparse_categorical_accuracy: 0.5116 - val_loss: 5.1200 - val_start_loss: 2.3190 - val_start_sparse_categorical_accuracy: 0.5432
Epoch 10/10


2025-02-01 20:29:54.190581: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:29:54.190610: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440
2025-02-01 20:29:54.190619: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594


[1m363/365[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 118ms/step - end_loss: 0.3064 - end_sparse_categorical_accuracy: 0.8983 - loss: 0.6539 - start_loss: 0.3473 - start_sparse_categorical_accuracy: 0.8858 

2025-02-01 20:30:44.167554: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:30:44.167581: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594


[1m365/365[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 142ms/step - end_loss: 0.3064 - end_sparse_categorical_accuracy: 0.8983 - loss: 0.6538 - start_loss: 0.3472 - start_sparse_categorical_accuracy: 0.8858 - val_end_loss: 2.9026 - val_end_sparse_categorical_accuracy: 0.5109 - val_loss: 5.3813 - val_start_loss: 2.4740 - val_start_sparse_categorical_accuracy: 0.5335


2025-02-01 20:30:52.979923: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 6972856757582473978
2025-02-01 20:30:52.979955: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 5903302096367588440
2025-02-01 20:30:52.979965: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 15371053870318656594


In [9]:
# https://www.kaggle.com/competitions/tweet-sentiment-extraction
def jaccard(str1, str2):
    a = set(str1.lower().split())
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

score = 0
total = test.shape[0]
(y_start, y_end) = ffn.predict(dataset_test.batch(64))
y_start = tf.math.argmax(y_start, axis=1)
y_end = tf.math.argmax(y_end, axis=1) + 1
spans_pred = []
for i in range(total):
    span_start = y_start[i]
    span_end = y_end[i]
    y_str = test.text.iloc[i][span_start:span_end]
    spans_pred.append((test.text.iloc[i], y_str))
    t_str = test.selected_text.iloc[i]
    score += 1 / total * jaccard(y_str, t_str)
score

[1m65/65[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 90ms/step


0.10026866426297167