-
Notifications
You must be signed in to change notification settings - Fork 1k
/
test_dump.py
76 lines (61 loc) · 1.84 KB
/
test_dump.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Module for testing the dump module."""
import random
import tempfile
import pytest
from surprise import (
BaselineOnly,
CoClustering,
dump,
KNNBaseline,
KNNBasic,
KNNWithMeans,
KNNWithZScore,
NMF,
NormalPredictor,
SlopeOne,
SVD,
SVDpp,
)
from surprise.model_selection import PredefinedKFold
@pytest.mark.parametrize(
"algo",
(
NormalPredictor(),
BaselineOnly(),
KNNBasic(),
KNNWithMeans(),
KNNBaseline(),
SVD(),
SVDpp(),
NMF(),
SlopeOne(),
CoClustering(),
KNNWithZScore(),
),
)
def test_dump(algo, u1_ml100k):
"""Train an algorithm, compute its predictions then dump them.
Ensure that the predictions that are loaded back are the correct ones, and
that the predictions of the dumped algorithm are also equal to the other
ones."""
random.seed(0)
trainset, testset = next(PredefinedKFold().split(u1_ml100k))
with tempfile.NamedTemporaryFile() as tmp_file:
dump.dump(tmp_file.name, algo=algo)
dump.load(tmp_file.name)
algo.fit(trainset)
predictions = algo.test(testset)
with tempfile.NamedTemporaryFile() as tmp_file:
dump.dump(tmp_file.name, predictions, algo)
predictions_dumped, algo_dumped = dump.load(tmp_file.name)
assert predictions == predictions_dumped
predictions_algo_dumped = algo_dumped.test(testset)
if not isinstance(algo, NormalPredictor): # predictions are random
assert predictions == predictions_algo_dumped
def test_dump_nothing():
"""Ensure that by default None objects are dumped."""
with tempfile.NamedTemporaryFile() as tmp_file:
dump.dump(tmp_file.name)
predictions, algo = dump.load(tmp_file.name)
assert predictions is None
assert algo is None