# TinyBERT Keyboard Suggestion Model Training

This notebook trains a keyboard suggestion model using **TinyBERT** with multi-task learning.

**Features:**
1. Word Completion: "Hel" ‚Üí ["Hello", "Help", "Helping"]
2. Next-Word Prediction: "How are" ‚Üí ["you", "they", "we"]
3. Typo Correction: "Thers" ‚Üí ["There", "Theirs", "Therapy"]
4. Gibberish Detection: Heuristic (no ML)

**Model Specifications:**
- Base: TinyBERT (4 layers, 312 hidden, 4 heads, ~14M params)
- Target Size: <5MB (after INT8 quantization)
- Latency: <50ms on mobile
- RAM Usage: <30MB runtime
- Deployment: iOS (CoreML) + Android (TFLite)

**Training Time:** 2-4 hours on Colab GPU (T4)

**Data Sources (Google Drive):**
- `single_word_freq.csv` - Word frequencies for completion
- `keyboard_training_data.txt` - Custom corpus for next-word
- `misspelled.csv` - Typo correction pairs

---

**Instructions:**
1. Runtime ‚Üí Change runtime type ‚Üí GPU (T4)
2. Run all cells
3. Model will be saved to Google Drive
4. Download for mobile deployment

## 1. Environment Setup

In [23]:
# Check if running in Colab
import os

IN_COLAB = 'COLAB_GPU' in os.environ or 'COLAB_TPU_ADDR' in os.environ

if IN_COLAB:
    print("‚úì Running in Google Colab")

    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Define Drive directory
    DRIVE_DIR = '/content/drive/MyDrive/Keyboard-Suggestions-ML-Colab'

    # Create directories
    os.makedirs(DRIVE_DIR, exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/datasets", exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/datasets/processed", exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/models", exist_ok=True)

    print(f"‚úì Google Drive mounted")
    print(f"‚úì Project directory: {DRIVE_DIR}")
else:
    print("‚úì Running locally")
    DRIVE_DIR = './data'  # Local fallback
    os.makedirs(DRIVE_DIR, exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/datasets", exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/datasets/processed", exist_ok=True)
    os.makedirs(f"{DRIVE_DIR}/models", exist_ok=True)

‚úì Running in Google Colab
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úì Google Drive mounted
‚úì Project directory: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab


In [24]:
# Install dependencies
!pip install -q transformers torch datasets accelerate
!pip install -q scikit-learn tqdm
print("‚úì Dependencies installed")

‚úì Dependencies installed


## 2. Verify Datasets in Google Drive

**Expected datasets in Google Drive:**
- `{DRIVE_DIR}/datasets/single_word_freq.csv`
- `{DRIVE_DIR}/datasets/keyboard_training_data.txt`
- `{DRIVE_DIR}/datasets/misspelled.csv`

In [25]:
import os

print("Checking datasets in Google Drive...")
print("="*60)

# Define dataset paths
WORD_FREQ_PATH = f"{DRIVE_DIR}/datasets/single_word_freq.csv"
CORPUS_PATH = f"{DRIVE_DIR}/datasets/keyboard_training_data.txt"
TYPO_PATH = f"{DRIVE_DIR}/datasets/misspelled.csv"

# Check each dataset
datasets_ok = True

if os.path.exists(WORD_FREQ_PATH):
    with open(WORD_FREQ_PATH, 'r', encoding='utf-8') as f:
        word_count = sum(1 for _ in f) - 1  # Subtract header
    print(f"‚úì single_word_freq.csv: {word_count:,} words")
else:
    print(f"‚úó Missing: {WORD_FREQ_PATH}")
    datasets_ok = False

if os.path.exists(CORPUS_PATH):
    with open(CORPUS_PATH, 'r', encoding='utf-8') as f:
        line_count = sum(1 for _ in f)
    print(f"‚úì keyboard_training_data.txt: {line_count:,} lines")
else:
    print(f"‚úó Missing: {CORPUS_PATH}")
    datasets_ok = False

if os.path.exists(TYPO_PATH):
    with open(TYPO_PATH, 'r', encoding='utf-8') as f:
        typo_count = sum(1 for _ in f) - 1  # Subtract header
    print(f"‚úì misspelled.csv: {typo_count:,} entries")
else:
    print(f"‚úó Missing: {TYPO_PATH}")
    datasets_ok = False

print("="*60)
if datasets_ok:
    print("‚úÖ All datasets found!")
else:
    print("‚ö†Ô∏è  Some datasets are missing. Please upload them to Google Drive.")
    print("\nExpected location: {DRIVE_DIR}/datasets/")
    print("Required files:")
    print("  - single_word_freq.csv (format: word,count_frequency)")
    print("  - keyboard_training_data.txt (plain text sentences)")
    print("  - misspelled.csv (format: number,correct_word,misspelled_words)")

Checking datasets in Google Drive...
‚úì single_word_freq.csv: 238,168 words
‚úì keyboard_training_data.txt: 269,207 lines
‚úì misspelled.csv: 40,890 entries
‚úÖ All datasets found!


## 3. Generate Training Data

Generate training pairs for all 3 tasks from your existing datasets

In [26]:
import json
import random
import csv
from typing import List, Tuple

random.seed(42)

def prepare_word_completion_data(word_freq_path: str, max_samples: int = 50000) -> List[dict]:
    """Generate word completion training pairs from single_word_freq.csv"""
    print("\nGenerating word completion data...")

    samples = []
    words_with_freq = []

    # Read words with frequencies
    with open(word_freq_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            word = row['word'].strip().lower()
            freq = int(row.get('count_frequency', 1))
            if len(word) >= 3:  # Only words with 3+ chars
                words_with_freq.append((word, freq))

    # Sort by frequency (higher first)
    words_with_freq.sort(key=lambda x: x[1], reverse=True)

    # Generate samples (weighted by frequency)
    for word, freq in words_with_freq:
        if len(samples) >= max_samples:
            break

        # Generate multiple prefixes for common words
        num_samples = min(3, max(1, freq // 1000))  # More samples for frequent words

        for _ in range(num_samples):
            if len(samples) >= max_samples:
                break

            # Random prefix length (50-80% of word)
            prefix_len = random.randint(max(1, len(word) // 2), max(2, int(len(word) * 0.8)))
            prefix = word[:prefix_len]

            samples.append({
                'input': prefix,
                'target': word,
                'task': 'completion'
            })

    print(f"  Generated {len(samples):,} completion pairs")
    return samples

def prepare_nextword_data(corpus_path: str, max_samples: int = 100000, context_length: int = 3) -> List[dict]:
    """Generate next-word prediction pairs from keyboard_training_data.txt"""
    print("\nGenerating next-word prediction data...")

    samples = []

    with open(corpus_path, 'r', encoding='utf-8') as f:
        for line in f:
            if len(samples) >= max_samples:
                break

            line = line.strip().lower()
            words = line.split()

            # Skip short sentences
            if len(words) < context_length + 1:
                continue

            # Generate multiple samples from each sentence
            for i in range(len(words) - context_length):
                context = ' '.join(words[i:i+context_length])
                target = words[i+context_length]

                # Filter out punctuation-only targets
                if target.isalpha() and len(target) > 1:
                    samples.append({
                        'input': context,
                        'target': target,
                        'task': 'nextword'
                    })

                if len(samples) >= max_samples:
                    break

    print(f"  Generated {len(samples):,} next-word pairs")
    return samples

def prepare_typo_data(typo_path: str, max_samples: int = 20000) -> List[dict]:
    """Generate typo correction pairs from misspelled.csv"""
    print("\nGenerating typo correction data...")

    samples = []

    with open(typo_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            if len(samples) >= max_samples:
                break

            correct = row['label'].strip().lower()
            misspelled_list = row['input'].strip().lower()

            # Split multiple misspellings (comma or space separated)
            typos = [t.strip() for t in misspelled_list.replace(',', ' ').split() if t.strip()]

            for typo in typos:
                if len(samples) >= max_samples:
                    break

                if typo and typo != correct:
                    samples.append({
                        'input': typo,
                        'target': correct,
                        'task': 'typo'
                    })

    print(f"  Generated {len(samples):,} typo pairs")
    return samples

# Generate all datasets
print("Preparing training datasets...")
print("="*60)

output_dir = f"{DRIVE_DIR}/datasets/processed"
os.makedirs(output_dir, exist_ok=True)

train_path = f"{output_dir}/train.jsonl"
val_path = f"{output_dir}/val.jsonl"

# Check if processed datasets already exist
if os.path.exists(train_path) and os.path.exists(val_path):
    print("‚úì Processed datasets found in Drive!")
    print(f"  Train: {train_path}")
    print(f"  Val: {val_path}")

    # Count samples
    with open(train_path, 'r') as f:
        train_count = sum(1 for _ in f)
    with open(val_path, 'r') as f:
        val_count = sum(1 for _ in f)
    print(f"  Train samples: {train_count:,}")
    print(f"  Val samples: {val_count:,}")
else:
    print("Generating training datasets from scratch...")

    # Generate each task
    completion_samples = prepare_word_completion_data(WORD_FREQ_PATH, max_samples=50000)
    nextword_samples = prepare_nextword_data(CORPUS_PATH, max_samples=100000, context_length=3)
    typo_samples = prepare_typo_data(TYPO_PATH, max_samples=20000)

    # Combine all samples
    all_samples = completion_samples + nextword_samples + typo_samples
    random.shuffle(all_samples)

    # Split train/val (90/10)
    split_idx = int(len(all_samples) * 0.9)
    train_samples = all_samples[:split_idx]
    val_samples = all_samples[split_idx:]

    # Save to JSONL
    with open(train_path, 'w', encoding='utf-8') as f:
        for sample in train_samples:
            f.write(json.dumps(sample) + '\n')

    with open(val_path, 'w', encoding='utf-8') as f:
        for sample in val_samples:
            f.write(json.dumps(sample) + '\n')

    print("\n" + "="*60)
    print("‚úì Dataset generation complete!")
    print(f"  Total samples: {len(all_samples):,}")
    print(f"  Train: {len(train_samples):,} ({train_path})")
    print(f"  Val: {len(val_samples):,} ({val_path})")
    print(f"\n  Task distribution:")
    print(f"    Completion: {len(completion_samples):,}")
    print(f"    Next-word: {len(nextword_samples):,}")
    print(f"    Typo: {len(typo_samples):,}")

print("\n" + "="*60)
print("‚úì Datasets ready for training!")

Preparing training datasets...
‚úì Processed datasets found in Drive!
  Train: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/datasets/processed/train.jsonl
  Val: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/datasets/processed/val.jsonl
  Train samples: 153,000
  Val samples: 17,000

‚úì Datasets ready for training!


## 4. Load TinyBERT Model and Tokenizer

In [27]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

print("Loading TinyBERT for Masked Language Modeling...")
print("="*60)

# Load TinyBERT (4 layers, 312 hidden, 4 heads, ~14M params)
MODEL_NAME = "google/bert_uncased_L-4_H-256_A-4"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"‚úì Model loaded on {device}")
print(f"  Model: {MODEL_NAME}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Layers: 4, Hidden: 312, Heads: 12")
print(f"  Model size: ~55MB (FP32) ‚Üí ~5MB (INT8 quantized)")
print(f"  Target latency: <50ms on mobile")
print(f"  Target RAM: <30MB runtime")

Loading TinyBERT for Masked Language Modeling...
‚úì Model loaded on cuda
  Model: google/bert_uncased_L-4_H-256_A-4
  Parameters: 11,201,594
  Layers: 4, Hidden: 312, Heads: 12
  Model size: ~55MB (FP32) ‚Üí ~5MB (INT8 quantized)
  Target latency: <50ms on mobile
  Target RAM: <30MB runtime


## 5. Prepare Training Data

In [28]:
from torch.utils.data import Dataset, DataLoader
import json
import torch

class KeyboardDataset(Dataset):
    """BERT MLM dataset for keyboard suggestions."""

    def __init__(self, data_path, tokenizer, max_length=16):
        self.data = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                self.data.append(json.loads(line))
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text_input = item['input']
        target_word = item['target']
        task = item.get('task', 'completion')

        # Add [MASK] for prediction
        # - Completion: "hel [MASK]" ‚Üí predict "lo" (hello)
        # - Next-word: "how are [MASK]" ‚Üí predict "you"
        # - Typo: "thers [MASK]" ‚Üí predict "there"
        text_input = f"{text_input} {self.tokenizer.mask_token}"

        # Tokenize
        inputs = self.tokenizer(
            text_input,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Labels: -100 everywhere except [MASK]
        labels = torch.full(inputs['input_ids'].shape, -100, dtype=torch.long)

        # Get target token ID (first token of target word)
        target_tokens = self.tokenizer.tokenize(target_word)
        target_id = self.tokenizer.convert_tokens_to_ids(target_tokens[0]) if target_tokens else self.tokenizer.unk_token_id

        # Set label at [MASK] position
        mask_positions = (inputs['input_ids'] == self.tokenizer.mask_token_id).nonzero(as_tuple=True)
        if len(mask_positions[1]) > 0:
            labels[0, mask_positions[1][0]] = target_id

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': labels.squeeze(),
            'task': task
        }

print("Loading training data...")
train_dataset = KeyboardDataset(train_path, tokenizer, max_length=16)
val_dataset = KeyboardDataset(val_path, tokenizer, max_length=16)

print(f"‚úì Train samples: {len(train_dataset):,}")
print(f"‚úì Val samples: {len(val_dataset):,}")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

print(f"‚úì Batch size: 32 (train), 64 (val)")
print(f"‚úì Max sequence length: 16 tokens")

Loading training data...
‚úì Train samples: 153,000
‚úì Val samples: 17,000
‚úì Batch size: 32 (train), 64 (val)
‚úì Max sequence length: 16 tokens


## 6. Training

In [29]:
from torch.optim import AdamW
from tqdm.auto import tqdm

NUM_EPOCHS = 3
LEARNING_RATE = 3e-5  # Lower LR for fine-tuning
SAVE_STEPS = 1000

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

print("Starting training...")
print("="*60)
print(f"Epochs: {NUM_EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Optimizer: AdamW")
print(f"Save checkpoints every: {SAVE_STEPS} steps")
print("="*60)

global_step = 0
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")

    model.train()
    train_loss = 0

    progress_bar = tqdm(train_loader, desc="Training")
    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        global_step += 1
        progress_bar.set_postfix({'loss': loss.item()})

        if global_step % SAVE_STEPS == 0:
            checkpoint_dir = f"{DRIVE_DIR}/models/checkpoint-{global_step}"
            model.save_pretrained(checkpoint_dir)
            tokenizer.save_pretrained(checkpoint_dir)
            print(f"\n‚úì Checkpoint saved: {checkpoint_dir}")

    avg_train_loss = train_loss / len(train_loader)
    print(f"  Train loss: {avg_train_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            val_loss += outputs.loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"  Val loss: {avg_val_loss:.4f}")

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_dir = f"{DRIVE_DIR}/models/best_model"
        model.save_pretrained(best_model_dir)
        tokenizer.save_pretrained(best_model_dir)
        print(f"  ‚úì Best model saved: {best_model_dir}")

print("\n" + "="*60)
print("‚úì Training complete!")
print(f"  Best val loss: {best_val_loss:.4f}")
print(f"  Total steps: {global_step}")

Starting training...
Epochs: 3
Learning rate: 3e-05
Optimizer: AdamW
Save checkpoints every: 1000 steps

Epoch 1/3


Training:   0%|          | 0/4782 [00:00<?, ?it/s]


‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-1000

‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-2000

‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-3000

‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-4000
  Train loss: 4.7742


Validation:   0%|          | 0/266 [00:00<?, ?it/s]

  Val loss: 4.3384
  ‚úì Best model saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/best_model

Epoch 2/3


Training:   0%|          | 0/4782 [00:00<?, ?it/s]


‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-5000

‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-6000

‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-7000

‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-8000

‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-9000
  Train loss: 4.1519


Validation:   0%|          | 0/266 [00:00<?, ?it/s]

  Val loss: 4.2051
  ‚úì Best model saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/best_model

Epoch 3/3


Training:   0%|          | 0/4782 [00:00<?, ?it/s]


‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-10000

‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-11000

‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-12000

‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-13000

‚úì Checkpoint saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/checkpoint-14000
  Train loss: 3.8505


Validation:   0%|          | 0/266 [00:00<?, ?it/s]

  Val loss: 4.1743
  ‚úì Best model saved: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/best_model

‚úì Training complete!
  Best val loss: 4.1743
  Total steps: 14346


## 7. Export and Save Model

In [30]:
# Export and save the BEST model (for iOS/Android deployment)
import shutil
from google.colab import files

print("Exporting BEST model for deployment...")
print("="*60)

# Load the best model (saved during training)
best_model_path = f"{DRIVE_DIR}/models/best_model"
print(f"\nLoading best model from: {best_model_path}")

# Load best model
best_model = AutoModelForMaskedLM.from_pretrained(best_model_path)
best_tokenizer = AutoTokenizer.from_pretrained(best_model_path)
print("‚úì Best model loaded")

# 1. Save to Google Drive as 'final' (for deployment)
drive_model_dir = f"{DRIVE_DIR}/models/tinybert_keyboard_final"
best_model.save_pretrained(drive_model_dir)
best_tokenizer.save_pretrained(drive_model_dir)
print(f"\n‚úÖ Saved to Google Drive: {drive_model_dir}")

# 2. Create downloadable zip
local_model_dir = "/content/tinybert_keyboard_model"
best_model.save_pretrained(local_model_dir)
best_tokenizer.save_pretrained(local_model_dir)

zip_path = "/content/tinybert_keyboard_model.zip"
shutil.make_archive("/content/tinybert_keyboard_model", 'zip', local_model_dir)
print(f"\n‚úÖ Created zip: {zip_path}")

# 3. Download to local device
if IN_COLAB:
    print("\nüì• Downloading BEST model to your computer...")
    files.download(zip_path)
    print("‚úÖ Download started! Check your Downloads folder.")

print("\n" + "="*60)
print("‚úÖ Export complete!")
print("\n‚ö†Ô∏è  IMPORTANT: This is the BEST model (lowest validation loss)")
print("   Use this for iOS/Android deployment!")
print("\nModel saved to:")
print(f"  1. Google Drive: {drive_model_dir}")
print(f"  2. Local download: tinybert_keyboard_model.zip")
print("\nNext steps:")
print("  1. Run Section 8 (Export to CoreML for iOS)")
print("  2. Run Section 9 (Export to TFLite for Android)")
print("\nüí° Both iOS and Android exports will use this BEST model!")

Exporting BEST model for deployment...

Loading best model from: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/best_model
‚úì Best model loaded

‚úÖ Saved to Google Drive: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/tinybert_keyboard_final

‚úÖ Created zip: /content/tinybert_keyboard_model.zip

üì• Downloading BEST model to your computer...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

‚úÖ Download started! Check your Downloads folder.

‚úÖ Export complete!

‚ö†Ô∏è  IMPORTANT: This is the BEST model (lowest validation loss)
   Use this for iOS/Android deployment!

Model saved to:
  1. Google Drive: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/tinybert_keyboard_final
  2. Local download: tinybert_keyboard_model.zip

Next steps:
  1. Run Section 8 (Export to CoreML for iOS)
  2. Run Section 9 (Export to TFLite for Android)

üí° Both iOS and Android exports will use this BEST model!


## 8. Export to CoreML (iOS)

In [35]:
# Export to CoreML for iOS with INT8 quantization
!pip install -q coremltools
import coremltools as ct
import coremltools.optimize.coreml as cto
import torch
import torch.nn as nn

# --- Wrapper Class (Keep ensuring this is defined) ---
class WrapperModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits

# ... (Assume Steps 1-3 ran successfully and you have 'mlmodel') ...

# 4. Quantize to INT8 (Corrected)
print("\n4. Quantizing to INT8...")

# A. Create the configuration
# We define: INT8, Linear Symmetric mode, Per-Channel granularity (best accuracy)
op_config = cto.OpLinearQuantizerConfig(
    mode="linear_symmetric",
    dtype="int8",
    granularity="per_channel"
)
config = cto.OptimizationConfig(global_config=op_config)

# B. Apply the quantization using the config
mlmodel_int8 = cto.linear_quantize_weights(mlmodel, config=config)

print("   ‚úì Quantization complete")

# 5. Save
coreml_path = f"{DRIVE_DIR}/models/TinyBERT_Keyboard_iOS.mlpackage"
mlmodel_int8.save(coreml_path)

print("\n" + "="*60)
print("‚úÖ iOS CoreML export complete!")
print(f"Saved to: {coreml_path}")


4. Quantizing to INT8...


Running compression pass linear_quantize_weights: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 74/74 [00:03<00:00, 19.49 ops/s]
Running MIL frontend_milinternal pipeline: 0 passes [00:00, ? passes/s]
Running MIL default pipeline: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 92/92 [00:02<00:00, 38.76 passes/s]
Running MIL backend_mlprogram pipeline: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12/12 [00:00<00:00, 62.74 passes/s]


   ‚úì Quantization complete

‚úÖ iOS CoreML export complete!
Saved to: /content/drive/MyDrive/Keyboard-Suggestions-ML-Colab/models/TinyBERT_Keyboard_iOS.mlpackage


## 9. Export to TFLite (Android)

In [47]:
# ==============================================================================
# Android TFLite Export: Direct PyTorch ‚Üí TensorFlow ‚Üí TFLite
# Bypassing ONNX to avoid compatibility issues
# ==============================================================================

print("Installing required packages...")
!pip install -q optimum tensorflow transformers

import os
import shutil
import tensorflow as tf
from pathlib import Path
from transformers import TFAutoModelForMaskedLM, AutoTokenizer

print("\nExporting to TFLite for Android (Direct PyTorch ‚Üí TF)...")
print("="*60)

# --- STEP 1: Save PyTorch Model ---
print("\n1. Saving PyTorch model...")
staging_dir = Path("temp_staging_model")
if staging_dir.exists():
    shutil.rmtree(staging_dir)
staging_dir.mkdir()

best_model.save_pretrained(staging_dir)
best_tokenizer.save_pretrained(staging_dir)
print("   ‚úì Model saved")

# --- STEP 2: Convert PyTorch ‚Üí TensorFlow ---
print("\n2. Converting PyTorch ‚Üí TensorFlow...")

# Load as TensorFlow model directly
tf_model = TFAutoModelForMaskedLM.from_pretrained(
    staging_dir,
    from_pt=True  # Convert from PyTorch
)
print("   ‚úì Conversion successful")

# --- STEP 3: Save as TensorFlow SavedModel ---
print("\n3. Saving TensorFlow SavedModel...")

tf_saved_model_dir = "tf_saved_model"
if os.path.exists(tf_saved_model_dir):
    shutil.rmtree(tf_saved_model_dir)

# Create input signature
import numpy as np

@tf.function(input_signature=[
    tf.TensorSpec(shape=(1, 16), dtype=tf.int32, name='input_ids'),
    tf.TensorSpec(shape=(1, 16), dtype=tf.int32, name='attention_mask')
])
def serving_fn(input_ids, attention_mask):
    outputs = tf_model(input_ids=input_ids, attention_mask=attention_mask)
    return {'logits': outputs.logits}

# Save with concrete function
tf.saved_model.save(
    tf_model,
    tf_saved_model_dir,
    signatures={'serving_default': serving_fn}
)
print("   ‚úì SavedModel created")

# --- STEP 4: Convert to TFLite ---
print("\n4. Converting to TFLite (INT8 Quantization)...")

converter = tf.lite.TFLiteConverter.from_saved_model(tf_saved_model_dir)

# Configure quantization
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS
]

try:
    tflite_model = converter.convert()
    print("   ‚úì TFLite conversion successful (INT8)")
except Exception as e:
    print(f"   ‚ö† INT8 failed, trying FP16...")
    converter = tf.lite.TFLiteConverter.from_saved_model(tf_saved_model_dir)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS
    ]
    tflite_model = converter.convert()
    print("   ‚úì TFLite conversion successful (FP16)")

# --- STEP 5: Save Final Model ---
print("\n5. Packaging for Android...")

tflite_dir = Path(f"{DRIVE_DIR}/models/android")
tflite_dir.mkdir(parents=True, exist_ok=True)

tflite_path = tflite_dir / "keyboard_model_quantized.tflite"
with open(tflite_path, "wb") as f:
    f.write(tflite_model)

best_tokenizer.save_vocabulary(str(tflite_dir))

model_size_mb = os.path.getsize(tflite_path) / (1024 * 1024)

# --- STEP 6: Verification ---
print("\n6. Verifying model...")
try:
    interpreter = tf.lite.Interpreter(model_path=str(tflite_path))
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    print(f"   ‚úì Model verified")
    print(f"   Inputs: {len(input_details)}")
    for inp in input_details:
        print(f"     - {inp['name']}: {inp['shape']}")
    print(f"   Outputs: {len(output_details)}")
    for out in output_details:
        print(f"     - {out['name']}: {out['shape']}")

except Exception as e:
    print(f"   ‚ö† Verification warning (OK on Android)")

# --- STEP 7: Cleanup ---
print("\n7. Cleaning up...")
for path in [staging_dir, tf_saved_model_dir]:
    if os.path.exists(path):
        shutil.rmtree(path)

print("\n" + "="*60)
print("‚úÖ Android TFLite export complete!")
print(f"\nSaved to: {tflite_dir}")
print(f"Model: keyboard_model_quantized.tflite ({model_size_mb:.2f} MB)")
print(f"\nüì¶ Files:")
print(f"  1. keyboard_model_quantized.tflite")
print(f"  2. vocab.txt")
print(f"\nüì± build.gradle:")
print(f"  implementation 'org.tensorflow:tensorflow-lite:2.14.0'")
print(f"  implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.14.0'")

Installing AI Edge Torch and dependencies...
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m570.0/570.0 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m4.9/4.9 MB[0m [31m87.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.5/1.5 MB[0m [31m68.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m323.3/323.3 kB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m5.5/5.5 MB[0m [3

ModuleNotFoundError: No module named 'ai_edge_torch'