In [None]:
import tensorflow as tf
import tensorflow_text as tf_text
import os
import numpy as np

# Paths
MODEL_PATH = "/mnt/c/Users/Vedank/Documents/Transformer/TRANS_BASE_EK/saved_keras"
CONTEXT_VOCAB_PATH = "/mnt/c/Users/Vedank/Documents/Transformer/column1.txt.vocab"
TARGET_VOCAB_PATH = "/mnt/c/Users/Vedank/Documents/Transformer/column2.txt.vocab"

# Check files
if not os.path.exists(CONTEXT_VOCAB_PATH):
    raise FileNotFoundError(f"Context vocabulary file not found at {CONTEXT_VOCAB_PATH}")
if not os.path.exists(TARGET_VOCAB_PATH):
    raise FileNotFoundError(f"Target vocabulary file not found at {TARGET_VOCAB_PATH}")

# Load tokenizers
context_tokenizer = tf_text.BertTokenizer(CONTEXT_VOCAB_PATH, lower_case=True)
target_tokenizer = tf_text.BertTokenizer(TARGET_VOCAB_PATH, lower_case=True)

# Load the SavedModel
print("Loading the model from:", MODEL_PATH)
model = tf.saved_model.load(MODEL_PATH)

# Define tokenization function
def tokenize_context(text, max_length=32):
    tokens = context_tokenizer.tokenize(text).merge_dims(-2, -1)
    tokens = tokens.to_tensor(default_value=0, shape=[1, max_length])
    return tf.cast(tokens, tf.float32)  # Match signature

# Define autoregressive translation
def translate_text(input_text, max_length=32):
    context_tokens = tokenize_context(input_text, max_length)
    target_tokens = tf.cast(target_tokenizer.tokenize("<start>").merge_dims(-2, -1).to_tensor(shape=[1, 1]), tf.float32)
    output_ids = []

    for _ in range(max_length):
        # Pad target_tokens to max_length
        padded_target = tf.pad(target_tokens, [[0, 0], [0, max_length - tf.shape(target_tokens)[1]]], constant_values=0)
        output_dict = model.signatures['serve'](args_0=context_tokens, args_0_1=padded_target)
        logits = output_dict['output_0']  # Shape: (1, 32, vocab_size)
        
        # Get the next token from the current position
        next_token_logits = logits[:, tf.shape(target_tokens)[1] - 1, :]
        next_token = tf.argmax(next_token_logits, axis=-1, output_type=tf.int32)
        
        if next_token.numpy()[0] == 3:  # Assuming 3 is <end> token ID
            break
        
        output_ids.append(next_token.numpy()[0])
        # Reshape next_token to (1, 1) and concatenate
        next_token_2d = tf.cast(next_token[None], tf.float32)  # Shape: (1, 1)
        target_tokens = tf.concat([target_tokens, next_token_2d], axis=1)
    
    # Detokenize the generated sequence
    output_text = target_tokenizer.detokenize(tf.constant(output_ids, dtype=tf.int32))
    return output_text.numpy().tobytes().decode("utf-8")

# Test the model
if __name__ == "__main__":
    print("Model loaded successfully!")
    
    test_sentences = [
        "How are you?",
        "What is your name?",
        "I am happy today."
    ]
    
    print("\nTranslating example sentences:")
    for sentence in test_sentences:
        try:
            translation = translate_text(sentence)
            print(f"Input: {sentence}")
            print(f"Output: {translation}\n")
        except Exception as e:
            print(f"Error translating '{sentence}': {e}\n")

Loading the model from: /mnt/c/Users/Sohum/Documents/Transformer/TRANS_BASE_EK/saved_keras
Model loaded successfully!

Translating example sentences:
Error translating 'How are you?': {{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:GPU:0}} ConcatOp : Ranks of all input tensors should match: shape[0] = [1,1] vs. shape[1] = [1,1,1] [Op:ConcatV2] name: concat

Error translating 'What is your name?': {{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:GPU:0}} ConcatOp : Ranks of all input tensors should match: shape[0] = [1,1] vs. shape[1] = [1,1,1] [Op:ConcatV2] name: concat

Error translating 'I am happy today.': {{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:GPU:0}} ConcatOp : Ranks of all input tensors should match: shape[0] = [1,1] vs. shape[1] = [1,1,1] [Op:ConcatV2] name: concat

