In [1]:
import pandas as pd
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from deepctr.models import DeepFM
from deepctr.inputs import SparseFeat, get_feature_names

In [2]:
data = pd.read_csv('ratings.csv')
sparse_features = ['userId', 'movieId', 'timestamp']
target = ['rating']

In [3]:
# 对特征标签进行编码
for feature in sparse_features:
    lbe = LabelEncoder()
    data[feature] = lbe.fit_transform(data[feature])
# 计算每个特征中 不同特征值的个数
fixlen_feature_columns = [SparseFeat(feature, data[feature].nunique()) for feature in sparse_features]
print(fixlen_feature_columns)
linear_feature_columns = fixlen_feature_columns
dnn_feature_columns = fixlen_feature_columns
feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns)

[SparseFeat(name='userId', vocabulary_size=7120, embedding_dim=4, use_hash=False, dtype='int32', embedding_name='userId', group_name='default_group'), SparseFeat(name='movieId', vocabulary_size=14026, embedding_dim=4, use_hash=False, dtype='int32', embedding_name='movieId', group_name='default_group'), SparseFeat(name='timestamp', vocabulary_size=822889, embedding_dim=4, use_hash=False, dtype='int32', embedding_name='timestamp', group_name='default_group')]


In [4]:
# 将数据集切分成训练集和测试集
train, test = train_test_split(data, test_size=0.2)
train_model_input = {name:train[name].values for name in feature_names}
test_model_input = {name:test[name].values for name in feature_names}

In [5]:
# 使用DeepFM进行训练
model = DeepFM(linear_feature_columns, dnn_feature_columns, task='regression')
model.compile('adam', 'mse', metrics=['mse'], )
history = model.fit(train_model_input, train[target].values, batch_size=256, epochs=1, verbose=True, validation_split=0.2, )

Train on 671088 samples, validate on 167772 samples


In [6]:
# 使用DeepFM进行预测
pred_ans = model.predict(test_model_input, batch_size=256)
# 输出RMSE或MSE
mse = round(mean_squared_error(test[target].values, pred_ans), 4)
rmae = mse ** 0.5
print('test EMSE', rmae)

test EMSE 0.8603487664894975
