In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
import re
from collections import defaultdict, Counter


In [None]:
sent_data_path = 'data/zh_sent_dataset.tsv'
t9_data_path = 'data/zh_T9_dataset.tsv'

sentences = pd.read_csv(sent_data_path, sep="\t", header=None, names=["sentence"])
codes = pd.read_csv(t9_data_path, sep="\t", header=None, names=["code", "char"])

# Build Nine-Key Code Mappings
code2chars = defaultdict(list)
char2code = {}

for _, row in codes.iterrows():
    code2chars[row.code].append(row.char)
    char2code[row.char] = row.code

# Build Training Samples
# For each character in sentence, use previous text as context and current code as input

samples = []

window_size = 20  # Limit context length (in characters)

for sentence in sentences["sentence"]:
    sentence = re.sub(r"[^\u4e00-\u9fa5]", "", sentence)  # Remove non-Chinese characters
    for i in range(len(sentence)):
        char = sentence[i]
        code = char2code.get(char)
        if code is None:
            continue
        context = sentence[max(0, i - window_size):i]
        samples.append((context, code, char))

print(f"Total samples: {len(samples)}")


Total samples: 1301608


In [12]:

# Build Vocabulary & Vectorization
all_chars = sorted(set(char2code.keys()))
char2idx = {c: i + 1 for i, c in enumerate(all_chars)}  # 0 用作 padding
idx2char = {i: c for c, i in char2idx.items()}

code_set = sorted(code2chars.keys())
code2idx = {c: i + 1 for i, c in enumerate(code_set)}  # 0 为 padding

max_context_len = window_size

def encode_context(text):
    return [char2idx.get(c, 0) for c in text][-max_context_len:]

def encode_code(code):
    return code2idx.get(code, 0)

X_context = []
X_code = []
Y_char = []

for ctx, code, char in samples:
    X_context.append(encode_context(ctx))
    X_code.append(encode_code(code))
    Y_char.append(char2idx[char])

# Padding
X_context = keras.preprocessing.sequence.pad_sequences(X_context, maxlen=max_context_len, padding='pre')
X_code = np.array(X_code)
Y_char = np.array(Y_char)

X_train_ctx, X_temp_ctx, X_train_code, X_temp_code, y_train, y_temp = train_test_split(
    X_context, X_code, Y_char, test_size=0.2, random_state=42)

X_val_ctx, X_test_ctx, X_val_code, X_test_code, y_val, y_test = train_test_split(
    X_temp_ctx, X_temp_code, y_temp, test_size=0.5, random_state=42)


In [17]:
vocab_size = len(char2idx) + 1
code_vocab_size = len(code2idx) + 1
embedding_dim = 64

ctx_input = keras.Input(shape=(max_context_len,), name="context_input")
ctx_emb = layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim, mask_zero=False)(ctx_input)
ctx_emb = layers.Masking(mask_value=0.0)(ctx_emb)
ctx_encoded = layers.Bidirectional(
layers.LSTM(64, recurrent_activation="sigmoid")
)(ctx_emb)

code_input = keras.Input(shape=(), dtype=tf.int32, name="code_input")
code_emb = layers.Embedding(input_dim=code_vocab_size, output_dim=32)(code_input)
code_encoded = layers.Flatten()(code_emb)

merged = layers.concatenate([ctx_encoded, code_encoded])
hidden = layers.Dense(128, activation="relu")(merged)
output = layers.Dense(vocab_size, activation="softmax")(hidden)

model = keras.Model(inputs=[ctx_input, code_input], outputs=output)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.summary()


In [18]:
callbacks = [
    keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
]

history = model.fit(
    {"context_input": X_train_ctx, "code_input": X_train_code},
    y_train,
    validation_data=(
        {"context_input": X_val_ctx, "code_input": X_val_code},
        y_val
    ),
    epochs=15,
    batch_size=256,
    callbacks=callbacks
)


Epoch 1/15
[1m 329/4068[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m13:42[0m 220ms/step - accuracy: 0.0890 - loss: 7.5059

KeyboardInterrupt: 

In [None]:
test_loss, test_acc = model.evaluate(
    {"context_input": X_test_ctx, "code_input": X_test_code},
    y_test
)
print(f"Test accuracy: {test_acc:.4f}")


In [None]:
def predict_next(code=None, context="", topk=5):
    ctx_enc = encode_context(context)
    code_enc = encode_code(code) if code else 0
    ctx_pad = keras.preprocessing.sequence.pad_sequences([ctx_enc], maxlen=max_context_len)
    pred = model.predict({"context_input": ctx_pad, "code_input": np.array([code_enc])}, verbose=0)[0]
    if code:
        possible_chars = code2chars[code]
        possible_ids = [char2idx[c] for c in possible_chars if c in char2idx]
        filtered = [(i, pred[i]) for i in possible_ids]
    else:
        filtered = list(enumerate(pred))

    filtered = sorted(filtered, key=lambda x: x[1], reverse=True)
    return [(idx2char[i], score) for i, score in filtered[:topk] if i in idx2char]


In [None]:
print("Input code=7426, context='我想知道'")
print(predict_next(code='7426', context='我想知道'))

print("Input code='', context='价格'")
print(predict_next(code='', context='价格'))

print("Input code='2878', context=''")
print(predict_next(code='2878', context=''))