# Context-Aware Kana-Kanji Converter (v2)

**Supports**: Google Colab & Kaggle

**Key Insight**: Model size is controlled by ARCHITECTURE, not by sample count!
- More samples = longer training but BETTER accuracy
- Architecture params (EMBEDDING_DIM, GRU_UNITS) = control size

**Input Format**: `context<SEP>kana`
- Before `<SEP>`: context (already converted, kanji)
- After `<SEP>`: kana to convert (hiragana)

**Example**:
```
Input:  ÂÜôÁúü„Çí<SEP>„Å®„Çã
Output: ÊíÆ„Çã
```

In [None]:
import os

# Auto-detect platform: Kaggle or Colab
if os.path.exists('/kaggle'):
    # Kaggle
    PLATFORM = 'Kaggle'
    DRIVE_DIR = '/kaggle/working'
else:
    # Google Colab
    PLATFORM = 'Colab'
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'

MODEL_DIR = f"{DRIVE_DIR}/models/gru_japanese_kana_kanji"
os.makedirs(MODEL_DIR, exist_ok=True)

print(f"‚úÖ Platform: {PLATFORM}")
print(f"üìÅ Model directory: {MODEL_DIR}")

In [None]:
!pip install -q tensorflow keras datasets numpy tqdm fugashi unidic-lite

In [None]:
# ===========================================================
# MODEL SIZE CONTROL
# ===========================================================
# Model size is controlled by architecture, NOT by sample count!
# - More samples = longer training but BETTER accuracy
# - Architecture params = control model size
#
# Like Zenzai GPT-2: train on ALL data, control size via architecture
# Target: Model size under 20MB (FP16 TFLite)
# ===========================================================

TESTING_MODE = False  # Set True for quick test, False for full training
MAX_SAMPLES = None  # Use ALL data! No limit!
BATCH_SIZE = 512

if TESTING_MODE:
    NUM_EPOCHS = 10
else:
    NUM_EPOCHS = 20

# ===========================================================
# ARCHITECTURE CONFIG (these control model size!)
# ===========================================================
# Adjust these to fit model under 20MB:
# - CHAR_VOCAB_SIZE: ~5K covers 99%+ of Japanese
# - EMBEDDING_DIM: smaller = smaller model
# - GRU_UNITS: smaller = smaller model
# ===========================================================

CHAR_VOCAB_SIZE = 5000   # 5K kanji/kana covers 99%+ of text
MAX_INPUT_LEN = 50       # context + <SEP> + kana
MAX_OUTPUT_LEN = 20      # kanji output
EMBEDDING_DIM = 64       # ‚Üì smaller = smaller model
GRU_UNITS = 128          # ‚Üì smaller = smaller model
NUM_ENCODER_LAYERS = 2   # Encoder depth
NUM_DECODER_LAYERS = 2   # Decoder depth

# Estimate model size (rough calculation)
embedding_params = CHAR_VOCAB_SIZE * EMBEDDING_DIM
encoder_params = 2 * NUM_ENCODER_LAYERS * 3 * GRU_UNITS * (EMBEDDING_DIM + GRU_UNITS + 1)
decoder_params = NUM_DECODER_LAYERS * 3 * (GRU_UNITS * 2) * (EMBEDDING_DIM + GRU_UNITS * 2 + 1)
output_params = GRU_UNITS * 4 * CHAR_VOCAB_SIZE  # After concat
total_params = embedding_params + encoder_params + decoder_params + output_params
estimated_size_mb = total_params * 4 / 1024 / 1024
estimated_fp16_mb = estimated_size_mb / 2

print('üìä Estimated Model Size:')
print(f'   Parameters: ~{total_params:,}')
print(f'   FP32: ~{estimated_size_mb:.1f} MB')
print(f'   FP16: ~{estimated_fp16_mb:.1f} MB')
if estimated_fp16_mb < 20:
    print('   ‚úÖ Under 20MB target!')
else:
    print('   ‚ö†Ô∏è Over 20MB target! Reduce EMBEDDING_DIM or GRU_UNITS')

# Special tokens
SPECIAL_TOKENS = ['<PAD>', '<UNK>', '<BOS>', '<EOS>', '<SEP>']
SEP_TOKEN = '<SEP>'

## 1. Load zenz Dataset (ALL DATA)

In [None]:
from datasets import load_dataset

print("Loading zenz-v2.5-dataset...")

try:
    if MAX_SAMPLES:
        dataset = load_dataset(
            "Miwa-Keita/zenz-v2.5-dataset",
            data_files="train_wikipedia.jsonl",
            split=f"train[:{MAX_SAMPLES}]"
        )
    else:
        # Load ALL data!
        dataset = load_dataset(
            "Miwa-Keita/zenz-v2.5-dataset",
            data_files="train_wikipedia.jsonl",
            split="train"
        )
except:
    if MAX_SAMPLES:
        dataset = load_dataset(
            "Miwa-Keita/zenz-v2.5-dataset",
            split=f"train[:{MAX_SAMPLES}]"
        )
    else:
        dataset = load_dataset(
            "Miwa-Keita/zenz-v2.5-dataset",
            split="train"
        )

print(f"‚úì Loaded {len(dataset):,} samples")
print(f"  (Using {'ALL DATA' if not MAX_SAMPLES else f'{MAX_SAMPLES:,} samples'})")

## 2. Setup Japanese Tokenizer (MeCab/fugashi)

In [None]:
import fugashi

tagger = fugashi.Tagger()

def tokenize_japanese(text):
    """Tokenize Japanese text and get reading (kana) for each word."""
    words = []
    for word in tagger(text):
        surface = word.surface
        try:
            reading = word.feature.kana or word.surface
        except:
            reading = word.surface
        words.append({'surface': surface, 'reading': reading})
    return words

# Test tokenizer
test_text = "ÂÜôÁúü„ÇíÊíÆ„Çã"
tokens = tokenize_japanese(test_text)
print(f"Test: {test_text}")
for t in tokens:
    print(f"  {t['surface']} ‚Üí {t['reading']}")

## 3. Create Training Data with Word-Level Alignment

In [None]:
import random
from tqdm import tqdm

def katakana_to_hiragana(text):
    """Convert katakana to hiragana."""
    result = []
    for char in text:
        code = ord(char)
        if 0x30A1 <= code <= 0x30F6:
            result.append(chr(code - 0x60))
        else:
            result.append(char)
    return ''.join(result)

def create_training_examples_from_sentence(kanji_sentence, min_context_words=1, max_target_words=3):
    """Create training examples by splitting at word boundaries."""
    examples = []
    
    tokens = tokenize_japanese(kanji_sentence)
    if len(tokens) < 2:
        return examples
    
    for split_idx in range(min_context_words, len(tokens)):
        context_words = tokens[:split_idx]
        context = ''.join([w['surface'] for w in context_words])
        
        end_idx = min(split_idx + max_target_words, len(tokens))
        target_words = tokens[split_idx:end_idx]
        
        if not target_words:
            continue
        
        target = ''.join([w['surface'] for w in target_words])
        kana = ''.join([w['reading'] for w in target_words])
        kana = katakana_to_hiragana(kana)
        
        if len(context) > 30 or len(kana) > 15 or len(target) > 15:
            continue
        if len(kana) < 1 or len(target) < 1:
            continue
        
        input_text = f"{context}{SEP_TOKEN}{kana}"
        examples.append({
            'input': input_text,
            'output': target,
            'context': context,
            'kana': kana
        })
    
    return examples

print("Generating word-aligned training data...")
print(f"Processing {len(dataset):,} sentences...")
training_data = []

for item in tqdm(dataset, desc="Processing"):
    kanji_sentence = item['output']
    examples = create_training_examples_from_sentence(kanji_sentence)
    training_data.extend(examples)

print(f"\n‚úì Generated {len(training_data):,} training examples")
print("\nSamples:")
for i in range(min(10, len(training_data))):
    d = training_data[i]
    print(f"  {d['context']}<SEP>{d['kana']} ‚Üí {d['output']}")

## 4. Build Vocabulary

In [None]:
from collections import Counter

char_counts = Counter()
for d in tqdm(training_data, desc="Counting"):
    input_text = d['input'].replace(SEP_TOKEN, '')
    char_counts.update(list(input_text))
    char_counts.update(list(d['output']))

print(f"\nUnique chars: {len(char_counts):,}")

char_to_idx = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
for char, _ in char_counts.most_common(CHAR_VOCAB_SIZE - len(SPECIAL_TOKENS)):
    char_to_idx[char] = len(char_to_idx)

idx_to_char = {v: k for k, v in char_to_idx.items()}
vocab_size = len(char_to_idx)
print(f"Vocab size: {vocab_size}")
print(f"<SEP> index: {char_to_idx[SEP_TOKEN]}")

## 5. Create Training Tensors

In [None]:
import numpy as np
import tensorflow as tf

PAD_IDX = char_to_idx['<PAD>']
BOS_IDX = char_to_idx['<BOS>']
EOS_IDX = char_to_idx['<EOS>']
UNK_IDX = char_to_idx['<UNK>']
SEP_IDX = char_to_idx['<SEP>']

def encode_input(text, max_len):
    tokens = []
    i = 0
    while i < len(text):
        if text[i:i+5] == SEP_TOKEN:
            tokens.append(SEP_TOKEN)
            i += 5
        else:
            tokens.append(text[i])
            i += 1
    ids = [char_to_idx.get(t, UNK_IDX) for t in tokens]
    while len(ids) < max_len:
        ids.append(PAD_IDX)
    return ids[:max_len]

def encode_output(text, max_len, add_bos=False, add_eos=False):
    tokens = list(text)
    if add_bos:
        tokens = ['<BOS>'] + tokens
    if add_eos:
        tokens = tokens + ['<EOS>']
    ids = [char_to_idx.get(c, UNK_IDX) for c in tokens]
    while len(ids) < max_len:
        ids.append(PAD_IDX)
    return ids[:max_len]

encoder_inputs = []
decoder_inputs = []
decoder_targets = []

for d in tqdm(training_data, desc="Encoding"):
    encoder_inputs.append(encode_input(d['input'], MAX_INPUT_LEN))
    decoder_inputs.append(encode_output(d['output'], MAX_OUTPUT_LEN, add_bos=True))
    decoder_targets.append(encode_output(d['output'], MAX_OUTPUT_LEN, add_eos=True))

encoder_inputs = np.array(encoder_inputs, dtype=np.int32)
decoder_inputs = np.array(decoder_inputs, dtype=np.int32)
decoder_targets = np.array(decoder_targets, dtype=np.int32)

print(f"\nShapes: {encoder_inputs.shape}, {decoder_inputs.shape}, {decoder_targets.shape}")

In [None]:
idx = np.random.permutation(len(encoder_inputs))
encoder_inputs = encoder_inputs[idx]
decoder_inputs = decoder_inputs[idx]
decoder_targets = decoder_targets[idx]

split = int(len(encoder_inputs) * 0.9)

train_ds = tf.data.Dataset.from_tensor_slices((
    {'encoder_input': encoder_inputs[:split], 'decoder_input': decoder_inputs[:split]},
    decoder_targets[:split]
)).shuffle(10000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((
    {'encoder_input': encoder_inputs[split:], 'decoder_input': decoder_inputs[split:]},
    decoder_targets[split:]
)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print(f"Train: {split:,}, Val: {len(encoder_inputs)-split:,}")

## 6. Build Model (Size Controlled by Architecture)

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Embedding, GRU, Dense, Dropout,
    Bidirectional, Attention, Concatenate, LayerNormalization
)

embedding = Embedding(vocab_size, EMBEDDING_DIM, name='embedding')

# Encoder
encoder_input = Input(shape=(MAX_INPUT_LEN,), dtype='int32', name='encoder_input')
enc_emb = embedding(encoder_input)
enc_emb = Dropout(0.1)(enc_emb)

encoder_out = enc_emb
for i in range(NUM_ENCODER_LAYERS):
    encoder_out = Bidirectional(
        GRU(GRU_UNITS, return_sequences=True),
        name=f'encoder_{i+1}'
    )(encoder_out)
    encoder_out = LayerNormalization()(encoder_out)

# Decoder
decoder_input = Input(shape=(MAX_OUTPUT_LEN,), dtype='int32', name='decoder_input')
dec_emb = embedding(decoder_input)
dec_emb = Dropout(0.1)(dec_emb)

decoder_out = dec_emb
for i in range(NUM_DECODER_LAYERS):
    decoder_out = GRU(GRU_UNITS * 2, return_sequences=True, name=f'decoder_{i+1}')(decoder_out)
    decoder_out = LayerNormalization()(decoder_out)

# Attention
context = Attention(use_scale=True, name='attention')([decoder_out, encoder_out])

# Combine and output
combined = Concatenate()([decoder_out, context])
combined = LayerNormalization()(combined)
combined = Dropout(0.2)(combined)
combined = Dense(GRU_UNITS * 2, activation='relu')(combined)
output = Dense(vocab_size, activation='softmax', name='output')(combined)

model = Model(
    inputs=[encoder_input, decoder_input],
    outputs=output,
    name='context_kana_kanji_v2'
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()
print(f"\nüìä Actual Model Stats:")
print(f"   Parameters: {model.count_params():,}")
print(f"   Size FP32: ~{model.count_params() * 4 / 1024 / 1024:.1f} MB")
print(f"   Size FP16: ~{model.count_params() * 2 / 1024 / 1024:.1f} MB")

## 7. Train

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

callbacks = [
    ModelCheckpoint(f'{MODEL_DIR}/best.keras', monitor='val_accuracy', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6, verbose=1)
]

history = model.fit(train_ds, epochs=NUM_EPOCHS, validation_data=val_ds, callbacks=callbacks)

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['loss'], label='Train')
ax1.plot(history.history['val_loss'], label='Val')
ax1.set_title('Loss'); ax1.legend()
ax2.plot(history.history['accuracy'], label='Train')
ax2.plot(history.history['val_accuracy'], label='Val')
ax2.set_title('Accuracy'); ax2.legend()
plt.savefig(f'{MODEL_DIR}/training.png')
plt.show()
print(f"Best val_accuracy: {max(history.history['val_accuracy'])*100:.2f}%")

## 8. Save

In [None]:
import json

model.save(f'{MODEL_DIR}/model.keras')

with open(f'{MODEL_DIR}/char_to_idx.json', 'w', encoding='utf-8') as f:
    json.dump(char_to_idx, f, ensure_ascii=False)
with open(f'{MODEL_DIR}/idx_to_char.json', 'w', encoding='utf-8') as f:
    json.dump({str(k): v for k, v in idx_to_char.items()}, f, ensure_ascii=False)
with open(f'{MODEL_DIR}/config.json', 'w') as f:
    json.dump({
        'vocab_size': vocab_size,
        'max_input_len': MAX_INPUT_LEN,
        'max_output_len': MAX_OUTPUT_LEN,
        'sep_token': SEP_TOKEN,
        'sep_idx': SEP_IDX
    }, f)

print("‚úì Saved")

In [None]:
try:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter._experimental_lower_tensor_list_ops = False
    
    tflite_model = converter.convert()
    with open(f'{MODEL_DIR}/model.tflite', 'wb') as f:
        f.write(tflite_model)
    print(f"‚úì model.tflite ({len(tflite_model)/(1024*1024):.2f}MB)")
    
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    tflite_fp16 = converter.convert()
    with open(f'{MODEL_DIR}/model_fp16.tflite', 'wb') as f:
        f.write(tflite_fp16)
    print(f"‚úì model_fp16.tflite ({len(tflite_fp16)/(1024*1024):.2f}MB)")
except Exception as e:
    print(f"‚ö† {e}")

## 9. Verification

In [None]:
print("="*60)
print("VERIFICATION: Context-Aware Kana-Kanji Conversion")
print("="*60)
print(f"\nInput format: context{SEP_TOKEN}kana")

def convert(context, kana, max_len=20):
    input_text = f"{context}{SEP_TOKEN}{kana}"
    enc_in = np.array([encode_input(input_text, MAX_INPUT_LEN)], dtype=np.int32)
    dec_in = np.zeros((1, MAX_OUTPUT_LEN), dtype=np.int32)
    dec_in[0, 0] = BOS_IDX
    
    result = []
    for i in range(max_len):
        preds = model.predict({'encoder_input': enc_in, 'decoder_input': dec_in}, verbose=0)
        next_idx = int(np.argmax(preds[0, i]))
        if next_idx == EOS_IDX:
            break
        if next_idx not in [PAD_IDX, BOS_IDX, EOS_IDX, UNK_IDX, SEP_IDX]:
            result.append(idx_to_char.get(next_idx, ''))
        if i + 1 < MAX_OUTPUT_LEN:
            dec_in[0, i + 1] = next_idx
    return ''.join(result)

# All test cases from user
test_cases = [
    {"context": "‰ªäÊó•„ÅØ„Å®„Å¶„ÇÇ", "kana": "„ÅÇ„Å§„ÅÑ", "expected": "Êöë„ÅÑ", "desc": "Weather hot"},
    {"context": "„ÅäËå∂„Åå", "kana": "„ÅÇ„Å§„ÅÑ", "expected": "ÁÜ±„ÅÑ", "desc": "Temperature hot"},
    {"context": "„Åì„ÅÆËæûÂÖ∏„ÅØ", "kana": "„ÅÇ„Å§„ÅÑ", "expected": "Âéö„ÅÑ", "desc": "Thick"},
    {"context": "ÊØéÊúùËµ∑„Åç„Çã„ÅÆ„Åå", "kana": "„ÅØ„ÇÑ„ÅÑ", "expected": "Êó©„ÅÑ", "desc": "Early"},
    {"context": "ÂΩº„ÅØËµ∞„Çã„ÅÆ„Åå", "kana": "„ÅØ„ÇÑ„ÅÑ", "expected": "ÈÄü„ÅÑ", "desc": "Fast"},
    {"context": "Â∑ù„Å´", "kana": "„ÅØ„Åó", "expected": "Ê©ã", "desc": "Bridge"},
    {"context": "„ÅîÈ£Ø„Çí", "kana": "„ÅØ„Åó", "expected": "ÁÆ∏", "desc": "Chopsticks"},
    {"context": "ÈÅì„ÅÆ", "kana": "„ÅØ„Åó", "expected": "Á´Ø", "desc": "Edge"},
    {"context": "Èü≥Ê•Ω„Çí", "kana": "„Åç„Åè", "expected": "ËÅ¥„Åè", "desc": "Listen"},
    {"context": "ÈÅì„Çí", "kana": "„Åç„Åè", "expected": "ËÅû„Åè", "desc": "Ask"},
    {"context": "„Åì„ÅÆËñ¨„ÅØ„Çà„Åè", "kana": "„Åç„Åè", "expected": "Âäπ„Åè", "desc": "Effective"},
    {"context": "Á®éÈáë„Çí", "kana": "„Åä„Åï„ÇÅ„Çã", "expected": "Á¥ç„ÇÅ„Çã", "desc": "Pay"},
    {"context": "ÂõΩ„Çí", "kana": "„Åä„Åï„ÇÅ„Çã", "expected": "Ê≤ª„ÇÅ„Çã", "desc": "Govern"},
    {"context": "Â≠¶Âïè„Çí", "kana": "„Åä„Åï„ÇÅ„Çã", "expected": "‰øÆ„ÇÅ„Çã", "desc": "Master"},
    {"context": "ÂèãÈÅî„Å´", "kana": "„ÅÇ„ÅÜ", "expected": "‰ºö„ÅÜ", "desc": "Meet"},
    {"context": "„Çµ„Ç§„Ç∫„Åå", "kana": "„ÅÇ„ÅÜ", "expected": "Âêà„ÅÜ", "desc": "Fit"},
    {"context": "‰∫ãÊïÖ„Å´", "kana": "„ÅÇ„ÅÜ", "expected": "ÈÅ≠„ÅÜ", "desc": "Encounter"},
    {"context": "ÂÜôÁúü„Çí", "kana": "„Å®„Çã", "expected": "ÊíÆ„Çã", "desc": "Take photo"},
    {"context": "Â°©„Çí", "kana": "„Å®„Çã", "expected": "Âèñ„Çã", "desc": "Take"},
    {"context": "È≠ö„Çí", "kana": "„Å®„Çã", "expected": "Êçï„Çã", "desc": "Catch"},
    {"context": "Áü≥„ÅØ", "kana": "„Åã„Åü„ÅÑ", "expected": "Á°¨„ÅÑ", "desc": "Hard solid"},
    {"context": "Ê±∫ÊÑè„Åå", "kana": "„Åã„Åü„ÅÑ", "expected": "Âõ∫„ÅÑ", "desc": "Hard firm"},
    {"context": "Êú¨„ÅÆÂÜÖÂÆπ„Åå", "kana": "„Åã„Åü„ÅÑ", "expected": "Â†Ö„ÅÑ", "desc": "Hard strict"},
    {"context": "„ÉÜ„Çπ„Éà", "kana": "„Åç„Åã„Çì", "expected": "ÊúüÈñì", "desc": "Period"},
    {"context": "‰∫§ÈÄö", "kana": "„Åç„Åã„Çì", "expected": "Ê©üÈñ¢", "desc": "Institution"},
    {"context": "ÂÆáÂÆô„Åã„Çâ", "kana": "„Åç„Åã„Çì", "expected": "Â∏∞ÈÇÑ", "desc": "Return"},
]

print("\nüìù Homophone Disambiguation Test:")
print("-" * 60)
correct = 0
for tc in test_cases:
    result = convert(tc['context'], tc['kana'])
    match = result == tc['expected'] or tc['expected'] in result or result in tc['expected']
    if match:
        correct += 1
    status = "‚úì" if match else "‚úó"
    print(f"{status} [{tc['desc']}]")
    print(f"   {tc['context']}<SEP>{tc['kana']} ‚Üí {result} (expected: {tc['expected']})")

print(f"\n‚úÖ Score: {correct}/{len(test_cases)} ({correct/len(test_cases)*100:.0f}%)")

In [None]:
print(f"\nüì¶ Exported Files ({PLATFORM}):")
for f in sorted(os.listdir(MODEL_DIR)):
    path = f'{MODEL_DIR}/{f}'
    if os.path.isfile(path):
        size = os.path.getsize(path)
        if size > 1024*1024:
            print(f"  {f}: {size/(1024*1024):.2f} MB")
        else:
            print(f"  {f}: {size/1024:.1f} KB")