Skip to content

Commit

Permalink
Fix wide benchmark (#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcAntoineSchmidtQC committed Oct 27, 2021
1 parent 5b40ae1 commit b6a3391
Show file tree
Hide file tree
Showing 25 changed files with 7,078 additions and 7,220 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Unreleased

- Fixed the sign of the log likelihood of the Gaussian distribution (not used for fitting coefficients).
- Renamed functions checking for qc.matrix compliance to refer to tabmat.
- Fixed the wide benchmarks which had duplicated columns (categorical and numerical).

2.0.1 - 2021-10-11
------------------
Expand Down
Binary file modified docs/_static/headline_benchmark.pdf
Binary file not shown.
Binary file modified docs/_static/headline_benchmark.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/intermediate-housing-l2.pdf
Binary file not shown.
Binary file modified docs/_static/intermediate-housing-l2.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/intermediate-housing-lasso.pdf
Binary file not shown.
Binary file modified docs/_static/intermediate-housing-lasso.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/intermediate-insurance-l2.pdf
Binary file not shown.
Binary file modified docs/_static/intermediate-insurance-l2.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/intermediate-insurance-lasso.pdf
Binary file not shown.
Binary file modified docs/_static/intermediate-insurance-lasso.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/narrow-insurance-l2.pdf
Binary file not shown.
Binary file modified docs/_static/narrow-insurance-l2.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/narrow-insurance-lasso.pdf
Binary file not shown.
Binary file modified docs/_static/narrow-insurance-lasso.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/wide-insurance-l2.pdf
Binary file not shown.
Binary file modified docs/_static/wide-insurance-l2.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/wide-insurance-lasso.pdf
Binary file not shown.
Binary file modified docs/_static/wide-insurance-lasso.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/benchmarks.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Benchmarks against glmnet and H2O
=================================

The following benchmarks were run on an Ubuntu 20.04 desktop with a six core Intel i7-4930k processor.
The following benchmarks were run on a MacBook Pro laptop with a quad-core Intel Core i5.

The title of each plot refers to both which dataset the benchmark was run on and whether a L2 ridge regression penalty or an L1 lasso penalty was included. For example "Narrow-Insurance-Ridge" was run on the ``narrow-insurance`` dataset with a ridge regression penalty. Each dataset/penalty pair is tested on five distributions that cover most of the common GLM types. The outcome variable is modified appropriately so that the behavior is similar to that expected for the distribution. For example, for the Poisson regression, we predict the number of claims per person. And for the binomial regression, we predict whether any given individual has ever had a claim. For the ``housing`` dataset, we only test three distributions because it does not contain count data that can be used as an outcome.

Expand Down
214 changes: 107 additions & 107 deletions docs/benchmarks/benchmark_data.csv

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions docs/benchmarks/benchmark_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,8 @@
)
plot_df = plot_df.pivot(columns="library_name")
plot_df.columns = plot_df.columns.get_level_values(1)
plot_df = plot_df.sort_index(axis=1).rename(columns={"r-glmnet": "glmnet"})
plot_df.index = [x.title() for x in plot_df.index]
plot_df = plot_df[["h2o", "glum", "r-glmnet"]].rename(
columns={"r-glmnet": "glmnet"}
)

title = prob_name.title() + "-" + ("Lasso" if reg == "lasso" else "Ridge")
plot_df.plot.bar(
Expand Down Expand Up @@ -170,6 +168,7 @@
)
plot_df = plot_df.pivot(columns="library_name")
plot_df.columns = plot_df.columns.get_level_values(1)
plot_df = plot_df.sort_index(axis=1).rename(columns={"r-glmnet": "glmnet"})
plot_df.index = [x.title() for x in plot_df.index]

title = prob_name.title() + "-" + ("Lasso" if reg == "lasso" else "Ridge")
Expand Down Expand Up @@ -237,6 +236,7 @@
)
plot_df = plot_df.pivot(columns="library_name")
plot_df.columns = plot_df.columns.get_level_values(1)
plot_df = plot_df.sort_index(axis=1).rename(columns={"r-glmnet": "glmnet"})
plot_df.index = [x.title() for x in plot_df.index]

plot_df.plot.bar(
Expand Down
6 changes: 5 additions & 1 deletion src/glum_benchmarks/data/create_insurance.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,11 @@ def generate_wide_insurance_dataset(
transformer = make_column_transformer(
(
FunctionTransformer(),
lambda x: x.select_dtypes(["number"]).columns,
lambda x: [
elmt
for elmt in x.select_dtypes(["number"]).columns
if elmt not in cat_cols
],
),
(
Pipeline([get_categorizer(col, "cat_" + col) for col in cat_cols]),
Expand Down

0 comments on commit b6a3391

Please sign in to comment.