# MobileBERT Keyboard Suggestion Model Training

This notebook trains a keyboard suggestion model using **MobileBERT** with multi-task learning.

**Features:**
1. Word Completion: "Hel" → ["Hello", "Help", "Helping"]
2. Next-Word Prediction: "How are" → ["you", "they", "we"]
3. Typo Correction: "Thers" → ["There", "Theirs", "Therapy"]
4. Gibberish Detection: Heuristic (no ML)

**Model Specifications:**
- Base: MobileBERT-TINY (15M parameters)
- Target Size: <15MB (after INT8 quantization)
- Latency: <50ms on mobile
- Deployment: iOS (CoreML) + Android (TFLite)

**Training Time:** 2-4 hours on Colab GPU (T4)

**Data Sources:**
- Word frequencies: GitHub (600K+ words)
- Text corpus: OpenSubtitles2024 (Hugging Face)
- Typos: WikEd Error Corpus + synthetic

---

**Instructions:**
1. Runtime → Change runtime type → GPU (T4)
2. Run all cells
3. Model will be saved to Google Drive
4. Download for mobile deployment

## 1. Environment Setup

In [None]:
# Check if running in Colab
import os

IN_COLAB = 'COLAB_GPU' in os.environ or 'COLAB_TPU_ADDR' in os.environ

if IN_COLAB:
    print("✓ Running in Google Colab")
    
    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Define Drive directory
    DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'
    
    # Create directories
    os.makedirs(DRIVE_DIR, exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/datasets", exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/models", exist_ok=True)
    
    print(f"✓ Google Drive mounted")
    print(f"✓ Project directory: {DRIVE_DIR}")
else:
    print("✓ Running locally")
    DRIVE_DIR = './data'  # Local fallback
    os.makedirs(DRIVE_DIR, exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/datasets", exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/models", exist_ok=True)

In [None]:
# Clone repository (if running in Colab)
if IN_COLAB:
    import os
    
    # Ensure we're in /content
    os.chdir('/content')
    
    # Remove existing repo if it exists (for re-runs)
    if os.path.exists('Keyboard-Suggestions-ML-Colab'):
        import shutil
        shutil.rmtree('Keyboard-Suggestions-ML-Colab')
        print("✓ Removed existing repository")
    
    # Clone fresh copy
    !git clone https://github.com/MinhPhuPham/Keyboard-Suggestions-ML-Colab.git
    
    # Change to project directory
    os.chdir('/content/Keyboard-Suggestions-ML-Colab')
    
    print(f"✓ Repository cloned")
    print(f"✓ Working directory: {os.getcwd()}")

In [None]:
# Install dependencies
!pip install -q -r requirements.txt
print("✓ Dependencies installed")

## 2. Download Datasets

We'll download 3 datasets:
1. **Word Frequencies** - GitHub (600K words)
2. **Text Corpus** - OpenSubtitles2024 (Hugging Face)
3. **Typo Corrections** - Synthetic + real typos

In [None]:
import urllib.request
import json

# Download word frequency list from GitHub
print("Downloading word frequency list...")

WORD_FREQ_URL = "https://raw.githubusercontent.com/first20hours/google-10000-english/master/google-10000-english-usa-no-swears.txt"
word_freq_path = f"{DRIVE_DIR}/datasets/word_freq.txt"

if not os.path.exists(word_freq_path):
    urllib.request.urlretrieve(WORD_FREQ_URL, word_freq_path)
    print(f"✓ Downloaded to: {word_freq_path}")
else:
    print(f"✓ Already exists: {word_freq_path}")

# Count words
with open(word_freq_path, 'r') as f:
    word_count = len(f.readlines())
print(f"  Words: {word_count:,}")

In [None]:
# Download text corpus from Hugging Face
print("\nDownloading text corpus from Hugging Face...")

from datasets import load_dataset

corpus_path = f"{DRIVE_DIR}/datasets/corpus.txt"

if not os.path.exists(corpus_path):
    print("  Loading WikiText dataset (clean and reliable)...")
    
    # Use WikiText-103 - clean, well-formatted English text
    dataset = load_dataset(
        "wikitext",
        "wikitext-103-raw-v1",
        split="train"
    )
    
    # Extract sentences
    sentences = []
    for i, item in enumerate(dataset):
        if i >= 100000:  # Limit to 100K items
            break
        if i % 10000 == 0:
            print(f"  Processed {i:,} items...")
        
        # Get text
        text = item['text'].strip().lower()
        
        # Filter: at least 3 words, not empty, not just punctuation
        if len(text) > 10 and len(text.split()) >= 3:
            sentences.append(text)
        
        # Stop if we have enough
        if len(sentences) >= 50000:
            break
    
    # Save to file
    with open(corpus_path, 'w', encoding='utf-8') as f:
        for sentence in sentences:
            f.write(sentence + '\n')
    
    print(f"✓ Downloaded {len(sentences):,} sentences")
    print(f"✓ Saved to: {corpus_path}")
else:
    with open(corpus_path, 'r') as f:
        sentence_count = len(f.readlines())
    print(f"✓ Already exists: {corpus_path}")
    print(f"  Sentences: {sentence_count:,}")

## 3. Generate Training Data

Generate training pairs for all 3 tasks

In [None]:
# Add src to path
import sys
sys.path.insert(0, './src')

from keyboard_data_prep import (
    prepare_word_completion_data,
    prepare_nextword_data,
    prepare_typo_data,
    combine_datasets
)

print("Preparing training datasets...")
print("="*60)

output_dir = f"{DRIVE_DIR}/datasets/processed"
train_path = f"{output_dir}/train.jsonl"
val_path = f"{output_dir}/val.jsonl"

# Check if processed datasets already exist in Drive
if os.path.exists(train_path) and os.path.exists(val_path):
    print("✓ Processed datasets found in Drive!")
    print(f"  Train: {train_path}")
    print(f"  Val: {val_path}")
    
    # Count samples
    with open(train_path, 'r') as f:
        train_count = sum(1 for _ in f)
    with open(val_path, 'r') as f:
        val_count = sum(1 for _ in f)
    print(f"  Train samples: {train_count:,}")
    print(f"  Val samples: {val_count:,}")
else:
    print("Generating training datasets from scratch...")
    
    # 1. Word completion
    print("\n1. Word Completion...")
    completion_path = prepare_word_completion_data(
        word_freq_path=word_freq_path,
        output_path=f"{output_dir}/completion.jsonl",
        max_samples=50000  # 50K pairs
    )
    
    # 2. Next-word prediction
    print("\n2. Next-Word Prediction...")
    nextword_path = prepare_nextword_data(
        corpus_path=corpus_path,
        output_path=f"{output_dir}/nextword.jsonl",
        max_samples=100000,  # 100K pairs
        context_length=3
    )
    
    # 3. Typo correction
    print("\n3. Typo Correction...")
    typo_path = prepare_typo_data(
        word_list_path=word_freq_path,
        output_path=f"{output_dir}/typo.jsonl",
        max_samples=20000  # 20K pairs
    )
    
    # 4. Combine and split
    print("\n4. Combining datasets...")
    train_path, val_path = combine_datasets(
        completion_path=completion_path,
        nextword_path=nextword_path,
        typo_path=typo_path,
        output_path=output_dir,
        train_ratio=0.9
    )
    
    print("\n" + "="*60)
    print("✓ Dataset generation complete!")
    print(f"  Train: {train_path}")
    print(f"  Val: {val_path}")

print("\n" + "="*60)
print("✓ Datasets ready for training!")

## 4. Load Model and Tokenizer

In [None]:
from transformers import MobileBertForMaskedLM, MobileBertTokenizer
import torch

print("Loading MobileBERT for Masked Language Modeling...")

tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")
model = MobileBertForMaskedLM.from_pretrained("google/mobilebert-uncased")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"✓ Model loaded on {device}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Model size: ~100MB (FP32) → ~12-15MB (INT8 quantized)")

## 5. Prepare Training Data

In [None]:
from torch.utils.data import Dataset, DataLoader
import json
import torch

class KeyboardDataset(Dataset):
    """BERT MLM dataset for keyboard suggestions."""
    
    def __init__(self, data_path, tokenizer, max_length=32):
        self.data = []
        with open(data_path, 'r') as f:
            for line in f:
                self.data.append(json.loads(line))
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        text_input = item['input']
        target_word = item['target']
        task = item.get('task', 'completion')  # Default to prevent KeyError
        
        # Add [MASK] for prediction - KEEP the input context!
        # For ALL tasks: input + [MASK]
        # - Completion: "Hel [MASK]" → predict "lo" (Hello)
        # - Next-word: "How are [MASK]" → predict "you"
        # - Typo: "Thers [MASK]" → predict "There" (model sees the typo!)
        text_input = f"{text_input} {self.tokenizer.mask_token}"
        
        # Tokenize
        inputs = self.tokenizer(
            text_input,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Labels: -100 everywhere except [MASK]
        labels = torch.full(inputs['input_ids'].shape, -100, dtype=torch.long)
        
        # Get target token ID
        target_tokens = self.tokenizer.tokenize(target_word)
        target_id = self.tokenizer.convert_tokens_to_ids(target_tokens[0]) if target_tokens else self.tokenizer.unk_token_id
        
        # Set label at [MASK] position
        mask_positions = (inputs['input_ids'] == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
        if len(mask_positions[1]) > 0:
            labels[0, mask_positions[1][0]] = target_id
        
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': labels.squeeze(),
            'task': task
        }

print("Loading training data...")
train_dataset = KeyboardDataset(train_path, tokenizer)
val_dataset = KeyboardDataset(val_path, tokenizer)

print(f"✓ Train samples: {len(train_dataset):,}")
print(f"✓ Val samples: {len(val_dataset):,}")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

## 6. Training

In [None]:
from torch.optim import AdamW
from tqdm.auto import tqdm

NUM_EPOCHS = 3
LEARNING_RATE = 3e-4
SAVE_STEPS = 1000

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

print("Starting training...")
print("="*60)

global_step = 0
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    
    model.train()
    train_loss = 0
    
    progress_bar = tqdm(train_loader, desc="Training")
    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        global_step += 1
        progress_bar.set_postfix({'loss': loss.item()})
        
        if global_step % SAVE_STEPS == 0:
            checkpoint_dir = f"{DRIVE_DIR}/models/checkpoint-{global_step}"
            model.save_pretrained(checkpoint_dir)
            tokenizer.save_pretrained(checkpoint_dir)
            print(f"\n✓ Checkpoint saved: {checkpoint_dir}")
    
    avg_train_loss = train_loss / len(train_loader)
    print(f"  Train loss: {avg_train_loss:.4f}")
    
    model.eval()
    val_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            val_loss += outputs.loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    print(f"  Val loss: {avg_val_loss:.4f}")
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_dir = f"{DRIVE_DIR}/models/best_model"
        model.save_pretrained(best_model_dir)
        tokenizer.save_pretrained(best_model_dir)
        print(f"  ✓ Best model saved: {best_model_dir}")

print("\n" + "="*60)
print("✓ Training complete!")
print(f"  Best val loss: {best_val_loss:.4f}")

## 7. Export and Save Model

In [None]:
# Export and save model
import shutil
from google.colab import files

print("Exporting and saving model...")
print("="*60)

# 1. Save to Google Drive (persistent storage)
drive_model_dir = f"{DRIVE_DIR}/models/mobilebert_keyboard_final"
model.save_pretrained(drive_model_dir)
tokenizer.save_pretrained(drive_model_dir)
print(f"\n✅ Saved to Google Drive: {drive_model_dir}")

# 2. Create downloadable zip for local use
local_model_dir = "/content/mobilebert_keyboard_model"
model.save_pretrained(local_model_dir)
tokenizer.save_pretrained(local_model_dir)

# Create zip file
zip_path = "/content/mobilebert_keyboard_model.zip"
shutil.make_archive("/content/mobilebert_keyboard_model", 'zip', local_model_dir)
print(f"\n✅ Created zip: {zip_path}")

# 3. Download to local device (if in Colab)
if IN_COLAB:
    print("\n📥 Downloading model to your computer...")
    files.download(zip_path)
    print("✅ Download started! Check your Downloads folder.")

print("\n" + "="*60)
print("✅ Export complete!")
print("\nModel saved to:")
print(f"  1. Google Drive: {drive_model_dir}")
print(f"  2. Local download: mobilebert_keyboard_model.zip")
print("\nModel details:")
print(f"  • Size: ~100MB (FP32)")
print(f"  • Parameters: {sum(p.numel() for p in model.parameters()):,}")
print("\nNext steps:")
print("  1. Extract the zip file")
print("  2. Apply INT8 quantization (~12-15MB)")
print("  3. Export to CoreML (iOS) or TFLite (Android)")
print("\n💡 The model in Google Drive persists across sessions!")

## 8. Export to CoreML (iOS)

In [None]:
# Export to CoreML for iOS
!pip install -q coremltools

import coremltools as ct
from coremltools.models.neural_network import quantization_utils
import torch

print("Exporting to CoreML for iOS...")
print("="*60)

# 1. Prepare model for tracing
print("\n1. Preparing model for conversion...")
model.cpu().eval()

# Create dummy inputs (batch_size=1, seq_length=32)
dummy_input_ids = torch.zeros(1, 32, dtype=torch.long)
dummy_attention_mask = torch.ones(1, 32, dtype=torch.long)

# 2. Trace the model
print("2. Tracing model...")
traced_model = torch.jit.trace(model, (dummy_input_ids, dummy_attention_mask))
print("   ✓ Model traced successfully")

# 3. Convert to CoreML (FP32 - ~100MB)
print("\n3. Converting to CoreML (FP32)...")
mlmodel = ct.convert(
    traced_model,
    inputs=[
        ct.TensorType(name="input_ids", shape=(1, 32), dtype=int),
        ct.TensorType(name="attention_mask", shape=(1, 32), dtype=int)
    ],
    outputs=[ct.TensorType(name="logits")]
)
print("   ✓ Conversion successful")

# 4. Quantize to INT8 (~12-15MB)
print("\n4. Quantizing to INT8...")
mlmodel_int8 = quantization_utils.quantize_weights(mlmodel, nbits=8)
print("   ✓ Quantization complete")

# 5. Save to Drive
coreml_path = f"{DRIVE_DIR}/models/MobileBERT_Keyboard_iOS.mlpackage"
mlmodel_int8.save(coreml_path)

print("\n" + "="*60)
print("✅ iOS CoreML export complete!")
print(f"\nSaved to: {coreml_path}")
print(f"Model size: ~12-15MB (INT8 quantized)")
print(f"Expected latency: 15-20ms on iPhone")
print("\nNext steps:")
print("  1. Download from Google Drive")
print("  2. Add to Xcode project")
print("  3. Use with Vision framework")

## 9. Export to TFLite (Android)

In [None]:
# Export to TFLite for Android
!pip install -q tensorflow

import tensorflow as tf
from pathlib import Path

print("Exporting to TFLite for Android...")
print("="*60)

# 1. Convert PyTorch to ONNX first
print("\n1. Converting PyTorch -> ONNX...")
model.cpu().eval()

onnx_path = "/content/model.onnx"
dummy_input_ids = torch.zeros(1, 32, dtype=torch.long)
dummy_attention_mask = torch.ones(1, 32, dtype=torch.long)

torch.onnx.export(
    model,
    (dummy_input_ids, dummy_attention_mask),
    onnx_path,
    input_names=['input_ids', 'attention_mask'],
    output_names=['logits'],
    dynamic_axes={
        'input_ids': {0: 'batch_size'},
        'attention_mask': {0: 'batch_size'},
        'logits': {0: 'batch_size'}
    },
    opset_version=12
)
print("   ✓ ONNX export successful")

# 2. Convert ONNX to TensorFlow SavedModel
print("\n2. Converting ONNX -> TensorFlow...")
!pip install -q onnx-tf
import onnx
from onnx_tf.backend import prepare

onnx_model = onnx.load(onnx_path)
tf_rep = prepare(onnx_model)
tf_model_path = "/content/tf_model"
tf_rep.export_graph(tf_model_path)
print("   ✓ TensorFlow conversion successful")

# 3. Convert to TFLite with INT8 quantization
print("\n3. Converting to TFLite with INT8 quantization...")
converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
converter.optimizations = [tf.lite.Optimize.DEFAULT]  # INT8 quantization
tflite_model = converter.convert()

# 4. Save to Drive
tflite_dir = Path(f"{DRIVE_DIR}/models/tflite")
tflite_dir.mkdir(parents=True, exist_ok=True)

tflite_path = tflite_dir / "keyboard_model_quantized.tflite"
with open(tflite_path, "wb") as f:
    f.write(tflite_model)

# 5. Save vocabulary
tokenizer.save_vocabulary(str(tflite_dir))

import os
model_size_mb = os.path.getsize(tflite_path) / 1024 / 1024

print("\n" + "="*60)
print("✅ Android TFLite export complete!")
print(f"\nSaved to: {tflite_path}")
print(f"Model size: {model_size_mb:.2f}MB (INT8 quantized)")
print(f"Expected latency: 10-30ms on Android")
print("\nFiles created:")
print(f"  • keyboard_model_quantized.tflite")
print(f"  • vocab.txt")
print("\nNext steps:")
print("  1. Download from Google Drive")
print("  2. Add to Android project (assets/)")
print("  3. Use with TensorFlow Lite Interpreter")