-
Notifications
You must be signed in to change notification settings - Fork 1k
/
test_bsl_options.py
150 lines (111 loc) · 4.24 KB
/
test_bsl_options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Ensure that options for baseline estimates are taken into account."""
import pytest
from surprise import BaselineOnly
from surprise.model_selection import cross_validate
def test_method_field(u1_ml100k, pkf):
"""Ensure the method field is taken into account."""
bsl_options = {"method": "als"}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_als = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
bsl_options = {"method": "sgd"}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_sgd = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
assert rmse_als != rmse_sgd
with pytest.raises(ValueError):
bsl_options = {"method": "wrong_name"}
algo = BaselineOnly(bsl_options=bsl_options)
cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
def test_als_n_epochs_field(u1_ml100k, pkf):
"""Ensure the n_epochs field is taken into account."""
bsl_options = {
"method": "als",
"n_epochs": 1,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_als_n_epochs_1 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
bsl_options = {
"method": "als",
"n_epochs": 5,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_als_n_epochs_5 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
assert rmse_als_n_epochs_1 != rmse_als_n_epochs_5
def test_als_reg_u_field(u1_ml100k, pkf):
"""Ensure the reg_u field is taken into account."""
bsl_options = {
"method": "als",
"reg_u": 0,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_als_regu_0 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
bsl_options = {
"method": "als",
"reg_u": 10,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_als_regu_10 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
assert rmse_als_regu_0 != rmse_als_regu_10
def test_als_reg_i_field(u1_ml100k, pkf):
"""Ensure the reg_i field is taken into account."""
bsl_options = {
"method": "als",
"reg_i": 0,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_als_regi_0 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
bsl_options = {
"method": "als",
"reg_i": 10,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_als_regi_10 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
assert rmse_als_regi_0 != rmse_als_regi_10
def test_sgd_n_epoch_field(u1_ml100k, pkf):
"""Ensure the n_epoch field is taken into account."""
bsl_options = {
"method": "sgd",
"n_epochs": 1,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_sgd_n_epoch_1 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
bsl_options = {
"method": "sgd",
"n_epochs": 20,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_sgd_n_epoch_5 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
assert rmse_sgd_n_epoch_1 != rmse_sgd_n_epoch_5
def test_sgd_learning_rate_field(u1_ml100k, pkf):
"""Ensure the learning_rate field is taken into account."""
bsl_options = {
"method": "sgd",
"n_epochs": 1,
"learning_rate": 0.005,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_sgd_lr_005 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
bsl_options = {
"method": "sgd",
"n_epochs": 1,
"learning_rate": 0.00005,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_sgd_lr_00005 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
assert rmse_sgd_lr_005 != rmse_sgd_lr_00005
def test_sgd_reg_field(u1_ml100k, pkf):
"""Ensure the reg field is taken into account."""
bsl_options = {
"method": "sgd",
"n_epochs": 1,
"reg": 0.02,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_sgd_reg_002 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
bsl_options = {
"method": "sgd",
"n_epochs": 1,
"reg": 1,
}
algo = BaselineOnly(bsl_options=bsl_options)
rmse_sgd_reg_1 = cross_validate(algo, u1_ml100k, ["rmse"], pkf)["test_rmse"]
assert rmse_sgd_reg_002 != rmse_sgd_reg_1