# Context-Aware Kana-Kanji Converter (v3.1)

**Supports**: Google Colab & Kaggle (Multi-GPU)

**Input Format**: `context<SEP>kana` ‚Üí `kanji`

**Testing workflow**:
1. Set `TESTING_MODE = True` ‚Üí 10K samples, 3 epochs
2. Train ‚Üí check loss decreasing, accuracy improving
3. Verify with real test cases from dataset
4. Set `TESTING_MODE = False` ‚Üí full training

In [None]:
import os
import gc

# Auto-detect platform (Colab check first - Colab also has /kaggle dir!)
if 'COLAB_RELEASE_TAG' in os.environ:
    PLATFORM = 'Colab'
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'
elif os.path.exists('/kaggle/working'):
    PLATFORM = 'Kaggle'
    DRIVE_DIR = '/kaggle/working'
else:
    PLATFORM = 'Local'
    DRIVE_DIR = './output'

MODEL_DIR = f"{DRIVE_DIR}/models/gru_japanese_kana_kanji"
CACHE_DIR = f"{DRIVE_DIR}/cache/kana_kanji"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)

print(f"‚úÖ Platform: {PLATFORM}")
print(f"üìÅ Model: {MODEL_DIR}")
print(f"üíæ Cache: {CACHE_DIR}")

In [None]:
!pip install -q tensorflow keras datasets numpy tqdm

In [None]:
import tensorflow as tf

# ===========================================================
# MULTI-GPU + MIXED PRECISION
# ===========================================================
strategy = tf.distribute.MirroredStrategy()
NUM_GPUS = strategy.num_replicas_in_sync
print(f"üî• GPUs available: {NUM_GPUS}")

# Mixed precision: T4 has good FP16 Tensor Cores
tf.keras.mixed_precision.set_global_policy('mixed_float16')
print(f"‚ö° Mixed precision: {tf.keras.mixed_precision.global_policy().name}")

In [None]:
# ===========================================================
# CONFIGURATION
# ===========================================================
# ‚ö†Ô∏è Set True for quick logic test (10K samples, 3 epochs)
# Set False for full training (8M samples, 10 epochs)
TESTING_MODE = True

if TESTING_MODE:
    MAX_SAMPLES = 10_000
    NUM_EPOCHS = 3
    CACHE_SUFFIX = '_test'
    print("‚ö†Ô∏è TESTING MODE: 10K samples, 3 epochs")
else:
    MAX_SAMPLES = 8_000_000
    NUM_EPOCHS = 10
    CACHE_SUFFIX = ''
    print("üöÄ FULL TRAINING: 8M samples, 10 epochs")

BATCH_SIZE = 512 * NUM_GPUS  # Scale batch with GPUs (512 per GPU)
FORCE_REBUILD_CACHE = False

# Length limits (filter long sequences)
MAX_CONTEXT_LEN = 30   # left_context max chars
MAX_INPUT_LEN = 30     # kana input max chars
MAX_OUTPUT_LEN = 30    # kanji output max chars
MAX_ENCODER_LEN = MAX_CONTEXT_LEN + 1 + MAX_INPUT_LEN  # context + <SEP>(1 token) + input
MAX_DECODER_LEN = MAX_OUTPUT_LEN + 2  # BOS + content + EOS

# Architecture (controls model size)
CHAR_VOCAB_SIZE = 6000
EMBEDDING_DIM = 64
GRU_UNITS = 128
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2

SPECIAL_TOKENS = ['<PAD>', '<UNK>', '<BOS>', '<EOS>', '<SEP>']
PAD_IDX = 0  # <PAD> is always index 0
SEP_TOKEN = '<SEP>'

print(f"Config: epochs={NUM_EPOCHS}, batch={BATCH_SIZE} ({BATCH_SIZE//NUM_GPUS}/GPU)")
print(f"Encoder max: {MAX_ENCODER_LEN}, Decoder max: {MAX_DECODER_LEN}")

## 0. Shared Utilities

In [None]:
def tokenize_with_sep(text):
    """Tokenize text handling <SEP> as single token."""
    tokens = []
    i = 0
    while i < len(text):
        if text[i:i+5] == '<SEP>':
            tokens.append('<SEP>')
            i += 5
        else:
            tokens.append(text[i])
            i += 1
    return tokens

def encode_tokens(tokens, vocab, max_len, pad_id, unk_id):
    """Encode token list to padded integer IDs."""
    ids = [vocab.get(t, unk_id) for t in tokens][:max_len]
    return ids + [pad_id] * (max_len - len(ids))

def encode_encoder_input(text, vocab, pad_id, unk_id):
    """Encode encoder input (context<SEP>kana)."""
    tokens = tokenize_with_sep(text)
    return encode_tokens(tokens, vocab, MAX_ENCODER_LEN, pad_id, unk_id)

def encode_decoder_seq(text, vocab, pad_id, unk_id, add_bos=False, add_eos=False):
    """Encode decoder sequence with optional BOS/EOS."""
    tokens = []
    if add_bos:
        tokens.append('<BOS>')
    tokens.extend(list(text))
    if add_eos:
        tokens.append('<EOS>')
    return encode_tokens(tokens, vocab, MAX_DECODER_LEN, pad_id, unk_id)

print("‚úì Shared utilities loaded")

## 1. Load or Build Cached Data

Testing mode uses separate cache files so it won't overwrite full cache.

In [None]:
import json
import numpy as np
from tqdm.auto import tqdm

# Cache file paths ‚Äî separate for test vs full
VOCAB_CACHE = f"{CACHE_DIR}/vocabulary_v3{CACHE_SUFFIX}.json"
ENC_CACHE = f"{CACHE_DIR}/enc_v3_1{CACHE_SUFFIX}.npy"
DEC_IN_CACHE = f"{CACHE_DIR}/dec_in_v3_1{CACHE_SUFFIX}.npy"
DEC_TGT_CACHE = f"{CACHE_DIR}/dec_tgt_v3_1{CACHE_SUFFIX}.npy"

def cache_exists():
    return all(os.path.exists(f) for f in [VOCAB_CACHE, ENC_CACHE, DEC_IN_CACHE, DEC_TGT_CACHE])

if cache_exists() and not FORCE_REBUILD_CACHE:
    print("üì¶ Loading from cache (memory-mapped, near-zero RAM)...")
    
    with open(VOCAB_CACHE, 'r', encoding='utf-8') as f:
        vocab_data = json.load(f)
    char_to_idx = vocab_data['char_to_idx']
    idx_to_char = {int(k): v for k, v in vocab_data['idx_to_char'].items()}
    vocab_size = len(char_to_idx)
    
    enc_mmap = np.load(ENC_CACHE, mmap_mode='r')
    dec_in_mmap = np.load(DEC_IN_CACHE, mmap_mode='r')
    dec_tgt_mmap = np.load(DEC_TGT_CACHE, mmap_mode='r')
    
    print(f"‚úì Loaded {len(enc_mmap):,} samples (memory-mapped)")
    CACHE_LOADED = True
else:
    print("üî® Building from scratch (will save to drive for next time)...")
    CACHE_LOADED = False

In [None]:
# Load dataset directly - no tokenization needed!
if not CACHE_LOADED:
    from datasets import load_dataset
    
    print("üì• Loading zenz dataset...")
    dataset = load_dataset(
        "Miwa-Keita/zenz-v2.5-dataset",
        data_files="train_wikipedia.jsonl",
        split="train"
    )
    print(f"‚úì Raw dataset: {len(dataset):,} items")

In [None]:
# Filter and prepare training data
if not CACHE_LOADED:
    print(f"\nüîç Filtering data (limit: {MAX_SAMPLES:,} items)...")
    
    training_data = []
    raw_samples = []  # Save for inspection
    skipped = {'too_long': 0, 'empty': 0}
    
    for item in tqdm(dataset, desc="Processing"):
        kana_input = item.get('input', '') or ''
        kanji_output = item.get('output', '') or ''
        left_context = item.get('left_context', '') or ''
        
        if not kana_input or not kanji_output:
            skipped['empty'] += 1
            continue
        
        if (len(left_context) > MAX_CONTEXT_LEN or 
            len(kana_input) > MAX_INPUT_LEN or 
            len(kanji_output) > MAX_OUTPUT_LEN):
            skipped['too_long'] += 1
            continue
        
        encoder_text = f"{left_context}<SEP>{kana_input}"
        training_data.append({
            'input': encoder_text,
            'output': kanji_output,
            'input_len': len(kana_input)
        })
        
        # Save raw samples for inspection (first 100)
        if len(raw_samples) < 100:
            raw_samples.append({
                'left_context': left_context,
                'kana_input': kana_input,
                'kanji_output': kanji_output,
                'encoder_text': encoder_text
            })
        
        if MAX_SAMPLES and len(training_data) >= MAX_SAMPLES:
            break
    
    print(f"\n‚úì Valid examples: {len(training_data):,}")
    print(f"  Skipped (too long): {skipped['too_long']:,}")
    print(f"  Skipped (empty): {skipped['empty']:,}")
    
    # üíæ Save raw samples for inspection
    SAMPLES_FILE = f"{CACHE_DIR}/raw_samples_kkc{CACHE_SUFFIX}.json"
    with open(SAMPLES_FILE, 'w', encoding='utf-8') as f:
        json.dump(raw_samples, f, ensure_ascii=False, indent=2)
    print(f"üíæ Saved {len(raw_samples)} raw samples ‚Üí {SAMPLES_FILE}")
    
    # Show a few examples
    print("\nüìù Sample data:")
    for s in raw_samples[:5]:
        print(f"  ctx: {s['left_context'][:20]}... | {s['kana_input']} ‚Üí {s['kanji_output']}")
    
    # üßπ Release HuggingFace dataset
    del dataset, skipped, raw_samples
    gc.collect()
    print("üßπ Released dataset from memory")

In [None]:
# Sort by length for bucketing (helps GRU training stability)
if not CACHE_LOADED:
    print("\nüìä Sorting by length (bucketing)...")
    training_data.sort(key=lambda x: x['input_len'])
    
    lengths = [d['input_len'] for d in training_data]
    print(f"  Short (0-10): {sum(1 for l in lengths if l <= 10):,}")
    print(f"  Medium (11-20): {sum(1 for l in lengths if 10 < l <= 20):,}")
    print(f"  Long (21+): {sum(1 for l in lengths if l > 20):,}")
    del lengths

In [None]:
# Build vocabulary
if not CACHE_LOADED:
    from collections import Counter
    
    print("\nüìù Building vocabulary...")
    char_counts = Counter()
    
    for d in tqdm(training_data, desc="Counting chars"):
        text = d['input'].replace('<SEP>', '') + d['output']
        char_counts.update(list(text))
    
    char_to_idx = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
    for char, _ in char_counts.most_common(CHAR_VOCAB_SIZE - len(SPECIAL_TOKENS)):
        char_to_idx[char] = len(char_to_idx)
    
    idx_to_char = {v: k for k, v in char_to_idx.items()}
    vocab_size = len(char_to_idx)
    print(f"‚úì Vocab size: {vocab_size}")
    
    with open(VOCAB_CACHE, 'w', encoding='utf-8') as f:
        json.dump({
            'char_to_idx': char_to_idx,
            'idx_to_char': {str(k): v for k, v in idx_to_char.items()}
        }, f, ensure_ascii=False)
    print(f"‚úì Vocab saved to {VOCAB_CACHE}")
    
    del char_counts
    gc.collect()

In [None]:
# Encode to tensors ‚Äî one array at a time to save memory!
if not CACHE_LOADED:
    PAD = char_to_idx['<PAD>']
    UNK = char_to_idx['<UNK>']
    n = len(training_data)
    
    print(f"\nüî¢ Encoding {n:,} samples (one array at a time)...")
    
    # --- 1. Encoder inputs ---
    arr = np.zeros((n, MAX_ENCODER_LEN), dtype=np.int32)
    for i, d in enumerate(tqdm(training_data, desc="Enc input")):
        arr[i] = encode_encoder_input(d['input'], char_to_idx, PAD, UNK)
    np.save(ENC_CACHE, arr)
    del arr; gc.collect()
    print(f"‚úì Saved encoder_inputs ‚Üí {ENC_CACHE}")
    
    # --- 2. Decoder inputs (with BOS) ---
    arr = np.zeros((n, MAX_DECODER_LEN), dtype=np.int32)
    for i, d in enumerate(tqdm(training_data, desc="Dec input")):
        arr[i] = encode_decoder_seq(d['output'], char_to_idx, PAD, UNK, add_bos=True)
    assert arr[0][0] == char_to_idx['<BOS>'], f"Expected BOS, got {arr[0][0]}"
    np.save(DEC_IN_CACHE, arr)
    del arr; gc.collect()
    print(f"‚úì Saved decoder_inputs ‚Üí {DEC_IN_CACHE}")
    
    # --- 3. Decoder targets (with EOS) ---
    arr = np.zeros((n, MAX_DECODER_LEN), dtype=np.int32)
    for i, d in enumerate(tqdm(training_data, desc="Dec target")):
        arr[i] = encode_decoder_seq(d['output'], char_to_idx, PAD, UNK, add_eos=True)
    assert char_to_idx['<EOS>'] in list(arr[0]), "Decoder target should contain EOS"
    np.save(DEC_TGT_CACHE, arr)
    del arr; gc.collect()
    print(f"‚úì Saved decoder_targets ‚Üí {DEC_TGT_CACHE}")
    
    # üíæ Save some test cases from the dataset for verification later
    TEST_CASES_FILE = f"{CACHE_DIR}/test_cases_kkc{CACHE_SUFFIX}.json"
    test_cases_data = []
    # Pick diverse examples: every N-th item to get different lengths
    step = max(1, len(training_data) // 20)
    for i in range(0, len(training_data), step):
        d = training_data[i]
        # Parse back context and kana from encoder text
        parts = d['input'].split('<SEP>')
        if len(parts) == 2:
            test_cases_data.append({
                'context': parts[0],
                'kana': parts[1],
                'expected': d['output']
            })
        if len(test_cases_data) >= 20:
            break
    
    with open(TEST_CASES_FILE, 'w', encoding='utf-8') as f:
        json.dump(test_cases_data, f, ensure_ascii=False, indent=2)
    print(f"üíæ Saved {len(test_cases_data)} test cases ‚Üí {TEST_CASES_FILE}")
    
    # üßπ Release training_data
    del training_data, test_cases_data
    gc.collect()
    print("\nüßπ All arrays saved. Released training_data from memory.")
    
    # Load as memory-mapped
    enc_mmap = np.load(ENC_CACHE, mmap_mode='r')
    dec_in_mmap = np.load(DEC_IN_CACHE, mmap_mode='r')
    dec_tgt_mmap = np.load(DEC_TGT_CACHE, mmap_mode='r')
    print(f"‚úì Loaded as mmap: enc={enc_mmap.shape}, dec_in={dec_in_mmap.shape}")

print(f"\nüìä Data: {len(enc_mmap):,} samples")

## 2. Create Dataset

In [None]:
n_samples = len(enc_mmap)
split = int(n_samples * 0.9)

# Bucket shuffle: shuffle within length groups
def bucket_shuffle_indices(n, bucket_size=50000):
    indices = []
    for start in range(0, n, bucket_size):
        end = min(start + bucket_size, n)
        bucket_idx = list(range(start, end))
        np.random.shuffle(bucket_idx)
        indices.extend(bucket_idx)
    return np.array(indices, dtype=np.int32)

all_indices = bucket_shuffle_indices(n_samples)
train_indices = all_indices[:split]
val_indices = all_indices[split:]

def make_generator(enc, dec_in, dec_tgt, indices):
    def gen():
        for i in indices:
            yield (
                {'encoder_input': enc[i], 'decoder_input': dec_in[i]},
                dec_tgt[i]
            )
    return gen

output_sig = (
    {
        'encoder_input': tf.TensorSpec(shape=(MAX_ENCODER_LEN,), dtype=tf.int32),
        'decoder_input': tf.TensorSpec(shape=(MAX_DECODER_LEN,), dtype=tf.int32),
    },
    tf.TensorSpec(shape=(MAX_DECODER_LEN,), dtype=tf.int32),
)

# .repeat() is required: from_generator is one-shot
train_ds = tf.data.Dataset.from_generator(
    make_generator(enc_mmap, dec_in_mmap, dec_tgt_mmap, train_indices),
    output_signature=output_sig
).repeat().batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_generator(
    make_generator(enc_mmap, dec_in_mmap, dec_tgt_mmap, val_indices),
    output_signature=output_sig
).repeat().batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print(f"Train: {len(train_indices):,}, Val: {len(val_indices):,}")
print(f"üí° Data loaded via mmap + generator (near-zero RAM)")

## 3. Build Model (Bidirectional GRU + Attention)

Model is built inside `strategy.scope()` for multi-GPU training.

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, GRU, Dense, Dropout, Bidirectional, Attention, Concatenate, LayerNormalization

with strategy.scope():
    # Shared embedding
    emb = Embedding(vocab_size, EMBEDDING_DIM, mask_zero=False, name='embedding')

    # Encoder (Bidirectional GRU)
    enc_in = Input(shape=(MAX_ENCODER_LEN,), dtype='int32', name='encoder_input')
    x = Dropout(0.1)(emb(enc_in))
    for i in range(NUM_ENCODER_LAYERS):
        x = LayerNormalization()(Bidirectional(GRU(GRU_UNITS, return_sequences=True), name=f'enc_{i+1}')(x))
    enc_out = x

    # Decoder (GRU with Attention)
    dec_in = Input(shape=(MAX_DECODER_LEN,), dtype='int32', name='decoder_input')
    y = Dropout(0.1)(emb(dec_in))
    for i in range(NUM_DECODER_LAYERS):
        y = LayerNormalization()(GRU(GRU_UNITS * 2, return_sequences=True, name=f'dec_{i+1}')(y))

    # Attention mechanism
    ctx = Attention(use_scale=True, name='attn')([y, enc_out])

    # Output
    combined = Concatenate()([y, ctx])
    combined = LayerNormalization()(combined)
    combined = Dropout(0.2)(combined)
    combined = Dense(GRU_UNITS * 2, activation='relu')(combined)
    out = Dense(vocab_size, activation='softmax', name='output', dtype='float32')(combined)

    model = Model([enc_in, dec_in], out, name='kana_kanji_v3_1')

    # Masked loss to ignore PAD tokens
    def masked_sparse_ce(y_true, y_pred):
        mask = tf.cast(tf.not_equal(y_true, PAD_IDX), tf.float32)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
        masked_loss = loss * mask
        return tf.reduce_sum(masked_loss) / (tf.reduce_sum(mask) + 1e-8)

    def masked_accuracy(y_true, y_pred):
        mask = tf.cast(tf.not_equal(y_true, PAD_IDX), tf.float32)
        pred_ids = tf.cast(tf.argmax(y_pred, axis=-1), y_true.dtype)
        correct = tf.cast(tf.equal(y_true, pred_ids), tf.float32) * mask
        return tf.reduce_sum(correct) / (tf.reduce_sum(mask) + 1e-8)

    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3, clipnorm=1.0)

    model.compile(
        optimizer=optimizer,
        loss=masked_sparse_ce,
        metrics=[masked_accuracy]
    )

actual_params = model.count_params()
print(f"üìä Actual params: {actual_params:,}")
print(f"   FP32: ~{actual_params * 4 / 1024 / 1024:.1f} MB")
print(f"   FP16: ~{actual_params * 2 / 1024 / 1024:.1f} MB")
model.summary()

## 4. Train

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

steps_per_epoch = len(train_indices) // BATCH_SIZE
validation_steps = max(1, len(val_indices) // BATCH_SIZE)

callbacks = [
    ModelCheckpoint(f'{MODEL_DIR}/best_v3_1.keras', monitor='val_masked_accuracy', save_best_only=True, verbose=1),
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6, verbose=1)
]

print(f"Steps/epoch: {steps_per_epoch}, Val steps: {validation_steps}")

history = model.fit(
    train_ds,
    epochs=NUM_EPOCHS,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_ds,
    validation_steps=validation_steps,
    callbacks=callbacks
)

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['loss'], label='Train')
ax1.plot(history.history['val_loss'], label='Val')
ax1.legend(); ax1.set_title('Loss (masked)')

ax2.plot(history.history['masked_accuracy'], label='Train')
ax2.plot(history.history['val_masked_accuracy'], label='Val')
ax2.legend(); ax2.set_title('Accuracy (masked)')

plt.savefig(f'{MODEL_DIR}/training_v3_1.png')
plt.show()

# ‚úÖ Logic check
losses = history.history['loss']
accs = history.history['masked_accuracy']
print(f"\nüìä Training Summary:")
print(f"  Loss:     {losses[0]:.4f} ‚Üí {losses[-1]:.4f} ({'‚úÖ decreasing' if losses[-1] < losses[0] else '‚ùå NOT decreasing'})")
print(f"  Accuracy: {accs[0]*100:.2f}% ‚Üí {accs[-1]*100:.2f}% ({'‚úÖ increasing' if accs[-1] > accs[0] else '‚ùå NOT increasing'})")
print(f"  Best val accuracy: {max(history.history['val_masked_accuracy'])*100:.2f}%")

## 5. Save

In [None]:
model.save(f'{MODEL_DIR}/model.keras')

with open(f'{MODEL_DIR}/char_to_idx.json', 'w', encoding='utf-8') as f:
    json.dump(char_to_idx, f, ensure_ascii=False)

with open(f'{MODEL_DIR}/idx_to_char.json', 'w', encoding='utf-8') as f:
    json.dump({str(k): v for k, v in idx_to_char.items()}, f, ensure_ascii=False)

with open(f'{MODEL_DIR}/config.json', 'w') as f:
    json.dump({
        'vocab_size': vocab_size,
        'max_encoder_len': MAX_ENCODER_LEN,
        'max_decoder_len': MAX_DECODER_LEN,
        'sep_token': SEP_TOKEN,
        'version': 'v3.1'
    }, f)

keras_size = os.path.getsize(f'{MODEL_DIR}/model.keras')
print(f"‚úì Model saved: {keras_size / 1024 / 1024:.2f} MB")

In [None]:
# Export TFLite
try:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter._experimental_lower_tensor_list_ops = False
    
    tflite = converter.convert()
    with open(f'{MODEL_DIR}/model.tflite', 'wb') as f:
        f.write(tflite)
    print(f"‚úì model.tflite ({len(tflite)/(1024*1024):.2f}MB)")
    
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    tflite16 = converter.convert()
    with open(f'{MODEL_DIR}/model_fp16.tflite', 'wb') as f:
        f.write(tflite16)
    print(f"‚úì model_fp16.tflite ({len(tflite16)/(1024*1024):.2f}MB)")
    
except Exception as e:
    print(f"‚ö† TFLite export failed: {e}")

## 6. Verification ‚Äî Real Test Cases from Dataset

Uses **real examples from training data** to verify the model learned.

**What to check**:
- ‚úÖ Output matches or partially matches expected kanji
- ‚úÖ Model doesn't output empty strings or garbage
- ‚úÖ Context influences the output (different context ‚Üí different result)

In [None]:
print("="*60)
print("VERIFICATION: Real Test Cases from Dataset")
print("="*60)

PAD = char_to_idx['<PAD>']
BOS = char_to_idx['<BOS>']
EOS = char_to_idx['<EOS>']
UNK = char_to_idx['<UNK>']
SEP = char_to_idx['<SEP>']

def convert(context, kana):
    """Convert kana to kanji using context. Uses shared encode function."""
    enc_text = f"{context}<SEP>{kana}"
    enc_ids = encode_encoder_input(enc_text, char_to_idx, PAD, UNK)
    enc_in = np.array([enc_ids], dtype=np.int32)
    
    dec_in = np.zeros((1, MAX_DECODER_LEN), dtype=np.int32)
    dec_in[0, 0] = BOS
    
    result = []
    for i in range(MAX_DECODER_LEN - 1):
        pred = model.predict({'encoder_input': enc_in, 'decoder_input': dec_in}, verbose=0)
        next_id = int(np.argmax(pred[0, i]))
        
        if next_id == EOS:
            break
        if next_id not in [PAD, BOS, EOS, UNK, SEP]:
            result.append(idx_to_char.get(next_id, ''))
        
        if i + 1 < MAX_DECODER_LEN:
            dec_in[0, i + 1] = next_id
    
    return ''.join(result)


# ==========================================================
# Load test cases saved during data prep
# ==========================================================
TEST_CASES_FILE = f"{CACHE_DIR}/test_cases_kkc{CACHE_SUFFIX}.json"
if os.path.exists(TEST_CASES_FILE):
    with open(TEST_CASES_FILE, 'r', encoding='utf-8') as f:
        test_cases = json.load(f)
    print(f"\nüìù Loaded {len(test_cases)} test cases from dataset")
else:
    # Fallback: use hardcoded test cases
    print("\n‚ö†Ô∏è No saved test cases found, using defaults")
    test_cases = [
        {'context': '‰ªäÊó•„ÅØ„Å®„Å¶„ÇÇ', 'kana': '„Ç¢„ÉÑ„Ç§', 'expected': 'Êöë„ÅÑ'},
        {'context': '„ÅäËå∂„Åå', 'kana': '„Ç¢„ÉÑ„Ç§', 'expected': 'ÁÜ±„ÅÑ'},
        {'context': 'Â∑ù„Å´', 'kana': '„Éè„Ç∑', 'expected': 'Ê©ã'},
        {'context': '„ÅîÈ£Ø„Çí', 'kana': '„Éè„Ç∑', 'expected': 'ÁÆ∏'},
    ]

# ==========================================================
# Run predictions
# ==========================================================
print("-" * 60)

correct = 0
partial = 0

for tc in test_cases:
    result = convert(tc['context'], tc['kana'])
    expected = tc['expected']
    
    exact_match = result == expected
    partial_match = expected in result or result in expected
    
    if exact_match:
        correct += 1
        status = '‚úÖ'
    elif partial_match:
        partial += 1
        status = 'üü°'
    else:
        status = '‚ùå'
    
    ctx_short = tc['context'][:15]
    print(f"  {status} {ctx_short}<SEP>{tc['kana']} ‚Üí {result} (expected: {expected})")

n = len(test_cases)
print(f"\nüìä Results:")
print(f"  Exact match: {correct}/{n} ({correct/n*100:.1f}%)")
print(f"  Partial match: {partial}/{n} ({partial/n*100:.1f}%)")
print(f"  Total useful: {correct+partial}/{n} ({(correct+partial)/n*100:.1f}%)")

if TESTING_MODE:
    print("\n‚ö†Ô∏è TESTING MODE ‚Äî results may be weak (only 10K samples).")
    print("   ‚úÖ Check: loss decreased, accuracy improved, no crashes.")
    print("   ‚Üí Set TESTING_MODE = False for real training.")

print("\n" + "="*60)
print("‚úÖ VERIFICATION COMPLETE")

In [None]:
# List saved files
print(f"\nüì¶ Files ({PLATFORM}):")
for f in sorted(os.listdir(MODEL_DIR)):
    p = f'{MODEL_DIR}/{f}'
    if os.path.isfile(p):
        s = os.path.getsize(p)
        if s > 1024*1024:
            print(f"  {f}: {s/(1024*1024):.2f}MB")
        else:
            print(f"  {f}: {s/1024:.1f}KB")