# GRU Japanese Keyboard Model - Dual Mode

**Architecture:** Kana-normalized model with separate kanji conversion

## Features
1. **Kana‚ÜíKanji Conversion**: „ÅÇ„Çä„Åå„Å®„ÅÜ ‚Üí [Êúâ„ÇäÈõ£„ÅÜ, ÊúâÈõ£„ÅÜ]
2. **Next Word Prediction**: „ÅÇ„Çä„Åå„Å®„ÅÜ ‚Üí [„Åî„Åñ„ÅÑ„Åæ„Åô, „Å≠, üôè]
3. **Prefix Completion**: „ÅÇ„Çä„Åå ‚Üí [„ÅÇ„Çä„Åå„Å®„ÅÜ, „ÅÇ„Çä„Åå„Åü„ÅÑ]
4. **Emoji Suggestions**: From dataset context

## Key Design
- Model trains on **kana only** (no kanji mixing)
- Separate **kana‚Üíkanji index** for display conversion
- Consistent predictions regardless of displayed script

---
**Instructions:**
1. Runtime ‚Üí GPU (T4)
2. Set `TESTING_MODE = True` for quick test

## 1. Setup

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')
DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'
MODEL_DIR = f"{DRIVE_DIR}/models/gru_keyboard_japanese"
os.makedirs(MODEL_DIR, exist_ok=True)
print(f"‚úì Model directory: {MODEL_DIR}")

In [None]:
!pip install -q tensorflow keras datasets pandas numpy scikit-learn tqdm regex

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

TESTING_MODE = True  # ‚Üê Change to False for full training

if TESTING_MODE:
    NUM_EPOCHS = 5
    BATCH_SIZE = 256
    VOCAB_SIZE_LIMIT = 6000
    SEQUENCE_LENGTH = 15
    MAX_SAMPLES = 200000
else:
    NUM_EPOCHS = 20
    BATCH_SIZE = 256
    VOCAB_SIZE_LIMIT = 6000
    SEQUENCE_LENGTH = 15
    MAX_SAMPLES = 300000

EMBEDDING_DIM = 128
GRU_UNITS = 256

PAD_TOKEN = '<PAD>'
UNK_TOKEN = '<UNK>'
BOS_TOKEN = '<BOS>'
EOS_TOKEN = '<EOS>'

print(f"Config: vocab={VOCAB_SIZE_LIMIT:,}, epochs={NUM_EPOCHS}, samples={MAX_SAMPLES:,}")

## 2. Load Dataset

In [None]:
from datasets import load_dataset
import re
import regex
from collections import Counter, defaultdict

print("Loading zenz-v2.5-dataset...")
print("="*60)

try:
    dataset = load_dataset(
        "Miwa-Keita/zenz-v2.5-dataset",
        data_files="train_wikipedia.jsonl",
        split=f"train[:{MAX_SAMPLES}]"
    )
    print(f"‚úì Loaded {len(dataset):,} samples (Wikipedia)")
except:
    dataset = load_dataset(
        "Miwa-Keita/zenz-v2.5-dataset",
        split=f"train[:{MAX_SAMPLES}]"
    )
    print(f"‚úì Loaded {len(dataset):,} samples")

# Show samples
print("\nSamples (input=kana, output=kanji):")
for i in range(3):
    print(f"  {dataset[i]['input'][:25]} ‚Üí {dataset[i]['output'][:25]}")

## 3. Build Kana Vocabulary + Kana‚ÜíKanji Index

**Key:** Train model on KANA only, use separate kanji lookup

In [None]:
import re
import regex
from collections import Counter, defaultdict

print("Building kana vocabulary + kanji index...")
print("="*60)

# Emoji support
EMOJI_PATTERN = regex.compile(r'[\p{Emoji_Presentation}\p{Extended_Pictographic}]')

def is_emoji(char):
    return bool(EMOJI_PATTERN.match(char))

def extract_emojis(text):
    return EMOJI_PATTERN.findall(text)

def segment_japanese(text):
    """Segment Japanese text into words."""
    particles = r'(„ÅØ|„Åå|„Çí|„Å´|„Åß|„Å®|„ÅÆ|„Åã„Çâ|„Åæ„Åß|„Çà„Çä|„Å∏|„ÇÑ|„ÇÇ|„Åã|„Å≠|„Çà|„Çè|„Å™|„Çâ|„Åó|„Å¶|„Åü|„Å†|„Åß„Åô|„Åæ„Åô)'
    segments = re.split(r'[„ÄÇ„ÄÅÔºÅÔºü\s\n„Éª„Äå„Äç„Äé„ÄèÔºàÔºâ„Äê„Äë]', text)
    
    words = []
    for seg in segments:
        if not seg:
            continue
        emojis = extract_emojis(seg)
        text_only = EMOJI_PATTERN.sub('', seg)
        
        if text_only:
            if len(text_only) <= 8:
                words.append(text_only)
            else:
                parts = re.split(particles, text_only)
                words.extend([p for p in parts if p])
        words.extend(emojis)
    
    return [w for w in words if w and len(w) <= 20]

# ============================================================
# BUILD KANA VOCABULARY (from INPUT field only)
# ============================================================
print("\n[1/2] Building kana vocabulary from INPUT field...")

word_counts = Counter()
all_kana_texts = []  # For training sequences

for item in dataset:
    # Use INPUT (kana) for vocabulary
    kana_text = item.get('input', '')
    if kana_text:
        all_kana_texts.append(kana_text)
        words = segment_japanese(kana_text)
        word_counts.update(words)

print(f"‚úì Found {len(word_counts):,} unique kana words")

# ============================================================
# BUILD KANA‚ÜíKANJI CONVERSION INDEX
# ============================================================
print("\n[2/2] Building kana‚Üíkanji conversion index...")

kana_kanji_counts = defaultdict(Counter)

for item in dataset:
    kana = item.get('input', '').strip()
    kanji = item.get('output', '').strip()
    
    if kana and kanji and kana != kanji:
        # Full phrase mapping
        kana_kanji_counts[kana][kanji] += 1
        
        # Word-level mapping
        kana_words = segment_japanese(kana)
        kanji_words = segment_japanese(kanji)
        
        if len(kana_words) == len(kanji_words):
            for k, j in zip(kana_words, kanji_words):
                if k != j and len(k) > 1:
                    kana_kanji_counts[k][j] += 1

# Finalize: keep top 5 kanji per kana
kana_kanji_index = {}
for kana, kanji_counts in kana_kanji_counts.items():
    top = [k for k, c in kanji_counts.most_common(5)]
    if top:
        kana_kanji_index[kana] = top

print(f"‚úì Built {len(kana_kanji_index):,} kana‚Üíkanji mappings")

# Show examples
print("\nSample conversions:")
examples = ['„ÅÇ„Çä„Åå„Å®„ÅÜ', '„Åî„Åñ„ÅÑ„Åæ„Åô', '„Åä„ÅØ„Çà„ÅÜ', '„Åì„Çì„Å´„Å°„ÅØ', '„Çè„Åü„Åó']
for ex in examples:
    kanji = kana_kanji_index.get(ex, ['(no conversion)'])
    print(f"  {ex} ‚Üí {kanji}")

In [None]:
# Filter valid words and build vocabulary
def is_valid_word(word):
    if not word or len(word) < 1:
        return False
    if len(word) <= 2 and EMOJI_PATTERN.match(word):
        return True
    for char in word:
        code = ord(char)
        if not (0x3040 <= code <= 0x309F or  # Hiragana
                0x30A0 <= code <= 0x30FF or  # Katakana
                0x4E00 <= code <= 0x9FFF or  # Kanji (allow some)
                0x3400 <= code <= 0x4DBF or
                is_emoji(char) or
                char in '„Éº„Äú'):
            return False
    return True

valid_words = [(w, c) for w, c in word_counts.most_common() if is_valid_word(w)]
valid_words = valid_words[:VOCAB_SIZE_LIMIT - 4]

word_to_index = {PAD_TOKEN: 0, UNK_TOKEN: 1, BOS_TOKEN: 2, EOS_TOKEN: 3}
for idx, (word, count) in enumerate(valid_words, start=4):
    word_to_index[word] = idx

index_to_word = {idx: word for word, idx in word_to_index.items()}
vocab_size = len(word_to_index)

print(f"\n‚úì Vocabulary size: {vocab_size:,}")
print(f"\nTop 15 words:")
for i, (w, c) in enumerate(valid_words[:15], 1):
    print(f"  {i:2d}. '{w}' ({c:,})")

In [None]:
# Build prefix index
print("Building prefix index...")

prefix_index = defaultdict(list)
for word, count in valid_words:
    idx = word_to_index[word]
    for i in range(1, len(word) + 1):
        prefix = word[:i]
        prefix_index[prefix].append((count, idx))

for prefix in prefix_index:
    prefix_index[prefix].sort(reverse=True)
    prefix_index[prefix] = [idx for c, idx in prefix_index[prefix][:20]]

print(f"‚úì Prefix index: {len(prefix_index):,} prefixes")

# Test
print("\nPrefix tests:")
for p in ['„ÅÇ„Çä', '„ÅÇ„Çä„Åå', '„Åî„Åñ„ÅÑ„Åæ', '„Åä„ÅØ']:
    if p in prefix_index:
        words = [index_to_word[i] for i in prefix_index[p][:3]]
        print(f"  '{p}' ‚Üí {words}")
    else:
        print(f"  '{p}' ‚Üí (no match)")

## 4. Create Training Data (Kana sequences)

In [None]:
import tensorflow as tf
import numpy as np

print("Creating training sequences from KANA text...")
print("="*60)

all_sequences = []
for text in all_kana_texts:
    words = segment_japanese(text)
    seq = [word_to_index.get(w, 1) for w in words]
    if len(seq) >= 2:
        all_sequences.append(seq)

print(f"‚úì {len(all_sequences):,} sequences")

# Create X, y pairs
X_data, y_data = [], []
for seq in all_sequences:
    for i in range(1, len(seq)):
        input_seq = seq[:i][-SEQUENCE_LENGTH:]
        X_data.append(input_seq)
        y_data.append(seq[i])

print(f"‚úì {len(X_data):,} training pairs")

X_padded = tf.keras.preprocessing.sequence.pad_sequences(X_data, maxlen=SEQUENCE_LENGTH, padding='pre')
y_array = np.array(y_data)

ds = tf.data.Dataset.from_tensor_slices((X_padded, y_array)).shuffle(10000).batch(BATCH_SIZE)

total = len(X_data) // BATCH_SIZE
val_size = max(1, total // 10)
train_ds = ds.take(total - val_size).prefetch(tf.data.AUTOTUNE)
val_ds = ds.skip(total - val_size).take(val_size).prefetch(tf.data.AUTOTUNE)

print(f"‚úì Train: {total - val_size} batches, Val: {val_size} batches")

## 5. Build & Train Model

In [None]:
from tensorflow.keras import mixed_precision
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, GRU, Dense, Dropout
from tensorflow.keras.optimizers import AdamW

mixed_precision.set_global_policy('mixed_float16')

inputs = Input(shape=(SEQUENCE_LENGTH,), name='input')
x = Embedding(vocab_size, EMBEDDING_DIM, name='embedding')(inputs)
x = GRU(GRU_UNITS, dropout=0.2, recurrent_dropout=0.2, name='gru')(x)
x = Dropout(0.3)(x)
outputs = Dense(vocab_size, activation='softmax', dtype='float32', name='output')(x)

model = Model(inputs=inputs, outputs=outputs, name='gru_japanese_kana')
model.compile(
    optimizer=AdamW(learning_rate=1e-3, weight_decay=1e-4),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
model.summary()

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

callbacks = [
    ModelCheckpoint(f'{MODEL_DIR}/best_model.keras', monitor='val_accuracy', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
]

history = model.fit(train_ds, epochs=NUM_EPOCHS, validation_data=val_ds, callbacks=callbacks, verbose=1)

## 6. Visualize

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['loss'], label='Train')
ax1.plot(history.history['val_loss'], label='Val')
ax1.set_title('Loss'); ax1.legend(); ax1.grid(True)

ax2.plot(history.history['accuracy'], label='Train')
ax2.plot(history.history['val_accuracy'], label='Val')
ax2.set_title('Accuracy'); ax2.legend(); ax2.grid(True)

plt.tight_layout()
plt.show()
print(f"Final Val Acc: {history.history['val_accuracy'][-1]*100:.2f}%")

## 7. Save Model

In [None]:
import json

model.save(f'{MODEL_DIR}/gru_model.keras')

with open(f'{MODEL_DIR}/word_to_index.json', 'w', encoding='utf-8') as f:
    json.dump(word_to_index, f, ensure_ascii=False, separators=(',', ':'))

with open(f'{MODEL_DIR}/prefix_index.json', 'w', encoding='utf-8') as f:
    json.dump(dict(prefix_index), f, ensure_ascii=False, separators=(',', ':'))

with open(f'{MODEL_DIR}/kana_kanji_index.json', 'w', encoding='utf-8') as f:
    json.dump(kana_kanji_index, f, ensure_ascii=False, separators=(',', ':'))

config = {
    'vocab_size': vocab_size,
    'sequence_length': SEQUENCE_LENGTH,
    'embedding_dim': EMBEDDING_DIM,
    'gru_units': GRU_UNITS,
    'language': 'japanese',
    'tokenization': 'kana-normalized',
    'features': ['kana_kanji_conversion', 'next_word_prediction', 'prefix_completion']
}
with open(f'{MODEL_DIR}/model_config.json', 'w') as f:
    json.dump(config, f, indent=2)

print("‚úì Saved: gru_model.keras, word_to_index.json, prefix_index.json, kana_kanji_index.json")

## 8. Export TFLite

In [None]:
import tensorflow as tf
import numpy as np
import time

print("Exporting TFLite...")

try:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter._experimental_lower_tensor_list_ops = False
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    
    tflite_model = converter.convert()
    path = f'{MODEL_DIR}/gru_model.tflite'
    with open(path, 'wb') as f:
        f.write(tflite_model)
    print(f"‚úì gru_model.tflite ({len(tflite_model)/(1024*1024):.2f}MB)")
    
    # Benchmark
    interpreter = tf.lite.Interpreter(model_path=path)
    interpreter.allocate_tensors()
    details = interpreter.get_input_details()[0]
    test = np.random.randint(0, vocab_size, (1, SEQUENCE_LENGTH)).astype(np.float32)
    
    for _ in range(10):
        interpreter.set_tensor(details['index'], test)
        interpreter.invoke()
    
    times = []
    for _ in range(50):
        t = time.time()
        interpreter.set_tensor(details['index'], test)
        interpreter.invoke()
        times.append((time.time() - t) * 1000)
    print(f"‚úì Latency: {np.mean(times):.2f}ms avg")
except Exception as e:
    print(f"‚úó Error: {str(e)[:100]}")

## 9. Export CoreML Weights

In [None]:
weights = model.get_weights()
np.savez(f'{MODEL_DIR}/gru_weights.npz', *weights)
print(f"‚úì gru_weights.npz ({len(weights)} arrays)")

## 10. Export Mobile Resources

In [None]:
import json
import os

print("Exporting mobile resources...")
print("="*60)

# index_to_word
path = f'{MODEL_DIR}/index_to_word.json'
with open(path, 'w', encoding='utf-8') as f:
    json.dump({str(k): v for k, v in index_to_word.items()}, f, ensure_ascii=False, separators=(',', ':'))
print(f"‚úì index_to_word.json ({os.path.getsize(path)/1024:.1f}KB)")

# phrase_suggestions (next word)
print("Building phrase suggestions...")
word_pairs = defaultdict(Counter)
for seq in all_sequences[:15000]:
    for i in range(len(seq) - 1):
        word_pairs[seq[i]][seq[i+1]] += 1

phrase_suggestions = {}
for word_idx, next_counts in word_pairs.items():
    if word_idx < 4:
        continue
    word = index_to_word.get(word_idx)
    if word:
        phrase_suggestions[word] = [idx for idx, c in next_counts.most_common(10)]

path = f'{MODEL_DIR}/phrase_suggestions.json'
with open(path, 'w', encoding='utf-8') as f:
    json.dump(phrase_suggestions, f, ensure_ascii=False, separators=(',', ':'))
print(f"‚úì phrase_suggestions.json ({len(phrase_suggestions):,} words, {os.path.getsize(path)/1024:.1f}KB)")

# emoji_suggestions
print("Building emoji suggestions...")
word_emoji = defaultdict(Counter)
for seq in all_sequences[:15000]:
    for i in range(len(seq) - 1):
        w = index_to_word.get(seq[i])
        n = index_to_word.get(seq[i+1])
        if w and n and EMOJI_PATTERN.match(n):
            word_emoji[w][n] += 1

emoji_suggestions = {w: [e for e, c in ec.most_common(5)] for w, ec in word_emoji.items() if ec}
path = f'{MODEL_DIR}/emoji_suggestions.json'
with open(path, 'w', encoding='utf-8') as f:
    json.dump(emoji_suggestions, f, ensure_ascii=False, separators=(',', ':'))
print(f"‚úì emoji_suggestions.json ({len(emoji_suggestions):,} words)")

# List all files
print("\n" + "="*60)
print("ALL EXPORTS:")
for f in sorted(os.listdir(MODEL_DIR)):
    size = os.path.getsize(f'{MODEL_DIR}/{f}') / 1024
    print(f"  {f}: {size:.1f}KB")

## 11. Verification Test

In [None]:
import json
from tensorflow.keras.preprocessing.sequence import pad_sequences

print("="*60)
print("VERIFICATION - Dual Mode Test")
print("="*60)

# Load all indices
with open(f'{MODEL_DIR}/word_to_index.json', 'r', encoding='utf-8') as f:
    w2i = json.load(f)
with open(f'{MODEL_DIR}/index_to_word.json', 'r', encoding='utf-8') as f:
    i2w = {int(k): v for k, v in json.load(f).items()}
with open(f'{MODEL_DIR}/prefix_index.json', 'r', encoding='utf-8') as f:
    prefix_idx = json.load(f)
with open(f'{MODEL_DIR}/kana_kanji_index.json', 'r', encoding='utf-8') as f:
    kana_kanji = json.load(f)
with open(f'{MODEL_DIR}/phrase_suggestions.json', 'r', encoding='utf-8') as f:
    phrase_sug = json.load(f)

def get_prefix(prefix, k=5):
    if prefix not in prefix_idx:
        return []
    return [i2w.get(i, '?') for i in prefix_idx[prefix][:k]]

def get_kanji(kana):
    return kana_kanji.get(kana, [kana])

def get_next(word, k=5):
    if word not in phrase_sug:
        return []
    return [i2w.get(i, '?') for i in phrase_sug[word][:k]]

# Test 1: Prefix completion
print("\nüìù TEST 1: Prefix Completion")
for prefix in ['„ÅÇ„Çä', '„ÅÇ„Çä„Åå', '„Åî„Åñ„ÅÑ„Åæ', '„Åä„ÅØ', '„Åì„Çì„Å´„Å°']:
    words = get_prefix(prefix)
    print(f"  '{prefix}' ‚Üí {words if words else '(no match)'}")

# Test 2: Kana‚ÜíKanji conversion
print("\nüìù TEST 2: Kana‚ÜíKanji Conversion")
for kana in ['„ÅÇ„Çä„Åå„Å®„ÅÜ', '„Åî„Åñ„ÅÑ„Åæ„Åô', '„Åä„ÅØ„Çà„ÅÜ', '„Åì„Çì„Å´„Å°„ÅØ', '„Çè„Åü„Åó']:
    kanji = get_kanji(kana)
    print(f"  '{kana}' ‚Üí {kanji}")

# Test 3: Next word prediction
print("\nüìù TEST 3: Next Word Prediction")
for word in ['„ÅÇ„Çä„Åå„Å®„ÅÜ', '„Åä„ÅØ„Çà„ÅÜ', '„Åì„Çå', '„Çè„Åü„Åó']:
    next_words = get_next(word)
    print(f"  '{word}' ‚Üí {next_words if next_words else '(no prediction)'}")

# Test 4: Complete flow
print("\nüìù TEST 4: Complete Flow")
print("-"*40)
for prefix in ['„ÅÇ„Çä„Åå', '„Åä„ÅØ']:
    print(f"\nUser types: '{prefix}'")
    
    # Step 1: Prefix completion
    kana_words = get_prefix(prefix, 3)
    if kana_words:
        print(f"  1. Prefix match: {kana_words}")
        selected_kana = kana_words[0]
        
        # Step 2: Kanji options
        kanji_options = get_kanji(selected_kana)
        print(f"  2. Kanji options: {kanji_options}")
        
        # Step 3: Next word
        next_words = get_next(selected_kana, 3)
        if next_words:
            print(f"  3. Next word: {next_words}")
            # Convert next words to kanji
            next_kanji = [get_kanji(w)[0] if w in kana_kanji else w for w in next_words]
            print(f"     (as kanji): {next_kanji}")

print("\n" + "="*60)
print("‚úÖ VERIFICATION COMPLETE")
print("   Features: Prefix, Kana‚ÜíKanji, Next Word Prediction")
print("="*60)

## Usage Guide

### Mobile Integration

```swift
// 1. User types partial kana
let prefix = "„ÅÇ„Çä„Åå"
let words = prefixIndex[prefix]  // ["„ÅÇ„Çä„Åå„Å®„ÅÜ", "„ÅÇ„Çä„Åå„Åü„ÅÑ"]

// 2. User selects word, show kanji options
let kana = "„ÅÇ„Çä„Åå„Å®„ÅÜ"
let kanji = kanaKanjiIndex[kana]  // ["Êúâ„ÇäÈõ£„ÅÜ", "ÊúâÈõ£„ÅÜ"]

// 3. Predict next word (using kana internally)
let next = phraseSuggestions[kana]  // ["„Åî„Åñ„ÅÑ„Åæ„Åô", "„Å≠"]
let nextKanji = next.map { kanaKanjiIndex[$0]?.first ?? $0 }
// ["Âæ°Â∫ß„ÅÑ„Åæ„Åô", "„Å≠"]
```

### Files
- `prefix_index.json` - Kana prefix completion
- `kana_kanji_index.json` - Kana‚ÜíKanji conversion
- `phrase_suggestions.json` - Next word prediction
- `gru_model.tflite` - ML model (optional, for complex predictions)