This repository has been archived by the owner on Jul 10, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 253
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SUBMARINE-561. [SDK] Add PyTorch implementation of AFM model
### What is this PR for? Add PyTorch implementation of Attentional Factorization Machine for CTR prediction. ([AFM](https://arxiv.org/pdf/1708.04617.pdf)) Make minor modifications to the PyTorch training flow. Add testing for the AFM model. ### What type of PR is it? [Improvement] ### Todos * [ ] - Task ### What is the Jira issue? https://issues.apache.org/jira/browse/SUBMARINE-561 ### How should this be tested? [python-sdk](https://github.com/andrewhsiehth/submarine/actions/runs/169985131) [Submarine](https://github.com/andrewhsiehth/submarine/actions/runs/169985125) ### Screenshots (if appropriate) ### Questions: * Does the licenses files need update? No * Is there breaking changes for older versions? No * Does this needs documentation? No Author: Andrew Hsieh <andrewhsiehth@gmail.com> Author: andrewhsiehth <andrewhsiehth@gmail.com> Closes #346 from andrewhsiehth/SUBMARINE-561 and squashes the following commits: 0521639 [andrewhsiehth] rename afm && refactor example/pytorch folder f98d59f [andrewhsiehth] mkdir for non-existing output directory 3057899 [andrewhsiehth] use pysubmarine-ci to auto-format f89d070 [Andrew Hsieh] python3.6 yapf d4d93c4 [Andrew Hsieh] try to make python3.5 happy 2929dfc [Andrew Hsieh] try to make codestyle checker happy v2 42d5091 [Andrew Hsieh] try to make codestyle checker happy 9ff2f8d [Andrew Hsieh] fix core, afm coding style adae613 [Andrew Hsieh] fix tqdm 4facbce [Andrew Hsieh] fix conftest.py coding style e4b3e50 [Andrew Hsieh] fix deepfm.py coding style cb6be07 [Andrew Hsieh] fix ctr.__init__ coding style 2b4eecf [Andrew Hsieh] fix base_pytorch_model coding style 573a4e8 [Andrew Hsieh] fix fileio coding style 5d6dfc0 [Andrew Hsieh] add afm testing 827c785 [Andrew Hsieh] update conftest b260042 [Andrew Hsieh] add afm example a7da1c3 [Andrew Hsieh] add afm to ctr ab7b4b7 [Andrew Hsieh] add afm fa151e5 [Andrew Hsieh] fix deepfm 380358c [Andrew Hsieh] fix testing 3f80bc6 [Andrew Hsieh] fix fileio 7471408 [Andrew Hsieh] fix data input_fn and fileio f57d732 [Andrew Hsieh] fix deepfm fdcda05 [Andrew Hsieh] fix layers/core.py ce535fc [Andrew Hsieh] fix optimizer zero_grad
- Loading branch information
Showing
15 changed files
with
444 additions
and
150 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
{ | ||
"input": { | ||
"train_data": "../../data/tr.libsvm", | ||
"valid_data": "../../data/va.libsvm", | ||
"test_data": "../../data/te.libsvm", | ||
"type": "libsvm" | ||
}, | ||
"output": { | ||
"save_model_dir": "./output", | ||
"metric": "roc_auc" | ||
}, | ||
"training": { | ||
"batch_size": 512, | ||
"num_epochs": 3, | ||
"log_steps": 10, | ||
"num_threads": 2, | ||
"num_gpus": 0, | ||
"seed": 42, | ||
"mode": "distributed", | ||
"backend": "gloo" | ||
}, | ||
"model": { | ||
"name": "ctr.afm", | ||
"kwargs": { | ||
"num_fields": 39, | ||
"num_features": 117581, | ||
"attention_dim": 64, | ||
"out_features": 1, | ||
"embedding_dim": 256, | ||
"hidden_units": [400, 400, 400], | ||
"dropout_rate": 0.3 | ||
} | ||
}, | ||
"loss": { | ||
"name": "BCEWithLogitsLoss", | ||
"kwargs": {} | ||
}, | ||
"optimizer": { | ||
"name": "adam", | ||
"kwargs": { | ||
"lr": 5e-4 | ||
} | ||
}, | ||
"resource": { | ||
"num_cpus": 4, | ||
"num_gpus": 0, | ||
"num_threads": 2 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from submarine.ml.pytorch.model.ctr import AFM | ||
|
||
import argparse | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"-conf", help="a JSON configuration file for AFM", type=str) | ||
parser.add_argument("-task_type", default='train', | ||
help="train or evaluate, by default is train") | ||
args = parser.parse_args() | ||
|
||
trainer = AFM(json_path=args.conf) | ||
|
||
if args.task_type == 'train': | ||
trainer.fit() | ||
print('[Train Done]') | ||
elif args.task_type == 'evaluate': | ||
score = trainer.evaluate() | ||
print(f'Eval score: {score}') | ||
elif args.task_type == 'predict': | ||
pred = trainer.predict() | ||
print('Predict:', pred) | ||
else: | ||
assert False, args.task_type |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
export JAVA_HOME=${JAVA_HOME:-$HOME/workspace/app/java} | ||
export HADOOP_HOME=${HADOOP_HOME:-$HADOOP_HDFS_HOME} | ||
export CLASSPATH=${CLASSPATH:-`hdfs classpath --glob`} | ||
export ARROW_LIBHDFS_DIR=${ARROW_LIBHDFS_DIR:-$HADOOP_HOME/lib/native} | ||
|
||
# path to pysubmarine/submarine | ||
PYTHONPATH=$HOME/workspace/submarine/submarine-sdk/pysubmarine | ||
|
||
HADOOP_CONF_PATH=${HADOOP_CONF_PATH:-$HADOOP_CONF_DIR} | ||
|
||
SUBMARINE_VERSION=0.5.0-SNAPSHOT | ||
SUBMARINE_HADOOP_VERSION=2.9 | ||
SUBMARINE_JAR=/opt/submarine-dist-${SUBMARINE_VERSION}-hadoop-${SUBMARINE_HADOOP_VERSION}/submarine-dist-${SUBMARINE_VERSION}-hadoop-${SUBMARINE_HADOOP_VERSION}/submarine-all-${SUBMARINE_VERSION}-hadoop-${SUBMARINE_HADOOP_VERSION}.jar | ||
|
||
java -cp $(${HADOOP_COMMON_HOME}/bin/hadoop classpath --glob):${SUBMARINE_JAR}:${HADOOP_CONF_PATH} \ | ||
org.apache.submarine.client.cli.Cli job run --name afm-job-001 \ | ||
--framework pytorch \ | ||
--verbose \ | ||
--input_path "" \ | ||
--num_workers 2 \ | ||
--worker_resources memory=1G,vcores=1 \ | ||
--worker_launch_cmd "JAVA_HOME=$JAVA_HOME HADOOP_HOME=$HADOOP_HOME CLASSPATH=$CLASSPATH ARROW_LIBHDFS_DIR=$ARROW_LIBHDFS_DIR PYTHONPATH=$PYTHONPATH sdk.zip/sdk/bin/python run_afm.py --conf ./afm.json --task_type train" \ | ||
--insecure \ | ||
--conf tony.containers.resources=sdk.zip#archive,${SUBMARINE_JAR},run_afm.py,afm.json | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.