In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pprint import pprint

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from auxiliar_func import *
from plot_func import *

In [3]:
df = pd.read_csv('Census-Income-KDD.csv')
target = 'income_50k'
df_tr, df_te = train_test_split(df, test_size=0.3, random_state=42)

In [5]:
def search_best_combination(
    model: object,
    model_params_grid: dict,
    prep_params_grid: dict,
    df: pd.DataFrame,
    target_metric: str = 'f1_macro',
    cv: int = 4,
    N: int = 5,
    verbose: int = 1,
    max_iter: int = 10
) -> pd.DataFrame:

    best_mod_param = [{k: v[0] for k, v in model_params_grid.items()}]
    best_prep_param = []

    results = pd.DataFrame(columns=['prep_param', 'model_param',
                           'accuracy', 'f1_macro', 'precision_macro', 'recall_macro'], dtype=object)

    def update_prep_params(mod_param, prep_par_list):
        '''searches the best preprocessing parameters for a given model parameters'''''
        nonlocal best_prep_param, results
        model.set_params(**mod_param)
        prep_par = test_preprocess_params(
            model, prep_par_list, df, cv=cv, verbose=verbose-1).sort_values(by=target_metric, ascending=False).reset_index(drop=True)
        prep_par['model_param'] = pd.Series([mod_param]*len(prep_par))
        results = pd.concat([results, prep_par])
        best_prep_param = prep_par['prep_param'][:N].tolist()

    def update_mod_params(prep_param, mod_par_list):
        '''searches the best model parameters for a given preprocessing parameters'''''
        nonlocal best_mod_param, results
        mod_par = test_model_params(
            model, mod_par_list, df, prep_param, cv=cv, verbose=verbose-1).sort_values(by=target_metric, ascending=False).reset_index(drop=True)
        mod_par['prep_param'] = pd.Series([prep_param]*len(mod_par))
        results = pd.concat([results, mod_par])
        best_mod_param = mod_par['model_param'][:N].tolist()

    best_metric = 0
    for i in range(1,max_iter+1):
        update_prep_params(best_mod_param[0], prep_params_grid)
        update_mod_params(best_prep_param[0], model_params_grid)
        if results[target_metric].max() > best_metric:
            best_metric = results[target_metric].max()
        else:
            break

        if verbose > 0:
            print(f"Iteration {i} | best metric: {best_metric}")
            print(f"Best preprocessing parameters: {best_prep_param}")
            print(f"Best model parameters: {best_mod_param}")

    results['combined'] = results['prep_param'].astype(str) + results['model_param'].astype(str)
    results = results.drop_duplicates(subset='combined').drop('combined', axis=1)

    return results.sort_values(by=target_metric, ascending=False).reset_index(drop=True)

prep_params_grid = {
    'scaling': [None],
    'imputation': ['mode'],
    'cat_age': [False],
    'target_freq': [0.7, 0.8],
    'generate_dummies': [True]
}

mod_par_grid = {
    'solver': ['svd', 'lsqr']
}

lda = LDA(n_components=1)

# results = test_preprocess_params(lda, prep_params_grid, df_tr, cv=2, verbose=2)
# results.head()

results = search_best_combination(lda, mod_par_grid, prep_params_grid, df_tr, target_metric='f1_macro', cv=2, verbose=1)
results.head()

Iteration 1 | best metric: 0.753257018930653
Best preprocessing parameters: [{'scaling': None, 'imputation': 'mode', 'cat_age': False, 'target_freq': 0.8, 'generate_dummies': True, 'remove_duplicates': True}, {'scaling': None, 'imputation': 'mode', 'cat_age': False, 'target_freq': 0.7, 'generate_dummies': True, 'remove_duplicates': True}]
Best model parameters: [{'solver': 'svd'}, {'solver': 'lsqr'}]


Unnamed: 0,prep_param,model_param,accuracy,f1_macro,precision_macro,recall_macro
0,"{'scaling': None, 'imputation': 'mode', 'cat_a...",{'solver': 'svd'},0.933885,0.753257,0.727472,0.788094
1,"{'scaling': None, 'imputation': 'mode', 'cat_a...",{'solver': 'lsqr'},0.933885,0.753257,0.727472,0.788094
2,"{'scaling': None, 'imputation': 'mode', 'cat_a...",{'solver': 'svd'},0.916415,0.73673,0.693563,0.823139
