# üß† Multi-Task GRU ‚Äî Shared Encoder (v2)

**Kana‚ÜíKanji Conversion (KKC) + Next Word Prediction (NWP)**

## Architecture
Both heads share one character-level encoder:
```
Input (char IDs) ‚Üí Shared BiGRU Encoder ‚Üí encoder_output
                      ‚Üì                         ‚Üì
               KKC Decoder (seq2seq)    NWP Head (attention+GRU)
                      ‚Üì                         ‚Üì
               kanji output              next_word prediction
```

- **KKC input**: hiragana chars (e.g., "„Åç„Çá„ÅÜ„ÅØ„Å¶„Çì„Åç„Åå„ÅÑ„ÅÑ")
- **NWP input**: context chars with `<SEP>` markers (e.g., "‰ªäÊó•<SEP>„ÅØ<SEP>Â§©Ê∞ó<SEP>„Åå")
- Both use same encoder ‚Äî encoder learns from BOTH tasks!

## Training: Two Forward Passes
Each step passes KKC data then NWP data through the **same encoder**.
The shared encoder gets gradients from both tasks.

## How to Run
1. Set `TESTING_MODE = True` for quick validation (100K samples, 10 epochs)
2. Set `TESTING_MODE = False` for full production training


## 1. Setup


In [None]:
# Pin versions for reproducibility
!pip install tensorflow==2.20.0 keras==3.13.1 fugashi[unidic-lite] -q
!pip install tqdm -q

import os, sys, gc
import numpy as np

# Clone/pull repo
REPO_DIR = '/content/KeyboardSuggestionsML'
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/MinhPhuPham/Keyboard-Suggestions-ML-Colab.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

sys.path.insert(0, REPO_DIR)

import tensorflow as tf
print(f"TF: {tf.__version__}")
print(f"GPU: {tf.config.list_physical_devices('GPU')}")


## 2. Configuration


In [None]:
from scripts.japanese_enhancement import config

# ‚ö†Ô∏è TESTING MODE: Set False for full production training
config.TESTING_MODE = True

if config.TESTING_MODE:
    config.MAX_SAMPLES = 100_000
    config.MAX_NWP_PAIRS = 500_000
    config.NUM_EPOCHS = 10
    config.CACHE_SUFFIX = '_test'
    print("‚ö†Ô∏è TESTING MODE: 100K samples, 10 epochs")
else:
    config.MAX_SAMPLES = 8_000_000
    config.MAX_NWP_PAIRS = 8_000_000
    config.NUM_EPOCHS = 10
    config.CACHE_SUFFIX = ''
    print("üöÄ FULL TRAINING: 8M samples, 10 epochs")

# GPU scaling
NUM_GPUS = max(1, len(tf.config.list_physical_devices('GPU')))
# Reduced for shared encoder (2 forward passes per step = 2x VRAM)
config.BATCH_SIZE = 256 * NUM_GPUS

# Override paths for v2
config.MODEL_DIR = f'{config.DRIVE_DIR}/models/multitask_v2'
config.CACHE_DIR = f'{config.DRIVE_DIR}/cache/multitask_v2'
config.ensure_dirs()
config.print_config()

# Cache paths
cache_paths = config.get_cache_paths(config.CACHE_DIR, config.CACHE_SUFFIX)


## 3. GPU Strategy


In [None]:
# Detect GPU strategy
if NUM_GPUS > 1:
    strategy = tf.distribute.MirroredStrategy()
    print(f"‚úÖ MirroredStrategy: {NUM_GPUS} GPUs")
elif NUM_GPUS == 1:
    strategy = None  # Single GPU, no strategy needed
    print("‚úÖ Single GPU mode")
else:
    strategy = None
    print("‚ö†Ô∏è CPU mode (no GPU)")


## 4. Load & Cache Data


In [None]:
from scripts.japanese_enhancement.data_loader import (
    check_cache, load_raw_dataset,
    build_kkc_cache, build_nwp_char_cache,
    load_kkc_cache, load_nwp_char_cache,
)

# Check existing cache
kkc_ready, _ = check_cache(cache_paths)
nwp_char_ready = os.path.exists(cache_paths.get('nwp_char_x', ''))

if not kkc_ready or not nwp_char_ready or config.FORCE_REBUILD_CACHE:
    print("\nüì• Loading raw data...")
    training_data = load_raw_dataset(config.MAX_SAMPLES)

    if not kkc_ready:
        c2i, _ = build_kkc_cache(training_data, cache_paths)
    else:
        c2i, _, _, _, _ = load_kkc_cache(cache_paths)

    if not nwp_char_ready:
        build_nwp_char_cache(training_data, cache_paths, c2i)

    del training_data
    gc.collect()
    print("\n‚úÖ Cache ready!")
else:
    print("‚úÖ Cache already exists, skipping build")


## 5. Load Cached Data


In [None]:
# Load KKC cache
char_to_idx, idx_to_char, enc_mmap, dec_in_mmap, dec_tgt_mmap = \
    load_kkc_cache(cache_paths)
kkc_data = (enc_mmap, dec_in_mmap, dec_tgt_mmap)

# Load NWP char cache (shared encoder format)
word_to_idx, idx_to_word, nwp_char_x_mmap, nwp_y_mmap = \
    load_nwp_char_cache(cache_paths)
nwp_char_data = (nwp_char_x_mmap, nwp_y_mmap, word_to_idx)

actual_char_vocab = len(char_to_idx)
actual_word_vocab = len(word_to_idx)
print(f"\nüìä Char vocab: {actual_char_vocab:,}")
print(f"üìä Word vocab: {actual_word_vocab:,}")


## 6. Create Datasets


In [None]:
from scripts.japanese_enhancement.training import create_shared_datasets

datasets, info = create_shared_datasets(kkc_data, nwp_char_data, config.BATCH_SIZE)


## 7. Build Shared Encoder Model


In [None]:
from scripts.japanese_enhancement.model import build_shared_multitask_model

model = build_shared_multitask_model(
    actual_char_vocab, actual_word_vocab, strategy=strategy
)
model.summary()

# Compile (optimizer only ‚Äî custom training loop manages losses)
if strategy:
    with strategy.scope():
        model.compile(optimizer=tf.keras.optimizers.Adam(
            learning_rate=config.LEARNING_RATE, clipnorm=1.0
        ))
else:
    model.compile(optimizer=tf.keras.optimizers.Adam(
        learning_rate=config.LEARNING_RATE, clipnorm=1.0
    ))

print(f"\n‚úÖ Shared encoder model ready")
print(f"   Inputs:  {len(model.inputs)} (encoder_input + decoder_input)")
print(f"   Outputs: {len(model.outputs)} (kkc_output + nwp_output)")
print(f"   Params:  {model.count_params():,}")


## 8. Train


In [None]:
from scripts.japanese_enhancement.training import train_shared_multitask

history = train_shared_multitask(model, datasets, info)


## 9. Training Plots


In [None]:
from scripts.japanese_enhancement.plotting import plot_training_history

plot_training_history(history)


## 10. Save & Export


In [None]:
from scripts.japanese_enhancement.export import save_model, export_tflite, list_saved_files

save_model(model, char_to_idx, word_to_idx)
export_tflite(model)
list_saved_files()
