# Pythia-14m Keyboard Suggestion Model Training

This notebook trains a keyboard suggestion model using **Pythia-14m** (GPT-style causal LM).

**Features:**
1. Word Completion: "Hel" → ["Hello", "Help", "Helping"]
2. Next-Word Prediction: "How are" → ["you", "they", "we"]
3. Typo Correction: "Thers" → ["There", "Theirs", "Therapy"]
4. Gibberish Detection: Heuristic (no ML)

**Model Specifications:**
- Base: Pythia-14m (6 layers, 128 hidden, 4 heads, ~14M params)
- Architecture: GPT-NeoX (Causal LM, decoder-only)
- Target Size: <20MB (after INT8 quantization)
- Latency: <50ms on mobile
- RAM Usage: 15-20MB runtime
- Deployment: iOS (CoreML) + Android (TFLite)

**Training Time:** 3-4 hours on Colab GPU (T4)

**Data Sources (Google Drive):**
- `single_word_freq.csv` - Word frequencies for completion
- `keyboard_training_data.txt` - Custom corpus for next-word
- `misspelled.csv` - Typo correction pairs

---

**Instructions:**
1. Runtime → Change runtime type → GPU (T4)
2. Run all cells
3. Model will be saved to Google Drive
4. Download for mobile deployment

## 1. Environment Setup

In [None]:
# Check if running in Colab
import os

IN_COLAB = 'COLAB_GPU' in os.environ or 'COLAB_TPU_ADDR' in os.environ

if IN_COLAB:
    print("✓ Running in Google Colab")

    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Define Drive directory
    DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'

    # Create directories
    os.makedirs(DRIVE_DIR, exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/datasets", exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/datasets/processed", exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/models", exist_ok=True)

    print(f"✓ Google Drive mounted")
    print(f"✓ Project directory: {DRIVE_DIR}")
else:
    print("✓ Running locally")
    DRIVE_DIR = './data'  # Local fallback
    os.makedirs(DRIVE_DIR, exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/datasets", exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/datasets/processed", exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/models", exist_ok=True)

In [None]:
# Install dependencies
!pip install -q transformers torch datasets accelerate
!pip install -q scikit-learn tqdm
print("✓ Dependencies installed")

## 2. Verify Datasets in Google Drive

**Expected datasets in Google Drive:**
- `{DRIVE_DIR}/datasets/single_word_freq.csv`
- `{DRIVE_DIR}/datasets/keyboard_training_data.txt`
- `{DRIVE_DIR}/datasets/misspelled.csv`

In [None]:
import os

print("Checking datasets in Google Drive...")
print("="*60)

# Define dataset paths
WORD_FREQ_PATH = f"{DRIVE_DIR}/datasets/single_word_freq.csv"
CORPUS_PATH = f"{DRIVE_DIR}/datasets/keyboard_training_data.txt"
TYPO_PATH = f"{DRIVE_DIR}/datasets/misspelled.csv"

# Check each dataset
datasets_ok = True

if os.path.exists(WORD_FREQ_PATH):
    with open(WORD_FREQ_PATH, 'r', encoding='utf-8') as f:
        word_count = sum(1 for _ in f) - 1  # Subtract header
    print(f"✓ single_word_freq.csv: {word_count:,} words")
else:
    print(f"✗ Missing: {WORD_FREQ_PATH}")
    datasets_ok = False

if os.path.exists(CORPUS_PATH):
    with open(CORPUS_PATH, 'r', encoding='utf-8') as f:
        line_count = sum(1 for _ in f)
    print(f"✓ keyboard_training_data.txt: {line_count:,} lines")
else:
    print(f"✗ Missing: {CORPUS_PATH}")
    datasets_ok = False

if os.path.exists(TYPO_PATH):
    with open(TYPO_PATH, 'r', encoding='utf-8') as f:
        typo_count = sum(1 for _ in f) - 1  # Subtract header
    print(f"✓ misspelled.csv: {typo_count:,} entries")
else:
    print(f"✗ Missing: {TYPO_PATH}")
    datasets_ok = False

print("="*60)
if datasets_ok:
    print("✅ All datasets found!")
else:
    print("⚠️  Some datasets are missing. Please upload them to Google Drive.")
    print("\nExpected location: {DRIVE_DIR}/datasets/")
    print("Required files:")
    print("  - single_word_freq.csv (format: word,count_frequency)")
    print("  - keyboard_training_data.txt (plain text sentences)")
    print("  - misspelled.csv (format: number,correct_word,misspelled_words)")

## 3. Generate Training Data

Generate training pairs for all 3 tasks from your existing datasets

In [None]:
import json
import random
import csv
from typing import List, Tuple

random.seed(42)

def prepare_word_completion_data(word_freq_path: str, max_samples: int = 50000) -> List[dict]:
    print("\nGenerating word completion data (Fixed)...")

    samples = []
    words_with_freq = []

    with open(word_freq_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            word = row['word'].strip().lower()
            freq = int(row.get('count_frequency', 1))
            # Only words 3+ chars, alphabetic only
            if len(word) >= 3 and word.isalpha():
                words_with_freq.append((word, freq))
    
    # Sort by frequency, take top 15k most common words (vs 10k in smaller version)
    words_with_freq.sort(key=lambda x: x[1], reverse=True)
    words_with_freq = words_with_freq[:25000]  # ✅ More words = better coverage
    
    print(f"  Using top {len(words_with_freq):,} words")
    
    # Generate samples
    for word, freq in words_with_freq:
        if len(samples) >= max_samples:
            break
        
        # More samples for frequent words (1-5 per word)
        num_samples = min(5, max(1, freq // 10000))
        
        for i in range(num_samples):
            if len(samples) >= max_samples:
                break
            
            # Varied prefix lengths: 40%, 50%, 60%, 70%, 80%
            prefix_ratio = 0.4 + (i * 0.1)
            prefix_len = max(1, int(len(word) * prefix_ratio))
            
            if prefix_len < len(word):  # Don't use full word as prefix
                samples.append({
                    'input': word[:prefix_len],
                    'target': word,
                    'task': 'completion'
                })
    
    print(f"  Generated {len(samples):,} completion pairs")
    return samples

def prepare_nextword_data(corpus_path: str, max_samples: int = 100000, context_length: int = 3) -> List[dict]:
    """
    Generate next-word prediction pairs - MAXIMUM ACCURACY VERSION
    """
    print("\nGenerating next-word prediction data...")
    
    samples = []
    seen_pairs = set()  # Avoid duplicates
    
    with open(corpus_path, 'r', encoding='utf-8') as f:
        for line in f:
            if len(samples) >= max_samples:
                break
            
            line = line.strip().lower()
            words = line.split()
            
            if len(words) < context_length + 1:
                continue
            
            for i in range(len(words) - context_length):
                if len(samples) >= max_samples:
                    break
                
                context = ' '.join(words[i:i+context_length])
                target = words[i+context_length]
                
                # Only valid alphabetic words, no duplicates
                pair_key = f"{context}|{target}"
                if (target.isalpha() and 
                    len(target) > 1 and 
                    pair_key not in seen_pairs):
                    
                    samples.append({
                        'input': context,
                        'target': target,
                        'task': 'nextword'
                    })
                    seen_pairs.add(pair_key)
    
    print(f"  Generated {len(samples):,} next-word pairs ({len(seen_pairs):,} unique)")
    return samples

def prepare_typo_data(typo_path: str, max_samples: int = 20000) -> List[dict]:
    """
    Generate typo correction pairs - MAXIMUM ACCURACY VERSION
    """
    print("\nGenerating typo correction data...")
    
    samples = []
    
    with open(typo_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            if len(samples) >= max_samples:
                break
            
            # ✅ FIXED: Use correct CSV column names
            correct = row['correct_word'].strip().lower()
            misspelled_list = row['misspelled_words'].strip().lower()
            
            # Split multiple misspellings (comma or space separated)
            typos = [t.strip() for t in misspelled_list.replace(',', ' ').split() if t.strip()]
            
            for typo in typos:
                if len(samples) >= max_samples:
                    break
                
                if typo and typo != correct:
                    samples.append({
                        'input': typo,
                        'target': correct,
                        'task': 'typo'
                    })
    
    print(f"  Generated {len(samples):,} typo pairs")
    return samples

# Generate all datasets
print("Preparing training datasets...")
print("="*60)

output_dir = f"{DRIVE_DIR}/datasets/processed"
os.makedirs(output_dir, exist_ok=True)

train_path = f"{output_dir}/train.jsonl"
val_path = f"{output_dir}/val.jsonl"

# Check if processed datasets already exist
if os.path.exists(train_path) and os.path.exists(val_path):
    print("✓ Processed datasets found in Drive!")
    print(f"  Train: {train_path}")
    print(f"  Val: {val_path}")

    # Count samples
    with open(train_path, 'r') as f:
        train_count = sum(1 for _ in f)
    with open(val_path, 'r') as f:
        val_count = sum(1 for _ in f)
    print(f"  Train samples: {train_count:,}")
    print(f"  Val samples: {val_count:,}")
else:
    print("Generating training datasets from scratch...")

    # Generate each task
    completion_samples = prepare_word_completion_data(WORD_FREQ_PATH, max_samples=50000)
    nextword_samples = prepare_nextword_data(CORPUS_PATH, max_samples=100000, context_length=3)
    typo_samples = prepare_typo_data(TYPO_PATH, max_samples=20000)

    # Combine all samples
    all_samples = completion_samples + nextword_samples + typo_samples
    random.shuffle(all_samples)

    # Split train/val (90/10)
    split_idx = int(len(all_samples) * 0.95)
    train_samples = all_samples[:split_idx]
    val_samples = all_samples[split_idx:]

    # Save to JSONL
    with open(train_path, 'w', encoding='utf-8') as f:
        for sample in train_samples:
            f.write(json.dumps(sample) + '\n')

    with open(val_path, 'w', encoding='utf-8') as f:
        for sample in val_samples:
            f.write(json.dumps(sample) + '\n')

    print("\n" + "="*60)
    print("✓ Dataset generation complete!")
    print(f"  Total samples: {len(all_samples):,}")
    print(f"  Train: {len(train_samples):,} ({train_path})")
    print(f"  Val: {len(val_samples):,} ({val_path})")
    print(f"\n  Task distribution:")
    print(f"    Completion: {len(completion_samples):,}")
    print(f"    Next-word: {len(nextword_samples):,}")
    print(f"    Typo: {len(typo_samples):,}")

print("\n" + "="*60)
print("✓ Datasets ready for training!")

## 4. Load Pythia-14m Model

In [None]:
import torch
from transformers import AutoTokenizer, GPTNeoXForCausalLM

MODEL_NAME = "EleutherAI/pythia-14m"

print(f"Loading Pythia-14m from: {MODEL_NAME}")
print("="*60)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = GPTNeoXForCausalLM.from_pretrained(MODEL_NAME)

# ✅ CRITICAL: Add padding token (GPT models don't have one by default)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"✓ Model loaded successfully")
print(f"✓ Device: {device}")
print(f"✓ Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"✓ Vocab size: {tokenizer.vocab_size:,}")
print("="*60)

## 5. Prepare Training Data

In [None]:
from torch.utils.data import Dataset
class PythiaKeyboardDataset(Dataset):
    """Causal LM dataset for Pythia-14m keyboard suggestions"""
    
    def __init__(self, data_path, tokenizer, max_length=12):
        self.data = []

        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                self.data.append(json.loads(line))
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_text = item['input']
        target_word = item['target']
        task = item.get('task', 'completion')
        # ✅ For Pythia: concatenate input + target (no [MASK])
        # Example: "hel" + "hello" = "helhello" (model learns to predict "hello" after "hel")
        full_text = f"{input_text}{target_word}"
        # Tokenize
        inputs = self.tokenizer(
            full_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        # ✅ Causal LM: labels are input_ids shifted by 1
        input_ids = inputs['input_ids'].squeeze()
        labels = input_ids.clone()
        
        # Shift labels: predict next token
        labels[:-1] = input_ids[1:]
        labels[-1] = -100  # Ignore last position
        
        # Mask padding tokens
        labels[labels == self.tokenizer.pad_token_id] = -100
        return {
            'input_ids': input_ids,
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': labels
        }

# Load datasets
print("Loading training data...")

train_dataset = PythiaKeyboardDataset(train_path, tokenizer, max_length=12)
val_dataset = PythiaKeyboardDataset(val_path, tokenizer, max_length=12)

print(f"✓ Train samples: {len(train_dataset):,}")
print(f"✓ Val samples: {len(val_dataset):,}")
print(f"✓ Max sequence length: 12 tokens")

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"✓ Train batches: {len(train_loader)}")
print(f"✓ Val batches: {len(val_loader)}")

## 6. Training

In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm.auto import tqdm

NUM_EPOCHS = 5
LEARNING_RATE = 5e-5
SAVE_STEPS = 1000

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS * len(train_loader))

print("Starting training...")
print("="*60)
print(f"Epochs: {NUM_EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Batch size: 16")
print(f"Total steps: {NUM_EPOCHS * len(train_loader):,}")
print("="*60)

global_step = 0
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")

    # Training
    model.train()
    train_loss = 0
    progress_bar = tqdm(train_loader, desc="Training")

    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
        global_step += 1
        progress_bar.set_postfix({'loss': loss.item(), 'lr': scheduler.get_last_lr()[0]})
        # Save checkpoint
        if global_step % SAVE_STEPS == 0:
            checkpoint_dir = f"{DRIVE_DIR}/models/pythia_checkpoint-{global_step}"
            model.save_pretrained(checkpoint_dir)
            tokenizer.save_pretrained(checkpoint_dir)
            print(f"\n✓ Checkpoint saved: {checkpoint_dir}")
    avg_train_loss = train_loss / len(train_loader)

    print(f"  Train loss: {avg_train_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            val_loss += outputs.loss.item()
    avg_val_loss = val_loss / len(val_loader)

    print(f"  Val loss: {avg_val_loss:.4f}")

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_dir = f"{DRIVE_DIR}/models/pythia_best_model"
        model.save_pretrained(best_model_dir)
        tokenizer.save_pretrained(best_model_dir)
        print(f"  ✓ Best model saved: {best_model_dir}")

print("\n" + "="*60)
print("✓ Training complete!")
print(f"  Best val loss: {best_val_loss:.4f}")
print(f"  Total steps: {global_step}")
print(f"  Expected accuracy: 80-85%")
print("="*60)

## 7. Test Predictions

In [None]:
def test_pythia_predictions(model, tokenizer, test_input, top_k=5):
    """Test Pythia-14m predictions"""
    model.eval()
    
    # Tokenize input
    inputs = tokenizer(test_input, return_tensors='pt').to(device)
    
    # Generate predictions
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[0, -1, :]  # Last token position
    
    # Get top-k predictions
    top_tokens = torch.topk(logits, k=top_k)
    predictions = []
    
    for idx, score in zip(top_tokens.indices, top_tokens.values):
        token = tokenizer.decode([idx])
        prob = torch.softmax(logits, dim=0)[idx].item() * 100
        predictions.append((token, prob))
    
    return predictions
# Test cases
test_cases = [
    "hel",
    "prod",
    "how ar",
    "best st",
    "you",
]
print("\n" + "="*60)
print("Testing Predictions:")
print("="*60)
for test_input in test_cases:
    predictions = test_pythia_predictions(model, tokenizer, test_input, top_k=3)
    print(f"\nInput: '{test_input}'")
    print("Predictions:")
    for i, (token, prob) in enumerate(predictions, 1):
        print(f"  {i}. '{token}' ({prob:.1f}%)")

## 8. Export to CoreML (iOS)

In [None]:
!pip install -q coremltools
import coremltools as ct
import numpy as np
print("Exporting to CoreML...")
# Load best model
best_model_path = f"{DRIVE_DIR}/models/pythia_best_model"
export_model = GPTNeoXForCausalLM.from_pretrained(best_model_path)
export_tokenizer = AutoTokenizer.from_pretrained(best_model_path)
export_model.eval()
# Trace model
dummy_input = torch.randint(0, export_tokenizer.vocab_size, (1, 12))
traced_model = torch.jit.trace(export_model, dummy_input)
# Convert to CoreML
mlmodel = ct.convert(
    traced_model,
    inputs=[ct.TensorType(name="input_ids", shape=(1, 12), dtype=np.int32)],
    compute_units=ct.ComputeUnit.ALL,
    compute_precision=ct.precision.FLOAT16,
    minimum_deployment_target=ct.target.iOS14
)
# Add metadata
mlmodel.author = "MinhPhuPham"
mlmodel.short_description = "Pythia-14m keyboard suggestion model"
mlmodel.version = "1.0"
# Quantize to INT8
print("Quantizing to INT8...")
import coremltools.optimize.coreml as cto
op_config = cto.OpLinearQuantizerConfig(
    mode="linear_symmetric",
    dtype="int8",
    granularity="per_channel"
)
config = cto.OptimizationConfig(global_config=op_config)
mlmodel_int8 = cto.linear_quantize_weights(mlmodel, config=config)
# Save
output_path = f"{DRIVE_DIR}/models/Pythia14m_Keyboard.mlpackage"
mlmodel_int8.save(output_path)
print(f"✓ CoreML model saved: {output_path}")
print(f"✓ Model size: ~12-15MB")
print(f"✓ Expected RAM: 15-20MB")
print(f"✓ Expected latency: <50ms")
print("\n" + "="*60)
print("✅ TRAINING COMPLETE!")
print("="*60)
print("\nNext steps:")
print("1. Download model from Google Drive")
print("2. Test on iOS device")
print("3. Verify accuracy >80%")
print("4. Deploy to production")

## 9. Export to TFLite (Android)