In [None]:
# Install required packages
!pip install transformers tensorflow pandas scikit-learn

# Import libraries
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from transformers import DistilBertTokenizer

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

def validate_model(model_path, csv_path, text_column, label_column, max_length):
    """
    Validate a TFLite model for contact information classification
    """
    # Load dataset
    df = pd.read_csv(csv_path)
    texts = df[text_column].values
    y_test = df[label_column].values

    # Map labels to integers
    label_map = {'address': 0, 'phone': 1, 'email': 2, 'url': 3}
    y_test = np.array([label_map[label.lower()] for label in y_test])

    print(f"Loaded {len(texts)} samples from CSV")

    # Load tokenizer
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    print("DistilBERT tokenizer loaded")

    # Load TFLite model
    print("Loading TFLite model...")
    interpreter = tf.lite.Interpreter(model_path=model_path)
    print("Allocating tensors...")
    interpreter.allocate_tensors()
    print("Model ready")

    # Get input and output details
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    print(f"\nModel expects {len(input_details)} inputs:")
    for detail in input_details:
        print(f"  - {detail['name']}: shape {detail['shape']}")

    # Run inference on all samples
    y_pred = []
    for i, text in enumerate(texts):
        if i % 100 == 0 and i > 0:
            print(f"Processed {i}/{len(texts)} samples...")

        # Tokenize input
        encoded = tokenizer(text,
                          max_length=max_length,
                          padding='max_length',
                          truncation=True,
                          return_tensors='np')

        input_ids = encoded['input_ids'].astype(np.int32)
        attention_mask = encoded['attention_mask'].astype(np.int32)
        segment_ids = np.zeros_like(input_ids, dtype=np.int32)

        # Set input tensors
        interpreter.set_tensor(input_details[0]['index'], input_ids)
        interpreter.set_tensor(input_details[1]['index'], attention_mask)
        interpreter.set_tensor(input_details[2]['index'], segment_ids)

        # Run inference
        interpreter.invoke()

        # Get prediction
        output = interpreter.get_tensor(output_details[0]['index'])
        y_pred.append(np.argmax(output[0]))

    y_pred = np.array(y_pred)

    # Calculate metrics
    accuracy = accuracy_score(y_test, y_pred)
    f1_macro = f1_score(y_test, y_pred, average='macro')
    cm = confusion_matrix(y_test, y_pred)

    # Print results
    print("\n" + "="*60)
    print("VALIDATION RESULTS")
    print("="*60)
    print(f"\nOverall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"F1 Score (Macro): {f1_macro:.4f}")
    print("\n" + "-"*60)
    print("Confusion Matrix:")
    print("-"*60)
    print("Rows = Actual | Columns = Predicted")
    print("\n           URL    Email    Phone  Address")

    class_names = ['URL     ', 'Email   ', 'Phone   ', 'Address ']
    for i, row in enumerate(cm):
        print(f"{class_names[i]} {row[0]:6d} {row[1]:6d} {row[2]:6d} {row[3]:8d}")

    print("="*60)

    return {
        'accuracy': accuracy,
        'f1_score': f1_macro,
        'confusion_matrix': cm
    }

# Run validation with files from Google Drive
# Update these paths to match where you uploaded your files in Google Drive
results = validate_model(
    model_path="/content/drive/MyDrive/distilbert_with_metaadata.tflite",  # Update this path
    csv_path="/content/drive/MyDrive/validation_contacts_dataset.csv",  # Update this path
    text_column='text',
    label_column='label',
    max_length=128
)

# Display final results
print("\n" + "="*60)
print("VALIDATION COMPLETE")
print("="*60)
print(f"Final Accuracy: {results['accuracy']:.4f}")
print(f"Final F1 Score: {results['f1_score']:.4f}")