Skip to content

Commit

Permalink
Add benchmark CI job (#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Aug 14, 2022
1 parent e2d53ea commit e29e073
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 15 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# This workflow will run the benchmark suite
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Benchmark

on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements_dev.txt
pip install tabulate
- name: Install Surprise
run: |
python -m pip install -e .
- name: Run Benchmarks
run: |
yes | python examples/benchmark.py
30 changes: 15 additions & 15 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@
from tabulate import tabulate

# The algorithms to cross-validate
classes = (
SVD,
SVDpp,
NMF,
SlopeOne,
KNNBasic,
KNNWithMeans,
KNNBaseline,
CoClustering,
BaselineOnly,
NormalPredictor,
algos = (
SVD(),
SVDpp(),
NMF(),
SlopeOne(),
KNNBasic(),
KNNWithMeans(),
KNNBaseline(),
CoClustering(),
BaselineOnly(),
NormalPredictor(),
)

# ugly dict to map algo names and datasets to their markdown links in the table
Expand Down Expand Up @@ -102,16 +102,16 @@
np.random.seed(0)
random.seed(0)

dataset = "ml-1m"
dataset = "ml-100k"
data = Dataset.load_builtin(dataset)
kf = KFold(random_state=0) # folds will be the same for all algorithms.

table = []
for klass in classes:
for algo in algos:
start = time.time()
out = cross_validate(klass(), data, ["rmse", "mae"], kf)
out = cross_validate(algo, data, ["rmse", "mae"], kf)
cv_time = str(datetime.timedelta(seconds=int(time.time() - start)))
link = LINK[klass.__name__]
link = LINK[algo.__class__.__name__]
mean_rmse = "{:.3f}".format(np.mean(out["test_rmse"]))
mean_mae = "{:.3f}".format(np.mean(out["test_mae"]))

Expand Down

0 comments on commit e29e073

Please sign in to comment.