In [7]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score
from src.pipeline import TextAnalysisPipeline  
from ax import optimize

## Data:

In [15]:
data = pd.read_csv('spam_dataset_examples.csv')
val_df = data.groupby("label")\
             .apply(lambda group: group.sample(frac=0.5, random_state=42), include_groups=False)\
             .reset_index()\
             .set_index("level_1")
test_df = data.drop(val_df.index)

## Auxilary functions:

In [None]:
def optimize_model(weights, spam_threshold):
    # Initialize the Decision Engine with the given weights and threshold
    cfg = {
        "decision_engine": {
            "weights": weights,
            "spam_threshold": spam_threshold
        }
    }
    decision_engine = TextAnalysisPipeline(cfg)

    # Evaluate on the validation set
    y_true = val_df['label'].values
    y_pred = []

    # TODO: predict over dataframe text
    for index, row in val_df.iterrows():
        sentiment_result = {"spam_likelihood": 0.5, "reasoning": "Example reasoning"}
        grammar_result = {"spam_likelihood": 0.5, "reasoning": "Example reasoning"}
        url_result = {"spam_likelihood": 0.5, "reasoning": "Example reasoning"}
        domain_result = {"spam_likelihood": 0.5, "reasoning": "Example reasoning"}

        result = decision_engine.make_decision(sentiment_result, grammar_result, url_result, domain_result)
        y_pred.append(result['is_spam'])

    # Calculate metrics
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)

    return f1

In [None]:
best_parameters, best_values = optimize(
    f=optimize_model,
    parameters=[
        {"name": "weights[sentiment]", "type": "range", "bounds": [0, 1]},
        {"name": "weights[grammar]", "type": "range", "bounds": [0, 1]},
        {"name": "weights[url]", "type": "range", "bounds": [0, 1]},
        {"name": "weights[domain]", "type": "range", "bounds": [0, 1]},
        {"name": "spam_threshold", "type": "range", "bounds": [0, 1]}
    ],
    total_trials=50,
)