In [25]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

# Load the Universal Sentence Encoder (USE) from TensorFlow Hub
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")

# Load the AG News dataset from TensorFlow Datasets
(train_data, test_data), info = tfds.load('ag_news_subset', split=['train', 'test'], with_info=True, as_supervised=True)

# Preprocessing function to lowercase text
def preprocess(text, label):
    text = tf.strings.lower(text)
    return text, label

# Apply preprocessing to the datasets
train_data = train_data.map(preprocess)
test_data = test_data.map(preprocess)

# Batch the datasets
batch_size = 32
train_data = train_data.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
test_data = test_data.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

# Function to embed text using USE outside of the map function
def embed_text(text_batch):
    embeddings = embed(text_batch)
    return embeddings

# Embed the entire dataset and create a generator for TensorFlow
def embed_dataset(dataset):
    for text_batch, label_batch in dataset:
        text_batch_str = text_batch.numpy()  # Convert the text batch to numpy strings
        embeddings = embed_text(text_batch_str)  # Get the USE embeddings
        yield embeddings, label_batch

# Convert the dataset into a TensorFlow Dataset with embeddings
train_data_embedded = tf.data.Dataset.from_generator(
    lambda: embed_dataset(train_data),
    output_signature=(
        tf.TensorSpec(shape=(None, 512), dtype=tf.float32),  
        tf.TensorSpec(shape=(None,), dtype=tf.int64)
    )
)

test_data_embedded = tf.data.Dataset.from_generator(
    lambda: embed_dataset(test_data),
    output_signature=(
        tf.TensorSpec(shape=(None, 512), dtype=tf.float32),
        tf.TensorSpec(shape=(None,), dtype=tf.int64)
    )
)

# Define a simple model that takes USE embeddings as input
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(512,)),  # USE generates 512-dimensional embeddings
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(4, activation='softmax')  # 4 classes in AG News dataset
])

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
epochs = 5
history = model.fit(train_data_embedded, epochs=epochs, validation_data=test_data_embedded)

# Evaluate the model on test data
test_loss, test_acc = model.evaluate(test_data_embedded)
print(f"Test Accuracy: {test_acc:.4f}")

Epoch 1/5
   3749/Unknown [1m138s[0m 36ms/step - accuracy: 0.8606 - loss: 0.4303

2024-10-08 10:13:25.350385: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-08 10:13:25.354242: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/Add_8/ReadVariableOp/_1]]
2024-10-08 10:13:25.354262: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 17448233847767517790
2024-10-08 10:13:25.354267: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4925557098472468529
2024-10-08 10:13:25.354276: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 17716219253532008484
2024-10-08 10:13:25.354286: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 11348681209493270804
2024-

[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m147s[0m 39ms/step - accuracy: 0.8606 - loss: 0.4302 - val_accuracy: 0.8857 - val_loss: 0.3261
Epoch 2/5
[1m   3/3750[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m2:26[0m 39ms/step - accuracy: 0.8524 - loss: 0.4130

2024-10-08 10:13:34.188207: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 3699640868639937127
2024-10-08 10:13:34.188221: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4892414630045973921
2024-10-08 10:13:34.188228: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 2787315973575985354


[1m3749/3750[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 36ms/step - accuracy: 0.8910 - loss: 0.3079

2024-10-08 10:15:47.678936: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-08 10:15:47.684291: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 17716219253532008484
2024-10-08 10:15:47.684310: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4925557098472468529
2024-10-08 10:15:47.684314: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 11796155222170091841
2024-10-08 10:15:47.684336: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 17640316776661417855
2024-10-08 10:15:47.684343: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 988939245374018315
2024-10-08 10:15:47.684347: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous re

[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m143s[0m 38ms/step - accuracy: 0.8910 - loss: 0.3079 - val_accuracy: 0.8901 - val_loss: 0.3129
Epoch 3/5
[1m   2/3750[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3:53[0m 62ms/step - accuracy: 0.8828 - loss: 0.3553

2024-10-08 10:15:57.106510: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 3699640868639937127
2024-10-08 10:15:57.106525: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4892414630045973921
2024-10-08 10:15:57.106534: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 2787315973575985354


[1m3749/3750[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 44ms/step - accuracy: 0.8959 - loss: 0.2915

2024-10-08 10:18:40.550450: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 17716219253532008484
2024-10-08 10:18:40.550469: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/Add_8/_20]]
2024-10-08 10:18:40.550474: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4925557098472468529
2024-10-08 10:18:40.550476: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 11796155222170091841
2024-10-08 10:18:40.550479: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 14616070943460344139
2024-10-08 10:18:40.550482: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 988939245374018315
2024-10-08 10:18:40.5504

[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m174s[0m 46ms/step - accuracy: 0.8959 - loss: 0.2915 - val_accuracy: 0.8946 - val_loss: 0.3053
Epoch 4/5
[1m   2/3750[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3:41[0m 59ms/step - accuracy: 0.8828 - loss: 0.3247

2024-10-08 10:18:51.005043: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 3699640868639937127
2024-10-08 10:18:51.005058: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4892414630045973921
2024-10-08 10:18:51.005066: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 2787315973575985354


[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step - accuracy: 0.9011 - loss: 0.2791

2024-10-08 10:21:37.037916: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4925557098472468529
2024-10-08 10:21:37.037932: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 11796155222170091841
2024-10-08 10:21:37.037936: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 14616070943460344139
2024-10-08 10:21:37.037940: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 988939245374018315
2024-10-08 10:21:37.037944: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 3991565797904686843
2024-10-08 10:21:37.037948: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 17640316776661417855
2024-10-08 10:21:37.037951: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv i

[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m176s[0m 47ms/step - accuracy: 0.9011 - loss: 0.2791 - val_accuracy: 0.8964 - val_loss: 0.2998
Epoch 5/5
[1m   2/3750[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3:56[0m 63ms/step - accuracy: 0.8984 - loss: 0.3249

2024-10-08 10:21:47.324413: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 3699640868639937127
2024-10-08 10:21:47.324428: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4892414630045973921
2024-10-08 10:21:47.324437: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 2787315973575985354


[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step - accuracy: 0.9045 - loss: 0.2674

2024-10-08 10:24:32.762232: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 17716219253532008484
2024-10-08 10:24:32.762249: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 11348681209493270804
2024-10-08 10:24:32.762253: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5481597926455346926
2024-10-08 10:24:32.762257: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 13506537335262705212
2024-10-08 10:24:32.762259: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4925557098472468529
2024-10-08 10:24:32.762268: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 11796155222170091841
2024-10-08 10:24:32.762271: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv

[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m176s[0m 47ms/step - accuracy: 0.9045 - loss: 0.2674 - val_accuracy: 0.8984 - val_loss: 0.2955
      3/Unknown [1m0s[0m 44ms/step - accuracy: 0.9167 - loss: 0.3442

2024-10-08 10:24:42.844297: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 3699640868639937127
2024-10-08 10:24:42.844319: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4892414630045973921
2024-10-08 10:24:42.844326: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 2787315973575985354


[1m238/238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 44ms/step - accuracy: 0.9034 - loss: 0.2838
Test Accuracy: 0.8984


2024-10-08 10:24:53.408835: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-08 10:24:53.414300: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 3699640868639937127
2024-10-08 10:24:53.414312: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4892414630045973921
2024-10-08 10:24:53.414320: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 2787315973575985354
