In [None]:
# -----------------------------
# Khmer to Romanization Test
# -----------------------------
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.layers import Input

# -----------------------------
# 1. Configuration
# -----------------------------
MODEL_PATH = "s2s.h5"         # Path to your trained model
DATA_INPUT = "csv/data_kh.csv"
DATA_TARGET = "csv/data_rom.csv"

# -----------------------------
# 2. Load model and vocab
# -----------------------------
print("Loading model...")
model = load_model(MODEL_PATH)

print("Reading CSV data...")
input_texts = pd.read_csv(DATA_INPUT, header=None)[0].astype(str).tolist()
target_texts = pd.read_csv(DATA_TARGET, header=None)[0].astype(str).tolist()
target_texts = ["\t" + t.strip() + "\n" for t in target_texts]

input_characters = sorted(list(set("".join(input_texts))))
target_characters = sorted(list(set("".join(target_texts))))

num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max(len(txt) for txt in input_texts)
max_decoder_seq_length = max(len(txt) for txt in target_texts)

input_token_index = {char: i for i, char in enumerate(input_characters)}
target_token_index = {char: i for i, char in enumerate(target_characters)}
reverse_target_char_index = {i: char for char, i in target_token_index.items()}

print(f"KH vocab size: {num_encoder_tokens}")
print(f"ROM vocab size: {num_decoder_tokens}")

# -----------------------------
# 3. Rebuild inference models
# -----------------------------
# Encoder
encoder_inputs = model.input[0]
encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output
encoder_states = [state_h_enc, state_c_enc]
encoder_model = Model(encoder_inputs, encoder_states)

# Decoder
decoder_inputs = model.input[1]
decoder_lstm = model.layers[3]
decoder_dense = model.layers[4]

decoder_state_input_h = Input(shape=(state_h_enc.shape[1],), name="decoder_input_h")
decoder_state_input_c = Input(shape=(state_c_enc.shape[1],), name="decoder_input_c")
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]

decoder_outputs, state_h, state_c = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs
)
decoder_outputs = decoder_dense(decoder_outputs)
decoder_states = [state_h, state_c]

decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs] + decoder_states
)

# -----------------------------
# 4. Helper functions
# -----------------------------
def encode_input_text(text):
    x = np.zeros((1, max_encoder_seq_length, num_encoder_tokens), dtype="float32")
    for t, char in enumerate(text):
        if char in input_token_index:
            x[0, t, input_token_index[char]] = 1.0
    return x

def decode_sequence(input_seq):
    states_value = encoder_model.predict(input_seq, verbose=0)
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    target_seq[0, 0, target_token_index["\t"]] = 1.0

    decoded_sentence = ""
    stop_condition = False

    while not stop_condition:
        output_tokens, h, c = decoder_model.predict([target_seq] + states_value, verbose=0)
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        if sampled_char == "\n" or len(decoded_sentence) > max_decoder_seq_length:
            stop_condition = True

        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.0
        states_value = [h, c]

    return decoded_sentence.strip()

# -----------------------------
# 5. Test with input
# -----------------------------
# Optional: create a dictionary of KH -> ROM from CSV for reference
reference_dict = {kh.strip(): rom.strip() for kh, rom in zip(input_texts, target_texts)}

print("\n--- Khmer Romanization Test Ready ---")

# Example interactive testing
while True:
    kh_text = input("\nEnter Khmer word (or 'q' to quit): ").strip()
    if kh_text.lower() == "q":
        break
    if not kh_text:
        continue

    # Encode and predict
    input_seq = encode_input_text(kh_text)
    predicted_rom = decode_sequence(input_seq)

    # Reference from CSV (if exists)
    reference_rom = reference_dict.get(kh_text, "N/A")

    print(f"KH: {kh_text}, Roman (reference): {reference_rom}, Predicted: {predicted_rom}")
