In [1]:
%load_ext autoreload
%autoreload 2
from functools import partial

import numpy as np
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import resample

from vflow import Vset, init_args  # must install pcsp first (pip install pcsp)

np.set_printoptions(threshold=5)  # to limit printing

In [2]:
np.random.seed(13)

X, y = make_classification(n_samples=50, n_features=5)
X_train, X_test, y_train, y_test = init_args(train_test_split(X, y, random_state=42),
                                             names=['X_train', 'X_test', 'X_train',
                                                    'X_test'])

subsampling_funcs = [partial(resample,
                             n_samples=20,
                             random_state=i)
                     for i in range(3)]

subsampling_set = Vset(name='subsampling', vfuncs=subsampling_funcs)
X_trains, y_trains = subsampling_set(X_train, y_train)

# fit models
modeling_set = Vset(name='modeling',
                    vfuncs=[LogisticRegression(C=1, max_iter=1000, tol=0.1),
                             DecisionTreeClassifier(min_samples_leaf=1)],
                    vfunc_keys=["LR", "DT"])

_ = modeling_set.fit(X_trains, y_trains)

# predict now returns modeling_set.output rather than the result of sep_dicts(output_dict)
preds_test = modeling_set.predict(X_test)

hard_metrics_set = Vset(name='hard_metrics',
                        vfuncs=[accuracy_score, balanced_accuracy_score],
                        vfunc_keys=["Acc", "Bal_Acc"],
                        tracking_dir='./mlruns')

hard_metrics = hard_metrics_set.evaluate(y_test, preds_test)

In [None]:
!mlflow ui

[2023-01-23 17:20:48 -0800] [136070] [INFO] Starting gunicorn 20.1.0
[2023-01-23 17:20:48 -0800] [136070] [INFO] Listening at: http://127.0.0.1:5000 (136070)
[2023-01-23 17:20:48 -0800] [136070] [INFO] Using worker: sync
[2023-01-23 17:20:48 -0800] [136071] [INFO] Booting worker with pid: 136071
[2023-01-23 17:20:48 -0800] [136072] [INFO] Booting worker with pid: 136072
[2023-01-23 17:20:48 -0800] [136073] [INFO] Booting worker with pid: 136073
[2023-01-23 17:20:48 -0800] [136074] [INFO] Booting worker with pid: 136074
