# Train DistilBERT (multilingual) and convert to TFLite

This notebook trains a small multilingual model for phishing detection using `distilbert-base-multilingual-cased`, saves the best checkpoint, converts it to a TensorFlow SavedModel and then to a TFLite file. It includes small tests and saves the artifacts to Google Drive for easy download.

## 0) Notes before you start

- Use GPU runtime (Runtime → Change runtime type → GPU).
- Upload `train.csv`, `validation.csv`, and `test.csv` via the file upload UI or mount your Drive and place them in a folder.
- This notebook is intentionally minimal and uses small defaults to keep runtime short.

In [None]:
# 1) Install dependencies (run once)
!pip install -q transformers datasets accelerate evaluate sentencepiece tensorflow-text
!pip install -q 'tensorflow>=2.12'  # for TFLite conversion and interpreter

print('Installed packages')

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalled packages


In [None]:
# 2) Mount Google Drive (optional) or upload files manually
from google.colab import drive, files
import os

drive_mount_path = '/content/drive'
print('If you want to use Drive, run: drive.mount(drive_mount_path) and place CSVs under a folder; otherwise use files.upload()')
# Uncomment to mount
# drive.mount(drive_mount_path)

# Helper: if local files not present, prompt manual upload
def ensure_file(path):
    if not os.path.exists(path):
        print(f"Upload {os.path.basename(path)}")
        uploaded = files.upload()
        for name in uploaded.keys():
            print('Uploaded', name)

# Set expected filenames (change if your filenames differ)
TRAIN_CSV = 'train.csv'
VAL_CSV = 'validation.csv'
TEST_CSV = 'test.csv'

for p in (TRAIN_CSV, VAL_CSV, TEST_CSV):
    ensure_file(p)

print('Ready to load CSVs from this runtime workspace')

If you want to use Drive, run: drive.mount(drive_mount_path) and place CSVs under a folder; otherwise use files.upload()
Upload train.csv


Saving train.csv to train.csv
Uploaded train.csv
Upload validation.csv


Saving validation.csv to validation.csv
Uploaded validation.csv
Upload test.csv


Saving test.csv to test.csv
Uploaded test.csv
Ready to load CSVs from this runtime workspace


In [None]:
# 3) Load and quick-validate the CSVs
import pandas as pd

train_df = pd.read_csv(TRAIN_CSV)
val_df = pd.read_csv(VAL_CSV)
test_df = pd.read_csv(TEST_CSV)

print('Train', len(train_df), 'Val', len(val_df), 'Test', len(test_df))
print('Sample train rows:')
display(train_df.head())

# Ensure label column exists and map to integers (0 = legitimate, 1 = phishing)
label_map = { 'phishing': 1, 'legitimate': 0 }
if train_df['label'].dtype != 'int64':
    train_df['label'] = train_df['label'].map(label_map).astype('int64')
    val_df['label'] = val_df['label'].map(label_map).astype('int64')
    test_df['label'] = test_df['label'].map(label_map).astype('int64')

print('Label distribution (train):')
print(train_df['label'].value_counts())

Train 47682 Val 10218 Test 10218
Sample train rows:


Unnamed: 0,label,message,language,source,channel,sender,received_at,scam_type
0,legitimate,"Ya ok, vikky vl c witin &lt;#&gt; mins and il...",en,huggingface_augmented,sms,2723293735,2025-10-13T05:59:57,legitimate_general
1,phishing,M-Pesa: Tafadhali thibitisha nambari 579607 ku...,sw,synthetic_generator,sms,2550087085,,otp_request
2,legitimate,Kann u schauen 4 mich in da lib ich habe Sache...,de,kaggle_multilingual,sms,4976310049,2025-07-26T23:27:19,legitimate_general
3,phishing,HARAKA: Akaunti yako ya Benki ya Kenya imesima...,sw,synthetic_generator,sms,2540330598,,account_suspension
4,legitimate,कुछ भी वह आ रहा है?,hi,kaggle_multilingual,sms,9122515726,2024-12-20T05:18:11,legitimate_general


Label distribution (train):
label
0    30497
1    17185
Name: count, dtype: int64


## 4) Prepare Hugging Face datasets and tokenizer
We tokenize with `distilbert-base-multilingual-cased`. We use short sequences (max_length=128) for mobile efficiency.

In [None]:
from datasets import Dataset
from transformers import AutoTokenizer

model_name = 'distilbert-base-multilingual-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

train_ds = Dataset.from_pandas(train_df[['message','label']].rename(columns={'message':'text'}))
val_ds = Dataset.from_pandas(val_df[['message','label']].rename(columns={'message':'text'}))
test_ds = Dataset.from_pandas(test_df[['message','label']].rename(columns={'message':'text'}))

def tokenize_fn(batch):
    return tokenizer(batch['text'], padding='max_length', truncation=True, max_length=128)

train_ds = train_ds.map(tokenize_fn, batched=True)
val_ds = val_ds.map(tokenize_fn, batched=True)
test_ds = test_ds.map(tokenize_fn, batched=True)

train_ds = train_ds.remove_columns(['text']).with_format('torch')
val_ds = val_ds.remove_columns(['text']).with_format('torch')
test_ds = test_ds.remove_columns(['text']).with_format('torch')

print('Datasets tokenized')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/466 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

Map:   0%|          | 0/47682 [00:00<?, ? examples/s]

Map:   0%|          | 0/10218 [00:00<?, ? examples/s]

Map:   0%|          | 0/10218 [00:00<?, ? examples/s]

Datasets tokenized


## 5) Initialize model and Trainer
We use the Hugging Face `Trainer` for simplicity.

In [None]:
import os
import numpy as np
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy='epoch', # Corrected from evaluation_strategy
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    report_to="none",
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
)

print('Trainer ready')

model.safetensors:   0%|          | 0.00/542M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainer ready


In [None]:
# 6) Train (this may take ~30-60 minutes on Colab GPU depending on dataset size)
train_result = trainer.train()
print('Training finished')
trainer.save_model('./phishing_detector_model')
tokenizer.save_pretrained('./phishing_detector_model')
print('Model and tokenizer saved to ./phishing_detector_model')

Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.0171,0.042983,0.992269,0.989227,0.993699,0.984795
2,0.011,0.060179,0.991486,0.988103,0.995317,0.980994


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.0171,0.042983,0.992269,0.989227,0.993699,0.984795
2,0.011,0.060179,0.991486,0.988103,0.995317,0.980994
3,0.0055,0.048656,0.993051,0.990326,0.993982,0.986696


## 7) Convert PyTorch model to TensorFlow and then to TFLite
We convert the saved PyTorch checkpoint to a TF SavedModel using `TFAutoModelForSequenceClassification.from_pretrained(..., from_pt=True)`, then use the TensorFlow Lite converter.

In [None]:
# 7.1 Convert to TensorFlow SavedModel
from transformers import TFAutoModelForSequenceClassification
import tensorflow as tf
import os

tf_model_dir = './tf_saved_model'
if os.path.exists(tf_model_dir):
    print('Removing previous TF model')
    import shutil
    shutil.rmtree(tf_model_dir)

print('Loading PyTorch checkpoint and converting to TF...')
tf_model = TFAutoModelForSequenceClassification.from_pretrained('./phishing_detector_model', from_pt=True)
tf_model.save(tf_model_dir)
print('Saved TF model to', tf_model_dir)

Loading PyTorch checkpoint and converting to TF...


TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.
All PyTorch model weights were used when initializing TFDistilBertForSequenceClassification.

All the weights of TFDistilBertForSequenceClassification were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.


Saved TF model to ./tf_saved_model


In [None]:
# 7.2 Convert SavedModel to TFLite
converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Use float16 quant if you want a smaller model and your device supports it
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()
open('phishing_detector.tflite', 'wb').write(tflite_model)
print('Float16 - Wrote phishing_detector.tflite (size MB):', round(len(tflite_model)/(1024*1024),2))


converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_dyn = converter.convert()
open('phishing_detector_dynamic.tflite', 'wb').write(tflite_dyn)
print('Dynamic-range quanitzation - Wrote phishing_detector_dynamic.tflite (MB):', round(len(tflite_dyn)/(1024*1024), 2))

Float16 - Wrote phishing_detector.tflite (size MB): 258.27
Dynamic-range quanitzation - Wrote phishing_detector_dynamic.tflite (MB): 129.86


### 7.3 Prepare representative calibration set (for full-int8 quantization)
Use a sample of real messages from the training CSV as the representative dataset. This cell builds `calibration_texts` by sampling up to 500 messages from `TRAIN_CSV` used earlier in the notebook.

In [None]:
# Build a representative calibration list from the training CSV (up to 500 samples)
import pandas as pd
# TRAIN_CSV was defined earlier in the notebook as 'train.csv' — if you mounted Drive or uploaded files, that variable will be present
try:
    df = pd.read_csv(TRAIN_CSV)
except Exception:
    # fallback: try the processed path if present in runtime workspace
    df = pd.read_csv('data/processed/train.csv')
# Ensure we have a message column and drop NA
df = df.dropna(subset=['message'])
n = min(500, len(df))
# stratified-ish sample: sample equally across labels if possible
if 'label' in df.columns:
    # convert label to string in case it's numeric
    df['label'] = df['label'].astype(str)
    # group and sample from each label proportionally
    groups = []
    for _, g in df.groupby('label'):
        groups.append(g.sample(frac=min(1, n/len(df)), random_state=42))
    sample_df = pd.concat(groups).sample(n=n, random_state=42) if len(df) > n else df.sample(n=n, random_state=42)
else:
    sample_df = df.sample(n=n, random_state=42)
calibration_texts = sample_df['message'].astype(str).tolist()
print('Prepared calibration_texts from', len(calibration_texts), 'messages')
# show first few examples
for t in calibration_texts[:8]:
    print('-', t)

Prepared calibration_texts from 500 messages
- FREE camera phones with linerental from 4.49/month with 750 cross ntwk mins. 1/2 price txt bundle deals also avble. Call 08001950382 or call2optout/J MF
- كيف أنت؟ فقط التحقق من نفسك
- Wir versuchten, Sie zu kontaktieren re Ihre Antwort auf unser Angebot eines Video Handset? 750 jederzeit irgendwelche Netzwerke Minuten? UNLIMITED TEXT? Camcorder? Antworten oder rufen Sie 08000930705 JETZT
- You have received ₦5,000 from EKEDC. Balance: ₦5,000. Ref: EPRY0ZAGOF
- Ja, ich bin ein Mann mit einer Frau! Bitte sag mir, was du magst und was du im Bett nicht magst.
- Nlekọta Ndị Ahịa First Bank: Akaụntụ gị chọrọ mmelite. Kpọọ anyị ozugbo na 08063357276
- Hongera! Umepokea KES 20,000 kutoka Benki ya Kenya. Bonyeza https://cutt.ly/c937 kudai pesa zako
- Tu as encore de la bière chez toi?


### 7.4 Full integer (int8) TFLite conversion using the representative dataset
This cell runs full integer quantization with the calibration texts built above and writes `phishing_detector_int8.tflite`.

In [None]:
# 7.4 Full integer (int8) TFLite conversion — robust representative generator
import numpy as np
import traceback
import tensorflow as tf
from transformers import AutoTokenizer

# Load tokenizer saved earlier during training
tokenizer = AutoTokenizer.from_pretrained('./phishing_detector_model')
MAX_LEN = 128

print('TensorFlow version:', tf.__version__)
print('SavedModel path:', tf_model_dir)

# Inspect saved model signature (helps confirm input names expected by converter)
try:
    loaded = tf.saved_model.load(tf_model_dir)
    sigs = list(loaded.signatures.keys())
    print('SavedModel signatures:', sigs)
    if 'serving_default' in sigs:
        sd = loaded.signatures['serving_default']
        try:
            print('serving_default structured_input_signature:', sd.structured_input_signature)
        except Exception:
            pass
except Exception as e:
    print('Warning: could not inspect SavedModel signatures:', e)

# Representative dataset generator that includes both the bare input names and the serving_default names.
# This guarantees the calibrator will find a matching key regardless of SavedModel naming.

def representative_with_both():
    for t in calibration_texts:
        enc = tokenizer(t, truncation=True, padding='max_length', max_length=MAX_LEN, return_tensors='np')
        input_ids = enc['input_ids'].astype(np.int32)
        attention_mask = enc['attention_mask'].astype(np.int32)
        # include both naming variants
        yield {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'serving_default_input_ids:0': input_ids,
            'serving_default_attention_mask:0': attention_mask,
        }

# Run conversion to full integer (int8) using the representative generator
try:
    converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_dir)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_with_both
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    print('Running int8 conversion (this may take a while)...')
    tflite_int8 = converter.convert()
    open('phishing_detector_int8.tflite', 'wb').write(tflite_int8)
    print('Wrote phishing_detector_int8.tflite (MB):', round(len(tflite_int8)/(1024*1024),2))
except Exception as e:
    print('Int8 conversion failed:')
    traceback.print_exc()
    # As a fallback, write errors to a file for inspection
    with open('int8_conversion_error.txt', 'w') as fh:
        import traceback as _tb
        fh.write(_tb.format_exc())
    print('Wrote int8_conversion_error.txt with traceback')


TensorFlow version: 2.19.0
SavedModel path: ./tf_saved_model
SavedModel signatures: ['serving_default']
serving_default structured_input_signature: ((), {'attention_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='attention_mask'), 'input_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_ids')})
Running int8 conversion (this may take a while)...
Wrote phishing_detector_int8.tflite (MB): 129.71


### 7.5 Quick smoke-test for the int8 TFLite model
This robust test resizes inputs to (1,128) like your working smoke-test and prints probabilities.

In [None]:
# 7.5 Robust smoke-tests: dynamic-range, float16, and int8 models
import os
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer
import pprint

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('./phishing_detector_model')

# Helper: robust inference using resize/allocate approach

def prepare_arrays_for_interpreter(enc, input_details):
    arrays = []
    for inp in input_details:
        name = inp['name'].lower()
        if 'input_ids' in name or ('input' in name and 'id' in name):
            arr = enc.get('input_ids')
        elif 'attention' in name and 'mask' in name:
            arr = enc.get('attention_mask')
        elif 'token_type' in name or 'segment' in name:
            arr = enc.get('token_type_ids', None)
            if arr is None:
                arr = np.zeros_like(enc['input_ids'])
        else:
            arr = enc.get('input_ids')
        if arr is None:
            raise RuntimeError(f"Could not find a suitable tensor for interpreter input '{name}'")
        # normalize dtype
        expected_dtype = inp['dtype']
        if arr.dtype != expected_dtype:
            try:
                arr = arr.astype(expected_dtype)
            except Exception:
                arr = arr.astype(np.int32)
        arrays.append(arr)
    return arrays


def safe_resize_and_allocate(interpreter, input_details, arrays):
    resized = False
    for inp, arr in zip(input_details, arrays):
        current_shape = list(inp['shape'])
        desired_shape = list(arr.shape)
        if current_shape != desired_shape:
            interpreter.resize_tensor_input(inp['index'], desired_shape, strict=False)
            resized = True
            print(f"Resized input '{inp['name']}' from {current_shape} -> {desired_shape}")
    if resized:
        interpreter.allocate_tensors()
        new_input_details = interpreter.get_input_details()
        new_output_details = interpreter.get_output_details()
        print('Re-queried input details (after resize & allocate):')
        pprint.pprint(new_input_details)
        return new_input_details, new_output_details
    else:
        try:
            interpreter.allocate_tensors()
        except Exception:
            pass
        return input_details, interpreter.get_output_details()


def run_smoke(tflite_path, samples):
    print('\nRunning smoke-test for:', tflite_path)
    interpreter = tf.lite.Interpreter(model_path=tflite_path)
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    print('input_details:', input_details)
    print('output_details:', output_details)

    for text in samples:
        enc = tokenizer(text, truncation=True, padding='max_length', max_length=128, return_tensors='np')
        enc = {k: np.asarray(v) for k, v in enc.items()}
        arrays = prepare_arrays_for_interpreter(enc, input_details)
        new_input_details, new_output_details = safe_resize_and_allocate(interpreter, input_details, arrays)
        for inp, arr in zip(new_input_details, arrays):
            idx = inp['index']
            expected_shape = tuple(inp['shape'])
            if tuple(arr.shape) != expected_shape:
                try:
                    arr = arr.reshape(expected_shape)
                except Exception:
                    raise ValueError(f"Final shape mismatch for input {inp['name']}: tensor shape {arr.shape} vs expected {expected_shape}")
            interpreter.set_tensor(idx, arr)
        interpreter.invoke()
        out = interpreter.get_tensor(new_output_details[0]['index'])
        if new_output_details[0]['dtype'] == np.int8:
            scale, zero_point = new_output_details[0]['quantization']
            out = (out.astype(np.float32) - zero_point) * scale
        import scipy.special
        if out.ndim == 2 and out.shape[1] >= 2:
            probs = scipy.special.softmax(out, axis=-1)[0].tolist()
            pred = int(np.argmax(out, axis=-1)[0])
        else:
            if out.ndim == 2 and out.shape[1] == 1:
                score = 1.0 / (1.0 + np.exp(-out[0][0]))
                probs = [1 - float(score), float(score)]
                pred = int(score > 0.5)
            else:
                probs = out.flatten().tolist()
                pred = int(np.argmax(out, axis=-1)[0]) if out.size > 1 else int(out.flatten()[0] > 0.5)
        print('TEXT:', text)
        print('pred:', pred, 'probs:', probs)


# Prepare sample messages
samples = [
    "URGENT: Your account will be suspended. Click http://fake.example to verify",
    "Hey, let's meet tomorrow for lunch"
]

# Test dynamic-range (if exists)
if os.path.exists('phishing_detector_dynamic.tflite'):
    run_smoke('phishing_detector_dynamic.tflite', samples)
else:
    print('phishing_detector_dynamic.tflite not found — dynamic-range test skipped')

# Test float16
if os.path.exists('phishing_detector.tflite'):
    run_smoke('phishing_detector.tflite', samples)
else:
    print('phishing_detector.tflite not found — float16 test skipped')

# Test int8
if os.path.exists('phishing_detector_int8.tflite'):
    run_smoke('phishing_detector_int8.tflite', samples)
else:
    print('phishing_detector_int8.tflite not found — int8 test skipped')



Running smoke-test for: phishing_detector_dynamic.tflite
input_details: [{'name': 'serving_default_attention_mask:0', 'index': 0, 'shape': array([1, 1], dtype=int32), 'shape_signature': array([-1, -1], dtype=int32), 'dtype': <class 'numpy.int32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'serving_default_input_ids:0', 'index': 1, 'shape': array([1, 1], dtype=int32), 'shape_signature': array([-1, -1], dtype=int32), 'dtype': <class 'numpy.int32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
output_details: [{'name': 'StatefulPartitionedCall:0', 'index': 719, 'shape': array([1, 2], dtype=int32), 'shape_signature': array([-1,  2], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)

### 8) Package artifacts and copy to Google Drive

This cell packages the TFLite artifacts (float16/dynamic/int8), the tokenizer folder, runs a small evaluation on a sample of `test.csv` (if present), writes a README with sizes and quick metrics, zips the package, and copies it to Drive under `phishing_detector_artifacts_v2`.

Run this cell after you have the `.tflite` files and the tokenizer saved to `./phishing_detector_model` or `./tokenizer`.

In [None]:
# Packaging & export cell: create a package, run light evaluation, and copy to Drive
import os, shutil, json, time, pathlib
from datetime import datetime
import pandas as pd
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer
try:
    from sklearn.metrics import accuracy_score, precision_recall_fscore_support
    SKLEARN = True
except Exception:
    SKLEARN = False

# Config
DRIVE_DST = '/content/drive/MyDrive/phishing_detector_artifacts_v2'  # change if you want a different path
LOCAL_PACKAGE = 'phishing_detector_package'
EVAL_SAMPLE = 2000  # number of test rows to sample for quick evaluation (set lower if you want faster runs)
MAX_LEN = 128

# Mount Drive (will prompt if not mounted yet)
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
os.makedirs(DRIVE_DST, exist_ok=True)
shutil.rmtree(LOCAL_PACKAGE, ignore_errors=True)
os.makedirs(LOCAL_PACKAGE, exist_ok=True)

# Gather available TFLite artifacts
candidates = [
    ('float16', 'phishing_detector.tflite'),
    ('dynamic', 'phishing_detector_dynamic.tflite'),
    ('int8', 'phishing_detector_int8.tflite'),
]
found = []
for qtype, fname in candidates:
    if os.path.exists(fname):
        size_mb = round(os.path.getsize(fname) / (1024*1024), 2)
        shutil.copy(fname, os.path.join(LOCAL_PACKAGE, fname))
        found.append({'quant': qtype, 'file': fname, 'size_mb': size_mb})

# Copy tokenizer / tokenizer directory
tokenizer_src = None
for tok_dir in ('tokenizer', 'phishing_detector_model', './phishing_detector_model'):
    if os.path.exists(tok_dir) and os.path.isdir(tok_dir):
        tokenizer_src = tok_dir
        shutil.copytree(tok_dir, os.path.join(LOCAL_PACKAGE, 'tokenizer'), dirs_exist_ok=True)
        break
if tokenizer_src is None:
    print('Warning: tokenizer folder not found; ensure you include tokenizer files when packaging')

# Light evaluation helper (works on a sample of test set to keep runtime short)

def run_tflite_eval(tflite_path, test_texts, test_labels, max_eval=500):
    # Build interpreter and resize inputs to (1, MAX_LEN)
    interpreter = tf.lite.Interpreter(model_path=tflite_path)
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    # Resize inputs to (1, MAX_LEN)
    for inp in input_details:
        interpreter.resize_tensor_input(inp['index'], [1, MAX_LEN])
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    # Evaluate on up to max_eval examples
    preds = []
    for text in test_texts[:max_eval]:
        enc = tokenizer(text, truncation=True, padding='max_length', max_length=MAX_LEN, return_tensors='np')
        enc = {k: np.asarray(v) for k,v in enc.items()}
        # set inputs in interpreter order
        for inp in input_details:
            name = inp['name'].lower()
            if 'input_ids' in name:
                arr = enc.get('input_ids')
            elif 'attention_mask' in name:
                arr = enc.get('attention_mask')
            else:
                arr = enc.get('input_ids')
            # cast if interpreter expects int8
            if inp['dtype'] == np.int8:
                arr = arr.astype(np.int8)
            else:
                arr = arr.astype(inp['dtype'])
            interpreter.set_tensor(inp['index'], arr)
        interpreter.invoke()
        out = interpreter.get_tensor(output_details[0]['index'])
        # dequantize if int8 output
        if output_details[0]['dtype'] == np.int8:
            scale, zero_point = output_details[0]['quantization']
            out = (out.astype(np.float32) - zero_point) * scale
        # compute pred
        if out.ndim == 2 and out.shape[1] >= 2:
            pred = int(np.argmax(out, axis=-1)[0])
        else:
            # fallback: sigmoid scenario
            s = 1.0/(1.0+np.exp(-out.ravel()[0]))
            pred = int(s > 0.5)
        preds.append(pred)
    # metrics
    if len(preds) == 0:
        return None
    if SKLEARN:
        acc = accuracy_score(test_labels[:min(len(test_labels), max_eval)], preds)
        p, r, f1, _ = precision_recall_fscore_support(test_labels[:min(len(test_labels), max_eval)], preds, average='binary')
        return {'accuracy': float(acc), 'precision': float(p), 'recall': float(r), 'f1': float(f1), 'n': min(len(test_labels), max_eval)}
    else:
        # simple accuracy fallback
        true = np.array(test_labels[:min(len(test_labels), max_eval)])
        arrp = np.array(preds)
        acc = float((arrp == true).mean())
        return {'accuracy': acc, 'n': min(len(test_labels), max_eval)}

# Prepare test data (sample) if available
tokenizer = AutoTokenizer.from_pretrained('./phishing_detector_model') if tokenizer_src else None
test_df = None
if os.path.exists('test.csv'):
    test_df = pd.read_csv('test.csv')
elif os.path.exists('data/processed/test.csv'):
    test_df = pd.read_csv('data/processed/test.csv')

eval_results = {}
if test_df is not None and 'message' in test_df.columns and 'label' in test_df.columns:
    # ensure numeric labels 0/1
    if test_df['label'].dtype != 'int64' and test_df['label'].dtype != 'int32':
        # try map strings to ints
        test_df['label'] = test_df['label'].map({'phishing':1, 'legitimate':0}).fillna(test_df['label'])
    labels = test_df['label'].astype(int).tolist()
    texts = test_df['message'].astype(str).tolist()
    # Evaluate each found tflite file (only short sample to keep time reasonable)
    for meta in found:
        q = meta['quant']
        fname = meta['file']
        local_path = os.path.join(LOCAL_PACKAGE, fname)
        print('\nEvaluating', fname, 'on a sample of up to', EVAL_SAMPLE, 'rows...')
        res = run_tflite_eval(local_path, texts, labels, max_eval=min(EVAL_SAMPLE, len(texts)))
        eval_results[fname] = res
else:
    print('No test.csv found for evaluation; skipping evaluation step')

# Create README with metadata and evaluation results
readme_path = os.path.join(LOCAL_PACKAGE, 'README.md')
now = datetime.utcnow().isoformat() + 'Z'
with open(readme_path, 'w') as fh:
    fh.write('# Phishing detector artifacts\n')
    fh.write('\nCreated: {}\n'.format(now))
    fh.write('\nModel: distilbert-base-multilingual-cased\n')
    fh.write('max_length: {}\n'.format(MAX_LEN))
    fh.write('input_names: serving_default_input_ids:0, serving_default_attention_mask:0\n')
    fh.write('\nFiles included:\n')
    for m in found:
        fh.write('- {file}  ({quant}, {size_mb} MB)\n'.format(**m))
    if tokenizer_src:
        fh.write('- tokenizer folder: {}\n'.format(tokenizer_src))
    fh.write('\nEvaluation results (sample):\n')
    fh.write(json.dumps(eval_results, indent=2))

# Save a small CSV of sample predictions (optional) -- here we'll skip for brevity unless user wants it explicitly

# Zip package and copy to Drive
archive_name = shutil.make_archive(LOCAL_PACKAGE, 'zip', LOCAL_PACKAGE)
shutil.copy(archive_name, DRIVE_DST)
print('Packaged artifacts ->', archive_name)
print('Copied archive to Drive at', DRIVE_DST)
print('Contents written to', LOCAL_PACKAGE)


Mounted at /content/drive

Evaluating phishing_detector.tflite on a sample of up to 2000 rows...


    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    



Evaluating phishing_detector_dynamic.tflite on a sample of up to 2000 rows...


    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    



Evaluating phishing_detector_int8.tflite on a sample of up to 2000 rows...


    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
  now = datetime.utcnow().isoformat() + 'Z'


Packaged artifacts -> /content/phishing_detector_package.zip
Copied archive to Drive at /content/drive/MyDrive/phishing_detector_artifacts_v2
Contents written to phishing_detector_package


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!ls -lah /content/drive/MyDrive/phishing_detector_artifacts

Mounted at /content/drive
total 1.2G
drwx------ 2 root root 4.0K Oct 23 23:24 mobile
drwx------ 2 root root 4.0K Oct 22 17:33 phishing_detector_model
-rw------- 1 root root 908M Oct 23 23:27 phishing_detector_package.zip
-rw------- 1 root root 259M Oct 22 17:33 phishing_detector.tflite


In [None]:
# unzip into /content so scripts can access the files quickly
!mkdir -p /content/phishing_detector_package
!unzip -o "/content/drive/MyDrive/phishing_detector_artifacts/phishing_detector_package.zip" -d /content/phishing_detector_package || true

# list the files we need
!ls -lah /content/phishing_detector_package
!ls -lah /content/phishing_detector_package/tokenizer

Archive:  /content/drive/MyDrive/phishing_detector_artifacts/phishing_detector_package.zip
  inflating: /content/phishing_detector_package/phishing_detector_int8.tflite  
  inflating: /content/phishing_detector_package/phishing_detector_dynamic.tflite  
  inflating: /content/phishing_detector_package/phishing_detector.tflite  
  inflating: /content/phishing_detector_package/README.md  
  inflating: /content/phishing_detector_package/tokenizer/special_tokens_map.json  
  inflating: /content/phishing_detector_package/tokenizer/vocab.txt  
  inflating: /content/phishing_detector_package/tokenizer/model.safetensors  
  inflating: /content/phishing_detector_package/tokenizer/tokenizer.json  
  inflating: /content/phishing_detector_package/tokenizer/tokenizer_config.json  
  inflating: /content/phishing_detector_package/tokenizer/training_args.bin  
  inflating: /content/phishing_detector_package/tokenizer/config.json  
total 518M
drwxr-xr-x 3 root root 4.0K Oct 24 01:20 .
drwxr-xr-x 1 root 

In [None]:
# Option A: copy the mobile folder into /content (faster IO)
!cp -r "/content/drive/MyDrive/phishing_detector_artifacts/mobile" /content/mobile || true
!ls -lah /content/mobile/inference_wrapper

total 40K
drwx------ 4 root root 4.0K Oct 24 01:03 .
drwx------ 3 root root 4.0K Oct 24 01:03 ..
-rw------- 1 root root 8.7K Oct 24 01:03 infer.py
drwx------ 2 root root 4.0K Oct 24 01:03 __pycache__
-rw------- 1 root root 1.5K Oct 24 01:03 README.md
-rw------- 1 root root  147 Oct 24 01:03 requirements.txt
drwx------ 2 root root 4.0K Oct 24 01:03 schema
-rw------- 1 root root 1.7K Oct 24 01:03 smoke_run.py


In [None]:
!pip install -q transformers tokenizers pandas jsonschema requests

In [None]:
import tensorflow as tf
print("TensorFlow:", tf.__version__)
# quick TFLite interpreter sanity check (don't load the large model yet)
from tensorflow.lite.python.interpreter import Interpreter
print("Interpreter OK")

TensorFlow: 2.19.0
Interpreter OK


In [None]:
from google.colab import files
uploaded = files.upload()   # choose files in the browser dialog
# move file into the repo folder structure expected by the script:
import os
os.makedirs('/content/data/processed', exist_ok=True)
for fname in uploaded.keys():
    # for example if you uploaded test.csv
    if fname.endswith('.csv'):
        os.replace(fname, f'/content/data/processed/{fname}')
print('Files moved to /content/data/processed/')
!ls -lah /content/data/processed

Saving test.csv to test.csv
Files moved to /content/data/processed/
total 1.9M
drwxr-xr-x 2 root root 4.0K Oct 24 01:05 .
drwxr-xr-x 3 root root 4.0K Oct 24 01:05 ..
-rw-r--r-- 1 root root 1.9M Oct 24 01:05 test.csv


In [None]:
# Run a short smoke test (10 messages) to validate end-to-end
!python /content/mobile/inference_wrapper/smoke_run.py -n 20

2025-10-24 00:14:26.141804: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761264866.161773   15667 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761264866.167782   15667 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1761264866.183243   15667 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761264866.183281   15667 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761264866.183285   15667 computation_placer.cc:177] computation placer alr

In [None]:
!python mobile/inference_wrapper/run_on_mock.py --weights 0.825,0.85,0.875 --out ./sweep_results

2025-10-24 02:34:28.135544: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761273268.193475   23454 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761273268.217225   23454 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1761273268.254642   23454 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761273268.254701   23454 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761273268.254711   23454 computation_placer.cc:177] computation placer alr