In [1]:
!pip install -q datasets transformers tensorflow

[0m

In [2]:
import os, random, tensorflow as tf
from datasets import load_dataset
from tensorflow.keras.mixed_precision import set_global_policy
from transformers import AutoTokenizer

#enable training on gpu with quantized weights 
set_global_policy("mixed_float16")

2025-04-26 18:02:27.040220: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-26 18:02:27.040277: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-26 18:02:27.041337: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-26 18:02:27.048708: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as not

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA GeForce RTX 3070 Laptop GPU, compute capability 8.6


2025-04-26 18:02:30.100678: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-04-26 18:02:30.110138: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-04-26 18:02:30.110172: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-04-26 18:02:30.110469: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.


In [3]:
#download dataset
dataset = load_dataset("stanfordnlp/sst2")

#download gpt2 tokenizer and manually set pad token (gpt2 doesn't have one by default so we use the end of string token)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [4]:
#set max token sequence length and define the tokenizer function
MAX_LEN = 128
def tokenize_fn(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=MAX_LEN)

#adjust batch size higher or lower based on available VRAM. 512 has good performance on 8gb
BATCH_SIZE = 512

#create train and valudation datasets based on the ones included in sst2
train_ds = dataset["train"].map(tokenize_fn).to_tf_dataset(
    columns=["input_ids", "attention_mask"],
    label_cols=["label"],
    shuffle=True,
    batch_size=BATCH_SIZE,
)
val_ds = dataset["validation"].map(tokenize_fn).to_tf_dataset(
    columns=["input_ids", "attention_mask"],
    label_cols=["label"],
    shuffle=False,
    batch_size=BATCH_SIZE,
)


Old behaviour: columns=['a'], labels=['labels'] -> (tf.Tensor, tf.Tensor)  
             : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor)  
New behaviour: columns=['a'],labels=['labels'] -> ({'a': tf.Tensor}, {'labels': tf.Tensor})  
             : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor) 
2025-04-26 18:02:33.166296: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-04-26 18:02:33.166371: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-04-26 18:02:33.166393: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NU

In [5]:
#hyperparameters
VOCAB_SIZE  = tokenizer.vocab_size      #size of tokenizer vocabulary
EMB_DIM     = 128                       #dimension of each token embedding
NUM_HEADS   = 4                         #number of attention heads
FF_DIM      = 512                       #size of feedforward hidden layer inside transformer block
NUM_LAYERS  = 2                         #number of transformer blocks to stack
DROPOUT     = 0.1                       #dropout rate

#inputs to the model
input_ids = tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_ids")         # token ID input
attention_mask = tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32, name="attention_mask") #attention mask input (1=keep, 0=ignore)

#token and positional embeddings
token_emb = tf.keras.layers.Embedding(VOCAB_SIZE, EMB_DIM, name="token_emb")(input_ids)  #learnable word embeddings
pos_emb = tf.keras.layers.Embedding(MAX_LEN, EMB_DIM, name="pos_emb")(tf.range(MAX_LEN)) #learnable positional embeddings
x = token_emb + pos_emb                                                                   #add token + position info together

#transformer encoder blocks
for _ in range(NUM_LAYERS):
    # multi-head self attention (token attends to all other tokens, masked if necessary)
    attn_output = tf.keras.layers.MultiHeadAttention(
        num_heads=NUM_HEADS,
        key_dim=EMB_DIM // NUM_HEADS,
        dropout=DROPOUT
    )(x, x, attention_mask=tf.expand_dims(attention_mask, axis=1))  #expand mask dims for attention layer

    x = tf.keras.layers.Add()([x, attn_output])                     #residual connection (add input to output)
    x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x)         #layer normalization for stability

    # feed-forward network (applied to each position independently)
    ff_output = tf.keras.layers.Dense(FF_DIM, activation="gelu")(x) #first dense layer with GELU activation
    ff_output = tf.keras.layers.Dense(EMB_DIM)(ff_output)           #second dense layer projects back to embedding dim

    x = tf.keras.layers.Add()([x, ff_output])                       #residual connection again
    x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x)         #normalize again after FFN

#classification head
x = tf.keras.layers.GlobalAveragePooling1D()(x)                   #average over all tokens (convert sequence to vector)
x = tf.keras.layers.Dense(128, activation="relu")(x)               #dense hidden layer with ReLU activation
x = tf.keras.layers.Dropout(0.1)(x)                                #dropout for regularization
logits = tf.keras.layers.Dense(2, dtype="float32")(x)              #output logits for 2 classes (before softmax)

# build the model
model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=logits)
model.summary()


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_ids (InputLayer)      [(None, 128)]                0         []                            
                                                                                                  
 token_emb (Embedding)       (None, 128, 128)             6432896   ['input_ids[0][0]']           
                                                                                                  
 attention_mask (InputLayer  [(None, 128)]                0         []                            
 )                                                                                                
                                                                                                  
 tf.__operators__.add (TFOp  (None, 128, 128)             0         ['token_emb[0][0]']       

In [6]:
#using adam optimizer with a low learning rate for stability
optimizer = tf.keras.optimizers.Adam(5e-5)

#using sparse categorical crossentropy for the loss function
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

#compile the model
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])

#early stopping setup to prevent overfitting 
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",               #watch validation loss 
        patience=3,                       #stop training if the validation loss doesn't improve for 3 epochs
        restore_best_weights=True         #resotre the weights of the best epoch (before it was overfit)
    )
]

#set a max of 20 epochs (early stopping will usually end before we get that far)
EPOCHS = 20

#train the model
model.fit(
    train_ds, #set training dataset
    validation_data=val_ds, #set validation dataset
    epochs=EPOCHS,
    callbacks=callbacks # use the early stopping callback 
)


Epoch 1/20


2025-04-26 18:02:39.723813: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f790c02c920 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-04-26 18:02:39.723888: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 3070 Laptop GPU, Compute Capability 8.6
2025-04-26 18:02:39.736728: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-04-26 18:02:39.781823: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90000
I0000 00:00:1745690559.878976   11714 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20


<keras.src.callbacks.History at 0x7f79dc09a320>

In [7]:
# save the model and tokenizer to file
SAVE_DIR = "./tf_manual_transformer_sst2_finetuned"
model.save(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79701e60b0>, 140160893004720), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79701e60b0>, 140160893004720), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7970195810>, 140160893002992), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7970195810>, 140160893002992), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79701977f0>, 140160893005488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79701977f0>, 140160893005488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870c1c0>, 140160893005200), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870c1c0>, 140160893005200), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870d3f0>, 140160893006640), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870d3f0>, 140160893006640), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870ea10>, 140160893006352), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870ea10>, 140160893006352), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870d8d0>, 140161360232240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870d8d0>, 140161360232240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968768610>, 140160892768144), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968768610>, 140160892768144), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796876b730>, 140160711538800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796876b730>, 140160711538800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b07f0>, 140160892646128), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b07f0>, 140160892646128), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b1810>, 140160711566608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b1810>, 140160711566608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b2890>, 140160711567168), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b2890>, 140160711567168), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f5990>, 140160892768048), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f5990>, 140160892768048), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f69b0>, 140160892770928), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f69b0>, 140160892770928), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f79d0>, 140160892771120), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f79d0>, 140160892771120), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968648a30>, 140160892771216), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968648a30>, 140160892771216), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968649a50>, 140160892774768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968649a50>, 140160892774768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796864aa70>, 140160892774960), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796864aa70>, 140160892774960), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796864baf0>, 140160711488464), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796864baf0>, 140160711488464), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968688bb0>, 140160711488800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968688bb0>, 140160711488800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796868bca0>, 140160711682816), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796868bca0>, 140160711682816), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d0d60>, 140160711683376), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d0d60>, 140160711683376), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d1db0>, 140160711677936), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d1db0>, 140160711677936), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d2e60>, 140160711756336), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d2e60>, 140160711756336), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968519f60>, 140160711765856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968519f60>, 140160711765856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796851afe0>, 140160711757056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796851afe0>, 140160711757056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968211780>, 140160893004720), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968211780>, 140160893004720), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968212650>, 140160893002992), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968212650>, 140160893002992), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968213640>, 140160893005488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968213640>, 140160893005488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682405e0>, 140160893005200), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682405e0>, 140160893005200), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968241480>, 140160893006640), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968241480>, 140160893006640), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682423b0>, 140160893006352), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682423b0>, 140160893006352), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682432e0>, 140161360232240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682432e0>, 140161360232240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828c2e0>, 140160892768144), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828c2e0>, 140160892768144), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828d8a0>, 140160711538800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828d8a0>, 140160711538800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828e800>, 140160892646128), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828e800>, 140160892646128), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828f700>, 140160711566608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828f700>, 140160711566608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d4700>, 140160711567168), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d4700>, 140160711567168), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d5cc0>, 140160892768048), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d5cc0>, 140160892768048), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d6b90>, 140160892770928), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d6b90>, 140160892770928), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d7ac0>, 140160892771120), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d7ac0>, 140160892771120), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968120a30>, 140160892771216), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968120a30>, 140160892771216), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968121a20>, 140160892774768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968121a20>, 140160892774768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681228c0>, 140160892774960), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681228c0>, 140160892774960), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681237f0>, 140160711488464), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681237f0>, 140160711488464), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681687f0>, 140160711488800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681687f0>, 140160711488800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968169db0>, 140160711682816), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968169db0>, 140160711682816), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796816ad10>, 140160711683376), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796816ad10>, 140160711683376), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796816bc10>, 140160711677936), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796816bc10>, 140160711677936), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681bcc10>, 140160711756336), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681bcc10>, 140160711756336), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681be1d0>, 140160711765856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681be1d0>, 140160711765856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681bf130>, 140160711757056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681bf130>, 140160711757056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79701e60b0>, 140160893004720), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79701e60b0>, 140160893004720), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7970195810>, 140160893002992), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7970195810>, 140160893002992), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79701977f0>, 140160893005488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79701977f0>, 140160893005488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870c1c0>, 140160893005200), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870c1c0>, 140160893005200), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870d3f0>, 140160893006640), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870d3f0>, 140160893006640), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870ea10>, 140160893006352), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870ea10>, 140160893006352), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870d8d0>, 140161360232240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796870d8d0>, 140161360232240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968768610>, 140160892768144), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968768610>, 140160892768144), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796876b730>, 140160711538800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796876b730>, 140160711538800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b07f0>, 140160892646128), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b07f0>, 140160892646128), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b1810>, 140160711566608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b1810>, 140160711566608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b2890>, 140160711567168), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687b2890>, 140160711567168), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f5990>, 140160892768048), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f5990>, 140160892768048), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f69b0>, 140160892770928), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f69b0>, 140160892770928), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f79d0>, 140160892771120), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79687f79d0>, 140160892771120), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968648a30>, 140160892771216), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968648a30>, 140160892771216), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968649a50>, 140160892774768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968649a50>, 140160892774768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796864aa70>, 140160892774960), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796864aa70>, 140160892774960), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796864baf0>, 140160711488464), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796864baf0>, 140160711488464), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968688bb0>, 140160711488800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968688bb0>, 140160711488800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796868bca0>, 140160711682816), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796868bca0>, 140160711682816), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d0d60>, 140160711683376), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d0d60>, 140160711683376), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d1db0>, 140160711677936), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d1db0>, 140160711677936), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d2e60>, 140160711756336), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79686d2e60>, 140160711756336), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968519f60>, 140160711765856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968519f60>, 140160711765856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796851afe0>, 140160711757056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796851afe0>, 140160711757056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968211780>, 140160893004720), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968211780>, 140160893004720), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968212650>, 140160893002992), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968212650>, 140160893002992), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968213640>, 140160893005488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968213640>, 140160893005488), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682405e0>, 140160893005200), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682405e0>, 140160893005200), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968241480>, 140160893006640), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968241480>, 140160893006640), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682423b0>, 140160893006352), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682423b0>, 140160893006352), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682432e0>, 140161360232240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682432e0>, 140161360232240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828c2e0>, 140160892768144), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828c2e0>, 140160892768144), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828d8a0>, 140160711538800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828d8a0>, 140160711538800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828e800>, 140160892646128), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828e800>, 140160892646128), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828f700>, 140160711566608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796828f700>, 140160711566608), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d4700>, 140160711567168), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d4700>, 140160711567168), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d5cc0>, 140160892768048), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d5cc0>, 140160892768048), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d6b90>, 140160892770928), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d6b90>, 140160892770928), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d7ac0>, 140160892771120), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79682d7ac0>, 140160892771120), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968120a30>, 140160892771216), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968120a30>, 140160892771216), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968121a20>, 140160892774768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968121a20>, 140160892774768), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681228c0>, 140160892774960), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681228c0>, 140160892774960), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681237f0>, 140160711488464), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(4, 32, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681237f0>, 140160711488464), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681687f0>, 140160711488800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681687f0>, 140160711488800), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968169db0>, 140160711682816), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 512), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7968169db0>, 140160711682816), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796816ad10>, 140160711683376), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(512,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796816ad10>, 140160711683376), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796816bc10>, 140160711677936), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(None, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f796816bc10>, 140160711677936), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681bcc10>, 140160711756336), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681bcc10>, 140160711756336), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681be1d0>, 140160711765856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128, 128), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681be1d0>, 140160711765856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681bf130>, 140160711757056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(128,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f79681bf130>, 140160711757056), {}).


INFO:tensorflow:Assets written to: ./tf_manual_transformer_sst2_finetuned/assets


INFO:tensorflow:Assets written to: ./tf_manual_transformer_sst2_finetuned/assets


('./tf_manual_transformer_sst2_finetuned/tokenizer_config.json',
 './tf_manual_transformer_sst2_finetuned/special_tokens_map.json',
 './tf_manual_transformer_sst2_finetuned/vocab.json',
 './tf_manual_transformer_sst2_finetuned/merges.txt',
 './tf_manual_transformer_sst2_finetuned/added_tokens.json',
 './tf_manual_transformer_sst2_finetuned/tokenizer.json')

In [9]:
#reimport everything in case we want to run from file without training agin
import tensorflow as tf
from transformers import AutoTokenizer

#load saved model and tokenizer
model = tf.keras.models.load_model("./tf_manual_transformer_sst2_finetuned", compile=False)
tokenizer = AutoTokenizer.from_pretrained("./tf_manual_transformer_sst2_finetuned")
tokenizer.pad_token = tokenizer.eos_token

#declare some test data
simple_reviews = [
    "I loved it.",
    "Amazing experience!",
    "Absolutely fantastic.",
    "Best film I've seen this year.",
    "Terrible and boring.",
    "Worst movie ever.",
    "Undeniably awful.",
    "Complete disaster."
]

ambiguous_reviews = [
    "It was okay, not great but not bad either.",
    "Some parts worked, others fell flat.",
    "I liked the idea more than the execution.",
    "Decent performances, weak story.",
    "Beautiful visuals, confusing plot.",
    "Good soundtrack, but I wouldn't watch it again.",
    "Enjoyable moments mixed with long boring scenes.",
    "Neither terrible nor amazing, just a movie.",
]

confusing_positive_reviews = [
    "I hated the theater, but the movie was actually pretty good.",
    "The plot was terrible, yet somehow I found it highly entertaining.",
    "It was a strange film, but the acting made it a fantastic experience.",
    "The first half was slow, but it ended really strong.",
    "Boring to start, but surprisingly emotional by the end.",
    "Terrible pacing, but unforgettable characters.",
    "Some parts were confusing, although I loved the visuals.",
    "I regret going, but I kind of liked the movie.",
]

confusing_negative_reviews = [
    "I loved the director's past work, but this movie was a total joke.",
    "The trailer looked great, but the movie itself was terrible.",
    "I enjoyed the first few minutes, but the rest was boring and predictable.",
    "It had a great cast, but they couldn't save the awful writing.",
    "Stunning visuals couldn't hide the fact that it was a complete mess.",
    "The premise was interesting, but the execution was painfully bad.",
    "I wanted to like it so badly, but it was just too slow and dull.",
    "There were some funny moments, but overall it was a huge waste of time.",
]

# function to run inference on a group of sentences
def run_inference(group_name, sentences):
    print(f"\n{group_name}\n")
    
    #tokenize the sentences
    enc = tokenizer(
        sentences,
        padding="max_length",   #force pad all sentences to max length
        truncation=True,        #truncate if sentence is longer than max length
        return_tensors="tf",    #return tensorflow tensors
        max_length=128
    )

    #forward pass through the model
    logits = model([enc["input_ids"], enc["attention_mask"]])

    #apply softmax to convert logits to probabilities
    probs = tf.nn.softmax(logits, axis=-1).numpy()

    #print predictions
    for sent, p in zip(sentences, probs):
        pred = "Positive" if p[1] > p[0] else "Negative"  #decide based on which probability is higher
        print(f"{sent:70s}  {pred}  (Pos: {p[1]:.1%} / Neg: {p[0]:.1%})")


# run inference on the sentences
run_inference("Simple Reviews", simple_reviews)
run_inference("Ambiguous Reviews", ambiguous_reviews)
run_inference("Confusing Positive Reviews", confusing_positive_reviews)
run_inference("Confusing Negative Reviews", confusing_negative_reviews)



Simple Reviews

I loved it.                                                             Positive  (Pos: 99.4% / Neg: 0.6%)
Amazing experience!                                                     Positive  (Pos: 68.7% / Neg: 31.3%)
Absolutely fantastic.                                                   Positive  (Pos: 100.0% / Neg: 0.0%)
Best film I've seen this year.                                          Positive  (Pos: 98.7% / Neg: 1.3%)
Terrible and boring.                                                    Negative  (Pos: 1.5% / Neg: 98.5%)
Worst movie ever.                                                       Positive  (Pos: 65.8% / Neg: 34.2%)
Undeniably awful.                                                       Negative  (Pos: 21.7% / Neg: 78.3%)
Complete disaster.                                                      Negative  (Pos: 6.1% / Neg: 93.9%)

Ambiguous Reviews

It was okay, not great but not bad either.                              Negative  (Pos: 1.5% / Neg: 98.