# Multi-Task GRU ‚Äî KKC + NWP

Two independent heads in a single model:
- **KKC Head**: Kana ‚Üí Kanji (char embedding ‚Üí Bi-GRU ‚Üí GRU decoder + attention)
- **NWP Head**: Next word prediction (word embedding ‚Üí Bi-GRU ‚Üí self-attention ‚Üí context GRU)

Combined loss: `1.0 √ó KKC + 0.3 √ó NWP`

Scripts in `scripts/japanese_enhancement/`

## 0. Setup

In [None]:
import os, sys, shutil

# --- Detect platform ---
if os.path.exists('/content/drive'):
    PLATFORM = 'colab'
    from google.colab import drive
    drive.mount('/content/drive')
elif os.path.exists('/kaggle/working'):
    PLATFORM = 'kaggle'
else:
    PLATFORM = 'local'

# --- Clone/refresh repo (Colab/Kaggle only) ---
REPO_URL = 'https://github.com/MinhPhuPham/Keyboard-Suggestions-ML-Colab.git'

if PLATFORM == 'colab':
    REPO_DIR = '/content/KeyboardSuggestionsML'
elif PLATFORM == 'kaggle':
    REPO_DIR = '/kaggle/working/KeyboardSuggestionsML'
else:
    REPO_DIR = None

if REPO_DIR is not None:
    # Always delete & re-clone for latest code
    if os.path.exists(REPO_DIR):
        shutil.rmtree(REPO_DIR)
        print('üóëÔ∏è Removed previous clone')
    os.system(f'git clone -q {REPO_URL} {REPO_DIR}')
    PROJECT_ROOT = REPO_DIR
    print(f'‚úì Cloned latest code to {REPO_DIR}')
else:
    # Local: notebook is in notebooks/japanese/, project root is 2 levels up
    PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '../..'))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

print(f'Platform: {PLATFORM}')
print(f'Project:  {PROJECT_ROOT}')

In [None]:
!pip install -q tensorflow==2.20.0 keras==3.13.1 datasets numpy tqdm fugashi unidic-lite matplotlib

In [None]:
# --- Environment versions (compare local vs Colab) ---
import sys
print(f'Python:     {sys.version}')
import tensorflow as tf
print(f'TensorFlow: {tf.__version__}')
import keras
print(f'Keras:      {keras.__version__}')
import numpy as np
print(f'NumPy:      {np.__version__}')
print(f'GPU:        {tf.config.list_physical_devices("GPU")}')
print(f'Platform:   {sys.platform}')

In [None]:
import tensorflow as tf
import numpy as np

gpus = tf.config.list_physical_devices('GPU')
NUM_GPUS = len(gpus) if gpus else 1
if NUM_GPUS > 1:
    strategy = tf.distribute.MirroredStrategy()
elif gpus:
    strategy = tf.distribute.OneDeviceStrategy('/gpu:0')
else:
    strategy = tf.distribute.OneDeviceStrategy('/cpu:0')

print(f'GPUs: {NUM_GPUS}, Strategy: {strategy.__class__.__name__}')

if gpus:
    tf.keras.mixed_precision.set_global_policy('mixed_float16')
    print(f'Mixed precision: {tf.keras.mixed_precision.global_policy().name}')

## 1. Configuration

Override any config values here before importing modules.

In [None]:
from scripts.japanese_enhancement import config

# ============================================
# ‚ö†Ô∏è OVERRIDE CONFIG HERE
# ============================================
config.TESTING_MODE = True      # True = 100K, False = 8M
config.FORCE_REBUILD_CACHE = False
config.BATCH_SIZE = 512 * NUM_GPUS

if config.TESTING_MODE:
    config.MAX_SAMPLES = 100_000
    config.MAX_NWP_PAIRS = 500_000
    config.NUM_EPOCHS = 10
    config.CACHE_SUFFIX = '_test'
else:
    config.MAX_SAMPLES = 8_000_000
    config.MAX_NWP_PAIRS = 8_000_000
    config.NUM_EPOCHS = 10
    config.CACHE_SUFFIX = ''

config.ensure_dirs()
config.print_config()

## 2. Load or Build Cache

Loads zenz dataset once, builds both KKC and NWP caches.
Uses memory-mapped .npy files for near-zero RAM.

In [None]:
from scripts.japanese_enhancement import data_loader
import gc

cache_paths = config.get_cache_paths(config.CACHE_DIR, config.CACHE_SUFFIX)
kkc_ready, nwp_ready = data_loader.check_cache(cache_paths)

if (kkc_ready and nwp_ready) and not config.FORCE_REBUILD_CACHE:
    print('‚úì All caches found, loading...')
else:
    print('üî® Building caches from scratch...')
    training_data = data_loader.load_raw_dataset()
    
    if not kkc_ready or config.FORCE_REBUILD_CACHE:
        data_loader.build_kkc_cache(training_data, cache_paths)
    
    if not nwp_ready or config.FORCE_REBUILD_CACHE:
        data_loader.build_nwp_cache(training_data, cache_paths)
    
    del training_data
    gc.collect()

char_to_idx, idx_to_char, enc_mmap, dec_in_mmap, dec_tgt_mmap = \
    data_loader.load_kkc_cache(cache_paths)

word_to_idx, idx_to_word, nwp_x_mmap, nwp_y_mmap = \
    data_loader.load_nwp_cache(cache_paths)

char_vocab_size = len(char_to_idx)
word_vocab_size = len(word_to_idx)
print(f'\nüìä Char vocab: {char_vocab_size:,}, Word vocab: {word_vocab_size:,}')
print(f'   KKC: {len(enc_mmap):,} samples')
print(f'   NWP: {len(nwp_x_mmap):,} pairs')

## 3. Create Datasets

In [None]:
from scripts.japanese_enhancement.training import create_datasets

kkc_data = (enc_mmap, dec_in_mmap, dec_tgt_mmap)
nwp_data = (nwp_x_mmap, nwp_y_mmap, word_to_idx)

datasets, info = create_datasets(kkc_data, nwp_data, config.BATCH_SIZE)
print('‚úì Datasets ready')

## 4. Build Model

3+input model with independent paths:
- KKC: `encoder_input` + `decoder_input` ‚Üí char embedding ‚Üí Bi-GRU ‚Üí decoder + attention
- NWP: `nwp_input` ‚Üí word embedding ‚Üí Bi-GRU ‚Üí self-attention ‚Üí context GRU

In [None]:
from scripts.japanese_enhancement.model import build_multitask_model

model = build_multitask_model(char_vocab_size, word_vocab_size, strategy)
model.summary()

params = model.count_params()
print(f'\nüìä Parameters: {params:,}')
print(f'   FP32: ~{params * 4 / 1024 / 1024:.1f} MB')
print(f'   FP16: ~{params * 2 / 1024 / 1024:.1f} MB')

## 5. Train

In [None]:
from scripts.japanese_enhancement.training import train_multitask

with strategy.scope():
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=config.LEARNING_RATE, clipnorm=1.0
    )
    model.optimizer = optimizer
    model.compile()

history = train_multitask(model, datasets, info)

## 5.1 Training Curves

In [None]:
from scripts.japanese_enhancement.plotting import plot_training_history
plot_training_history(history)

## 6. Save & Export

In [None]:
from scripts.japanese_enhancement.export import save_model, export_tflite, list_saved_files

save_model(model, char_to_idx, word_to_idx)
export_tflite(model)
list_saved_files()

## 7. Verification

Test both KKC and NWP heads with real test cases from training data.

In [None]:
from scripts.japanese_enhancement.verify import verify_all

verify_all(
    model, char_to_idx, idx_to_char,
    word_to_idx, idx_to_word, cache_paths
)