# Model 2: Next Word Predictor (Language Model)

**Task:** Predict next words word-by-word
- Input: `[„Åä‰∏ñË©±]`
- Output: `„Å´ ‚Üí „Å™„Å£„Å¶ ‚Üí „Åä„Çä„Åæ„Åô` (word-by-word)

**Architecture:** Bi-GRU + Luong Attention (Word-Level)

**Target:** ~2.5MB, 85%+ accuracy, <5ms inference

## 1. Setup

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')
DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'
MODEL_DIR = f"{DRIVE_DIR}/models/gru_japanese_next_word"
os.makedirs(MODEL_DIR, exist_ok=True)
print(f"‚úì Model: {MODEL_DIR}")

In [None]:
!pip install -q tensorflow keras datasets numpy tqdm fugashi unidic-lite

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

TESTING_MODE = True

if TESTING_MODE:
    NUM_EPOCHS = 4
    BATCH_SIZE = 256
    MAX_SAMPLES = 300000
else:
    NUM_EPOCHS = 25
    BATCH_SIZE = 256
    MAX_SAMPLES = 500000

# Model specs
WORD_VOCAB_SIZE = 6000
MAX_CONTEXT_LEN = 10  # Max words in context
EMBEDDING_DIM = 96
GRU_UNITS = 192

SPECIAL_TOKENS = ['<PAD>', '<UNK>', '<BOS>', '<EOS>']

print(f"Config: epochs={NUM_EPOCHS}, samples={MAX_SAMPLES:,}")
print(f"Model: Embed={EMBEDDING_DIM}, GRU={GRU_UNITS}")

## 2. Load Dataset

In [None]:
from datasets import load_dataset

print("Loading zenz-v2.5-dataset...")

try:
    dataset = load_dataset(
        "Miwa-Keita/zenz-v2.5-dataset",
        data_files="train_wikipedia.jsonl",
        split=f"train[:{MAX_SAMPLES}]"
    )
except:
    dataset = load_dataset(
        "Miwa-Keita/zenz-v2.5-dataset",
        split=f"train[:{MAX_SAMPLES}]"
    )

print(f"‚úì Loaded {len(dataset):,} samples")

## 3. Tokenize Words

In [None]:
import fugashi
from collections import Counter
from tqdm import tqdm

tagger = fugashi.Tagger()

def tokenize_words(text):
    """Word-level tokenization with emoji support."""
    result = []
    for t in tagger(text):
        # Keep words and emojis, filter only whitespace
        if t.feature.pos1 not in ['Á©∫ÁôΩ']:
            result.append(t.surface)
    return result

# Test
print(f"Words: {tokenize_words('ÊúâÈõ£„ÅÜ„Åî„Åñ„ÅÑ„Åæ„Åôüòä')}")

## 4. Build Word Vocabulary

In [None]:
print("Building word vocabulary...")

word_counts = Counter()

for item in tqdm(dataset, desc="Counting words"):
    kanji = item.get('output', '')
    words = tokenize_words(kanji)
    word_counts.update(words)

print(f"‚úì Found {len(word_counts):,} unique words")
print(f"Top 15: {[w for w, c in word_counts.most_common(15)]}")

# Build vocab
word_to_idx = {}
for i, tok in enumerate(SPECIAL_TOKENS):
    word_to_idx[tok] = i

for word, _ in word_counts.most_common(WORD_VOCAB_SIZE - len(SPECIAL_TOKENS)):
    word_to_idx[word] = len(word_to_idx)

idx_to_word = {v: k for k, v in word_to_idx.items()}
vocab_size = len(word_to_idx)

print(f"‚úì Vocab size: {vocab_size:,}")

## 5. Create Training Data (Word-by-Word)

In [None]:
import numpy as np

print("Creating word-by-word training data...")

def encode_words(words, max_len=MAX_CONTEXT_LEN):
    ids = [word_to_idx.get(w, word_to_idx['<UNK>']) for w in words]
    if len(ids) < max_len:
        ids = [word_to_idx['<PAD>']] * (max_len - len(ids)) + ids  # Left-pad
    return ids[-max_len:]  # Keep last N tokens

X_data = []
y_data = []

for item in tqdm(dataset, desc="Processing"):
    kanji = item.get('output', '').strip()
    if not kanji:
        continue
    
    words = tokenize_words(kanji)
    if len(words) < 2:
        continue
    
    # Create training pairs: context ‚Üí next_word
    for i in range(1, len(words)):
        context = words[max(0, i-MAX_CONTEXT_LEN):i]
        next_word = words[i]
        
        # Only train on words in vocabulary
        if next_word not in word_to_idx:
            continue
        
        X_data.append(encode_words(context))
        y_data.append(word_to_idx[next_word])

X_data = np.array(X_data)
y_data = np.array(y_data)

print(f"\n‚úì {len(X_data):,} training samples")
print(f"‚úì Shape: {X_data.shape}")

In [None]:
import tensorflow as tf

# Shuffle and split
indices = np.random.permutation(len(X_data))
X_data = X_data[indices]
y_data = y_data[indices]

split = int(len(X_data) * 0.9)
X_train, X_val = X_data[:split], X_data[split:]
y_train, y_val = y_data[:split], y_data[split:]

train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print(f"‚úì Train: {len(X_train):,}, Val: {len(X_val):,}")

## 6. Build Next Word Model

In [None]:
from tensorflow.keras import mixed_precision
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Embedding, GRU, Dense, Dropout,
    Bidirectional, Attention, Concatenate, LayerNormalization
)

mixed_precision.set_global_policy('mixed_float16')

print("Building Next Word Language Model...")

inputs = Input(shape=(MAX_CONTEXT_LEN,), name='input')

# Embedding
x = Embedding(vocab_size, EMBEDDING_DIM, name='embedding')(inputs)

# Bi-directional GRU
encoder_out = Bidirectional(
    GRU(GRU_UNITS, return_sequences=True, dropout=0.2),
    name='bi_gru'
)(x)

# Luong Attention (self-attention)
attention_out = Attention(use_scale=True, name='attention')(
    [encoder_out, encoder_out]
)

# Combine
combined = Concatenate()([encoder_out, attention_out])
combined = LayerNormalization()(combined)

# Context vector (last state)
context = GRU(GRU_UNITS, name='context_gru')(combined)
context = Dropout(0.3)(context)

# Output: predict next word
outputs = Dense(vocab_size, activation='softmax', dtype='float32', name='output')(context)

model = Model(inputs, outputs, name='next_word_lm')
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3, clipnorm=1.0),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()
print(f"\n‚úì Parameters: {model.count_params():,}")
print(f"‚úì Estimated size: {model.count_params() * 4 / 1024 / 1024:.2f} MB")

## 7. Train

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

callbacks = [
    ModelCheckpoint(f'{MODEL_DIR}/best.keras', monitor='val_accuracy', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6)
]

history = model.fit(
    train_ds,
    epochs=NUM_EPOCHS,
    validation_data=val_ds,
    callbacks=callbacks,
    verbose=1
)

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['loss'], label='Train')
ax1.plot(history.history['val_loss'], label='Val')
ax1.set_title('Loss'); ax1.legend()

ax2.plot(history.history['accuracy'], label='Train')
ax2.plot(history.history['val_accuracy'], label='Val')
ax2.set_title('Accuracy'); ax2.legend()
plt.savefig(f'{MODEL_DIR}/training.png')
plt.show()

print(f"\n‚úì Final Val Accuracy: {history.history['val_accuracy'][-1]*100:.2f}%")

## 8. Save Model

In [None]:
import json

model.save(f'{MODEL_DIR}/model.keras')

with open(f'{MODEL_DIR}/word_to_idx.json', 'w', encoding='utf-8') as f:
    json.dump(word_to_idx, f, ensure_ascii=False)

with open(f'{MODEL_DIR}/idx_to_word.json', 'w', encoding='utf-8') as f:
    json.dump({str(k): v for k, v in idx_to_word.items()}, f, ensure_ascii=False)

config = {
    'vocab_size': vocab_size,
    'max_context_len': MAX_CONTEXT_LEN,
    'embedding_dim': EMBEDDING_DIM,
    'gru_units': GRU_UNITS,
    'architecture': 'BiGRU_LuongAttention_LM',
    'task': 'next_word_prediction',
    'special_tokens': SPECIAL_TOKENS
}
with open(f'{MODEL_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print("‚úì Saved model and config")

## 9. Export TFLite

In [None]:
print("Exporting TFLite...")

try:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter._experimental_lower_tensor_list_ops = False
    
    tflite_model = converter.convert()
    with open(f'{MODEL_DIR}/model.tflite', 'wb') as f:
        f.write(tflite_model)
    print(f"‚úì model.tflite ({len(tflite_model)/(1024*1024):.2f}MB)")
    
    # FP16
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    tflite_fp16 = converter.convert()
    with open(f'{MODEL_DIR}/model_fp16.tflite', 'wb') as f:
        f.write(tflite_fp16)
    print(f"‚úì model_fp16.tflite ({len(tflite_fp16)/(1024*1024):.2f}MB)")
except Exception as e:
    print(f"‚ö† Error: {e}")

## 10. Verification

In [None]:
print("="*60)
print("VERIFICATION: Next Word Prediction")
print("="*60)

def predict_next_word(context_words, top_k=5):
    """Predict next word given context."""
    encoded = np.array([encode_words(context_words)])
    probs = model.predict(encoded, verbose=0)[0]
    
    top_indices = np.argsort(probs)[-top_k*2:][::-1]
    predictions = []
    for idx in top_indices:
        word = idx_to_word.get(idx, '<UNK>')
        if word not in SPECIAL_TOKENS:
            predictions.append(word)
        if len(predictions) >= top_k:
            break
    return predictions

def generate_sequence(start_words, num_words=5):
    """Generate word-by-word sequence."""
    context = list(start_words)
    generated = []
    
    for _ in range(num_words):
        predictions = predict_next_word(context, top_k=1)
        if not predictions or predictions[0] == '<EOS>':
            break
        next_word = predictions[0]
        generated.append(next_word)
        context.append(next_word)
    
    return generated

# Test: Top-K predictions
print("\nüìù Top-5 Next Word Predictions:")
print("-" * 50)
tests = [
    ['„ÅÇ„Çä„Åå„Å®„ÅÜ'],           # ‚Üí „Åî„Åñ„ÅÑ„Åæ„Åô, „Åî„Åñ„ÅÑ„Åæ„Åó„Åü, „Å≠
    ['„Åä‰∏ñË©±'],               # ‚Üí „Å´, „Å´„Å™„Å£„Å¶
    ['Ë°å„Åç'],                 # ‚Üí „Åæ„Åô, „Åü„ÅÑ
    ['Áî≥„ÅóË®≥'],               # ‚Üí „ÅÇ„Çä„Åæ„Åõ„Çì, „Åî„Åñ„ÅÑ„Åæ„Åõ„Çì
    ['„Åù„ÅÜ'],                 # ‚Üí „Åß„Åô, „Å†
    ['‰ªäÊó•'],                 # ‚Üí „ÅØ, „ÅÆ
    ['Êó•Êú¨'],                 # ‚Üí „ÅÆ, „ÅØ
]
for ctx in tests:
    result = predict_next_word(ctx)
    print(f"  {''.join(ctx)} ‚Üí {result}")

# Test: Word-by-word generation
print("\nüìù Word-by-Word Generation:")
print("-" * 50)
generations = [
    ['„Åä‰∏ñË©±'],      # Should generate: „Å´ „Å™„Å£„Å¶ „Åä„Çä„Åæ„Åô
    ['„ÅÇ„Çä„Åå„Å®„ÅÜ'],  # Should generate: „Åî„Åñ„ÅÑ„Åæ„Åô/„Åî„Åñ„ÅÑ„Åæ„Åó„Åü
    ['‰ªäÊó•', '„ÅØ'],  # Should continue
]
for start in generations:
    gen = generate_sequence(start, num_words=4)
    print(f"  {''.join(start)} ‚Üí {''.join(gen)}")

print("\n" + "="*60)
print("‚úÖ VERIFICATION COMPLETE")

In [None]:
# List exports
print("\nExported files:")
for f in sorted(os.listdir(MODEL_DIR)):
    path = f'{MODEL_DIR}/{f}'
    if os.path.isfile(path):
        size = os.path.getsize(path)
        if size > 1024*1024:
            print(f"  {f}: {size/(1024*1024):.1f} MB")
        else:
            print(f"  {f}: {size/1024:.1f} KB")