# MobileBERT Keyboard Suggestion Model Training

This notebook trains a keyboard suggestion model using **MobileBERT** with multi-task learning.

**Features:**
1. Word Completion: "Hel" → ["Hello", "Help", "Helping"]
2. Next-Word Prediction: "How are" → ["you", "they", "we"]
3. Typo Correction: "Thers" → ["There", "Theirs", "Therapy"]
4. Gibberish Detection: Heuristic (no ML)

**Model Specifications:**
- Base: MobileBERT-TINY (15M parameters)
- Target Size: <15MB (after INT8 quantization)
- Latency: <50ms on mobile
- Deployment: iOS (CoreML) + Android (TFLite)

**Training Time:** 2-4 hours on Colab GPU (T4)

**Data Sources:**
- Word frequencies: GitHub (600K+ words)
- Text corpus: OpenSubtitles2024 (Hugging Face)
- Typos: WikEd Error Corpus + synthetic

---

**Instructions:**
1. Runtime → Change runtime type → GPU (T4)
2. Run all cells
3. Model will be saved to Google Drive
4. Download for mobile deployment

## 1. Environment Setup

In [None]:
# Check if running in Colab
import os
IN_COLAB = 'COLAB_GPU' in os.environ or 'google.colab' in str(get_ipython())

if IN_COLAB:
    print("✓ Running in Google Colab")
    print(f"✓ GPU: {!nvidia-smi --query-gpu=name --format=csv,noheader}")
else:
    print("✓ Running locally")

In [None]:
# Clone repository (if in Colab)
if IN_COLAB:
    os.chdir('/content')
    
    # Remove existing repo
    if os.path.exists('Keyboard-Suggestions-ML-Colab'):
        !rm -rf Keyboard-Suggestions-ML-Colab
    
    # Clone fresh copy
    !git clone https://github.com/MinhPhuPham/Keyboard-Suggestions-ML-Colab.git Keyboard-Suggestions-ML-Colab
    os.chdir('/content/Keyboard-Suggestions-ML-Colab')
    
    print(f"✓ Repository cloned")
    print(f"✓ Working directory: {os.getcwd()}")

In [None]:
# Install dependencies
!pip install -q transformers torch datasets accelerate
!pip install -q huggingface_hub
print("✓ Dependencies installed")

## 2. Google Drive Setup

In [None]:
# Mount Google Drive
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Create project directory in Drive
    DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'
    os.makedirs(DRIVE_DIR, exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/datasets", exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/models", exist_ok=True)
    
    print(f"✓ Google Drive mounted")
    print(f"✓ Project directory: {DRIVE_DIR}")
else:
    DRIVE_DIR = './drive_backup'
    os.makedirs(DRIVE_DIR, exist_ok=True)

## 3. Download Datasets

We'll download 3 datasets:
1. **Word Frequencies** - GitHub (600K words)
2. **Text Corpus** - OpenSubtitles2024 (Hugging Face)
3. **Typo Corrections** - Synthetic + real typos

In [None]:
import urllib.request
import json

# Download word frequency list from GitHub
print("Downloading word frequency list...")

WORD_FREQ_URL = "https://raw.githubusercontent.com/first20hours/google-10000-english/master/google-10000-english-usa-no-swears.txt"
word_freq_path = f"{DRIVE_DIR}/datasets/word_freq.txt"

if not os.path.exists(word_freq_path):
    urllib.request.urlretrieve(WORD_FREQ_URL, word_freq_path)
    print(f"✓ Downloaded to: {word_freq_path}")
else:
    print(f"✓ Already exists: {word_freq_path}")

# Count words
with open(word_freq_path, 'r') as f:
    word_count = len(f.readlines())
print(f"  Words: {word_count:,}")

In [None]:
# Download text corpus from Hugging Face
print("\nDownloading text corpus from Hugging Face...")

from datasets import load_dataset

corpus_path = f"{DRIVE_DIR}/datasets/corpus.txt"

if not os.path.exists(corpus_path):
    # Load OpenSubtitles dataset (English only, streaming for efficiency)
    print("  Loading OpenSubtitles dataset (this may take a few minutes)...")
    
    # Use a smaller, faster dataset for training
    dataset = load_dataset(
        "sentence-transformers/embedding-training-data",
        split="train",
        streaming=True
    )
    
    # Extract sentences
    sentences = []
    for i, item in enumerate(dataset):
        if i >= 100000:  # Limit to 100K sentences
            break
        if i % 10000 == 0:
            print(f"  Processed {i:,} sentences...")
        
        # Extract text
        if 'sentence' in item:
            text = item['sentence']
        elif 'text' in item:
            text = item['text']
        else:
            continue
        
        # Clean and filter
        text = text.strip().lower()
        if len(text.split()) >= 3:  # At least 3 words
            sentences.append(text)
    
    # Save to file
    with open(corpus_path, 'w', encoding='utf-8') as f:
        for sentence in sentences:
            f.write(sentence + '\n')
    
    print(f"✓ Downloaded {len(sentences):,} sentences")
    print(f"✓ Saved to: {corpus_path}")
else:
    with open(corpus_path, 'r') as f:
        sentence_count = len(f.readlines())
    print(f"✓ Already exists: {corpus_path}")
    print(f"  Sentences: {sentence_count:,}")

## 4. Generate Training Data

Generate training pairs for all 3 tasks

In [None]:
# Add src to path
import sys
sys.path.insert(0, './src')

from keyboard_data_prep import (
    prepare_word_completion_data,
    prepare_nextword_data,
    prepare_typo_data,
    combine_datasets
)

print("Preparing training datasets...")
print("="*60)

output_dir = f"{DRIVE_DIR}/datasets/processed"
train_path = f"{output_dir}/train.jsonl"
val_path = f"{output_dir}/val.jsonl"

# Check if processed datasets already exist in Drive
if os.path.exists(train_path) and os.path.exists(val_path):
    print("✓ Processed datasets found in Drive!")
    print(f"  Train: {train_path}")
    print(f"  Val: {val_path}")
    
    # Count samples
    with open(train_path, 'r') as f:
        train_count = sum(1 for _ in f)
    with open(val_path, 'r') as f:
        val_count = sum(1 for _ in f)
    print(f"  Train samples: {train_count:,}")
    print(f"  Val samples: {val_count:,}")
else:
    print("Generating training datasets from scratch...")
    
    # 1. Word completion
    print("\n1. Word Completion...")
    completion_path = prepare_word_completion_data(
        word_freq_path=word_freq_path,
        output_path=f"{output_dir}/completion.jsonl",
        max_samples=50000  # 50K pairs
    )
    
    # 2. Next-word prediction
    print("\n2. Next-Word Prediction...")
    nextword_path = prepare_nextword_data(
        corpus_path=corpus_path,
        output_path=f"{output_dir}/nextword.jsonl",
        max_samples=100000,  # 100K pairs
        context_length=3
    )
    
    # 3. Typo correction
    print("\n3. Typo Correction...")
    typo_path = prepare_typo_data(
        word_list_path=word_freq_path,
        output_path=f"{output_dir}/typo.jsonl",
        max_samples=20000  # 20K pairs
    )
    
    # 4. Combine and split
    print("\n4. Combining datasets...")
    train_path, val_path = combine_datasets(
        completion_path=completion_path,
        nextword_path=nextword_path,
        typo_path=typo_path,
        output_path=output_dir,
        train_ratio=0.9
    )
    
    print("\n" + "="*60)
    print("✓ Dataset generation complete!")
    print(f"  Train: {train_path}")
    print(f"  Val: {val_path}")

print("\n" + "="*60)
print("✓ Datasets ready for training!")

## 5. Load Model and Tokenizer

In [None]:
from keyboard_model import KeyboardSuggestionModel
from transformers import MobileBertTokenizer
import torch

print("Loading MobileBERT model...")

# Load tokenizer
tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")

# Create model
model = KeyboardSuggestionModel(
    base_model_name="google/mobilebert-uncased",
    vocab_size=tokenizer.vocab_size,
    dropout=0.1
)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"✓ Model loaded on {device}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 6. Prepare Training Data

In [None]:
from torch.utils.data import Dataset, DataLoader
import json

class KeyboardDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_length=32):
        self.data = []
        with open(data_path, 'r') as f:
            for line in f:
                self.data.append(json.loads(line))
        
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Tokenize input and target
        input_enc = self.tokenizer(
            item['input'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        target_enc = self.tokenizer(
            item['target'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': input_enc['input_ids'].squeeze(),
            'attention_mask': input_enc['attention_mask'].squeeze(),
            'labels': target_enc['input_ids'].squeeze(),
            'task': item['task']
        }

# Create datasets
print("Loading training data...")
train_dataset = KeyboardDataset(train_path, tokenizer)
val_dataset = KeyboardDataset(val_path, tokenizer)

print(f"✓ Train samples: {len(train_dataset):,}")
print(f"✓ Val samples: {len(val_dataset):,}")

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

## 7. Training

In [None]:
from torch.optim import AdamW
from tqdm.auto import tqdm
import torch.nn.functional as F

# Training configuration
NUM_EPOCHS = 3
LEARNING_RATE = 3e-4
SAVE_STEPS = 1000

# Optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

# Training loop
print("Starting training...")
print("="*60)

global_step = 0
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    
    # Training
    model.train()
    train_loss = 0
    
    progress_bar = tqdm(train_loader, desc="Training")
    for batch in progress_bar:
        # Move to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        tasks = batch['task']
        
        # Forward pass (simplified - use first task for now)
        task = tasks[0]
        predictions, scores = model(input_ids, attention_mask, task=task)
        
        # Compute loss (cross-entropy)
        loss = F.cross_entropy(
            scores.view(-1, scores.size(-1)),
            labels.view(-1),
            ignore_index=tokenizer.pad_token_id
        )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        global_step += 1
        
        progress_bar.set_postfix({'loss': loss.item()})
        
        # Save checkpoint
        if global_step % SAVE_STEPS == 0:
            checkpoint_dir = f"{DRIVE_DIR}/models/checkpoint-{global_step}"
            model.save_pretrained(checkpoint_dir)
            tokenizer.save_pretrained(checkpoint_dir)
            print(f"\n✓ Checkpoint saved: {checkpoint_dir}")
    
    avg_train_loss = train_loss / len(train_loader)
    print(f"  Train loss: {avg_train_loss:.4f}")
    
    # Validation
    model.eval()
    val_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            task = batch['task'][0]
            
            predictions, scores = model(input_ids, attention_mask, task=task)
            
            loss = F.cross_entropy(
                scores.view(-1, scores.size(-1)),
                labels.view(-1),
                ignore_index=tokenizer.pad_token_id
            )
            
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    print(f"  Val loss: {avg_val_loss:.4f}")
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_dir = f"{DRIVE_DIR}/models/best_model"
        model.save_pretrained(best_model_dir)
        tokenizer.save_pretrained(best_model_dir)
        print(f"  ✓ Best model saved: {best_model_dir}")

print("\n" + "="*60)
print("✓ Training complete!")
print(f"  Best val loss: {best_val_loss:.4f}")

## 8. Save Final Model

In [None]:
# Save final model
final_model_dir = f"{DRIVE_DIR}/models/keyboard_mobilebert_final"
model.save_pretrained(final_model_dir)
tokenizer.save_pretrained(final_model_dir)

print(f"✓ Final model saved to: {final_model_dir}")
print("\nModel is ready for:")
print("  1. INT8 quantization")
print("  2. CoreML export (iOS)")
print("  3. TFLite export (Android)")
print("\nNext steps: Run export scripts locally")

## 9. Test the Model

In [None]:
# Test predictions
model.eval()

test_cases = [
    ("hel", "completion"),
    ("how are", "next_word"),
    ("thers", "typo")
]

print("Testing model predictions...")
print("="*60)

for text, task in test_cases:
    # Tokenize
    inputs = tokenizer(
        text,
        return_tensors='pt',
        padding='max_length',
        max_length=32,
        truncation=True
    )
    
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    
    # Predict
    with torch.no_grad():
        predictions, scores = model(input_ids, attention_mask, task=task)
    
    # Decode
    top_predictions = []
    for pred in predictions[0]:
        word = tokenizer.decode([pred.item()], skip_special_tokens=True)
        top_predictions.append(word)
    
    print(f"\nInput: \"{text}\" (task: {task})")
    print(f"Predictions: {top_predictions}")

print("\n" + "="*60)
print("✓ Testing complete!")