# Context-Aware Kana-Kanji Converter (v2)

**Supports**: Google Colab & Kaggle

**Key Features**:
- Model size controlled by architecture (target: <20MB FP16)
- Preprocessing cached for fast subsequent runs
- **Multiprocessing** for 4-8x faster tokenization

**Input Format**: `context<SEP>kana` ‚Üí `kanji`

**Example**: `ÂÜôÁúü„Çí<SEP>„Å®„Çã` ‚Üí `ÊíÆ„Çã`

In [None]:
import os

# Auto-detect platform
if os.path.exists('/kaggle'):
    PLATFORM = 'Kaggle'
    DRIVE_DIR = '/kaggle/working'
else:
    PLATFORM = 'Colab'
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'

MODEL_DIR = f"{DRIVE_DIR}/models/gru_japanese_kana_kanji"
CACHE_DIR = f"{DRIVE_DIR}/cache/kana_kanji"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)

print(f"‚úÖ Platform: {PLATFORM}")
print(f"üìÅ Model: {MODEL_DIR}")
print(f"üíæ Cache: {CACHE_DIR}")

In [None]:
!pip install -q tensorflow keras datasets numpy tqdm fugashi unidic-lite

In [None]:
# ===========================================================
# CONFIGURATION
# ===========================================================
TESTING_MODE = False
MAX_SAMPLES = None  # Use ALL data
BATCH_SIZE = 512
FORCE_REBUILD_CACHE = False
NUM_WORKERS = 4  # ‚Üê Parallel workers (increase for more CPUs)

NUM_EPOCHS = 10 if TESTING_MODE else 20

# Architecture (controls model size)
CHAR_VOCAB_SIZE = 6000
MAX_INPUT_LEN = 50
MAX_OUTPUT_LEN = 20
EMBEDDING_DIM = 64
GRU_UNITS = 128
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2

# Estimate size
est_params = CHAR_VOCAB_SIZE * EMBEDDING_DIM + GRU_UNITS * 4 * CHAR_VOCAB_SIZE
print(f'üìä Est. ~{est_params * 2 / 1024 / 1024:.1f} MB FP16')

SPECIAL_TOKENS = ['<PAD>', '<UNK>', '<BOS>', '<EOS>', '<SEP>']
SEP_TOKEN = '<SEP>'

## 1. Load or Build Cached Data

**First run**: ~10-15 min (with multiprocessing) | **Subsequent**: ~30 sec

In [None]:
import json
import numpy as np
from tqdm import tqdm

VOCAB_CACHE = f"{CACHE_DIR}/vocabulary.json"
TENSORS_CACHE = f"{CACHE_DIR}/tensors.npz"

def cache_exists():
    return os.path.exists(VOCAB_CACHE) and os.path.exists(TENSORS_CACHE)

if cache_exists() and not FORCE_REBUILD_CACHE:
    print("üì¶ Loading from cache...")
    
    with open(VOCAB_CACHE, 'r', encoding='utf-8') as f:
        vocab_data = json.load(f)
    char_to_idx = vocab_data['char_to_idx']
    idx_to_char = {int(k): v for k, v in vocab_data['idx_to_char'].items()}
    vocab_size = len(char_to_idx)
    
    tensors = np.load(TENSORS_CACHE)
    encoder_inputs = tensors['encoder_inputs']
    decoder_inputs = tensors['decoder_inputs']
    decoder_targets = tensors['decoder_targets']
    
    print(f"‚úì Loaded {len(encoder_inputs):,} samples")
    CACHE_LOADED = True
else:
    print("üî® Building from scratch (with multiprocessing)...")
    CACHE_LOADED = False

In [None]:
if not CACHE_LOADED:
    from datasets import load_dataset
    
    print("Loading dataset...")
    try:
        if MAX_SAMPLES:
            dataset = load_dataset("Miwa-Keita/zenz-v2.5-dataset", 
                                   data_files="train_wikipedia.jsonl", 
                                   split=f"train[:{MAX_SAMPLES}]")
        else:
            dataset = load_dataset("Miwa-Keita/zenz-v2.5-dataset", 
                                   data_files="train_wikipedia.jsonl", 
                                   split="train")
    except:
        dataset = load_dataset("Miwa-Keita/zenz-v2.5-dataset", split="train")
    
    sentences = [item['output'] for item in dataset]
    print(f"‚úì Loaded {len(sentences):,} sentences")

In [None]:
# MULTIPROCESSING: 4-8x faster!
if not CACHE_LOADED:
    import multiprocessing as mp
    import fugashi
    
    def katakana_to_hiragana(text):
        result = []
        for char in text:
            code = ord(char)
            if 0x30A1 <= code <= 0x30F6:
                result.append(chr(code - 0x60))
            else:
                result.append(char)
        return ''.join(result)
    
    def process_batch(batch_sentences):
        """Process a batch of sentences (runs in worker process)."""
        tagger = fugashi.Tagger()  # Each worker gets own tagger
        results = []
        
        for sentence in batch_sentences:
            tokens = []
            for word in tagger(sentence):
                try:
                    reading = word.feature.kana or word.surface
                except:
                    reading = word.surface
                tokens.append({'surface': word.surface, 'reading': reading})
            
            if len(tokens) < 2:
                continue
            
            for split_idx in range(1, len(tokens)):
                context = ''.join([w['surface'] for w in tokens[:split_idx]])
                end_idx = min(split_idx + 3, len(tokens))
                target_words = tokens[split_idx:end_idx]
                
                if not target_words:
                    continue
                
                target = ''.join([w['surface'] for w in target_words])
                kana = katakana_to_hiragana(''.join([w['reading'] for w in target_words]))
                
                if len(context) > 30 or len(kana) > 15 or len(target) > 15:
                    continue
                if len(kana) < 1 or len(target) < 1:
                    continue
                
                results.append({'input': f"{context}<SEP>{kana}", 'output': target})
        
        return results
    
    # Split sentences into batches
    batch_size = len(sentences) // NUM_WORKERS + 1
    batches = [sentences[i:i+batch_size] for i in range(0, len(sentences), batch_size)]
    
    print(f"üöÄ Processing with {NUM_WORKERS} workers...")
    
    # Use multiprocessing pool
    with mp.Pool(NUM_WORKERS) as pool:
        results = list(tqdm(pool.imap(process_batch, batches), total=len(batches), desc="Workers"))
    
    # Flatten results
    training_data = []
    for batch_result in results:
        training_data.extend(batch_result)
    
    print(f"‚úì Generated {len(training_data):,} examples")

In [None]:
if not CACHE_LOADED:
    from collections import Counter
    
    print("Building vocabulary...")
    char_counts = Counter()
    for d in tqdm(training_data, desc="Counting"):
        char_counts.update(list(d['input'].replace('<SEP>', '')))
        char_counts.update(list(d['output']))
    
    char_to_idx = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
    for char, _ in char_counts.most_common(CHAR_VOCAB_SIZE - len(SPECIAL_TOKENS)):
        char_to_idx[char] = len(char_to_idx)
    
    idx_to_char = {v: k for k, v in char_to_idx.items()}
    vocab_size = len(char_to_idx)
    print(f"Vocab: {vocab_size}")
    
    # Encode
    print("Encoding...")
    PAD, UNK = char_to_idx['<PAD>'], char_to_idx['<UNK>']
    
    def encode_in(text, max_len):
        tokens = []
        i = 0
        while i < len(text):
            if text[i:i+5] == '<SEP>':
                tokens.append('<SEP>')
                i += 5
            else:
                tokens.append(text[i])
                i += 1
        ids = [char_to_idx.get(t, UNK) for t in tokens][:max_len]
        return ids + [PAD] * (max_len - len(ids))
    
    def encode_out(text, max_len, bos=False, eos=False):
        tokens = (['<BOS>'] if bos else []) + list(text) + (['<EOS>'] if eos else [])
        ids = [char_to_idx.get(c, UNK) for c in tokens][:max_len]
        return ids + [PAD] * (max_len - len(ids))
    
    encoder_inputs = np.array([encode_in(d['input'], MAX_INPUT_LEN) for d in tqdm(training_data)], dtype=np.int32)
    decoder_inputs = np.array([encode_out(d['output'], MAX_OUTPUT_LEN, bos=True) for d in training_data], dtype=np.int32)
    decoder_targets = np.array([encode_out(d['output'], MAX_OUTPUT_LEN, eos=True) for d in training_data], dtype=np.int32)
    
    # Save cache
    print("üíæ Saving cache...")
    with open(VOCAB_CACHE, 'w', encoding='utf-8') as f:
        json.dump({'char_to_idx': char_to_idx, 'idx_to_char': {str(k): v for k, v in idx_to_char.items()}}, f, ensure_ascii=False)
    np.savez_compressed(TENSORS_CACHE, encoder_inputs=encoder_inputs, decoder_inputs=decoder_inputs, decoder_targets=decoder_targets)
    print("‚úì Cached!")

print(f"Data: {encoder_inputs.shape}")

## 2. Create Dataset

In [None]:
import tensorflow as tf

idx = np.random.permutation(len(encoder_inputs))
encoder_inputs, decoder_inputs, decoder_targets = encoder_inputs[idx], decoder_inputs[idx], decoder_targets[idx]

split = int(len(encoder_inputs) * 0.9)
train_ds = tf.data.Dataset.from_tensor_slices((
    {'encoder_input': encoder_inputs[:split], 'decoder_input': decoder_inputs[:split]}, decoder_targets[:split]
)).shuffle(10000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((
    {'encoder_input': encoder_inputs[split:], 'decoder_input': decoder_inputs[split:]}, decoder_targets[split:]
)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print(f"Train: {split:,}, Val: {len(encoder_inputs)-split:,}")

## 3. Build Model

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, GRU, Dense, Dropout, Bidirectional, Attention, Concatenate, LayerNormalization

emb = Embedding(vocab_size, EMBEDDING_DIM, name='embedding')

# Encoder
enc_in = Input(shape=(MAX_INPUT_LEN,), dtype='int32', name='encoder_input')
x = Dropout(0.1)(emb(enc_in))
for i in range(NUM_ENCODER_LAYERS):
    x = LayerNormalization()(Bidirectional(GRU(GRU_UNITS, return_sequences=True), name=f'enc_{i+1}')(x))
enc_out = x

# Decoder
dec_in = Input(shape=(MAX_OUTPUT_LEN,), dtype='int32', name='decoder_input')
y = Dropout(0.1)(emb(dec_in))
for i in range(NUM_DECODER_LAYERS):
    y = LayerNormalization()(GRU(GRU_UNITS * 2, return_sequences=True, name=f'dec_{i+1}')(y))

# Attention + Output
ctx = Attention(use_scale=True, name='attn')([y, enc_out])
out = Dense(vocab_size, activation='softmax', name='output')(
    Dense(GRU_UNITS * 2, activation='relu')(Dropout(0.2)(LayerNormalization()(Concatenate()([y, ctx]))))
)

model = Model([enc_in, dec_in], out, name='kana_kanji_v2')
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
print(f"üìä {model.count_params():,} params, ~{model.count_params()*2/1024/1024:.1f}MB FP16")

## 4. Train

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

callbacks = [
    ModelCheckpoint(f'{MODEL_DIR}/best.keras', monitor='val_accuracy', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6, verbose=1)
]
history = model.fit(train_ds, epochs=NUM_EPOCHS, validation_data=val_ds, callbacks=callbacks)

In [None]:
import matplotlib.pyplot as plt
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['loss'], label='Train'); ax1.plot(history.history['val_loss'], label='Val'); ax1.legend(); ax1.set_title('Loss')
ax2.plot(history.history['accuracy'], label='Train'); ax2.plot(history.history['val_accuracy'], label='Val'); ax2.legend(); ax2.set_title('Accuracy')
plt.savefig(f'{MODEL_DIR}/training.png'); plt.show()
print(f"Best: {max(history.history['val_accuracy'])*100:.2f}%")

## 5. Save

In [None]:
model.save(f'{MODEL_DIR}/model.keras')
with open(f'{MODEL_DIR}/char_to_idx.json', 'w', encoding='utf-8') as f: json.dump(char_to_idx, f, ensure_ascii=False)
with open(f'{MODEL_DIR}/idx_to_char.json', 'w', encoding='utf-8') as f: json.dump({str(k): v for k, v in idx_to_char.items()}, f, ensure_ascii=False)
with open(f'{MODEL_DIR}/config.json', 'w') as f: json.dump({'vocab_size': vocab_size, 'max_input_len': MAX_INPUT_LEN, 'max_output_len': MAX_OUTPUT_LEN, 'sep_token': SEP_TOKEN}, f)
print("‚úì Saved")

In [None]:
try:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter._experimental_lower_tensor_list_ops = False
    tflite = converter.convert()
    with open(f'{MODEL_DIR}/model.tflite', 'wb') as f: f.write(tflite)
    print(f"‚úì model.tflite ({len(tflite)/(1024*1024):.2f}MB)")
    
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    tflite16 = converter.convert()
    with open(f'{MODEL_DIR}/model_fp16.tflite', 'wb') as f: f.write(tflite16)
    print(f"‚úì model_fp16.tflite ({len(tflite16)/(1024*1024):.2f}MB)")
except Exception as e: print(f"‚ö† {e}")

## 6. Verification

In [None]:
print("="*60)
print("VERIFICATION")
print("="*60)

PAD, BOS, EOS, UNK, SEP = [char_to_idx[t] for t in ['<PAD>', '<BOS>', '<EOS>', '<UNK>', '<SEP>']]

def enc(text, max_len):
    tokens = []
    i = 0
    while i < len(text):
        if text[i:i+5] == '<SEP>': tokens.append('<SEP>'); i += 5
        else: tokens.append(text[i]); i += 1
    ids = [char_to_idx.get(t, UNK) for t in tokens][:max_len]
    return ids + [PAD] * (max_len - len(ids))

def convert(context, kana):
    enc_in = np.array([enc(f"{context}<SEP>{kana}", MAX_INPUT_LEN)], dtype=np.int32)
    dec_in = np.zeros((1, MAX_OUTPUT_LEN), dtype=np.int32); dec_in[0, 0] = BOS
    result = []
    for i in range(20):
        pred = model.predict({'encoder_input': enc_in, 'decoder_input': dec_in}, verbose=0)
        nxt = int(np.argmax(pred[0, i]))
        if nxt == EOS: break
        if nxt not in [PAD, BOS, EOS, UNK, SEP]: result.append(idx_to_char.get(nxt, ''))
        if i + 1 < MAX_OUTPUT_LEN: dec_in[0, i + 1] = nxt
    return ''.join(result)

tests = [
    ("‰ªäÊó•„ÅØ„Å®„Å¶„ÇÇ", "„ÅÇ„Å§„ÅÑ", "Êöë„ÅÑ"), ("„ÅäËå∂„Åå", "„ÅÇ„Å§„ÅÑ", "ÁÜ±„ÅÑ"), ("„Åì„ÅÆËæûÂÖ∏„ÅØ", "„ÅÇ„Å§„ÅÑ", "Âéö„ÅÑ"),
    ("ÊØéÊúùËµ∑„Åç„Çã„ÅÆ„Åå", "„ÅØ„ÇÑ„ÅÑ", "Êó©„ÅÑ"), ("ÂΩº„ÅØËµ∞„Çã„ÅÆ„Åå", "„ÅØ„ÇÑ„ÅÑ", "ÈÄü„ÅÑ"),
    ("Â∑ù„Å´", "„ÅØ„Åó", "Ê©ã"), ("„ÅîÈ£Ø„Çí", "„ÅØ„Åó", "ÁÆ∏"), ("ÈÅì„ÅÆ", "„ÅØ„Åó", "Á´Ø"),
    ("Èü≥Ê•Ω„Çí", "„Åç„Åè", "ËÅ¥„Åè"), ("ÈÅì„Çí", "„Åç„Åè", "ËÅû„Åè"), ("„Åì„ÅÆËñ¨„ÅØ„Çà„Åè", "„Åç„Åè", "Âäπ„Åè"),
    ("ÂÜôÁúü„Çí", "„Å®„Çã", "ÊíÆ„Çã"), ("Â°©„Çí", "„Å®„Çã", "Âèñ„Çã"), ("È≠ö„Çí", "„Å®„Çã", "Êçï„Çã"),
    ("ÂèãÈÅî„Å´", "„ÅÇ„ÅÜ", "‰ºö„ÅÜ"), ("„Çµ„Ç§„Ç∫„Åå", "„ÅÇ„ÅÜ", "Âêà„ÅÜ"), ("‰∫ãÊïÖ„Å´", "„ÅÇ„ÅÜ", "ÈÅ≠„ÅÜ"),
    ("„ÉÜ„Çπ„Éà", "„Åç„Åã„Çì", "ÊúüÈñì"), ("‰∫§ÈÄö", "„Åç„Åã„Çì", "Ê©üÈñ¢"), ("ÂÆáÂÆô„Åã„Çâ", "„Åç„Åã„Çì", "Â∏∞ÈÇÑ"),
]

correct = 0
for ctx, kana, exp in tests:
    res = convert(ctx, kana)
    ok = res == exp or exp in res or res in exp
    if ok: correct += 1
    print(f"{'‚úì' if ok else '‚úó'} {ctx}<SEP>{kana} ‚Üí {res} (exp: {exp})")

print(f"\n‚úÖ Score: {correct}/{len(tests)} ({correct/len(tests)*100:.0f}%)")

In [None]:
print(f"\nüì¶ Files ({PLATFORM}):")
for f in sorted(os.listdir(MODEL_DIR)):
    p = f'{MODEL_DIR}/{f}'
    if os.path.isfile(p):
        s = os.path.getsize(p)
        print(f"  {f}: {s/(1024*1024):.2f}MB" if s > 1024*1024 else f"  {f}: {s/1024:.1f}KB")