# üß† Multi-Task GRU ‚Äî Shared Encoder (v2)

**Kana‚ÜíKanji Conversion (KKC) + Next Word Prediction (NWP)**

## Architecture
Both heads share one character-level encoder:
```
Input (char IDs) ‚Üí Shared BiGRU Encoder ‚Üí encoder_output
                      ‚Üì                         ‚Üì
               KKC Decoder (seq2seq)    NWP Head (attention+GRU)
                      ‚Üì                         ‚Üì
               kanji output              next_word prediction
```

- **KKC input**: hiragana chars (e.g., "„Åç„Çá„ÅÜ„ÅØ„Å¶„Çì„Åç„Åå„ÅÑ„ÅÑ")
- **NWP input**: context chars with `<SEP>` markers (e.g., "‰ªäÊó•<SEP>„ÅØ<SEP>Â§©Ê∞ó<SEP>„Åå")
- Both use same encoder ‚Äî encoder learns from BOTH tasks!

## Training: Two Forward Passes
Each step passes KKC data then NWP data through the **same encoder**.
The shared encoder gets gradients from both tasks.

## How to Run
1. Set `TESTING_MODE = True` for quick validation (100K samples, 10 epochs)
2. Set `TESTING_MODE = False` for full production training

## Platform Support
- **Colab**: Mounts Google Drive, clones repo
- **Kaggle**: Clones repo, downloads dataset via `gdown`


## 1. Setup


In [None]:
# Pin versions for reproducibility
!pip install tensorflow==2.20.0 keras==3.13.1 fugashi[unidic-lite] -q
!pip install tqdm gdown -q

import os, sys, gc
import numpy as np

# ===========================================================
# PLATFORM DETECTION
# ===========================================================
if os.path.exists('/kaggle/working'):
    PLATFORM = 'kaggle'
elif os.path.exists('/content'):
    PLATFORM = 'colab'
else:
    PLATFORM = 'local'

print(f"üñ•Ô∏è Platform: {PLATFORM}")

# ===========================================================
# MOUNT / CLONE
# ===========================================================
if PLATFORM == 'colab':
    from google.colab import drive
    drive.mount('/content/drive')
    REPO_DIR = '/content/KeyboardSuggestionsML'
elif PLATFORM == 'kaggle':
    REPO_DIR = '/kaggle/working/KeyboardSuggestionsML'
else:
    REPO_DIR = os.path.expanduser('~/KeyboardSuggestionsML')

# Clone/pull repo
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/MinhPhuPham/Keyboard-Suggestions-ML-Colab.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

sys.path.insert(0, REPO_DIR)

import tensorflow as tf
print(f"TF: {tf.__version__}")
print(f"GPU: {tf.config.list_physical_devices('GPU')}")


## 2. Configuration


In [None]:
from scripts.japanese_enhancement import config

# ‚ö†Ô∏è TESTING MODE: Set False for full production training
config.TESTING_MODE = True

if config.TESTING_MODE:
    config.MAX_SAMPLES = 2_000_000
    config.MAX_NWP_PAIRS = 5_000_000
    config.NUM_EPOCHS = 4
    config.FORCE_REBUILD_CACHE = True
    config.CACHE_SUFFIX = '_test'
    print(f"‚ö†Ô∏è TESTING MODE: ${config.MAX_SAMPLES} samples, ${config.NUM_EPOCHS} epochs")
else:
    config.MAX_SAMPLES = 6_000_000
    config.MAX_NWP_PAIRS = 10_000_000
    config.NUM_EPOCHS = 10
    config.CACHE_SUFFIX = ''
    print(f"üöÄ FULL TRAINING: ${config.MAX_SAMPLES} samples, ${config.NUM_EPOCHS} epochs")

# Single GPU ‚Äî batch size tuned for T4 (16GB)
# 2 forward passes per step ‚Üí keep batch small to avoid OOM
config.BATCH_SIZE = 512

# Override paths for v2
config.MODEL_DIR = f'{config.DRIVE_DIR}/models/multitask_v2'
config.CACHE_DIR = f'{config.DRIVE_DIR}/cache/multitask_v2'
config.ensure_dirs()

# ===========================================================
# DOWNLOAD DATASET (if not exists)
# ===========================================================
dataset_path = f'{config.DATASET_DIR}/ime_dataset_10m.jsonl'
if not os.path.exists(dataset_path):
    print(f"üì• Downloading dataset to {dataset_path}...")
    os.makedirs(config.DATASET_DIR, exist_ok=True)
    !gdown "1b5YgqVUEU2HGlkPBcSUIjL1XyJ7n8Cgg" -O {dataset_path}
    print(f"‚úÖ Dataset downloaded! Size: {os.path.getsize(dataset_path) / 1e6:.1f} MB")
else:
    print(f"‚úÖ Dataset exists: {dataset_path} ({os.path.getsize(dataset_path) / 1e6:.1f} MB)")

config.print_config()

# Cache paths
cache_paths = config.get_cache_paths(config.CACHE_DIR, config.CACHE_SUFFIX)


## 3. GPU Check


In [None]:
# Verify single GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f"‚úÖ GPU: {gpus[0].name}")
    # Enable memory growth to avoid pre-allocating all VRAM
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
        print("   Memory growth enabled")
    except RuntimeError:
        pass  # Already set
else:
    print("‚ö†Ô∏è No GPU ‚Äî training will be slow")


## 4. Load & Cache Data


In [None]:
# Force reload data_loader & tokenizer to pick up latest code after git pull
# NOTE: Do NOT reload config ‚Äî that would reset user overrides from Cell 2!
import importlib
import scripts.japanese_enhancement.tokenizer as _tok_mod
import scripts.japanese_enhancement.data_loader as _dl_mod
importlib.reload(_tok_mod)
importlib.reload(_dl_mod)

from scripts.japanese_enhancement.data_loader import (
    check_cache, load_raw_dataset,
    build_kkc_cache, build_nwp_char_cache,
    load_kkc_cache, load_nwp_char_cache,
)

# Check existing cache
kkc_ready, _ = check_cache(cache_paths)
nwp_char_ready = os.path.exists(cache_paths.get('nwp_char_x', ''))

force = config.FORCE_REBUILD_CACHE
need_build = not kkc_ready or not nwp_char_ready or force
print(f'FORCE_REBUILD={force}, kkc_ready={kkc_ready}, nwp_ready={nwp_char_ready}, need_build={need_build}')

if need_build:
    if force:
        print('\U0001f504 Force rebuild enabled \u2014 rebuilding all caches...')
    print('\n\U0001f4e5 Loading raw data...')
    training_data = load_raw_dataset(config.MAX_SAMPLES)

    if not kkc_ready or force:
        c2i, _ = build_kkc_cache(training_data, cache_paths)
    else:
        c2i, _, _, _, _ = load_kkc_cache(cache_paths)

    if not nwp_char_ready or force:
        build_nwp_char_cache(training_data, cache_paths, c2i)

    del training_data
    gc.collect()
    print('\n\u2705 Cache ready!')
else:
    print('\u2705 Cache already exists, skipping build')


## 5. Load Cached Data


In [None]:
# Load KKC cache
char_to_idx, idx_to_char, enc_mmap, dec_in_mmap, dec_tgt_mmap = \
    load_kkc_cache(cache_paths)
kkc_data = (enc_mmap, dec_in_mmap, dec_tgt_mmap)

# Load NWP char cache (shared encoder format)
word_to_idx, idx_to_word, nwp_char_x_mmap, nwp_y_mmap = \
    load_nwp_char_cache(cache_paths)
nwp_char_data = (nwp_char_x_mmap, nwp_y_mmap, word_to_idx)

actual_char_vocab = len(char_to_idx)
actual_word_vocab = len(word_to_idx)
print(f"\nüìä Char vocab: {actual_char_vocab:,}")
print(f"üìä Word vocab: {actual_word_vocab:,}")


## 5b. Preview Test Cases


In [None]:
# ============================================================
# üìã Log 100 Training Items (for testing)
# Copy these to use as test cases in test_prediction.py
# ============================================================

import json as _json

# Load test cases from cache
_kkc_test_path = cache_paths.get('kkc_test_cases', '')
_nwp_test_path = cache_paths.get('nwp_test_cases', '')

# --- KKC Test Cases ---
print('=' * 60)
print('üìã KKC TEST CASES (Kana ‚Üí Kanji)')
print('=' * 60)
if os.path.exists(_kkc_test_path):
    with open(_kkc_test_path, 'r', encoding='utf-8') as f:
        _kkc_tests = _json.load(f)
    for j, t in enumerate(_kkc_tests[:50]):
        ctx = t.get('context', '')
        ctx_str = f' [ctx: {ctx[:15]}]' if ctx else ''
        print(f"  {j+1:2d}. {t['kana']} ‚Üí {t['expected']}{ctx_str}")
    print(f'  Total: {len(_kkc_tests)} test cases')
else:
    print('  ‚ö† No KKC test cases found')

# --- NWP Test Cases ---
print()
print('=' * 60)
print('üìã NWP TEST CASES (Next Word Prediction)')
print('=' * 60)
if os.path.exists(_nwp_test_path):
    with open(_nwp_test_path, 'r', encoding='utf-8') as f:
        _nwp_tests = _json.load(f)
    for j, t in enumerate(_nwp_tests[:50]):
        ctx = ' '.join(t['context'])
        print(f"  {j+1:2d}. {ctx} ‚Üí {t['expected']}  [{t.get('sentence', '')[:20]}]")
    print(f'  Total: {len(_nwp_tests)} test cases')
else:
    print('  ‚ö† No NWP test cases found')

# --- Sample Raw Data (encoder/decoder) ---
print()
print('=' * 60)
print('üìã SAMPLE ENCODED DATA (first 10)')
print('=' * 60)
idx_to_char_local = {v: k for k, v in char_to_idx.items()}
for j in range(min(10, len(enc_mmap))):
    enc_chars = [idx_to_char_local.get(int(c), '?') for c in enc_mmap[j] if c != 0]
    dec_chars = [idx_to_char_local.get(int(c), '?') for c in dec_tgt_mmap[j] if c != 0]
    print(f"  {j+1:2d}. {''.join(enc_chars)[:30]} ‚Üí {''.join(dec_chars)[:20]}")


## 6. Create Datasets


In [None]:
import importlib
import scripts.japanese_enhancement.training as _train_mod
importlib.reload(_train_mod)
from scripts.japanese_enhancement.training import create_shared_datasets

datasets, info = create_shared_datasets(kkc_data, nwp_char_data, config.BATCH_SIZE)


## 7. Build Shared Encoder Model


In [None]:
import importlib
import scripts.japanese_enhancement.model as _model_mod
importlib.reload(_model_mod)
from scripts.japanese_enhancement.model import build_shared_multitask_model

model = build_shared_multitask_model(
    actual_char_vocab, actual_word_vocab, strategy=None
)
model.summary()

# Compile
model.compile(optimizer=tf.keras.optimizers.Adam(
    learning_rate=config.LEARNING_RATE, clipnorm=1.0
))

print(f"\n\u2705 Shared encoder model ready")
print(f"   Inputs:  {len(model.inputs)} (encoder_input + decoder_input)")
print(f"   Outputs: {len(model.outputs)} (kkc_output + nwp_output)")
print(f"   Params:  {model.count_params():,}")


## 8. Train


In [None]:
import importlib
import scripts.japanese_enhancement.training as _train_mod
importlib.reload(_train_mod)
from scripts.japanese_enhancement.training import train_shared_multitask

history = train_shared_multitask(model, datasets, info)


## 9. Training Plots


In [None]:
from scripts.japanese_enhancement.plotting import plot_training_history

plot_training_history(history)


## 10. Save & Export


In [None]:
from scripts.japanese_enhancement.export import save_model, export_tflite, list_saved_files

save_model(model, char_to_idx, word_to_idx)
export_tflite(model)
list_saved_files()
