In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
import re
from collections import defaultdict, Counter


In [None]:
sent_data_path = 'data/zh_sent_dataset.tsv'
t9_data_path = 'data/zh_T9_dataset.tsv'

sentences = pd.read_csv(sent_data_path, sep="\t", header=None, names=["sentence"])
codes = pd.read_csv(t9_data_path, sep="\t", header=None, names=["code", "char"])

# Build Nine-Key Code Mappings
code2chars = defaultdict(list)
char2code = {}

for _, row in codes.iterrows():
    code2chars[row.code].append(row.char)
    char2code[row.char] = row.code

# Build Training Samples
# For each character in sentence, use previous text as context and current code as input

samples = []

window_size = 20  # Limit context length (in characters)

for sentence in sentences["sentence"]:
    sentence = re.sub(r"[^\u4e00-\u9fa5]", "", sentence)  # Remove non-Chinese characters
    for i in range(len(sentence)):
        char = sentence[i]
        code = char2code.get(char)
        if code is None:
            continue
        context = sentence[max(0, i - window_size):i]
        samples.append((context, code, char))

print(f"Total samples: {len(samples)}")


Total samples: 1301608


In [12]:

# Build Vocabulary & Vectorization
all_chars = sorted(set(char2code.keys()))
char2idx = {c: i + 1 for i, c in enumerate(all_chars)}  # 0 用作 padding
idx2char = {i: c for c, i in char2idx.items()}

code_set = sorted(code2chars.keys())
code2idx = {c: i + 1 for i, c in enumerate(code_set)}  # 0 为 padding

max_context_len = window_size

def encode_context(text):
    return [char2idx.get(c, 0) for c in text][-max_context_len:]

def encode_code(code):
    return code2idx.get(code, 0)

X_context = []
X_code = []
Y_char = []

for ctx, code, char in samples:
    X_context.append(encode_context(ctx))
    X_code.append(encode_code(code))
    Y_char.append(char2idx[char])

# Padding
X_context = keras.preprocessing.sequence.pad_sequences(X_context, maxlen=max_context_len, padding='pre')
X_code = np.array(X_code)
Y_char = np.array(Y_char)

X_train_ctx, X_temp_ctx, X_train_code, X_temp_code, y_train, y_temp = train_test_split(
    X_context, X_code, Y_char, test_size=0.2, random_state=42)

X_val_ctx, X_test_ctx, X_val_code, X_test_code, y_val, y_test = train_test_split(
    X_temp_ctx, X_temp_code, y_temp, test_size=0.5, random_state=42)


In [None]:
vocab_size = len(char2idx) + 1
code_vocab_size = len(code2idx) + 1
embedding_dim = 64

ctx_input = keras.Input(shape=(max_context_len,), name="context_input")
ctx_emb = layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim, mask_zero=False)(ctx_input)
ctx_emb = layers.Masking(mask_value=0.0)(ctx_emb)
ctx_encoded = layers.Bidirectional(
layers.LSTM(64, recurrent_activation="sigmoid")
)(ctx_emb)

code_input = keras.Input(shape=(), dtype=tf.int32, name="code_input")
code_emb = layers.Embedding(input_dim=code_vocab_size, output_dim=32)(code_input)
code_encoded = layers.Flatten()(code_emb)

merged = layers.concatenate([ctx_encoded, code_encoded])
hidden = layers.Dense(128, activation="relu")(merged)
output = layers.Dense(vocab_size, activation="softmax")(hidden)

model = keras.Model(inputs=[ctx_input, code_input], outputs=output)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.summary()


SyntaxError: invalid character '“' (U+201C) (1905046101.py, line 9)

In [14]:
callbacks = [
    keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
]

history = model.fit(
    {"context_input": X_train_ctx, "code_input": X_train_code},
    y_train,
    validation_data=(
        {"context_input": X_val_ctx, "code_input": X_val_code},
        y_val
    ),
    epochs=15,
    batch_size=256,
    callbacks=callbacks
)


Epoch 1/15


2025-05-30 02:13:10.593267: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: INVALID_ARGUMENT: assertion failed: [You are passing a RNN mask that does not correspond to right-padded sequences, while using cuDNN, which is not supported. With cuDNN, RNN masks can only be used for right-padding, e.g. `[[True, True, False, False]]` would be a valid mask, but any mask that isn\'t just contiguous `True`\'s on the left and contiguous `False`\'s on the right would be invalid. You can pass `use_cudnn=False` to your RNN layer to stop using cuDNN (this may be slower).]
	 [[{{function_node __inference_one_step_on_data_13331}}{{node functional_3_1/bidirectional_3_1/forward_lstm_3_1/Assert/Assert}}]]
2025-05-30 02:13:10.593289: I tensorflow/core/framework/local_rendezvous.cc:428] Local rendezvous send item cancelled. Key hash: 8248750667461394598


InvalidArgumentError: Graph execution error:

Detected at node functional_3_1/bidirectional_3_1/forward_lstm_3_1/Assert/Assert defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start

  File "/Users/shen/.pyenv/versions/3.11.10/lib/python3.11/asyncio/base_events.py", line 608, in run_forever

  File "/Users/shen/.pyenv/versions/3.11.10/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once

  File "/Users/shen/.pyenv/versions/3.11.10/lib/python3.11/asyncio/events.py", line 84, in _run

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 534, in process_one

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 362, in execute_request

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 778, in execute_request

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 449, in do_execute

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code

  File "/var/folders/s9/j56343l57_s2lw7_47rw2bpr0000gn/T/ipykernel_78109/13945941.py", line 5, in <module>

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 371, in fit

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 219, in function

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 833, in __call__

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 889, in _call

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 696, in _initialize

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 599, in wrapped_fn

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 339, in converted_call

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 459, in _call_unconverted

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 132, in multi_step_on_iterator

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 833, in __call__

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 906, in _call

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 132, in call_function

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 599, in wrapped_fn

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 331, in converted_call

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 459, in _call_unconverted

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 113, in one_step_on_data

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1673, in run

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/distribute/distribute_lib.py", line 3263, in call_for_each_replica

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/distribute/distribute_lib.py", line 4061, in _call_for_each_replica

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 57, in train_step

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/layers/layer.py", line 908, in __call__

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/ops/operation.py", line 46, in __call__

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/models/functional.py", line 182, in call

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/ops/function.py", line 171, in _run_through_graph

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/models/functional.py", line 637, in call

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/layers/layer.py", line 908, in __call__

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/ops/operation.py", line 46, in __call__

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/layers/rnn/bidirectional.py", line 218, in call

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/layers/layer.py", line 908, in __call__

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/ops/operation.py", line 46, in __call__

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/layers/rnn/lstm.py", line 584, in call

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/layers/rnn/rnn.py", line 402, in call

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/layers/rnn/lstm.py", line 551, in inner_loop

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/backend/tensorflow/rnn.py", line 841, in lstm

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/backend/tensorflow/rnn.py", line 874, in _cudnn_lstm

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/keras/src/backend/tensorflow/rnn.py", line 557, in _assert_valid_mask

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/util/dispatch.py", line 1260, in op_dispatch_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/util/tf_should_use.py", line 288, in wrapped

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/ops/control_flow_assert.py", line 115, in Assert

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/ops/gen_logging_ops.py", line 62, in _assert

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/framework/func_graph.py", line 670, in _create_op_internal

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/framework/ops.py", line 2701, in _create_op_internal

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/framework/ops.py", line 1196, in from_node_def

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/framework/ops.py", line 1062, in _create_c_op

  File "/Users/shen/.pyenv/versions/3.11.10/envs/keras/lib/python3.11/site-packages/tensorflow/python/util/tf_stack.py", line 162, in extract_stack

assertion failed: [You are passing a RNN mask that does not correspond to right-padded sequences, while using cuDNN, which is not supported. With cuDNN, RNN masks can only be used for right-padding, e.g. `[[True, True, False, False]]` would be a valid mask, but any mask that isn\'t just contiguous `True`\'s on the left and contiguous `False`\'s on the right would be invalid. You can pass `use_cudnn=False` to your RNN layer to stop using cuDNN (this may be slower).]
	 [[{{node functional_3_1/bidirectional_3_1/forward_lstm_3_1/Assert/Assert}}]] [Op:__inference_multi_step_on_iterator_13420]

In [None]:
test_loss, test_acc = model.evaluate(
    {"context_input": X_test_ctx, "code_input": X_test_code},
    y_test
)
print(f"Test accuracy: {test_acc:.4f}")


In [None]:
def predict_next(code=None, context="", topk=5):
    ctx_enc = encode_context(context)
    code_enc = encode_code(code) if code else 0
    ctx_pad = keras.preprocessing.sequence.pad_sequences([ctx_enc], maxlen=max_context_len)
    pred = model.predict({"context_input": ctx_pad, "code_input": np.array([code_enc])}, verbose=0)[0]
    if code:
        possible_chars = code2chars[code]
        possible_ids = [char2idx[c] for c in possible_chars if c in char2idx]
        filtered = [(i, pred[i]) for i in possible_ids]
    else:
        filtered = list(enumerate(pred))

    filtered = sorted(filtered, key=lambda x: x[1], reverse=True)
    return [(idx2char[i], score) for i, score in filtered[:topk] if i in idx2char]


In [None]:
print("Input code=7426, context='我想知道'")
print(predict_next(code='7426', context='我想知道'))

print("Input code='', context='价格'")
print(predict_next(code='', context='价格'))

print("Input code='2878', context=''")
print(predict_next(code='2878', context=''))