# Context-Aware Kana-Kanji Converter (v3)

**Supports**: Google Colab & Kaggle

**Key Features**:
- Uses zenz dataset **directly** (no tokenization needed)
- **Length-based bucketing** for stable GRU training
- Bidirectional GRU + Attention (best for this task)
- Model size controlled by architecture (<20MB FP16)

**Input Format**: `context<SEP>kana` ‚Üí `kanji`

In [None]:
import os

# Auto-detect platform
if os.path.exists('/kaggle'):
    PLATFORM = 'Kaggle'
    DRIVE_DIR = '/kaggle/working'
else:
    PLATFORM = 'Colab'
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'

MODEL_DIR = f"{DRIVE_DIR}/models/gru_japanese_kana_kanji"
CACHE_DIR = f"{DRIVE_DIR}/cache/kana_kanji"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)

print(f"‚úÖ Platform: {PLATFORM}")
print(f"üìÅ Model: {MODEL_DIR}")
print(f"üíæ Cache: {CACHE_DIR}")

In [None]:
!pip install -q tensorflow keras datasets numpy tqdm

In [None]:
# ===========================================================
# CONFIGURATION
# ===========================================================
TESTING_MODE = False
MAX_SAMPLES = 5_000_000  # Start with 5M (set None for all ~17.5M)
BATCH_SIZE = 512
FORCE_REBUILD_CACHE = False

NUM_EPOCHS = 10 if TESTING_MODE else 20

# Length limits (filter long sequences)
MAX_CONTEXT_LEN = 30   # left_context max chars
MAX_INPUT_LEN = 30     # kana input max chars
MAX_OUTPUT_LEN = 20    # kanji output max chars
MAX_ENCODER_LEN = MAX_CONTEXT_LEN + 5 + MAX_INPUT_LEN  # context + <SEP> + input

# Architecture (controls model size)
CHAR_VOCAB_SIZE = 6000
EMBEDDING_DIM = 64
GRU_UNITS = 128
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2

# Estimate size
est_params = CHAR_VOCAB_SIZE * EMBEDDING_DIM + GRU_UNITS * 4 * CHAR_VOCAB_SIZE
print(f'üìä Est. ~{est_params * 2 / 1024 / 1024:.1f} MB FP16')

SPECIAL_TOKENS = ['<PAD>', '<UNK>', '<BOS>', '<EOS>', '<SEP>']
SEP_TOKEN = '<SEP>'

## 1. Load or Build Cached Data

Uses zenz dataset **directly** - no tokenization needed!

In [None]:
import json
import numpy as np
from tqdm.auto import tqdm

VOCAB_CACHE = f"{CACHE_DIR}/vocabulary_v3.json"
TENSORS_CACHE = f"{CACHE_DIR}/tensors_v3.npz"

def cache_exists():
    return os.path.exists(VOCAB_CACHE) and os.path.exists(TENSORS_CACHE)

if cache_exists() and not FORCE_REBUILD_CACHE:
    print("üì¶ Loading from cache...")
    
    with open(VOCAB_CACHE, 'r', encoding='utf-8') as f:
        vocab_data = json.load(f)
    char_to_idx = vocab_data['char_to_idx']
    idx_to_char = {int(k): v for k, v in vocab_data['idx_to_char'].items()}
    vocab_size = len(char_to_idx)
    
    tensors = np.load(TENSORS_CACHE)
    encoder_inputs = tensors['encoder_inputs']
    decoder_inputs = tensors['decoder_inputs']
    decoder_targets = tensors['decoder_targets']
    
    print(f"‚úì Loaded {len(encoder_inputs):,} samples")
    CACHE_LOADED = True
else:
    print("üî® Building from scratch (using dataset directly)...")
    CACHE_LOADED = False

In [None]:
# Load dataset directly - no tokenization needed!
if not CACHE_LOADED:
    from datasets import load_dataset
    
    print("üì• Loading zenz dataset...")
    dataset = load_dataset(
        "Miwa-Keita/zenz-v2.5-dataset",
        data_files="train_wikipedia.jsonl",
        split="train"
    )
    print(f"‚úì Raw dataset: {len(dataset):,} items")

In [None]:
# Filter and prepare training data (fast - no tokenization!)
if not CACHE_LOADED:
    print("\nüîç Filtering and preparing data...")
    
    training_data = []
    skipped = {'too_long': 0, 'empty': 0}
    
    for item in tqdm(dataset, desc="Processing"):
        # Get fields from dataset
        kana_input = item.get('input', '') or ''  # Katakana input
        kanji_output = item.get('output', '') or ''  # Kanji output
        left_context = item.get('left_context', '') or ''  # Context (can be null)
        
        # Skip empty
        if not kana_input or not kanji_output:
            skipped['empty'] += 1
            continue
        
        # Skip too long
        if (len(left_context) > MAX_CONTEXT_LEN or 
            len(kana_input) > MAX_INPUT_LEN or 
            len(kanji_output) > MAX_OUTPUT_LEN):
            skipped['too_long'] += 1
            continue
        
        # Format: "context<SEP>kana" -> "kanji"
        encoder_text = f"{left_context}<SEP>{kana_input}"
        decoder_text = kanji_output
        
        training_data.append({
            'input': encoder_text,
            'output': decoder_text,
            'input_len': len(kana_input)  # For bucketing
        })
        
        # Stop if reached limit
        if MAX_SAMPLES and len(training_data) >= MAX_SAMPLES:
            break
    
    print(f"\n‚úì Valid examples: {len(training_data):,}")
    print(f"  Skipped (too long): {skipped['too_long']:,}")
    print(f"  Skipped (empty): {skipped['empty']:,}")

In [None]:
# Sort by length for bucketing (helps GRU training stability)
if not CACHE_LOADED:
    print("\nüìä Sorting by length (bucketing)...")
    training_data.sort(key=lambda x: x['input_len'])
    
    # Show length distribution
    lengths = [d['input_len'] for d in training_data]
    print(f"  Short (0-10): {sum(1 for l in lengths if l <= 10):,}")
    print(f"  Medium (11-20): {sum(1 for l in lengths if 10 < l <= 20):,}")
    print(f"  Long (21+): {sum(1 for l in lengths if l > 20):,}")

In [None]:
# Build vocabulary
if not CACHE_LOADED:
    from collections import Counter
    
    print("\nüìù Building vocabulary...")
    char_counts = Counter()
    
    for d in tqdm(training_data, desc="Counting chars"):
        # Count chars (excluding <SEP> marker)
        text = d['input'].replace('<SEP>', '') + d['output']
        char_counts.update(list(text))
    
    # Build vocab with special tokens first
    char_to_idx = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
    for char, _ in char_counts.most_common(CHAR_VOCAB_SIZE - len(SPECIAL_TOKENS)):
        char_to_idx[char] = len(char_to_idx)
    
    idx_to_char = {v: k for k, v in char_to_idx.items()}
    vocab_size = len(char_to_idx)
    print(f"‚úì Vocab size: {vocab_size}")

In [None]:
# Encode to tensors
if not CACHE_LOADED:
    print("\nüî¢ Encoding to tensors...")
    
    PAD, UNK = char_to_idx['<PAD>'], char_to_idx['<UNK>']
    
    def encode_input(text):
        """Encode input with <SEP> token handling."""
        tokens = []
        i = 0
        while i < len(text):
            if text[i:i+5] == '<SEP>':
                tokens.append('<SEP>')
                i += 5
            else:
                tokens.append(text[i])
                i += 1
        ids = [char_to_idx.get(t, UNK) for t in tokens][:MAX_ENCODER_LEN]
        return ids + [PAD] * (MAX_ENCODER_LEN - len(ids))
    
    def encode_output(text, add_bos=False, add_eos=False):
        """Encode output with optional BOS/EOS."""
        tokens = []
        if add_bos:
            tokens.append('<BOS>')
        tokens.extend(list(text))
        if add_eos:
            tokens.append('<EOS>')
        ids = [char_to_idx.get(c, UNK) for c in tokens][:MAX_OUTPUT_LEN + 1]
        return ids + [PAD] * (MAX_OUTPUT_LEN + 1 - len(ids))
    
    # Encode all data
    encoder_inputs = np.array(
        [encode_input(d['input']) for d in tqdm(training_data, desc="Encoding")],
        dtype=np.int32
    )
    decoder_inputs = np.array(
        [encode_output(d['output'], add_bos=True) for d in training_data],
        dtype=np.int32
    )
    decoder_targets = np.array(
        [encode_output(d['output'], add_eos=True) for d in training_data],
        dtype=np.int32
    )
    
    # Save cache
    print("\nüíæ Saving cache...")
    with open(VOCAB_CACHE, 'w', encoding='utf-8') as f:
        json.dump({
            'char_to_idx': char_to_idx,
            'idx_to_char': {str(k): v for k, v in idx_to_char.items()}
        }, f, ensure_ascii=False)
    
    np.savez_compressed(
        TENSORS_CACHE,
        encoder_inputs=encoder_inputs,
        decoder_inputs=decoder_inputs,
        decoder_targets=decoder_targets
    )
    print("‚úì Cached!")

print(f"\nüìä Data shape: {encoder_inputs.shape}")

## 2. Create Dataset

In [None]:
import tensorflow as tf

# Shuffle within buckets (keep similar lengths together)
def bucket_shuffle(data, bucket_size=50000):
    """Shuffle within buckets to maintain length grouping."""
    n = len(data)
    indices = []
    for start in range(0, n, bucket_size):
        end = min(start + bucket_size, n)
        bucket_idx = list(range(start, end))
        np.random.shuffle(bucket_idx)
        indices.extend(bucket_idx)
    return np.array(indices)

idx = bucket_shuffle(encoder_inputs)
encoder_inputs = encoder_inputs[idx]
decoder_inputs = decoder_inputs[idx]
decoder_targets = decoder_targets[idx]

# Split
split = int(len(encoder_inputs) * 0.9)
train_ds = tf.data.Dataset.from_tensor_slices((
    {'encoder_input': encoder_inputs[:split], 'decoder_input': decoder_inputs[:split]},
    decoder_targets[:split]
)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((
    {'encoder_input': encoder_inputs[split:], 'decoder_input': decoder_inputs[split:]},
    decoder_targets[split:]
)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print(f"Train: {split:,}, Val: {len(encoder_inputs)-split:,}")

## 3. Build Model (Bidirectional GRU + Attention)

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, GRU, Dense, Dropout, Bidirectional, Attention, Concatenate, LayerNormalization

# Shared embedding
emb = Embedding(vocab_size, EMBEDDING_DIM, name='embedding')

# Encoder (Bidirectional GRU - reads input forward and backward)
enc_in = Input(shape=(MAX_ENCODER_LEN,), dtype='int32', name='encoder_input')
x = Dropout(0.1)(emb(enc_in))
for i in range(NUM_ENCODER_LAYERS):
    x = LayerNormalization()(Bidirectional(GRU(GRU_UNITS, return_sequences=True), name=f'enc_{i+1}')(x))
enc_out = x

# Decoder (GRU with Attention)
dec_in = Input(shape=(MAX_OUTPUT_LEN + 1,), dtype='int32', name='decoder_input')
y = Dropout(0.1)(emb(dec_in))
for i in range(NUM_DECODER_LAYERS):
    y = LayerNormalization()(GRU(GRU_UNITS * 2, return_sequences=True, name=f'dec_{i+1}')(y))

# Attention mechanism
ctx = Attention(use_scale=True, name='attn')([y, enc_out])

# Output
combined = Concatenate()([y, ctx])
combined = LayerNormalization()(combined)
combined = Dropout(0.2)(combined)
combined = Dense(GRU_UNITS * 2, activation='relu')(combined)
out = Dense(vocab_size, activation='softmax', name='output')(combined)

model = Model([enc_in, dec_in], out, name='kana_kanji_v3')
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

print(f"üìä {model.count_params():,} params, ~{model.count_params()*2/1024/1024:.1f}MB FP16")
model.summary()

## 4. Train

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

callbacks = [
    ModelCheckpoint(f'{MODEL_DIR}/best.keras', monitor='val_accuracy', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6, verbose=1)
]

history = model.fit(
    train_ds,
    epochs=NUM_EPOCHS,
    validation_data=val_ds,
    callbacks=callbacks
)

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['loss'], label='Train')
ax1.plot(history.history['val_loss'], label='Val')
ax1.legend(); ax1.set_title('Loss')

ax2.plot(history.history['accuracy'], label='Train')
ax2.plot(history.history['val_accuracy'], label='Val')
ax2.legend(); ax2.set_title('Accuracy')

plt.savefig(f'{MODEL_DIR}/training.png')
plt.show()
print(f"Best val accuracy: {max(history.history['val_accuracy'])*100:.2f}%")

## 5. Save

In [None]:
# Save model and config
model.save(f'{MODEL_DIR}/model.keras')

with open(f'{MODEL_DIR}/char_to_idx.json', 'w', encoding='utf-8') as f:
    json.dump(char_to_idx, f, ensure_ascii=False)

with open(f'{MODEL_DIR}/idx_to_char.json', 'w', encoding='utf-8') as f:
    json.dump({str(k): v for k, v in idx_to_char.items()}, f, ensure_ascii=False)

with open(f'{MODEL_DIR}/config.json', 'w') as f:
    json.dump({
        'vocab_size': vocab_size,
        'max_encoder_len': MAX_ENCODER_LEN,
        'max_output_len': MAX_OUTPUT_LEN + 1,
        'sep_token': SEP_TOKEN
    }, f)

print("‚úì Model saved")

In [None]:
# Export TFLite
try:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter._experimental_lower_tensor_list_ops = False
    
    tflite = converter.convert()
    with open(f'{MODEL_DIR}/model.tflite', 'wb') as f:
        f.write(tflite)
    print(f"‚úì model.tflite ({len(tflite)/(1024*1024):.2f}MB)")
    
    # FP16 version
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    tflite16 = converter.convert()
    with open(f'{MODEL_DIR}/model_fp16.tflite', 'wb') as f:
        f.write(tflite16)
    print(f"‚úì model_fp16.tflite ({len(tflite16)/(1024*1024):.2f}MB)")
    
except Exception as e:
    print(f"‚ö† TFLite export failed: {e}")

## 6. Verification

In [None]:
print("="*60)
print("VERIFICATION")
print("="*60)

PAD, BOS, EOS, UNK, SEP = [char_to_idx[t] for t in ['<PAD>', '<BOS>', '<EOS>', '<UNK>', '<SEP>']]

def encode_input_for_inference(context, kana):
    """Encode input for inference."""
    text = f"{context}<SEP>{kana}"
    tokens = []
    i = 0
    while i < len(text):
        if text[i:i+5] == '<SEP>':
            tokens.append('<SEP>')
            i += 5
        else:
            tokens.append(text[i])
            i += 1
    ids = [char_to_idx.get(t, UNK) for t in tokens][:MAX_ENCODER_LEN]
    ids = ids + [PAD] * (MAX_ENCODER_LEN - len(ids))
    return np.array([ids], dtype=np.int32)

def convert(context, kana):
    """Convert kana to kanji using context."""
    enc_in = encode_input_for_inference(context, kana)
    dec_in = np.zeros((1, MAX_OUTPUT_LEN + 1), dtype=np.int32)
    dec_in[0, 0] = BOS
    
    result = []
    for i in range(MAX_OUTPUT_LEN):
        pred = model.predict({'encoder_input': enc_in, 'decoder_input': dec_in}, verbose=0)
        next_id = int(np.argmax(pred[0, i]))
        
        if next_id == EOS:
            break
        if next_id not in [PAD, BOS, EOS, UNK, SEP]:
            result.append(idx_to_char.get(next_id, ''))
        
        if i + 1 < MAX_OUTPUT_LEN + 1:
            dec_in[0, i + 1] = next_id
    
    return ''.join(result)

# Test cases (context, kana, expected kanji)
tests = [
    ("‰ªäÊó•„ÅØ„Å®„Å¶„ÇÇ", "„Ç¢„ÉÑ„Ç§", "Êöë„ÅÑ"),
    ("„ÅäËå∂„Åå", "„Ç¢„ÉÑ„Ç§", "ÁÜ±„ÅÑ"),
    ("„Åì„ÅÆËæûÂÖ∏„ÅØ", "„Ç¢„ÉÑ„Ç§", "Âéö„ÅÑ"),
    ("ÊØéÊúùËµ∑„Åç„Çã„ÅÆ„Åå", "„Éè„É§„Ç§", "Êó©„ÅÑ"),
    ("ÂΩº„ÅØËµ∞„Çã„ÅÆ„Åå", "„Éè„É§„Ç§", "ÈÄü„ÅÑ"),
    ("Â∑ù„Å´", "„Éè„Ç∑", "Ê©ã"),
    ("„ÅîÈ£Ø„Çí", "„Éè„Ç∑", "ÁÆ∏"),
    ("ÈÅì„ÅÆ", "„Éè„Ç∑", "Á´Ø"),
    ("Èü≥Ê•Ω„Çí", "„Ç≠„ÇØ", "ËÅ¥„Åè"),
    ("ÈÅì„Çí", "„Ç≠„ÇØ", "ËÅû„Åè"),
    ("ÂÜôÁúü„Çí", "„Éà„É´", "ÊíÆ„Çã"),
    ("Â°©„Çí", "„Éà„É´", "Âèñ„Çã"),
]

correct = 0
for ctx, kana, expected in tests:
    result = convert(ctx, kana)
    ok = result == expected or expected in result or result in expected
    if ok:
        correct += 1
    print(f"{'‚úì' if ok else '‚úó'} {ctx}<SEP>{kana} ‚Üí {result} (expected: {expected})")

print(f"\n‚úÖ Score: {correct}/{len(tests)} ({correct/len(tests)*100:.0f}%)")

In [None]:
# List saved files
print(f"\nüì¶ Files ({PLATFORM}):")
for f in sorted(os.listdir(MODEL_DIR)):
    p = f'{MODEL_DIR}/{f}'
    if os.path.isfile(p):
        s = os.path.getsize(p)
        if s > 1024*1024:
            print(f"  {f}: {s/(1024*1024):.2f}MB")
        else:
            print(f"  {f}: {s/1024:.1f}KB")