In [1]:
# Parameters
cross_val_args = {"n_splits": 3, "random_state": 0, "shuffle": True}
dataset_args = {
    "feature_columns": ["sepal_length", "sepal_width", "petal_length", "petal_width"],
    "source": "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/d546eaee765268bf2f487608c537c05e22e4b221/iris.csv",
    "target_column": "species",
}
labels = {"features": "both", "n_estimators": 1}
model_args = {
    "max_depth": 2,
    "min_samples_leaf": 1,
    "n_estimators": 1,
    "random_state": 0,
}
output_path = "/Users/levinbrinkmann/repros/ml-project-template/data/grid/feature_estimator/train/n_estimators_1__features_both"


In [2]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import os
import pickle
from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestClassifier
from iris.load_data import load_dataset
from iris.utils.pandas import add_labels
from iris.utils.io import make_dir

make_dir(output_path)

In [3]:
X, y = load_dataset(**dataset_args)
cv = KFold(**cross_val_args)
clf = RandomForestClassifier(**model_args)

metrics = []

# calculate cross validated performance
for i, (train_index, test_index) in enumerate(cv.split(X, y)):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    clf.fit(X_train, y_train)
    in_acc = clf.score(X_train, y_train)
    out_acc = clf.score(X_test, y_test)
    metrics.append(dict(cv=i, sample='in-sample', metric='accuracy', value=in_acc))
    metrics.append(dict(cv=i, sample='out-of-sample', metric='accuracy', value=out_acc))

# fit model on full dataset
clf.fit(X, y)
acc = clf.score(X, y)
metrics.append(dict(cv=None, sample='in-sample', metric='accuracy', value=acc))

# save metrics
metrics_df = pd.DataFrame.from_records(metrics)
metrics_df = add_labels(metrics_df, labels=labels)
metrics_df.to_parquet(os.path.join(output_path, 'metrics.parquet'))

# save model
model_filename = os.path.join(output_path, 'model.sav')
pickle.dump(clf, open(model_filename, 'wb'))