In [1]:
import tensorflow as tf
import tensorflow_hub as hub
from transformers import BertTokenizer, TFBertModel
import numpy as np
import time

# For reproducibility
tf.random.set_seed(42)
np.random.seed(42)

2025-07-14 14:18:19.374621: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752502699.547428      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752502699.600225      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# Load IMDB dataset
print("Loading IMDB dataset...")
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=10000)
word_index = tf.keras.datasets.imdb.get_word_index()

# Convert back to text
index_word = {v + 3: k for k, v in word_index.items()}
index_word[0], index_word[1], index_word[2], index_word[3] = "<pad>", "<start>", "<unk>", "<unused>"

def decode_review(encoded_review):
    return " ".join([index_word.get(i, "?") for i in encoded_review])

x_train_text = [decode_review(seq) for seq in x_train]
x_test_text = [decode_review(seq) for seq in x_test]

x_train_text = x_train_text[:2000]
y_train = y_train[:2000]
x_test_text = x_test_text[:500]
y_test = y_test[:500]
max_length = 128


Loading IMDB dataset...
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
[1m17464789/17464789[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb_word_index.json
[1m1641221/1641221[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1us/step


In [3]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = TFBertModel.from_pretrained('bert-base-uncased')
bert_model.trainable = False

def encode_bert(texts):
    encodings = bert_tokenizer(
        texts,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="tf"
    )
    return encodings['input_ids'], encodings['attention_mask']

x_train_ids, x_train_mask = encode_bert(x_train_text)
x_test_ids, x_test_mask = encode_bert(x_test_text)

def build_bert_classifier():
    input_ids = tf.keras.Input(shape=(max_length,), dtype=tf.int32, name='input_ids')
    attention_mask = tf.keras.Input(shape=(max_length,), dtype=tf.int32, name='attention_mask')

    # Wrap the Hugging Face model call in a Lambda layer
    def bert_layer(inputs):
        ids, mask = inputs
        return bert_model(input_ids=ids, attention_mask=mask).last_hidden_state

    bert_output = tf.keras.layers.Lambda(
        bert_layer,
        output_shape=(max_length, 768),  # Explicit shape!
        name="bert_output"
    )([input_ids, attention_mask])

    # Extract [CLS] token
    cls_token = tf.keras.layers.Lambda(lambda x: x[:, 0], output_shape=(768,), name="cls_token")(bert_output)

    x = tf.keras.layers.Dense(256, activation='relu')(cls_token)
    x = tf.keras.layers.Dropout(0.3)(x)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    return tf.keras.Model(inputs=[input_ids, attention_mask], outputs=output)



bert_classifier = build_bert_classifier()
bert_classifier.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])

print("Training BERT classifier...")
bert_classifier.fit([x_train_ids, x_train_mask], y_train, epochs=5, batch_size=32,
                    validation_data=([x_test_ids, x_test_mask], y_test))


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

I0000 00:00:1752502735.722865      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertFo

Training BERT classifier...
Epoch 1/5


I0000 00:00:1752502763.135711      72 service.cc:148] XLA service 0x7d29a8016f50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1752502763.136298      72 service.cc:156]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
W0000 00:00:1752502763.610085      72 assert_op.cc:38] Ignoring Assert operator functional_1/bert_output_1/tf_bert_model/bert/embeddings/assert_less/Assert/Assert
I0000 00:00:1752502765.356434      72 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m 1/63[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m21:51[0m 21s/step - accuracy: 0.4688 - loss: 0.7534

I0000 00:00:1752502767.799404      72 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m62/63[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 105ms/step - accuracy: 0.6072 - loss: 0.6525

W0000 00:00:1752502774.744747      71 assert_op.cc:38] Ignoring Assert operator functional_1/bert_output_1/tf_bert_model/bert/embeddings/assert_less/Assert/Assert


[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 160ms/step - accuracy: 0.6083 - loss: 0.6515

W0000 00:00:1752502782.274558      73 assert_op.cc:38] Ignoring Assert operator functional_1/bert_output_1/tf_bert_model/bert/embeddings/assert_less/Assert/Assert
W0000 00:00:1752502785.682222      71 assert_op.cc:38] Ignoring Assert operator functional_1/bert_output_1/tf_bert_model/bert/embeddings/assert_less/Assert/Assert


[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 314ms/step - accuracy: 0.6092 - loss: 0.6506 - val_accuracy: 0.7680 - val_loss: 0.4933
Epoch 2/5
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 131ms/step - accuracy: 0.7659 - loss: 0.4928 - val_accuracy: 0.7480 - val_loss: 0.5143
Epoch 3/5
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 131ms/step - accuracy: 0.7682 - loss: 0.4823 - val_accuracy: 0.7740 - val_loss: 0.4865
Epoch 4/5
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 131ms/step - accuracy: 0.7907 - loss: 0.4444 - val_accuracy: 0.7620 - val_loss: 0.5024
Epoch 5/5
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 131ms/step - accuracy: 0.8034 - loss: 0.4369 - val_accuracy: 0.7680 - val_loss: 0.4846


<keras.src.callbacks.history.History at 0x7d2a1bd516d0>

In [4]:
bert_loss, bert_acc = bert_classifier.evaluate([x_test_ids, x_test_mask], y_test, verbose=0)

print(f"BERT Accuracy: {bert_acc:.4f}")

BERT Accuracy: 0.7680


In [5]:
print("\nGetting predictions...")
bert_predictions = bert_classifier.predict([x_test_ids, x_test_mask])

# Convert probabilities to binary predictions
bert_binary_preds = (bert_predictions > 0.5).astype(int)

print(f"BERT Binary Accuracy: {np.mean(bert_binary_preds.flatten() == y_test):.4f}")

print("\nSample predictions (first 5 test examples):")
for i in range(min(5, len(x_test_text))):
    print(f"Text: {x_test_text[i][:100]}...")
    print(f"  True label: {y_test[i]}")
    print(f"  BERT pred: {bert_predictions[i][0]:.4f} -> {bert_binary_preds[i][0]}")
    print("  ---")


Getting predictions...


W0000 00:00:1752502824.166071      74 assert_op.cc:38] Ignoring Assert operator functional_1/bert_output_1/tf_bert_model/bert/embeddings/assert_less/Assert/Assert


[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 104ms/step

W0000 00:00:1752502831.284561      71 assert_op.cc:38] Ignoring Assert operator functional_1/bert_output_1/tf_bert_model/bert/embeddings/assert_less/Assert/Assert


[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 499ms/step
BERT Binary Accuracy: 0.7680

Sample predictions (first 5 test examples):
Text: <start> please give this one a miss br br <unk> <unk> and the rest of the cast rendered terrible per...
  True label: 0
  BERT pred: 0.2243 -> 0
  ---
Text: <start> this film requires a lot of patience because it focuses on mood and character development th...
  True label: 1
  BERT pred: 0.9367 -> 1
  ---
Text: <start> many animation buffs consider <unk> <unk> the great forgotten genius of one special branch o...
  True label: 1
  BERT pred: 0.8526 -> 1
  ---
Text: <start> i generally love this type of movie however this time i found myself wanting to kick the scr...
  True label: 0
  BERT pred: 0.1045 -> 0
  ---
Text: <start> like some other people wrote i'm a die hard mario fan and i loved this game br br this game ...
  True label: 1
  BERT pred: 0.7314 -> 1
  ---
