# Enhanced GRU Japanese Keyboard

**Two Functions:**
1. **Kana‚ÜíKanji**: „Åä„Åõ ‚Üí [„Åä‰∏ñË©±, „Åä„Åõ„Å°]
2. **Next Phrase**: „Åä‰∏ñË©± ‚Üí [„Å´„Å™„Å£„Å¶„Åä„Çä„Åæ„Åô, „Åî„Åñ„ÅÑ„Åæ„Åô]

**Architecture:** Bi-GRU + Luong Attention

## 1. Setup

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')
DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'
MODEL_DIR = f"{DRIVE_DIR}/models/gru_japanese_enhanced"
os.makedirs(MODEL_DIR, exist_ok=True)
print(f"‚úì Model: {MODEL_DIR}")

In [None]:
!pip install -q tensorflow keras datasets pandas numpy tqdm fugashi unidic-lite

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

TESTING_MODE = True  # False for full training

if TESTING_MODE:
    NUM_EPOCHS = 3
    BATCH_SIZE = 256
    MAX_SAMPLES = 200000
else:
    NUM_EPOCHS = 20
    BATCH_SIZE = 256
    MAX_SAMPLES = 300000

# Vocabulary
CHAR_VOCAB_SIZE = 3000
PHRASE_VOCAB_SIZE = 5000
MAX_SEQ_LENGTH = 50

# Model
EMBEDDING_DIM = 128
GRU_UNITS = 256

# Special tokens
SPECIAL_TOKENS = ['<PAD>', '<UNK>', '<BOS>', '<EOS>', '<KANA>', '<KANJI>', '<CTX>', '<NEXT>']

print(f"Config: epochs={NUM_EPOCHS}, samples={MAX_SAMPLES:,}")

## 2. Load Dataset

In [None]:
from datasets import load_dataset

print("Loading zenz-v2.5-dataset...")

try:
    dataset = load_dataset(
        "Miwa-Keita/zenz-v2.5-dataset",
        data_files="train_wikipedia.jsonl",
        split=f"train[:{MAX_SAMPLES}]"
    )
except:
    dataset = load_dataset(
        "Miwa-Keita/zenz-v2.5-dataset",
        split=f"train[:{MAX_SAMPLES}]"
    )

print(f"‚úì Loaded {len(dataset):,} samples")

## 3. Tokenizers

In [None]:
import fugashi
from collections import Counter, defaultdict

tagger = fugashi.Tagger()

def tokenize_words(text):
    result = []
    for t in tagger(text):
        # Keep words AND emojis
        if t.feature.pos1 not in ['Á©∫ÁôΩ']:  # Only filter whitespace
            result.append(t.surface)
    return result

def tokenize_chars(text):
    """Char-level: „Ç¢„É™„Ç¨„Éà„Ç¶ ‚Üí [„Ç¢, „É™, „Ç¨, „Éà, „Ç¶]"""
    return list(text.replace(' ', '').replace('\n', ''))

# Test
print(f"Words: {tokenize_words('ÊúâÈõ£„ÅÜ„Åî„Åñ„ÅÑ„Åæ„Åô')}")
print(f"Chars: {tokenize_chars('„Ç¢„É™„Ç¨„Éà„Ç¶')}")

## 4. Build Vocabulary (Chars + Phrases)

In [None]:
from tqdm import tqdm

print("Building vocabularies...")
print("="*60)

char_counts = Counter()
phrase_counts = Counter()  # word ‚Üí next_phrase patterns
word_phrase_map = defaultdict(Counter)  # For exporting

for item in tqdm(dataset, desc="Extracting patterns"):
    kana = item.get('input', '')
    kanji = item.get('output', '')
    
    # Chars from kana and kanji
    char_counts.update(tokenize_chars(kana))
    char_counts.update(tokenize_chars(kanji))
    
    # Extract word ‚Üí phrase patterns
    words = tokenize_words(kanji)
    for i in range(len(words) - 1):
        word = words[i]
        # Next phrase = 1-3 following words
        for phrase_len in [1, 2, 3]:
            if i + phrase_len < len(words):
                phrase = ''.join(words[i+1:i+1+phrase_len])
                if 1 < len(phrase) <= 15:  # Valid phrase length
                    phrase_counts[phrase] += 1
                    word_phrase_map[word][phrase] += 1

print(f"\n‚úì {len(char_counts):,} unique chars")
print(f"‚úì {len(phrase_counts):,} unique phrases")
print(f"\nTop phrases: {[p for p, c in phrase_counts.most_common(15)]}")

In [None]:
# Build UNIFIED vocab: special + chars + phrases

token_to_idx = {}

# 1. Special tokens
for i, tok in enumerate(SPECIAL_TOKENS):
    token_to_idx[tok] = i

# 2. Characters (for kana-kanji)
for char, _ in char_counts.most_common(CHAR_VOCAB_SIZE):
    if char not in token_to_idx:
        token_to_idx[char] = len(token_to_idx)

# 3. Phrases (for next-phrase prediction)
for phrase, _ in phrase_counts.most_common(PHRASE_VOCAB_SIZE):
    if phrase not in token_to_idx and len(phrase) > 1:
        token_to_idx[phrase] = len(token_to_idx)

idx_to_token = {v: k for k, v in token_to_idx.items()}
vocab_size = len(token_to_idx)

# Phrase-only vocab for predictions
phrase_vocab = [p for p, c in phrase_counts.most_common(PHRASE_VOCAB_SIZE) if len(p) > 1]

print(f"\n‚úì Unified vocab: {vocab_size:,}")
print(f"‚úì Phrase vocab: {len(phrase_vocab):,}")

## 5. Create Training Data

In [None]:
import numpy as np

print("Creating training data...")
print("="*60)

def encode(tokens, max_len=MAX_SEQ_LENGTH):
    ids = [token_to_idx.get(t, token_to_idx['<UNK>']) for t in tokens]
    if len(ids) < max_len:
        ids = ids + [token_to_idx['<PAD>']] * (max_len - len(ids))
    return ids[:max_len]

X_data = []
y_data = []

# ============================================================
# Task 1: Kana-Kanji (char-level)
# „Åä„Åõ„Çè ‚Üí „Åä‰∏ñË©±
# ============================================================
print("\n[Task 1] Kana‚ÜíKanji...")

for item in tqdm(dataset, desc="Kana‚ÜíKanji"):
    kana = item.get('input', '').strip()
    kanji = item.get('output', '').strip()
    
    if not kana or not kanji or len(kana) > 30 or len(kanji) > 30:
        continue
    
    input_tokens = ['<KANA>'] + tokenize_chars(kana) + ['<KANJI>']
    target_tokens = tokenize_chars(kanji) + ['<EOS>']
    
    for i in range(min(len(target_tokens), 25)):
        ctx = input_tokens + target_tokens[:i]
        target = target_tokens[i]
        X_data.append(encode(ctx))
        y_data.append(token_to_idx.get(target, token_to_idx['<UNK>']))

task1_count = len(X_data)
print(f"‚úì Task 1: {task1_count:,} samples")

# ============================================================
# Task 2: Word ‚Üí Next Phrase (phrase-level)
# „Åä‰∏ñË©± ‚Üí „Å´„Å™„Å£„Å¶„Åä„Çä„Åæ„Åô
# „ÅÇ„Çä„Åå„Å®„ÅÜ ‚Üí „Åî„Åñ„ÅÑ„Åæ„Åô
# ============================================================
print("\n[Task 2] Word‚ÜíPhrase...")

for item in tqdm(dataset, desc="Word‚ÜíPhrase"):
    kanji = item.get('output', '').strip()
    if not kanji:
        continue
    
    words = tokenize_words(kanji)
    if len(words) < 2:
        continue
    
    for i in range(len(words) - 1):
        word = words[i]
        
        # Try different phrase lengths (1-3 words)
        for phrase_len in [1, 2, 3]:
            if i + phrase_len >= len(words):
                break
            
            phrase = ''.join(words[i+1:i+1+phrase_len])
            
            # Only train on phrases in vocabulary
            if phrase in token_to_idx and len(phrase) > 1:
                # Input: <CTX> + context words + <NEXT>
                context = words[max(0, i-2):i+1]  # 2-3 words context
                input_tokens = ['<CTX>'] + context + ['<NEXT>']
                
                X_data.append(encode(input_tokens))
                y_data.append(token_to_idx[phrase])  # COMPLETE PHRASE

task2_count = len(X_data) - task1_count
print(f"‚úì Task 2: {task2_count:,} samples")
print(f"\n‚úì Total: {len(X_data):,} samples")

In [None]:
import tensorflow as tf

# Shuffle and split
X_data = np.array(X_data)
y_data = np.array(y_data)

indices = np.random.permutation(len(X_data))
X_data = X_data[indices]
y_data = y_data[indices]

split = int(len(X_data) * 0.9)
X_train, X_val = X_data[:split], X_data[split:]
y_train, y_val = y_data[:split], y_data[split:]

train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print(f"‚úì Train: {len(X_train):,}, Val: {len(X_val):,}")

## 6. Build Model

In [None]:
from tensorflow.keras import mixed_precision
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Embedding, GRU, Dense, Dropout,
    Bidirectional, Attention, Concatenate, LayerNormalization
)

mixed_precision.set_global_policy('mixed_float16')

inputs = Input(shape=(MAX_SEQ_LENGTH,), name='input')
x = Embedding(vocab_size, EMBEDDING_DIM, name='embedding')(inputs)

# Bi-directional GRU
encoder_out = Bidirectional(
    GRU(GRU_UNITS, return_sequences=True, dropout=0.2),
    name='bi_encoder'
)(x)

# Luong Attention
attention_out = Attention(use_scale=True, name='attention')([encoder_out, encoder_out])

# Combine
combined = Concatenate()([encoder_out, attention_out])
combined = LayerNormalization()(combined)

# Decoder
decoder_out = GRU(GRU_UNITS, name='decoder')(combined)
decoder_out = Dropout(0.3)(decoder_out)

outputs = Dense(vocab_size, activation='softmax', dtype='float32', name='output')(decoder_out)

model = Model(inputs, outputs, name='japanese_keyboard_gru')
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3, clipnorm=1.0),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
model.summary()
print(f"\n‚úì Parameters: {model.count_params():,}")

## 7. Train

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

callbacks = [
    ModelCheckpoint(f'{MODEL_DIR}/best.keras', monitor='val_accuracy', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6)
]

history = model.fit(
    train_ds,
    epochs=NUM_EPOCHS,
    validation_data=val_ds,
    callbacks=callbacks,
    verbose=1
)

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['loss'], label='Train')
ax1.plot(history.history['val_loss'], label='Val')
ax1.set_title('Loss'); ax1.legend()

ax2.plot(history.history['accuracy'], label='Train')
ax2.plot(history.history['val_accuracy'], label='Val')
ax2.set_title('Accuracy'); ax2.legend()
plt.savefig(f'{MODEL_DIR}/training.png')
plt.show()

print(f"\n‚úì Val Acc: {history.history['val_accuracy'][-1]*100:.2f}%")

## 8. Save Resources

In [None]:
import json

model.save(f'{MODEL_DIR}/model.keras')

with open(f'{MODEL_DIR}/token_to_idx.json', 'w', encoding='utf-8') as f:
    json.dump(token_to_idx, f, ensure_ascii=False)

with open(f'{MODEL_DIR}/idx_to_token.json', 'w', encoding='utf-8') as f:
    json.dump({str(k): v for k, v in idx_to_token.items()}, f, ensure_ascii=False)

# Word ‚Üí Phrase suggestions (for fast lookup)
word_phrase_suggestions = {}
for word, phrase_counter in word_phrase_map.items():
    top_phrases = [p for p, c in phrase_counter.most_common(10) if p in token_to_idx]
    if top_phrases:
        word_phrase_suggestions[word] = top_phrases

with open(f'{MODEL_DIR}/word_phrase_suggestions.json', 'w', encoding='utf-8') as f:
    json.dump(word_phrase_suggestions, f, ensure_ascii=False)

print(f"‚úì word_phrase_suggestions.json ({len(word_phrase_suggestions):,} words)")

# Config
config = {
    'vocab_size': vocab_size,
    'max_seq_length': MAX_SEQ_LENGTH,
    'embedding_dim': EMBEDDING_DIM,
    'gru_units': GRU_UNITS,
    'architecture': 'BiGRU_LuongAttention',
    'tasks': ['kana_kanji', 'next_phrase'],
    'special_tokens': SPECIAL_TOKENS
}
with open(f'{MODEL_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print("‚úì Saved all resources")

## 9. Export TFLite

In [None]:
print("Exporting TFLite...")

try:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter._experimental_lower_tensor_list_ops = False
    
    tflite_model = converter.convert()
    with open(f'{MODEL_DIR}/model.tflite', 'wb') as f:
        f.write(tflite_model)
    print(f"‚úì model.tflite ({len(tflite_model)/(1024*1024):.2f}MB)")
    
    # FP16
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    tflite_fp16 = converter.convert()
    with open(f'{MODEL_DIR}/model_fp16.tflite', 'wb') as f:
        f.write(tflite_fp16)
    print(f"‚úì model_fp16.tflite ({len(tflite_fp16)/(1024*1024):.2f}MB)")
except Exception as e:
    print(f"‚ö† Error: {e}")

## 10. Verification

In [None]:
print("="*60)
print("VERIFICATION TEST")
print("="*60)

def generate_chars(input_tokens, max_len=20):
    """Generate characters (for kana-kanji)."""
    current = encode(input_tokens)
    generated = []
    
    for _ in range(max_len):
        probs = model.predict(np.array([current]), verbose=0)[0]
        next_idx = np.argmax(probs)
        next_token = idx_to_token.get(next_idx, '<UNK>')
        
        if next_token in ['<EOS>', '<PAD>']:
            break
        generated.append(next_token)
        current = encode(input_tokens + generated)
    
    return ''.join(generated)

def predict_next_phrases(context_words, top_k=5):
    """Predict next COMPLETE PHRASES."""
    input_tokens = ['<CTX>'] + context_words[-3:] + ['<NEXT>']
    current = encode(input_tokens)
    
    probs = model.predict(np.array([current]), verbose=0)[0]
    top_indices = np.argsort(probs)[-100:][::-1]
    
    # Filter: only return phrases (length > 1), not chars
    predictions = []
    for idx in top_indices:
        token = idx_to_token.get(idx, '')
        if token and len(token) > 1 and token not in SPECIAL_TOKENS:
            predictions.append(token)
        if len(predictions) >= top_k:
            break
    
    return predictions

# ============================================================
# Test 1: Kana ‚Üí Kanji
# ============================================================
print("\nüìù Kana‚ÜíKanji Conversion")
print("-" * 40)
tests = ['„Ç¢„É™„Ç¨„Éà„Ç¶', '„Ç¥„Ç∂„Ç§„Éû„Çπ', '„Ç™„Çª„ÉØ']
for kana in tests:
    inp = ['<KANA>'] + tokenize_chars(kana) + ['<KANJI>']
    result = generate_chars(inp, max_len=len(kana)*2)
    print(f"  {kana} ‚Üí {result}")

# ============================================================
# Test 2: Word ‚Üí Next Phrase
# ============================================================
print("\nüìù Next Phrase Prediction")
print("-" * 40)
tests = [
    ['„ÅÇ„Çä„Åå„Å®„ÅÜ'],         # ‚Üí „Åî„Åñ„ÅÑ„Åæ„Åô, „Åî„Åñ„ÅÑ„Åæ„Åó„Åü
    ['„Åä‰∏ñË©±'],             # ‚Üí „Å´„Å™„Å£„Å¶„Åä„Çä„Åæ„Åô
    ['Ë°å„Åç'],               # ‚Üí „Åæ„Åô, „Åü„ÅÑ„Åß„Åô
    ['Áî≥„ÅóË®≥'],             # ‚Üí „ÅÇ„Çä„Åæ„Åõ„Çì, „Åî„Åñ„ÅÑ„Åæ„Åõ„Çì
    ['„Åù„ÅÜ'],               # ‚Üí „Åß„Åô„Å≠, ÊÄù„ÅÑ„Åæ„Åô
]
for ctx in tests:
    result = predict_next_phrases(ctx)
    print(f"  {''.join(ctx)} ‚Üí {result}")

print("\n" + "="*60)
print("‚úÖ VERIFICATION COMPLETE")

In [None]:
# Show word_phrase_suggestions examples
print("\nüìö Sample Word‚ÜíPhrase Suggestions:")
sample_words = ['„ÅÇ„Çä„Åå„Å®„ÅÜ', '„Åä‰∏ñË©±', 'Ë°å„Åç', 'Áî≥„ÅóË®≥', '„Åù„ÅÜ', '‰ªäÊó•', 'Êó•Êú¨']
for word in sample_words:
    phrases = word_phrase_suggestions.get(word, [])
    if phrases:
        print(f"  {word} ‚Üí {phrases[:5]}")

In [None]:
# List exports
print("\nExported files:")
for f in sorted(os.listdir(MODEL_DIR)):
    path = f'{MODEL_DIR}/{f}'
    if os.path.isfile(path):
        size = os.path.getsize(path)
        if size > 1024*1024:
            print(f"  {f}: {size/(1024*1024):.1f} MB")
        else:
            print(f"  {f}: {size/1024:.1f} KB")