In [1]:
import pandas as pd
import numpy as np
import pathlib
import loguru

from src.dataset import DatasetLoader
from src.dataset import DatasetEnum
from src.pipeline import ModelPipeline

In [2]:
data_loader = DatasetLoader(pathlib.Path("./datasets"))

data = data_loader.load_dataset(DatasetEnum.water)
features = [col_name for col_name in data.columns if col_name != "target"]

data.head(3)

Unnamed: 0,aluminium,ammonia,arsenic,barium,cadmium,chloramine,chromium,copper,flouride,bacteria,...,lead,nitrates,nitrites,mercury,perchlorate,radium,selenium,silver,uranium,target
0,1.65,9.08,0.04,2.85,0.007,0.35,0.83,0.17,0.05,0.2,...,0.054,16.08,1.13,0.007,37.75,6.78,0.08,0.34,0.02,1
1,2.32,21.16,0.01,3.31,0.002,5.28,0.68,0.66,0.9,0.65,...,0.1,2.01,1.93,0.003,32.26,3.21,0.08,0.27,0.05,1
2,1.01,14.02,0.04,0.58,0.008,4.24,0.53,0.02,0.99,0.05,...,0.078,14.16,1.11,0.006,50.28,7.07,0.07,0.44,0.01,0


In [3]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score

from catboost import CatBoostClassifier
from xgboost import XGBClassifier
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier

from sklearn.model_selection import StratifiedKFold

models_dct = {
    "CatBoost": CatBoostClassifier(silent=True),
    "XGBoost": XGBClassifier(enable_categorical=True),
    "MostFrequent": DummyClassifier(strategy="most_frequent"),
    "KNN": KNeighborsClassifier(),
    "NaiveBayes": GaussianNB(),
    "RandomForest": RandomForestClassifier(n_estimators=500),
    "LogRegression": LogisticRegression(max_iter=10000),
    "DecisionTree": DecisionTreeClassifier(),
}

metrics_list = [("score", roc_auc_score), ("binary", accuracy_score), ("binary", f1_score)]

In [4]:
import warnings

warnings.filterwarnings("ignore", message="is_categorical_dtype is deprecated")
warnings.filterwarnings("ignore", message="is_sparse is deprecated")

# To avoid different folds separately on each dataset
num_folds = 5
fold_generator = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=101)

summary_list = []

for model_name, model in models_dct.items():
    loguru.logger.info(f"Model: {model_name}")
    pipeline = ModelPipeline(base_model=model, features=features, metrics=metrics_list)

    folds = fold_generator.split(data, y=data["target"])
    for i, (train_fold_idx, test_fold_idx) in enumerate(folds):
        loguru.logger.info(f"Fold: {i + 1} / {num_folds}")
        train_fold = data.iloc[train_fold_idx]
        test_fold = data.iloc[test_fold_idx]
        
        pipeline.fit(train_fold)
        predictions = pipeline.predict(test_fold)

        metrics = pipeline.calculate_metrics(test_fold, predictions)

        result_dict = {
            "fold": i,
            "model": model_name,
        }
        result_dict.update(metrics)
        summary_list.append(result_dict)

summary = pd.DataFrame.from_records(summary_list)
summary

[32m2023-10-06 17:29:47.419[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mModel: CatBoost[0m
[32m2023-10-06 17:29:47.426[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mFold: 1 / 5[0m
[32m2023-10-06 17:29:52.464[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mFold: 2 / 5[0m
[32m2023-10-06 17:29:57.335[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mFold: 3 / 5[0m
[32m2023-10-06 17:30:02.133[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mFold: 4 / 5[0m
[32m2023-10-06 17:30:06.922[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mFold: 5 / 5[0m
[32m2023-10-06 17:30:11.660[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mModel: XGBoost[0m
[32m2023-10-06 17:30:11.665[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mFold: 1 / 5[0m


Unnamed: 0,fold,model,<function roc_auc_score at 0x7ff9dfc4a710>,<function accuracy_score at 0x7ff9e010bd00>,<function f1_score at 0x7ff9dfc48550>
0,0,CatBoost,0.990236,0.97375,0.87931
1,1,CatBoost,0.991244,0.96873,0.855491
2,2,CatBoost,0.98601,0.962477,0.817073
3,3,CatBoost,0.988045,0.971232,0.863095
4,4,CatBoost,0.992272,0.971857,0.866469
5,0,XGBoost,0.989703,0.97125,0.869318
6,1,XGBoost,0.989075,0.966229,0.844828
7,2,XGBoost,0.983451,0.963727,0.826347
8,3,XGBoost,0.988166,0.964978,0.837209
9,4,XGBoost,0.989942,0.969981,0.861272
