From 5d6dfc0a8046d89a96d1f78d0e2303f6a9915faf Mon Sep 17 00:00:00 2001 From: Andrew Hsieh Date: Sun, 12 Jul 2020 20:59:47 +0800 Subject: [PATCH] add afm testing --- .../ml/pytorch/model/test_afm_pytorch.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 submarine-sdk/pysubmarine/tests/ml/pytorch/model/test_afm_pytorch.py diff --git a/submarine-sdk/pysubmarine/tests/ml/pytorch/model/test_afm_pytorch.py b/submarine-sdk/pysubmarine/tests/ml/pytorch/model/test_afm_pytorch.py new file mode 100644 index 0000000000..96d61c9832 --- /dev/null +++ b/submarine-sdk/pysubmarine/tests/ml/pytorch/model/test_afm_pytorch.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from submarine.ml.pytorch.model.ctr import AttentionalFM + + +def test_run_afm(get_model_param): + param = get_model_param + + trainer = AttentionalFM(param) + trainer.fit() + trainer.evaluate() + trainer.predict()