Skip to content

Commit

Permalink
Added basic sanity checks for algorithms
Browse files Browse the repository at this point in the history
Check that RMSE don't change accross commits
  • Loading branch information
NicolasHug committed May 21, 2018
1 parent d2b8a22 commit 8b071db
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os

import pytest
from six import iteritems

from surprise import NormalPredictor
from surprise import BaselineOnly
Expand All @@ -22,6 +23,7 @@
from surprise import Reader
from surprise import KNNWithZScore
from surprise.model_selection import train_test_split
from surprise import accuracy


def test_unknown_user_or_item(toy_data):
Expand Down Expand Up @@ -84,3 +86,32 @@ def test_nearest_neighbors():
algo_ib = KNNBasic(sim_options={'user_based': False})
algo_ib.fit(trainset)
assert algo_ub.get_neighbors(0, k=10) != algo_ib.get_neighbors(0, k=10)


def test_sanity_checks(u1_ml100k, pkf):
"""
Basic sanity checks for all algorithms: check that RMSE stays the same.
"""

expected_rmse = {
BaselineOnly: 1.0268524031297395,
KNNBasic: 1.1337265249554591,
KNNWithMeans: 1.1043129441881696,
KNNBaseline: 1.0700718041752253,
KNNWithZScore: 1.11179436167853,
SVD: 1.0077323320656948,
SVDpp: 1.00284553561452,
NMF: 1.0865370266372372,
SlopeOne: 1.1559939123891685,
CoClustering: 1.0841941385276614,
}

for klass, rmse in iteritems(expected_rmse):
if klass in (SVD, SVDpp, NMF, CoClustering):
algo = klass(random_state=0)
else:
algo = klass()
trainset, testset = next(pkf.split(u1_ml100k))
algo.fit(trainset)
predictions = algo.test(testset)
assert accuracy.rmse(predictions, verbose=False) == rmse

0 comments on commit 8b071db

Please sign in to comment.