# Next Word Predictor v2 (Word-Level Language Model)

**Supports**: Google Colab & Kaggle

**Task**: Predict next word from kanji context
- Input: `[‰ªäÊó•, „ÅØ]` ‚Üí Output: `Â§©Ê∞ó` / `Êöë„ÅÑ` / `ËâØ„ÅÑ`

**Architecture**: Bi-GRU + Self-Attention + Context GRU

**v2 Improvements over v1**:
- Uses `left_context + output` combined (5√ó more training pairs)
- Memory-safe: mmap + generator = ~0 GB data in RAM
- Removed mixed_precision (hurt accuracy on small model)
- More data: 2M items ‚Üí ~5M+ training pairs
- Platform detection fix (Colab/Kaggle/Local)
- Cache to drive (.npy for mmap support)

In [None]:
import os
import gc

# Auto-detect platform (Colab check first - Colab also has /kaggle dir!)
if 'COLAB_RELEASE_TAG' in os.environ:
    PLATFORM = 'Colab'
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'
elif os.path.exists('/kaggle/working'):
    PLATFORM = 'Kaggle'
    DRIVE_DIR = '/kaggle/working'
else:
    PLATFORM = 'Local'
    DRIVE_DIR = './output'

MODEL_DIR = f"{DRIVE_DIR}/models/gru_japanese_next_word"
CACHE_DIR = f"{DRIVE_DIR}/cache/nwp"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)

print(f"‚úÖ Platform: {PLATFORM}")
print(f"üìÅ Model: {MODEL_DIR}")
print(f"üíæ Cache: {CACHE_DIR}")

In [None]:
!pip install -q tensorflow keras datasets numpy tqdm fugashi unidic-lite

In [None]:
# ===========================================================
# CONFIGURATION
# ===========================================================
TESTING_MODE = False
MAX_SAMPLES = 5_000_000    # Dataset items to process
MAX_NWP_PAIRS = 8_000_000  # Max training pairs to create
BATCH_SIZE = 512
FORCE_REBUILD_CACHE = False

NUM_EPOCHS = 3 if TESTING_MODE else 15

# Word-level model config
WORD_VOCAB_SIZE = 6000
MAX_WORD_CONTEXT = 10  # Max words in context (left-padded)
EMBEDDING_DIM = 96
GRU_UNITS = 192

SPECIAL_TOKENS = ['<PAD>', '<UNK>', '<BOS>', '<EOS>']
PAD_IDX = 0

print(f"Config: epochs={NUM_EPOCHS}, max_items={MAX_SAMPLES:,}, max_pairs={MAX_NWP_PAIRS:,}")
print(f"Model: vocab={WORD_VOCAB_SIZE}, embed={EMBEDDING_DIM}, GRU={GRU_UNITS}")

## 0. Shared Utilities

In [None]:
import fugashi

tagger = fugashi.Tagger()

def tokenize_words(text):
    """Word-level tokenization using fugashi (MeCab)."""
    if not text:
        return []
    result = []
    for t in tagger(text):
        if t.feature.pos1 not in ['Á©∫ÁôΩ']:  # Skip whitespace
            result.append(t.surface)
    return result

def encode_words(words, vocab, pad_id, unk_id, max_len=None):
    """Encode word list to padded integer IDs (left-padded)."""
    if max_len is None:
        max_len = MAX_WORD_CONTEXT
    ids = [vocab.get(w, unk_id) for w in words]
    if len(ids) < max_len:
        ids = [pad_id] * (max_len - len(ids)) + ids  # Left-pad
    return ids[-max_len:]  # Keep last N tokens

# Quick test
test_words = tokenize_words('‰ªäÊó•„ÅØ„Å®„Å¶„ÇÇÊöë„ÅÑ„Åß„Åô„Å≠')
print(f"‚úì Tokenize test: {test_words}")
print(f"  ({len(test_words)} words)")

## 1. Load or Build Cache

**Key improvement**: Uses `left_context + output` combined for full sentence context.

Before (v1): only `output` ‚Üí 2-3 words ‚Üí ~1 pair/item

After (v2): `left_context + output` ‚Üí 5-10 words ‚Üí ~5 pairs/item

In [None]:
import json
import numpy as np
from tqdm.auto import tqdm

# Cache paths (.npy for mmap support)
VOCAB_CACHE = f"{CACHE_DIR}/word_vocab_v2.json"
NWP_X_CACHE = f"{CACHE_DIR}/nwp_x_v2.npy"
NWP_Y_CACHE = f"{CACHE_DIR}/nwp_y_v2.npy"

def cache_exists():
    return all(os.path.exists(f) for f in [VOCAB_CACHE, NWP_X_CACHE, NWP_Y_CACHE])

if cache_exists() and not FORCE_REBUILD_CACHE:
    print("üì¶ Loading from cache (memory-mapped)...")
    
    with open(VOCAB_CACHE, 'r', encoding='utf-8') as f:
        vocab_data = json.load(f)
    word_to_idx = vocab_data['word_to_idx']
    idx_to_word = {int(k): v for k, v in vocab_data['idx_to_word'].items()}
    vocab_size = len(word_to_idx)
    
    x_mmap = np.load(NWP_X_CACHE, mmap_mode='r')
    y_mmap = np.load(NWP_Y_CACHE, mmap_mode='r')
    
    print(f"‚úì Vocab: {vocab_size:,} words")
    print(f"‚úì Pairs: {len(x_mmap):,} (memory-mapped)")
    CACHE_LOADED = True
else:
    print("üî® Building from scratch (will save to drive)...")
    CACHE_LOADED = False

In [None]:
# Load dataset + build word vocabulary
if not CACHE_LOADED:
    from datasets import load_dataset
    from collections import Counter
    
    print("üì• Loading zenz dataset...")
    dataset = load_dataset(
        "Miwa-Keita/zenz-v2.5-dataset",
        data_files="train_wikipedia.jsonl",
        split="train"
    )
    print(f"‚úì Raw: {len(dataset):,} items")
    
    # Pass 1: Build vocab from left_context + output (combined)
    print("\nüìù Building word vocabulary (left_context + output)...")
    word_counts = Counter()
    processed = 0
    
    for item in tqdm(dataset, desc="Counting words"):
        left_ctx = item.get('left_context', '') or ''
        output = item.get('output', '') or ''
        text = left_ctx + output
        if not text.strip():
            continue
        words = tokenize_words(text)
        word_counts.update(words)
        processed += 1
        if MAX_SAMPLES and processed >= MAX_SAMPLES:
            break
    
    print(f"\n‚úì Found {len(word_counts):,} unique words from {processed:,} items")
    print(f"  Top 15: {[w for w, c in word_counts.most_common(15)]}")
    
    # Build vocab: special tokens first, then most common words
    word_to_idx = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
    for word, _ in word_counts.most_common(WORD_VOCAB_SIZE - len(SPECIAL_TOKENS)):
        word_to_idx[word] = len(word_to_idx)
    
    idx_to_word = {v: k for k, v in word_to_idx.items()}
    vocab_size = len(word_to_idx)
    print(f"‚úì Vocab size: {vocab_size:,}")
    
    # Save vocab
    with open(VOCAB_CACHE, 'w', encoding='utf-8') as f:
        json.dump({
            'word_to_idx': word_to_idx,
            'idx_to_word': {str(k): v for k, v in idx_to_word.items()}
        }, f, ensure_ascii=False)
    print(f"‚úì Vocab saved to {VOCAB_CACHE}")
    
    del word_counts
    gc.collect()

In [None]:
# Create training pairs from left_context + output
# Example: "Â§©Ê∞ó„ÅåËâØ„ÅÑ‰ªäÊó•„ÅØÊöë„ÅÑ" ‚Üí ["Â§©Ê∞ó","„Åå","ËâØ„ÅÑ","‰ªäÊó•","„ÅØ","Êöë„ÅÑ"]
#   Pairs: [Â§©Ê∞ó]‚Üí„Åå, [Â§©Ê∞ó,„Åå]‚ÜíËâØ„ÅÑ, [Â§©Ê∞ó,„Åå,ËâØ„ÅÑ]‚Üí‰ªäÊó•, ...
if not CACHE_LOADED:
    print("\nüî¢ Creating training pairs (left_context + output)...")
    
    PAD = word_to_idx['<PAD>']
    UNK = word_to_idx['<UNK>']
    
    # Pre-allocate arrays (fill up to MAX_NWP_PAIRS)
    X = np.zeros((MAX_NWP_PAIRS, MAX_WORD_CONTEXT), dtype=np.int32)
    y = np.zeros(MAX_NWP_PAIRS, dtype=np.int32)
    pair_idx = 0
    processed = 0
    
    for item in tqdm(dataset, desc="Creating pairs"):
        left_ctx = item.get('left_context', '') or ''
        output = item.get('output', '') or ''
        text = left_ctx + output
        if not text.strip():
            continue
        
        words = tokenize_words(text)
        if len(words) < 2:
            continue
        
        # Create sliding window pairs: context ‚Üí next_word
        for i in range(1, len(words)):
            next_word = words[i]
            if next_word not in word_to_idx:
                continue
            
            context = words[max(0, i - MAX_WORD_CONTEXT):i]
            X[pair_idx] = encode_words(context, word_to_idx, PAD, UNK)
            y[pair_idx] = word_to_idx[next_word]
            pair_idx += 1
            
            if pair_idx >= MAX_NWP_PAIRS:
                break
        
        if pair_idx >= MAX_NWP_PAIRS:
            break
        
        processed += 1
        if MAX_SAMPLES and processed >= MAX_SAMPLES:
            break
    
    # Trim to actual size
    X = X[:pair_idx]
    y = y[:pair_idx]
    print(f"\n‚úì Created {pair_idx:,} training pairs from {processed:,} items")
    print(f"  Avg pairs/item: {pair_idx / max(processed, 1):.1f}")
    
    # Save as .npy and release
    np.save(NWP_X_CACHE, X)
    np.save(NWP_Y_CACHE, y)
    del X, y
    gc.collect()
    
    # Release dataset
    del dataset
    gc.collect()
    print("üßπ Saved cache, released memory")
    
    # Load as memory-mapped
    x_mmap = np.load(NWP_X_CACHE, mmap_mode='r')
    y_mmap = np.load(NWP_Y_CACHE, mmap_mode='r')
    print(f"‚úì Loaded as mmap: X={x_mmap.shape}, y={y_mmap.shape}")

print(f"\nüìä Total pairs: {len(x_mmap):,}")

## 2. Create Dataset

In [None]:
import tensorflow as tf

n_samples = len(x_mmap)
split = int(n_samples * 0.9)

# Random shuffle indices
indices = np.random.permutation(n_samples).astype(np.int32)
train_idx = indices[:split]
val_idx = indices[split:]

def make_generator(x, y_arr, idx_arr):
    """Generator reads from mmap arrays (zero RAM copy)."""
    def gen():
        for i in idx_arr:
            yield x[i], y_arr[i]
    return gen

output_sig = (
    tf.TensorSpec(shape=(MAX_WORD_CONTEXT,), dtype=tf.int32),
    tf.TensorSpec(shape=(), dtype=tf.int32),
)

train_ds = tf.data.Dataset.from_generator(
    make_generator(x_mmap, y_mmap, train_idx),
    output_signature=output_sig
).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_generator(
    make_generator(x_mmap, y_mmap, val_idx),
    output_signature=output_sig
).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print(f"Train: {len(train_idx):,}, Val: {len(val_idx):,}")
print(f"üí° Data loaded via mmap + generator (near-zero RAM)")

## 3. Build Model (Bi-GRU + Self-Attention)

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Embedding, GRU, Dense, Dropout,
    Bidirectional, Attention, Concatenate, LayerNormalization
)

inputs = Input(shape=(MAX_WORD_CONTEXT,), name='input')

# Embedding
x = Embedding(vocab_size, EMBEDDING_DIM, name='embedding')(inputs)

# Bidirectional GRU
encoder_out = Bidirectional(
    GRU(GRU_UNITS, return_sequences=True, dropout=0.2),
    name='bi_gru'
)(x)

# Self-Attention (Luong-style)
attention_out = Attention(use_scale=True, name='attention')(
    [encoder_out, encoder_out]
)

# Combine encoder + attention
combined = Concatenate()([encoder_out, attention_out])
combined = LayerNormalization()(combined)

# Context GRU (compress to single vector)
context = GRU(GRU_UNITS, name='context_gru')(combined)
context = Dropout(0.3)(context)

# Output: predict next word
outputs = Dense(vocab_size, activation='softmax', name='output')(context)

model = Model(inputs, outputs, name='next_word_lm_v2')

# Gradient clipping for stable training
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3, clipnorm=1.0),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()
params = model.count_params()
print(f"\nüìä Parameters: {params:,}")
print(f"   FP32: ~{params * 4 / 1024 / 1024:.1f} MB")
print(f"   FP16: ~{params * 2 / 1024 / 1024:.1f} MB")

## 4. Train

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

# Calculate steps (from_generator doesn't auto-detect size)
steps_per_epoch = len(train_idx) // BATCH_SIZE
validation_steps = len(val_idx) // BATCH_SIZE

callbacks = [
    ModelCheckpoint(
        f'{MODEL_DIR}/best_v2.keras',
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=2,
        min_lr=1e-6,
        verbose=1
    )
]

history = model.fit(
    train_ds,
    epochs=NUM_EPOCHS,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_ds,
    validation_steps=validation_steps,
    callbacks=callbacks
)

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['loss'], label='Train')
ax1.plot(history.history['val_loss'], label='Val')
ax1.set_title('Loss'); ax1.legend()

ax2.plot(history.history['accuracy'], label='Train')
ax2.plot(history.history['val_accuracy'], label='Val')
ax2.set_title('Accuracy'); ax2.legend()

plt.savefig(f'{MODEL_DIR}/training_v2.png')
plt.show()
print(f"Best val accuracy: {max(history.history['val_accuracy'])*100:.2f}%")

## 5. Save & Export

In [None]:
# Save model + vocab + config
model.save(f'{MODEL_DIR}/model.keras')

with open(f'{MODEL_DIR}/word_to_idx.json', 'w', encoding='utf-8') as f:
    json.dump(word_to_idx, f, ensure_ascii=False)

with open(f'{MODEL_DIR}/idx_to_word.json', 'w', encoding='utf-8') as f:
    json.dump({str(k): v for k, v in idx_to_word.items()}, f, ensure_ascii=False)

with open(f'{MODEL_DIR}/config.json', 'w') as f:
    json.dump({
        'vocab_size': vocab_size,
        'max_context_len': MAX_WORD_CONTEXT,
        'embedding_dim': EMBEDDING_DIM,
        'gru_units': GRU_UNITS,
        'architecture': 'BiGRU_SelfAttention_ContextGRU',
        'special_tokens': SPECIAL_TOKENS,
        'version': 'v2'
    }, f, indent=2)

keras_size = os.path.getsize(f'{MODEL_DIR}/model.keras')
print(f"‚úì Model saved: {keras_size / 1024 / 1024:.2f} MB")

In [None]:
# Export TFLite
try:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS
    ]
    converter._experimental_lower_tensor_list_ops = False
    
    tflite = converter.convert()
    with open(f'{MODEL_DIR}/model.tflite', 'wb') as f:
        f.write(tflite)
    print(f"‚úì model.tflite ({len(tflite)/(1024*1024):.2f} MB)")
    
    # FP16 version (smaller, same accuracy)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    tflite16 = converter.convert()
    with open(f'{MODEL_DIR}/model_fp16.tflite', 'wb') as f:
        f.write(tflite16)
    print(f"‚úì model_fp16.tflite ({len(tflite16)/(1024*1024):.2f} MB)")
    
except Exception as e:
    print(f"‚ö† TFLite export failed: {e}")

## 6. Verification

In [None]:
print("="*60)
print("VERIFICATION: Next Word Prediction v2")
print("="*60)

PAD = word_to_idx['<PAD>']
UNK = word_to_idx['<UNK>']

def predict_next_word(context_words, top_k=5):
    """Predict next word given context words."""
    encoded = np.array([encode_words(context_words, word_to_idx, PAD, UNK)])
    probs = model.predict(encoded, verbose=0)[0]
    
    top_indices = np.argsort(probs)[-top_k*2:][::-1]
    predictions = []
    for idx in top_indices:
        word = idx_to_word.get(idx, '<UNK>')
        if word not in SPECIAL_TOKENS:
            predictions.append((word, float(probs[idx])))
        if len(predictions) >= top_k:
            break
    return predictions

def generate_sequence(start_words, num_words=5):
    """Generate word-by-word sequence."""
    context = list(start_words)
    generated = []
    for _ in range(num_words):
        preds = predict_next_word(context, top_k=1)
        if not preds or preds[0][0] == '<EOS>':
            break
        next_word = preds[0][0]
        generated.append(next_word)
        context.append(next_word)
    return generated

# ==========================================================
# Test: Top-K predictions
# ==========================================================
print("\nüìù Top-5 Next Word Predictions:")
print("-" * 50)
tests = [
    (['„ÅÇ„Çä„Åå„Å®„ÅÜ'],           '‚Üí „Åî„Åñ„ÅÑ„Åæ„Åô'),
    (['„Åä', '‰∏ñË©±'],           '‚Üí „Å´'),
    (['‰ªäÊó•', '„ÅØ'],           '‚Üí particle/topic'),
    (['Êó•Êú¨', '„ÅÆ'],           '‚Üí contextual'),
    (['Áî≥„ÅóË®≥'],               '‚Üí „Åî„Åñ„ÅÑ„Åæ„Åõ„Çì'),
    (['Êù±‰∫¨', '„Å´'],           '‚Üí location'),
    (['„Åù„Çå', '„ÅØ'],           '‚Üí contextual'),
    (['Ë°å„Åç', '„Åü„ÅÑ'],         '‚Üí „Å®/„Åß„Åô'),
    (['Â§ß', 'Â≠¶'],             '‚Üí „ÅÆ/„Å´/„Åß'),
    (['ÂïèÈ°å', '„Åå'],           '‚Üí „ÅÇ„Çã/„Å™„ÅÑ'),
]
for ctx, hint in tests:
    result = predict_next_word(ctx)
    words = [f"{w}({p:.2f})" for w, p in result[:5]]
    print(f"  {''.join(ctx)} {hint}")
    print(f"    ‚Üí {', '.join(words)}")

# ==========================================================
# Test: Word-by-word generation
# ==========================================================
print("\nüìù Word-by-Word Generation:")
print("-" * 50)
generations = [
    ['„ÅÇ„Çä„Åå„Å®„ÅÜ'],
    ['‰ªäÊó•', '„ÅØ'],
    ['Êó•Êú¨', '„ÅÆ'],
    ['„Åä', '‰∏ñË©±'],
]
for start in generations:
    gen = generate_sequence(start, num_words=4)
    print(f"  {''.join(start)} ‚Üí {''.join(gen)}")

print("\n" + "="*60)
print("‚úÖ VERIFICATION COMPLETE")

In [None]:
# List exported files
print(f"\nüì¶ Files ({PLATFORM}):")
for f in sorted(os.listdir(MODEL_DIR)):
    p = f'{MODEL_DIR}/{f}'
    if os.path.isfile(p):
        s = os.path.getsize(p)
        if s > 1024*1024:
            print(f"  {f}: {s/(1024*1024):.2f} MB")
        else:
            print(f"  {f}: {s/1024:.1f} KB")