# Context-Aware Kana-Kanji Converter

**Enhanced**: Uses `<SEP>` token to mark context boundary

**Input Format**: `context<SEP>kana`
- Before `<SEP>`: context (already converted, kanji)
- After `<SEP>`: kana to convert (hiragana)

**Example**:
```
Input:  ÂÜôÁúü„Çí<SEP>„Å®„Çã
Output: ÊíÆ„Çã
```

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')
DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'
MODEL_DIR = f"{DRIVE_DIR}/models/gru_japanese_kana_kanji"
os.makedirs(MODEL_DIR, exist_ok=True)

In [None]:
!pip install -q tensorflow keras datasets numpy tqdm

In [None]:
TESTING_MODE = True

if TESTING_MODE:
    NUM_EPOCHS = 10
    BATCH_SIZE = 256
    MAX_SAMPLES = 400000
else:
    NUM_EPOCHS = 20
    BATCH_SIZE = 256
    MAX_SAMPLES = 800000

CHAR_VOCAB_SIZE = 3000
MAX_INPUT_LEN = 40    # context + <SEP> + kana
MAX_OUTPUT_LEN = 20   # kanji output
EMBEDDING_DIM = 64
GRU_UNITS = 128

# Special tokens
SPECIAL_TOKENS = ['<PAD>', '<UNK>', '<BOS>', '<EOS>', '<SEP>']
SEP_TOKEN = '<SEP>'

## 1. Load zenz Dataset

In [None]:
from datasets import load_dataset

print("Loading zenz-v2.5-dataset...")

try:
    dataset = load_dataset(
        "Miwa-Keita/zenz-v2.5-dataset",
        data_files="train_wikipedia.jsonl",
        split=f"train[:{MAX_SAMPLES}]"
    )
except:
    dataset = load_dataset(
        "Miwa-Keita/zenz-v2.5-dataset",
        split=f"train[:{MAX_SAMPLES}]"
    )

print(f"‚úì Loaded {len(dataset):,} samples")

## 2. Create Training Data with <SEP> Token

Split sentences into: `context<SEP>kana` ‚Üí `kanji`

In [None]:
import random
from tqdm import tqdm

def katakana_to_hiragana(text):
    """Convert katakana to hiragana."""
    result = []
    for char in text:
        code = ord(char)
        if 0x30A1 <= code <= 0x30F6:
            result.append(chr(code - 0x60))
        else:
            result.append(char)
    return ''.join(result)

def create_training_example(kana_full, kanji_full):
    """Create training example: context<SEP>kana ‚Üí kanji."""
    min_context = 2
    min_kana = 1
    max_kana = 12
    
    if len(kanji_full) < min_context + min_kana + 1:
        return None
    
    # Random split point
    max_split = min(len(kanji_full) - min_kana, 20)
    if max_split <= min_context:
        return None
    
    split_pos = random.randint(min_context, max_split)
    
    # Context: already converted kanji
    context = kanji_full[:split_pos]
    
    # Target: what to convert to
    target_end = min(split_pos + max_kana, len(kanji_full))
    target = kanji_full[split_pos:target_end]
    
    # Kana: approximate hiragana for the target
    kana_approx = kana_full[split_pos:target_end]
    kana = katakana_to_hiragana(kana_approx)
    
    if not context or not kana or not target:
        return None
    if len(kana) < 1 or len(target) < 1:
        return None
    
    # Create input: context<SEP>kana
    input_text = f"{context}{SEP_TOKEN}{kana}"
    
    return {
        'input': input_text,
        'output': target,
        'context': context,
        'kana': kana
    }

# Generate training examples
print("Generating training data (context<SEP>kana ‚Üí kanji)...")
training_data = []

for item in tqdm(dataset, desc="Processing"):
    kana_full = item['input']
    kanji_full = item['output']
    
    # Generate multiple examples per sentence
    for _ in range(3):
        result = create_training_example(kana_full, kanji_full)
        if result:
            training_data.append(result)

print(f"\n‚úì Generated {len(training_data):,} training examples")
print("\nSamples:")
for i in range(10):
    d = training_data[i]
    print(f"  Input:  {d['input']}")
    print(f"  Output: {d['output']}")
    print(f"  (context={d['context']}, kana={d['kana']})")
    print()

## 3. Build Vocabulary

In [None]:
from collections import Counter

char_counts = Counter()
for d in tqdm(training_data, desc="Counting"):
    # Don't count <SEP> as regular char
    input_text = d['input'].replace(SEP_TOKEN, '')
    char_counts.update(list(input_text))
    char_counts.update(list(d['output']))

print(f"\nUnique chars: {len(char_counts):,}")

# Build vocab with special tokens
char_to_idx = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
for char, _ in char_counts.most_common(CHAR_VOCAB_SIZE - len(SPECIAL_TOKENS)):
    char_to_idx[char] = len(char_to_idx)

idx_to_char = {v: k for k, v in char_to_idx.items()}
vocab_size = len(char_to_idx)
print(f"Vocab size: {vocab_size}")
print(f"<SEP> index: {char_to_idx[SEP_TOKEN]}")

## 4. Create Training Tensors

In [None]:
import numpy as np
import tensorflow as tf

PAD_IDX = char_to_idx['<PAD>']
BOS_IDX = char_to_idx['<BOS>']
EOS_IDX = char_to_idx['<EOS>']
UNK_IDX = char_to_idx['<UNK>']
SEP_IDX = char_to_idx['<SEP>']

def encode_input(text, max_len):
    """Encode input with <SEP> token."""
    tokens = []
    i = 0
    while i < len(text):
        # Check for <SEP> token
        if text[i:i+5] == SEP_TOKEN:
            tokens.append(SEP_TOKEN)
            i += 5
        else:
            tokens.append(text[i])
            i += 1
    
    ids = [char_to_idx.get(t, UNK_IDX) for t in tokens]
    while len(ids) < max_len:
        ids.append(PAD_IDX)
    return ids[:max_len]

def encode_output(text, max_len, add_bos=False, add_eos=False):
    """Encode output text."""
    tokens = list(text)
    if add_bos:
        tokens = ['<BOS>'] + tokens
    if add_eos:
        tokens = tokens + ['<EOS>']
    ids = [char_to_idx.get(c, UNK_IDX) for c in tokens]
    while len(ids) < max_len:
        ids.append(PAD_IDX)
    return ids[:max_len]

encoder_inputs = []
decoder_inputs = []
decoder_targets = []

for d in tqdm(training_data, desc="Encoding"):
    encoder_inputs.append(encode_input(d['input'], MAX_INPUT_LEN))
    decoder_inputs.append(encode_output(d['output'], MAX_OUTPUT_LEN, add_bos=True))
    decoder_targets.append(encode_output(d['output'], MAX_OUTPUT_LEN, add_eos=True))

encoder_inputs = np.array(encoder_inputs, dtype=np.int32)
decoder_inputs = np.array(decoder_inputs, dtype=np.int32)
decoder_targets = np.array(decoder_targets, dtype=np.int32)

print(f"\nShapes:")
print(f"  Encoder Input:  {encoder_inputs.shape}")
print(f"  Decoder Input:  {decoder_inputs.shape}")
print(f"  Decoder Target: {decoder_targets.shape}")

In [None]:
# Verify encoding
print("Verify encoding:")
for i in range(3):
    # Decode back
    enc_decoded = ''.join([idx_to_char.get(idx, '?') for idx in encoder_inputs[i] if idx != PAD_IDX])
    dec_decoded = ''.join([idx_to_char.get(idx, '?') for idx in decoder_targets[i] if idx not in [PAD_IDX, BOS_IDX, EOS_IDX]])
    print(f"  {enc_decoded} ‚Üí {dec_decoded}")

In [None]:
# Shuffle and split
idx = np.random.permutation(len(encoder_inputs))
encoder_inputs = encoder_inputs[idx]
decoder_inputs = decoder_inputs[idx]
decoder_targets = decoder_targets[idx]

split = int(len(encoder_inputs) * 0.9)

train_ds = tf.data.Dataset.from_tensor_slices((
    {'encoder_input': encoder_inputs[:split], 'decoder_input': decoder_inputs[:split]},
    decoder_targets[:split]
)).shuffle(10000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((
    {'encoder_input': encoder_inputs[split:], 'decoder_input': decoder_inputs[split:]},
    decoder_targets[split:]
)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print(f"Train: {split:,}, Val: {len(encoder_inputs)-split:,}")

## 5. Build Model (Encoder-Decoder with Attention)

The encoder processes `context<SEP>kana` as a single sequence.
The `<SEP>` token helps the model understand the boundary.

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Embedding, GRU, Dense, Dropout,
    Bidirectional, Attention, Concatenate, LayerNormalization
)

# Embedding
embedding = Embedding(vocab_size, EMBEDDING_DIM, name='embedding')

# Encoder: processes context<SEP>kana
encoder_input = Input(shape=(MAX_INPUT_LEN,), dtype='int32', name='encoder_input')
enc_emb = embedding(encoder_input)
encoder_outputs = Bidirectional(
    GRU(GRU_UNITS, return_sequences=True),
    name='encoder'
)(enc_emb)

# Decoder
decoder_input = Input(shape=(MAX_OUTPUT_LEN,), dtype='int32', name='decoder_input')
dec_emb = embedding(decoder_input)
decoder_outputs = GRU(GRU_UNITS * 2, return_sequences=True, name='decoder')(dec_emb)

# Attention over encoder
context = Attention(use_scale=True, name='attention')([decoder_outputs, encoder_outputs])

# Combine
combined = Concatenate()([decoder_outputs, context])
combined = LayerNormalization()(combined)
combined = Dropout(0.2)(combined)

# Output
output = Dense(vocab_size, activation='softmax', name='output')(combined)

model = Model(
    inputs=[encoder_input, decoder_input],
    outputs=output,
    name='context_kana_kanji'
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()
print(f"\nParams: {model.count_params():,}")
print(f"Size: ~{model.count_params() * 4 / 1024 / 1024:.1f} MB")

## 6. Train

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

callbacks = [
    ModelCheckpoint(f'{MODEL_DIR}/best.keras', monitor='val_accuracy', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6, verbose=1)
]

history = model.fit(train_ds, epochs=NUM_EPOCHS, validation_data=val_ds, callbacks=callbacks)

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['loss'], label='Train')
ax1.plot(history.history['val_loss'], label='Val')
ax1.set_title('Loss'); ax1.legend()
ax2.plot(history.history['accuracy'], label='Train')
ax2.plot(history.history['val_accuracy'], label='Val')
ax2.set_title('Accuracy'); ax2.legend()
plt.savefig(f'{MODEL_DIR}/training.png')
plt.show()
print(f"Best val_accuracy: {max(history.history['val_accuracy'])*100:.2f}%")

## 7. Save

In [None]:
import json

model.save(f'{MODEL_DIR}/model.keras')

with open(f'{MODEL_DIR}/char_to_idx.json', 'w', encoding='utf-8') as f:
    json.dump(char_to_idx, f, ensure_ascii=False)
with open(f'{MODEL_DIR}/idx_to_char.json', 'w', encoding='utf-8') as f:
    json.dump({str(k): v for k, v in idx_to_char.items()}, f, ensure_ascii=False)
with open(f'{MODEL_DIR}/config.json', 'w') as f:
    json.dump({
        'vocab_size': vocab_size,
        'max_input_len': MAX_INPUT_LEN,
        'max_output_len': MAX_OUTPUT_LEN,
        'sep_token': SEP_TOKEN,
        'sep_idx': SEP_IDX
    }, f)

print("‚úì Saved")

In [None]:
try:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter._experimental_lower_tensor_list_ops = False
    
    tflite_model = converter.convert()
    with open(f'{MODEL_DIR}/model.tflite', 'wb') as f:
        f.write(tflite_model)
    print(f"‚úì model.tflite ({len(tflite_model)/(1024*1024):.2f}MB)")
    
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    tflite_fp16 = converter.convert()
    with open(f'{MODEL_DIR}/model_fp16.tflite', 'wb') as f:
        f.write(tflite_fp16)
    print(f"‚úì model_fp16.tflite ({len(tflite_fp16)/(1024*1024):.2f}MB)")
except Exception as e:
    print(f"‚ö† {e}")

## 8. Verification

In [None]:
print("="*60)
print("VERIFICATION: Context-Aware Kana-Kanji Conversion")
print("="*60)
print(f"\nInput format: context{SEP_TOKEN}kana")

def convert(context, kana, max_len=20):
    """Convert kana to kanji based on context.
    
    Mobile usage:
        input_text = context + "<SEP>" + kana
        result = model.predict(input_text)
    """
    # Create input: context<SEP>kana
    input_text = f"{context}{SEP_TOKEN}{kana}"
    enc_in = np.array([encode_input(input_text, MAX_INPUT_LEN)], dtype=np.int32)
    
    dec_in = np.zeros((1, MAX_OUTPUT_LEN), dtype=np.int32)
    dec_in[0, 0] = BOS_IDX
    
    result = []
    for i in range(max_len):
        preds = model.predict({
            'encoder_input': enc_in,
            'decoder_input': dec_in
        }, verbose=0)
        next_idx = int(np.argmax(preds[0, i]))
        if next_idx == EOS_IDX:
            break
        if next_idx not in [PAD_IDX, BOS_IDX, EOS_IDX, UNK_IDX, SEP_IDX]:
            result.append(idx_to_char.get(next_idx, ''))
        if i + 1 < MAX_OUTPUT_LEN:
            dec_in[0, i + 1] = next_idx
    return ''.join(result)

# Comprehensive test cases
test_cases = [
    # „ÅÇ„Å§„ÅÑ (hot weather / hot temperature / thick)
    {"context": "‰ªäÊó•„ÅØ„Å®„Å¶„ÇÇ", "kana": "„ÅÇ„Å§„ÅÑ", "expected": "Êöë„ÅÑ", "desc": "Weather hot"},
    {"context": "„ÅäËå∂„Åå", "kana": "„ÅÇ„Å§„ÅÑ", "expected": "ÁÜ±„ÅÑ", "desc": "Temperature hot"},
    {"context": "„Åì„ÅÆËæûÂÖ∏„ÅØ", "kana": "„ÅÇ„Å§„ÅÑ", "expected": "Âéö„ÅÑ", "desc": "Thick"},
    
    # „ÅØ„ÇÑ„ÅÑ (early / fast)
    {"context": "ÊØéÊúùËµ∑„Åç„Çã„ÅÆ„Åå", "kana": "„ÅØ„ÇÑ„ÅÑ", "expected": "Êó©„ÅÑ", "desc": "Early"},
    {"context": "ÂΩº„ÅØËµ∞„Çã„ÅÆ„Åå", "kana": "„ÅØ„ÇÑ„ÅÑ", "expected": "ÈÄü„ÅÑ", "desc": "Fast"},
    
    # „ÅØ„Åó (bridge / chopsticks / edge)
    {"context": "Â∑ù„Å´", "kana": "„ÅØ„Åó", "expected": "Ê©ã", "desc": "Bridge"},
    {"context": "„ÅîÈ£Ø„Çí", "kana": "„ÅØ„Åó", "expected": "ÁÆ∏", "desc": "Chopsticks"},
    {"context": "ÈÅì„ÅÆ", "kana": "„ÅØ„Åó", "expected": "Á´Ø", "desc": "Edge"},
    
    # „Åç„Åè (listen / ask / effective)
    {"context": "Èü≥Ê•Ω„Çí", "kana": "„Åç„Åè", "expected": "ËÅ¥„Åè", "desc": "Listen"},
    {"context": "ÈÅì„Çí", "kana": "„Åç„Åè", "expected": "ËÅû„Åè", "desc": "Ask"},
    {"context": "„Åì„ÅÆËñ¨„ÅØ„Çà„Åè", "kana": "„Åç„Åè", "expected": "Âäπ„Åè", "desc": "Effective"},
    
    # „ÅÇ„ÅÜ (meet / fit / encounter)
    {"context": "ÂèãÈÅî„Å´", "kana": "„ÅÇ„ÅÜ", "expected": "‰ºö„ÅÜ", "desc": "Meet"},
    {"context": "„Çµ„Ç§„Ç∫„Åå", "kana": "„ÅÇ„ÅÜ", "expected": "Âêà„ÅÜ", "desc": "Fit"},
    {"context": "‰∫ãÊïÖ„Å´", "kana": "„ÅÇ„ÅÜ", "expected": "ÈÅ≠„ÅÜ", "desc": "Encounter"},
    
    # „Å®„Çã (take photo / take / catch)
    {"context": "ÂÜôÁúü„Çí", "kana": "„Å®„Çã", "expected": "ÊíÆ„Çã", "desc": "Take photo"},
    {"context": "Â°©„Çí", "kana": "„Å®„Çã", "expected": "Âèñ„Çã", "desc": "Take"},
    {"context": "È≠ö„Çí", "kana": "„Å®„Çã", "expected": "Êçï„Çã", "desc": "Catch"},
    
    # „Åç„Åã„Çì (period / institution / return)
    {"context": "„ÉÜ„Çπ„Éà", "kana": "„Åç„Åã„Çì", "expected": "ÊúüÈñì", "desc": "Period"},
    {"context": "‰∫§ÈÄö", "kana": "„Åç„Åã„Çì", "expected": "Ê©üÈñ¢", "desc": "Institution"},
    {"context": "ÂÆáÂÆô„Åã„Çâ", "kana": "„Åç„Åã„Çì", "expected": "Â∏∞ÈÇÑ", "desc": "Return"},
]

print("\nüìù Homophone Disambiguation Test:")
print("-" * 60)
correct = 0
for tc in test_cases:
    result = convert(tc['context'], tc['kana'])
    match = result == tc['expected'] or tc['expected'] in result
    if match:
        correct += 1
    status = "‚úì" if match else "‚úó"
    print(f"{status} [{tc['desc']}]")
    print(f"   Input: {tc['context']}<SEP>{tc['kana']}")
    print(f"   Output: {result} (expected: {tc['expected']})")
    print()

print(f"‚úÖ Score: {correct}/{len(test_cases)} ({correct/len(test_cases)*100:.0f}%)")

In [None]:
print("\nüì± Mobile Usage Example:")
print("-" * 60)
print("""
// Swift usage:
let context = "ÂÜôÁúü„Çí"      // Already converted text
let kana = "„Å®„Çã"           // Currently typing (hiragana)
let input = context + "<SEP>" + kana

let result = model.predict(input)  // ‚Üí "ÊíÆ„Çã"
""")

In [None]:
print("\nüì¶ Exported Files:")
for f in sorted(os.listdir(MODEL_DIR)):
    path = f'{MODEL_DIR}/{f}'
    if os.path.isfile(path):
        size = os.path.getsize(path)
        if size > 1024*1024:
            print(f"  {f}: {size/(1024*1024):.2f} MB")
        else:
            print(f"  {f}: {size/1024:.1f} KB")