In [None]:
# 🚀 Financial Email/SMS Classification with DistilBERT

This notebook provides a step-by-step guide to train and test a DistilBERT model for classifying emails and SMS as financial or non-financial.

## 📊 Datasets Used:
1. `genai_gmail_chat.financial_transactions.csv`: Financial transaction data (1140 records)
2. `pluto_money.sms_data.csv`: SMS data (83 records)
3. `pluto_money.email_logs.csv`: Email logs (150MB)
4. `krishplutomoney all emails gmail_data...csv`: Additional email data (117 records)

## 🎯 Goals:
1. Classify messages as financial/non-financial
2. Extract structured financial data
3. Achieve >95% accuracy
4. Store results in MongoDB


In [None]:
# 🚀 Financial Email/SMS Classification with DistilBERT

This notebook provides a step-by-step guide to train and test a DistilBERT model for classifying emails and SMS as financial or non-financial.

## 📊 Datasets Used:
1. `mail_data.csv`: Base dataset with spam/ham labels
2. `genai_gmail_chat.financial_transactions.csv`: Financial transaction data
3. `pluto_money.sms_data.csv`: SMS data
4. `pluto_money.email_logs.csv`: Email logs
5. `krishplutomoney all emails gmail_data...csv`: Additional email data

## 🎯 Goals:
1. Classify messages as financial/non-financial
2. Extract structured financial data
3. Achieve >95% accuracy
4. Store results in MongoDB


In [None]:
## 1. Setup and Imports


In [None]:
# Add current directory to path
import sys
import os
sys.path.append('.')

# Essential imports
import pandas as pd
import numpy as np
import torch
from transformers import (
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
    TrainingArguments,
    Trainer
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Local imports
from config import model_config, data_config
from data_preprocessing import TextPreprocessor, DatasetPreparator

# Check CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
# Load datasets
def load_datasets():
    # Load financial transactions (these are already labeled as financial)
    financial_data = pd.read_csv('../datasets/genai_gmail_chat.financial_transactions.csv')
    
    # Load SMS data
    sms_data = pd.read_csv('../datasets/pluto_money.sms_data.csv')
    
    # Load email logs
    email_logs = pd.read_csv('../datasets/pluto_money.email_logs.csv')
    
    # Load additional email data
    additional_emails = pd.read_csv('../datasets/krishplutomoney all emails gmail_data_117454877979500520700_20250630_012957.csv')
    
    return financial_data, sms_data, email_logs, additional_emails

# Load all datasets
financial_data, sms_data, email_logs, additional_emails = load_datasets()

print("Dataset sizes:")
print(f"Financial transactions: {len(financial_data)} records")
print(f"SMS data: {len(sms_data)} records")
print(f"Email logs: {len(email_logs)} records")
print(f"Additional emails: {len(additional_emails)} records")


In [None]:
## 2. Load and Prepare Datasets


In [None]:
# Initialize preprocessor
preprocessor = TextPreprocessor()

def prepare_data_for_training(financial_data, sms_data, email_logs, additional_emails):
    # 1. Process financial transactions (already labeled)
    financial_data['text'] = financial_data['snippet']
    financial_data['is_financial'] = 1  # All are financial
    
    # 2. Process SMS data
    sms_data['text'] = sms_data['message']
    sms_data['is_financial'] = sms_data['message'].apply(lambda x: 
        1 if preprocessor.extract_financial_features(x)['has_financial_indicators'] else 0
    )
    
    # 3. Process email logs
    email_logs['text'] = email_logs['subject'] + ' ' + email_logs['body']
    email_logs['is_financial'] = email_logs['text'].apply(lambda x: 
        1 if preprocessor.extract_financial_features(x)['has_financial_indicators'] else 0
    )
    
    # 4. Process additional emails
    additional_emails['text'] = additional_emails['snippet']
    additional_emails['is_financial'] = additional_emails['snippet'].apply(lambda x: 
        1 if preprocessor.extract_financial_features(x)['has_financial_indicators'] else 0
    )
    
    # Combine all datasets
    combined_data = pd.concat([
        financial_data[['text', 'is_financial']],
        sms_data[['text', 'is_financial']],
        email_logs[['text', 'is_financial']],
        additional_emails[['text', 'is_financial']]
    ], ignore_index=True)
    
    # Clean text
    combined_data['text'] = combined_data['text'].apply(preprocessor.clean_text)
    
    # Remove duplicates
    combined_data = combined_data.drop_duplicates(subset=['text'])
    
    # Split data
    train_df, temp_df = train_test_split(
        combined_data, 
        test_size=0.3, 
        random_state=42,
        stratify=combined_data['is_financial']
    )
    
    val_df, test_df = train_test_split(
        temp_df,
        test_size=0.5,
        random_state=42,
        stratify=temp_df['is_financial']
    )
    
    return train_df, val_df, test_df

# Prepare data
train_df, val_df, test_df = prepare_data_for_training(
    financial_data, sms_data, email_logs, additional_emails
)

print("\nDataset splits:")
print(f"Training set: {len(train_df)} samples")
print(f"Validation set: {len(val_df)} samples")
print(f"Test set: {len(test_df)} samples")

print("\nClass distribution:")
print("Training set:")
print(train_df['is_financial'].value_counts(normalize=True))


In [None]:
# Load datasets
def load_datasets():
    # Load mail_data.csv (spam/ham)
    mail_data = pd.read_csv('../datasets/mail_data.csv')
    
    # Load financial transactions
    financial_data = pd.read_csv('../datasets/genai_gmail_chat.financial_transactions.csv')
    
    # Load SMS data
    sms_data = pd.read_csv('../datasets/pluto_money.sms_data.csv')
    
    # Load additional email data
    additional_emails = pd.read_csv('../datasets/krishplutomoney all emails gmail_data_117454877979500520700_20250630_012957.csv')
    
    return mail_data, financial_data, sms_data, additional_emails

# Load all datasets
mail_data, financial_data, sms_data, additional_emails = load_datasets()

print("Dataset sizes:")
print(f"Mail data: {len(mail_data)} records")
print(f"Financial transactions: {len(financial_data)} records")
print(f"SMS data: {len(sms_data)} records")
print(f"Additional emails: {len(additional_emails)} records")


In [None]:
## 3. Data Preprocessing


In [None]:
# Initialize preprocessor
preprocessor = TextPreprocessor()

def prepare_data_for_training(mail_data, financial_data, sms_data, additional_emails):
    # 1. Process mail_data (spam/ham)
    mail_data['text'] = mail_data['Message']
    mail_data['is_financial'] = mail_data['Category'].apply(lambda x: 
        1 if preprocessor.extract_financial_features(x)['has_financial_indicators'] else 0
    )
    
    # 2. Process financial transactions
    financial_data['text'] = financial_data['snippet']
    financial_data['is_financial'] = 1  # All are financial
    
    # 3. Process SMS data
    sms_data['text'] = sms_data['message']
    sms_data['is_financial'] = sms_data['message'].apply(lambda x: 
        1 if preprocessor.extract_financial_features(x)['has_financial_indicators'] else 0
    )
    
    # 4. Process additional emails
    additional_emails['text'] = additional_emails['snippet']
    additional_emails['is_financial'] = additional_emails['snippet'].apply(lambda x: 
        1 if preprocessor.extract_financial_features(x)['has_financial_indicators'] else 0
    )
    
    # Combine all datasets
    combined_data = pd.concat([
        mail_data[['text', 'is_financial']],
        financial_data[['text', 'is_financial']],
        sms_data[['text', 'is_financial']],
        additional_emails[['text', 'is_financial']]
    ], ignore_index=True)
    
    # Clean text
    combined_data['text'] = combined_data['text'].apply(preprocessor.clean_text)
    
    # Remove duplicates
    combined_data = combined_data.drop_duplicates(subset=['text'])
    
    # Split data
    train_df, temp_df = train_test_split(
        combined_data, 
        test_size=0.3, 
        random_state=42,
        stratify=combined_data['is_financial']
    )
    
    val_df, test_df = train_test_split(
        temp_df,
        test_size=0.5,
        random_state=42,
        stratify=temp_df['is_financial']
    )
    
    return train_df, val_df, test_df

# Prepare data
train_df, val_df, test_df = prepare_data_for_training(
    mail_data, financial_data, sms_data, additional_emails
)

print("\nDataset splits:")
print(f"Training set: {len(train_df)} samples")
print(f"Validation set: {len(val_df)} samples")
print(f"Test set: {len(test_df)} samples")

print("\nClass distribution:")
print("Training set:")
print(train_df['is_financial'].value_counts(normalize=True))


In [None]:
## 4. Initialize DistilBERT Model


In [None]:
def initialize_model():
    # Initialize tokenizer
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    
    # Initialize model
    model = DistilBertForSequenceClassification.from_pretrained(
        'distilbert-base-uncased',
        num_labels=2,
        problem_type="single_label_classification"
    )
    
    # Move model to GPU if available
    model.to(device)
    
    return tokenizer, model

tokenizer, model = initialize_model()

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


In [None]:
## 5. Prepare Training Data


In [None]:
def prepare_dataset(df, tokenizer):
    # Tokenize texts
    encodings = tokenizer(
        df['text'].tolist(),
        truncation=True,
        padding=True,
        max_length=model_config.max_length,
        return_tensors='pt'
    )
    
    # Create dataset
    dataset = {
        'input_ids': encodings['input_ids'],
        'attention_mask': encodings['attention_mask'],
        'labels': torch.tensor(df['is_financial'].tolist())
    }
    
    return dataset

# Prepare datasets
train_dataset = prepare_dataset(train_df, tokenizer)
val_dataset = prepare_dataset(val_df, tokenizer)
test_dataset = prepare_dataset(test_df, tokenizer)

print("Dataset shapes:")
print(f"Training: {train_dataset['input_ids'].shape}")
print(f"Validation: {val_dataset['input_ids'].shape}")
print(f"Test: {test_dataset['input_ids'].shape}")


In [None]:
## 6. Training Configuration


In [None]:
def get_training_args():
    return TrainingArguments(
        output_dir='./results',
        num_train_epochs=model_config.num_epochs,
        per_device_train_batch_size=model_config.batch_size,
        per_device_eval_batch_size=model_config.batch_size * 2,
        warmup_steps=model_config.warmup_steps,
        weight_decay=model_config.weight_decay,
        logging_dir='./logs',
        logging_steps=model_config.logging_steps,
        evaluation_strategy="steps",
        eval_steps=model_config.eval_steps,
        save_steps=model_config.save_steps,
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        fp16=model_config.fp16 and torch.cuda.is_available(),
        gradient_accumulation_steps=model_config.gradient_accumulation_steps
    )

training_args = get_training_args()
print("Training configuration:")
print(f"Epochs: {model_config.num_epochs}")
print(f"Batch size: {model_config.batch_size}")
print(f"Learning rate: {model_config.learning_rate}")
print(f"FP16: {model_config.fp16 and torch.cuda.is_available()}")


In [None]:
## 7. Model Training


In [None]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    
    # Calculate metrics
    report = classification_report(labels, preds, output_dict=True)
    
    return {
        'accuracy': report['accuracy'],
        'f1': report['weighted avg']['f1-score'],
        'precision': report['weighted avg']['precision'],
        'recall': report['weighted avg']['recall']
    }

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

# Train the model
print("Starting training...")
trainer.train()


In [None]:
## 8. Model Evaluation


In [None]:
def evaluate_model(trainer, test_dataset):
    # Evaluate on test set
    print("Evaluating model on test set...")
    test_results = trainer.evaluate(test_dataset)
    
    print("\nTest Results:")
    print(f"Accuracy: {test_results['eval_accuracy']:.4f}")
    print(f"F1 Score: {test_results['eval_f1']:.4f}")
    print(f"Precision: {test_results['eval_precision']:.4f}")
    print(f"Recall: {test_results['eval_recall']:.4f}")
    
    return test_results

test_results = evaluate_model(trainer, test_dataset)


In [None]:
## 9. Confusion Matrix


In [None]:
def plot_confusion_matrix(trainer, test_dataset, test_df):
    # Get predictions
    predictions = trainer.predict(test_dataset)
    preds = predictions.predictions.argmax(-1)
    
    # Create confusion matrix
    cm = confusion_matrix(test_df['is_financial'], preds)
    
    # Plot
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        cm, 
        annot=True, 
        fmt='d',
        cmap='Blues',
        xticklabels=['Non-Financial', 'Financial'],
        yticklabels=['Non-Financial', 'Financial']
    )
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    # Calculate metrics
    tn, fp, fn, tp = cm.ravel()
    print(f"\nDetailed Metrics:")
    print(f"True Negatives: {tn}")
    print(f"False Positives: {fp}")
    print(f"False Negatives: {fn}")
    print(f"True Positives: {tp}")

plot_confusion_matrix(trainer, test_dataset, test_df)


In [None]:
## 10. Sample Predictions


In [None]:
def test_predictions(model, tokenizer, texts):
    model.eval()
    results = []
    
    for text in texts:
        # Tokenize
        inputs = tokenizer(
            text,
            truncation=True,
            padding=True,
            max_length=model_config.max_length,
            return_tensors='pt'
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Predict
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.softmax(outputs.logits, dim=-1)
            prediction = torch.argmax(outputs.logits, dim=-1)
        
        results.append({
            'text': text,
            'is_financial': bool(prediction.item()),
            'confidence': probs[0][prediction.item()].item()
        })
    
    return results

# Test examples
test_texts = [
    "Your account has been credited with Rs. 5000",
    "Meeting scheduled for tomorrow at 3 PM",
    "UPI payment of Rs. 2500 to Amazon completed",
    "Happy birthday! Hope you have a great day",
    "Your mutual fund investment of Rs. 10000 has been processed",
    "Please review the attached document",
    "Credit card payment due: Rs. 15000 by 15th",
    "Weather forecast for today: Sunny with clear skies"
]

results = test_predictions(model, tokenizer, test_texts)

print("Sample Predictions:")
print("-" * 80)
for result in results:
    print(f"Text: {result['text']}")
    print(f"Prediction: {'FINANCIAL' if result['is_financial'] else 'NON-FINANCIAL'}")
    print(f"Confidence: {result['confidence']:.3f}")
    print("-" * 80)


In [None]:
## 11. Save Model


In [None]:
def save_model(model, tokenizer, test_results):
    # Create output directory
    os.makedirs('models/distilbert', exist_ok=True)
    
    # Save model
    model.save_pretrained('models/distilbert')
    tokenizer.save_pretrained('models/distilbert')
    
    # Save test results
    import json
    with open('models/distilbert/test_results.json', 'w') as f:
        json.dump(test_results, f, indent=2)
    
    print("Model and results saved to 'models/distilbert'")

save_model(model, tokenizer, test_results)


## 12. Next Steps

1. **Fine-tuning**: Experiment with different hyperparameters
2. **Data Augmentation**: Add more financial examples
3. **Error Analysis**: Review misclassified examples
4. **Production**: Deploy model with MongoDB integration
5. **Monitoring**: Set up performance tracking

The model is now ready for:
- Integration with your Agno framework
- Deployment to production
- Real-time classification of emails and SMS
