From e580802d893ff2a3ebd53f022a13a727f60b3f03 Mon Sep 17 00:00:00 2001 From: Andrew Hsieh Date: Sun, 19 Jul 2020 15:36:55 +0800 Subject: [PATCH] SUBMARINE-561. [SDK] Add PyTorch implementation of AFM model ### What is this PR for? Add PyTorch implementation of Attentional Factorization Machine for CTR prediction. ([AFM](https://arxiv.org/pdf/1708.04617.pdf)) Make minor modifications to the PyTorch training flow. Add testing for the AFM model. ### What type of PR is it? [Improvement] ### Todos * [ ] - Task ### What is the Jira issue? https://issues.apache.org/jira/browse/SUBMARINE-561 ### How should this be tested? [python-sdk](https://github.com/andrewhsiehth/submarine/actions/runs/169985131) [Submarine](https://github.com/andrewhsiehth/submarine/actions/runs/169985125) ### Screenshots (if appropriate) ### Questions: * Does the licenses files need update? No * Is there breaking changes for older versions? No * Does this needs documentation? No Author: Andrew Hsieh Author: andrewhsiehth Closes #346 from andrewhsiehth/SUBMARINE-561 and squashes the following commits: 0521639 [andrewhsiehth] rename afm && refactor example/pytorch folder f98d59f [andrewhsiehth] mkdir for non-existing output directory 3057899 [andrewhsiehth] use pysubmarine-ci to auto-format f89d070 [Andrew Hsieh] python3.6 yapf d4d93c4 [Andrew Hsieh] try to make python3.5 happy 2929dfc [Andrew Hsieh] try to make codestyle checker happy v2 42d5091 [Andrew Hsieh] try to make codestyle checker happy 9ff2f8d [Andrew Hsieh] fix core, afm coding style adae613 [Andrew Hsieh] fix tqdm 4facbce [Andrew Hsieh] fix conftest.py coding style e4b3e50 [Andrew Hsieh] fix deepfm.py coding style cb6be07 [Andrew Hsieh] fix ctr.__init__ coding style 2b4eecf [Andrew Hsieh] fix base_pytorch_model coding style 573a4e8 [Andrew Hsieh] fix fileio coding style 5d6dfc0 [Andrew Hsieh] add afm testing 827c785 [Andrew Hsieh] update conftest b260042 [Andrew Hsieh] add afm example a7da1c3 [Andrew Hsieh] add afm to ctr ab7b4b7 [Andrew Hsieh] add afm fa151e5 [Andrew Hsieh] fix deepfm 380358c [Andrew Hsieh] fix testing 3f80bc6 [Andrew Hsieh] fix fileio 7471408 [Andrew Hsieh] fix data input_fn and fileio f57d732 [Andrew Hsieh] fix deepfm fdcda05 [Andrew Hsieh] fix layers/core.py ce535fc [Andrew Hsieh] fix optimizer zero_grad --- .../pysubmarine/example/pytorch/afm/afm.json | 49 +++++++++ .../example/pytorch/afm/run_afm.py | 41 +++++++ .../example/pytorch/afm/run_afm.sh | 41 +++++++ .../example/pytorch/{ => deepfm}/deepfm.json | 26 ++--- .../{run_ctr.py => deepfm/run_deepfm.py} | 0 .../pytorch/{ => deepfm}/run_deepfm.sh | 4 +- .../ml/pytorch/input/libsvm_dataset.py | 101 ++++++++++++------ .../submarine/ml/pytorch/layers/core.py | 41 +++---- .../ml/pytorch/model/base_pytorch_model.py | 44 ++++---- .../ml/pytorch/model/ctr/__init__.py | 3 +- .../submarine/ml/pytorch/model/ctr/afm.py | 94 ++++++++++++++++ .../submarine/ml/pytorch/model/ctr/deepfm.py | 28 ++--- .../pysubmarine/submarine/utils/fileio.py | 81 +++++++------- .../tests/ml/pytorch/model/conftest.py | 16 ++- .../ml/pytorch/model/test_afm_pytorch.py | 25 +++++ 15 files changed, 444 insertions(+), 150 deletions(-) create mode 100644 submarine-sdk/pysubmarine/example/pytorch/afm/afm.json create mode 100644 submarine-sdk/pysubmarine/example/pytorch/afm/run_afm.py create mode 100644 submarine-sdk/pysubmarine/example/pytorch/afm/run_afm.sh rename submarine-sdk/pysubmarine/example/pytorch/{ => deepfm}/deepfm.json (53%) rename submarine-sdk/pysubmarine/example/pytorch/{run_ctr.py => deepfm/run_deepfm.py} (100%) rename submarine-sdk/pysubmarine/example/pytorch/{ => deepfm}/run_deepfm.sh (93%) create mode 100644 submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/afm.py create mode 100644 submarine-sdk/pysubmarine/tests/ml/pytorch/model/test_afm_pytorch.py diff --git a/submarine-sdk/pysubmarine/example/pytorch/afm/afm.json b/submarine-sdk/pysubmarine/example/pytorch/afm/afm.json new file mode 100644 index 0000000000..cc68e9533e --- /dev/null +++ b/submarine-sdk/pysubmarine/example/pytorch/afm/afm.json @@ -0,0 +1,49 @@ +{ + "input": { + "train_data": "../../data/tr.libsvm", + "valid_data": "../../data/va.libsvm", + "test_data": "../../data/te.libsvm", + "type": "libsvm" + }, + "output": { + "save_model_dir": "./output", + "metric": "roc_auc" + }, + "training": { + "batch_size": 512, + "num_epochs": 3, + "log_steps": 10, + "num_threads": 2, + "num_gpus": 0, + "seed": 42, + "mode": "distributed", + "backend": "gloo" + }, + "model": { + "name": "ctr.afm", + "kwargs": { + "num_fields": 39, + "num_features": 117581, + "attention_dim": 64, + "out_features": 1, + "embedding_dim": 256, + "hidden_units": [400, 400, 400], + "dropout_rate": 0.3 + } + }, + "loss": { + "name": "BCEWithLogitsLoss", + "kwargs": {} + }, + "optimizer": { + "name": "adam", + "kwargs": { + "lr": 5e-4 + } + }, + "resource": { + "num_cpus": 4, + "num_gpus": 0, + "num_threads": 2 + } +} diff --git a/submarine-sdk/pysubmarine/example/pytorch/afm/run_afm.py b/submarine-sdk/pysubmarine/example/pytorch/afm/run_afm.py new file mode 100644 index 0000000000..e6c28d6c39 --- /dev/null +++ b/submarine-sdk/pysubmarine/example/pytorch/afm/run_afm.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from submarine.ml.pytorch.model.ctr import AFM + +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "-conf", help="a JSON configuration file for AFM", type=str) + parser.add_argument("-task_type", default='train', + help="train or evaluate, by default is train") + args = parser.parse_args() + + trainer = AFM(json_path=args.conf) + + if args.task_type == 'train': + trainer.fit() + print('[Train Done]') + elif args.task_type == 'evaluate': + score = trainer.evaluate() + print(f'Eval score: {score}') + elif args.task_type == 'predict': + pred = trainer.predict() + print('Predict:', pred) + else: + assert False, args.task_type diff --git a/submarine-sdk/pysubmarine/example/pytorch/afm/run_afm.sh b/submarine-sdk/pysubmarine/example/pytorch/afm/run_afm.sh new file mode 100644 index 0000000000..494931cb5d --- /dev/null +++ b/submarine-sdk/pysubmarine/example/pytorch/afm/run_afm.sh @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +export JAVA_HOME=${JAVA_HOME:-$HOME/workspace/app/java} +export HADOOP_HOME=${HADOOP_HOME:-$HADOOP_HDFS_HOME} +export CLASSPATH=${CLASSPATH:-`hdfs classpath --glob`} +export ARROW_LIBHDFS_DIR=${ARROW_LIBHDFS_DIR:-$HADOOP_HOME/lib/native} + +# path to pysubmarine/submarine +PYTHONPATH=$HOME/workspace/submarine/submarine-sdk/pysubmarine + +HADOOP_CONF_PATH=${HADOOP_CONF_PATH:-$HADOOP_CONF_DIR} + +SUBMARINE_VERSION=0.5.0-SNAPSHOT +SUBMARINE_HADOOP_VERSION=2.9 +SUBMARINE_JAR=/opt/submarine-dist-${SUBMARINE_VERSION}-hadoop-${SUBMARINE_HADOOP_VERSION}/submarine-dist-${SUBMARINE_VERSION}-hadoop-${SUBMARINE_HADOOP_VERSION}/submarine-all-${SUBMARINE_VERSION}-hadoop-${SUBMARINE_HADOOP_VERSION}.jar + +java -cp $(${HADOOP_COMMON_HOME}/bin/hadoop classpath --glob):${SUBMARINE_JAR}:${HADOOP_CONF_PATH} \ + org.apache.submarine.client.cli.Cli job run --name afm-job-001 \ + --framework pytorch \ + --verbose \ + --input_path "" \ + --num_workers 2 \ + --worker_resources memory=1G,vcores=1 \ + --worker_launch_cmd "JAVA_HOME=$JAVA_HOME HADOOP_HOME=$HADOOP_HOME CLASSPATH=$CLASSPATH ARROW_LIBHDFS_DIR=$ARROW_LIBHDFS_DIR PYTHONPATH=$PYTHONPATH sdk.zip/sdk/bin/python run_afm.py --conf ./afm.json --task_type train" \ + --insecure \ + --conf tony.containers.resources=sdk.zip#archive,${SUBMARINE_JAR},run_afm.py,afm.json + diff --git a/submarine-sdk/pysubmarine/example/pytorch/deepfm.json b/submarine-sdk/pysubmarine/example/pytorch/deepfm/deepfm.json similarity index 53% rename from submarine-sdk/pysubmarine/example/pytorch/deepfm.json rename to submarine-sdk/pysubmarine/example/pytorch/deepfm/deepfm.json index a1c70690b4..41a694b1a9 100644 --- a/submarine-sdk/pysubmarine/example/pytorch/deepfm.json +++ b/submarine-sdk/pysubmarine/example/pytorch/deepfm/deepfm.json @@ -1,8 +1,8 @@ { "input": { - "train_data": "../data/tr.libsvm", - "valid_data": "../data/va.libsvm", - "test_data": "../data/te.libsvm", + "train_data": "../../data/tr.libsvm", + "valid_data": "../../data/va.libsvm", + "test_data": "../../data/te.libsvm", "type": "libsvm" }, "output": { @@ -10,10 +10,10 @@ "metric": "roc_auc" }, "training": { - "batch_size": 64, - "num_epochs": 1, + "batch_size": 512, + "num_epochs": 3, "log_steps": 10, - "num_threads": 0, + "num_threads": 2, "num_gpus": 0, "seed": 42, "mode": "distributed", @@ -22,12 +22,12 @@ "model": { "name": "ctr.deepfm", "kwargs": { - "field_dims": [15, 52, 30, 19, 111, 51, 26, 19, 53, 5, 13, 8, 23, 21, 77, 25, 39, 11, - 8, 61, 15, 3, 34, 75, 30, 79, 11, 85, 37, 10, 94, 19, 5, 32, 6, 12, 42, 18, 23], + "num_fields": 39, + "num_features": 117581, "out_features": 1, - "embedding_dim": 16, - "hidden_units": [400, 400], - "dropout_rates": [0.2, 0.2] + "embedding_dim": 256, + "hidden_units": [400, 400, 400], + "dropout_rates": [0.3, 0.3, 0.3] } }, "loss": { @@ -37,12 +37,12 @@ "optimizer": { "name": "adam", "kwargs": { - "lr": 1e-3 + "lr": 5e-4 } }, "resource": { "num_cpus": 4, "num_gpus": 0, - "num_threads": 0 + "num_threads": 2 } } diff --git a/submarine-sdk/pysubmarine/example/pytorch/run_ctr.py b/submarine-sdk/pysubmarine/example/pytorch/deepfm/run_deepfm.py similarity index 100% rename from submarine-sdk/pysubmarine/example/pytorch/run_ctr.py rename to submarine-sdk/pysubmarine/example/pytorch/deepfm/run_deepfm.py diff --git a/submarine-sdk/pysubmarine/example/pytorch/run_deepfm.sh b/submarine-sdk/pysubmarine/example/pytorch/deepfm/run_deepfm.sh similarity index 93% rename from submarine-sdk/pysubmarine/example/pytorch/run_deepfm.sh rename to submarine-sdk/pysubmarine/example/pytorch/deepfm/run_deepfm.sh index 96f8add4fb..8cfe096024 100644 --- a/submarine-sdk/pysubmarine/example/pytorch/run_deepfm.sh +++ b/submarine-sdk/pysubmarine/example/pytorch/deepfm/run_deepfm.sh @@ -35,7 +35,7 @@ java -cp $(${HADOOP_COMMON_HOME}/bin/hadoop classpath --glob):${SUBMARINE_JAR}:$ --input_path "" \ --num_workers 2 \ --worker_resources memory=1G,vcores=1 \ - --worker_launch_cmd "JAVA_HOME=$JAVA_HOME HADOOP_HOME=$HADOOP_HOME CLASSPATH=$CLASSPATH ARROW_LIBHDFS_DIR=$ARROW_LIBHDFS_DIR PYTHONPATH=$PYTHONPATH sdk.zip/sdk/bin/python run_ctr.py --conf ./deepfm.json --task_type train" \ + --worker_launch_cmd "JAVA_HOME=$JAVA_HOME HADOOP_HOME=$HADOOP_HOME CLASSPATH=$CLASSPATH ARROW_LIBHDFS_DIR=$ARROW_LIBHDFS_DIR PYTHONPATH=$PYTHONPATH sdk.zip/sdk/bin/python run_deepfm.py --conf ./deepfm.json --task_type train" \ --insecure \ - --conf tony.containers.resources=sdk.zip#archive,${SUBMARINE_JAR},run_ctr.py,deepfm.json + --conf tony.containers.resources=sdk.zip#archive,${SUBMARINE_JAR},run_deepfm.py,deepfm.json diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/input/libsvm_dataset.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/input/libsvm_dataset.py index a0c8a4ed79..bb33697a75 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/input/libsvm_dataset.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/input/libsvm_dataset.py @@ -13,59 +13,98 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pandas as pd +import numpy as np import torch from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler -from submarine.utils.fileio import read_file +from submarine.utils.fileio import open_buffered_file_reader, file_info + +import os +import itertools +import functools +import multiprocessing as mp +from typing import List, Tuple class LIBSVMDataset(Dataset): - def __init__(self, path): - self.data, self.label = self.preprocess_data(read_file(path)) + def __init__(self, data_uri: str, sample_offset: np.ndarray): + self.data_uri = data_uri + self.sample_offset = sample_offset + + def __len__(self) -> int: + return len(self.sample_offset) + + def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor, int]: + with open_buffered_file_reader(self.data_uri) as infile: + infile.seek(self.sample_offset[idx], os.SEEK_SET) + sample = infile.readline() + return LIBSVMDataset.parse_sample(sample) - def __getitem__(self, idx): - return self.data.iloc[idx], self.label.iloc[idx] + @classmethod + def parse_sample(cls, + sample: bytes) -> Tuple[torch.Tensor, torch.Tensor, int]: + label, *entries = sample.rstrip(b'\n').split(b' ') + feature_idx = torch.zeros(len(entries), dtype=torch.long) + feature_value = torch.zeros(len(entries), dtype=torch.float) + for i, entry in enumerate(entries): + fidx, fvalue = entry.split(b':') + feature_idx[i], feature_value[i] = int(fidx), float(fvalue) + return feature_idx, feature_value, int(label) - def __len__(self): - return len(self.data) + @classmethod + def prepare_dataset(cls, data_uri: str, n_jobs: int = os.cpu_count()): + sample_offset = LIBSVMDataset._locate_sample_offsets(data_uri=data_uri, + n_jobs=n_jobs) + return LIBSVMDataset(data_uri=data_uri, sample_offset=sample_offset) - def preprocess_data(self, stream): + @classmethod + def _locate_sample_offsets(cls, data_uri: str, n_jobs: int) -> np.ndarray: + finfo = file_info(data_uri) + chunk_size, _ = divmod(finfo.size, n_jobs) - def _convert_line(line): - feat_ids = [] - feat_vals = [] - for x in line: - feat_id, feat_val = x.split(':') - feat_ids.append(int(feat_id)) - feat_vals.append(float(feat_val)) - return (torch.as_tensor(feat_ids, dtype=torch.int64), - torch.as_tensor(feat_vals, dtype=torch.float32)) + chunk_starts = [0] + with open_buffered_file_reader(data_uri) as infile: + while infile.tell() < finfo.size: + infile.seek(chunk_size, os.SEEK_CUR) + infile.readline() + chunk_starts.append(min(infile.tell(), finfo.size)) - df = pd.read_table(stream, header=None) - df = df.loc[:, 0].str.split(n=1, expand=True) - label = df.loc[:, 0].apply(int) - data = df.loc[:, 1].str.split().apply(_convert_line) - return data, label + with mp.Pool(processes=n_jobs, + maxtasksperchild=1) as pool: + return np.asarray( + list( + itertools.chain.from_iterable( + pool.imap(functools.partial( + LIBSVMDataset._locate_sample_offsets_job, data_uri), + iterable=enumerate( + zip(chunk_starts[:-1], + chunk_starts[1:])))))) - def collate_fn(self, batch): - data, label = tuple(zip(*batch)) - _, feat_val = tuple(zip(*data)) - return (torch.stack(feat_val, dim=0).type(torch.long), - torch.as_tensor(label, dtype=torch.float32).unsqueeze(dim=-1)) + @classmethod + def _locate_sample_offsets_job( + cls, data_uri: str, task: Tuple[int, Tuple[int, int]]) -> List[int]: + _, (start, end) = task + offsets = [start] + with open_buffered_file_reader(data_uri) as infile: + infile.seek(start, os.SEEK_SET) + while infile.tell() < end: + infile.readline() + offsets.append(infile.tell()) + assert offsets.pop() == end + return offsets def libsvm_input_fn(filepath, batch_size=256, num_threads=1, **kwargs): def _input_fn(): - dataset = LIBSVMDataset(filepath) + dataset = LIBSVMDataset.prepare_dataset(data_uri=filepath, + n_jobs=num_threads) sampler = DistributedSampler(dataset) return DataLoader(dataset=dataset, batch_size=batch_size, sampler=sampler, - num_workers=num_threads, - collate_fn=dataset.collate_fn) + num_workers=0) # should be 0 (pytorch bug) return _input_fn diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py index 6fff591d40..98c8472060 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py @@ -13,52 +13,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -from itertools import accumulate - import torch from torch import nn -class FieldLinear(nn.Module): +class FeatureLinear(nn.Module): - def __init__(self, field_dims, out_features): + def __init__(self, num_features, out_features): """ - :param field_dims: List of dimensions of each field. + :param num_features: number of total features. :param out_features: The number of output features. """ super().__init__() - self.weight = nn.Embedding(num_embeddings=sum(field_dims), + self.weight = nn.Embedding(num_embeddings=num_features, embedding_dim=out_features) self.bias = nn.Parameter(torch.zeros((out_features,))) - self.register_buffer( - 'offset', - torch.as_tensor([0, *accumulate(field_dims)][:-1], - dtype=torch.long)) - def forward(self, x): + def forward(self, feature_idx, feature_value): """ - :param x: torch.LongTensor (batch_size, num_fields) + :param feature_idx: torch.LongTensor (batch_size, num_fields) + :param feature_value: torch.LongTensor (batch_size, num_fields) """ - return torch.sum(self.weight(x + self.offset), dim=1) + self.bias + return torch.sum( + self.weight(feature_idx) * feature_value.unsqueeze(dim=-1), + dim=1) + self.bias -class FieldEmbedding(nn.Module): +class FeatureEmbedding(nn.Module): - def __init__(self, field_dims, embedding_dim): + def __init__(self, num_features, embedding_dim): super().__init__() - self.weight = nn.Embedding(num_embeddings=sum(field_dims), + self.weight = nn.Embedding(num_embeddings=num_features, embedding_dim=embedding_dim) - self.register_buffer( - 'offset', - torch.as_tensor([0, *accumulate(field_dims)][:-1], - dtype=torch.long)) - def forward(self, x): + def forward(self, feature_idx, feature_value): """ - :param x: torch.LongTensor (batch_size, num_fields) + :param feature_idx: torch.LongTensor (batch_size, num_fields) + :param feature_value: torch.LongTensor (batch_size, num_fields) """ - return self.weight( - x + self.offset) # (batch_size, num_fields, embedding_dim) + return self.weight(feature_idx) * feature_value.unsqueeze(dim=-1) class PairwiseInteraction(nn.Module): diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py index b862a0ed92..168a78ad8f 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py @@ -17,6 +17,7 @@ import logging import os from abc import ABC +from pathlib import Path import torch from torch import distributed @@ -44,6 +45,9 @@ def __init__(self, params=None, json_path=None): self.params = get_from_dicts(params, default_parameters) self.params = get_from_json(json_path, self.params) self._sanity_check() + Path(self.params['output'] + ['save_model_dir']).expanduser().resolve().mkdir(parents=True, + exist_ok=True) logging.info("Model parameters : %s", self.params) self.input_type = self.params['input']['type'] @@ -68,32 +72,34 @@ def __del__(self): distributed.destroy_process_group() def train(self, train_loader): - for _, batch in enumerate(train_loader): - sample, target = batch - output = self.model(sample) - loss = self.loss(output, target) - loss.backward() - self.optimizer.zero_grad() - self.optimizer.step() + self.model.train() + with torch.enable_grad(): + for _, batch in enumerate(train_loader): + feature_idx, feature_value, label = batch + output = self.model(feature_idx, feature_value).squeeze() + loss = self.loss(output, label.float()) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() def evaluate(self): outputs = [] - targets = [] + labels = [] valid_loader = get_from_registry(self.input_type, input_fn_registry)( filepath=self.params['input']['valid_data'], **self.params['training'])() - + self.model.eval() with torch.no_grad(): for _, batch in enumerate(valid_loader): - sample, target = batch - output = self.model(sample) + feature_idx, feature_value, label = batch + output = self.model(feature_idx, feature_value).squeeze() outputs.append(output) - targets.append(target) + labels.append(label) return self.metric( - torch.cat(targets, dim=0).cpu().numpy(), + torch.cat(labels, dim=0).cpu().numpy(), torch.cat(outputs, dim=0).cpu().numpy()) def predict(self): @@ -102,12 +108,12 @@ def predict(self): test_loader = get_from_registry(self.input_type, input_fn_registry)( filepath=self.params['input']['test_data'], **self.params['training'])() - + self.model.eval() with torch.no_grad(): for _, batch in enumerate(test_loader): - sample, _ = batch - output = self.model(sample) - outputs.append(output) + feature_idx, feature_value, _ = batch + output = self.model(feature_idx, feature_value).squeeze() + outputs.append(torch.sigmoid(output)) return torch.cat(outputs, dim=0).cpu().numpy() @@ -141,8 +147,8 @@ def save_checkpoint(self): 'optimizer': self.optimizer.state_dict() }, buffer) write_file(buffer, - path=os.path.join( - self.params['output']['save_model_dir'], 'ckpt.pkl')) + uri=os.path.join(self.params['output']['save_model_dir'], + 'ckpt.pkl')) def model_fn(self, params): seed = params["training"]["seed"] diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/__init__.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/__init__.py index 3fb493537e..34bc8d677f 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/__init__.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/__init__.py @@ -14,5 +14,6 @@ # limitations under the License. from .deepfm import DeepFM +from .afm import AFM -__all__ = ["DeepFM"] +__all__ = ['DeepFM', 'AFM'] diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/afm.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/afm.py new file mode 100644 index 0000000000..e5314f753c --- /dev/null +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/afm.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from submarine.ml.pytorch.layers.core import (FeatureEmbedding, FeatureLinear) +from submarine.ml.pytorch.model.base_pytorch_model import BasePyTorchModel + + +class AFM(BasePyTorchModel): + + def model_fn(self, params): + super().model_fn(params) + return _AFM(**self.params['model']['kwargs']) + + +class _AFM(nn.Module): + + def __init__(self, num_features: int, embedding_dim: int, + attention_dim: int, out_features: int, dropout_rate: float, + **kwargs): + super().__init__() + self.feature_linear = FeatureLinear(num_features=num_features, + out_features=out_features) + self.feature_embedding = FeatureEmbedding(num_features=num_features, + embedding_dim=embedding_dim) + self.attentional_interaction = AttentionalInteratction( + embedding_dim=embedding_dim, + attention_dim=attention_dim, + out_features=out_features, + dropout_rate=dropout_rate) + + def forward(self, feature_idx: torch.LongTensor, + feature_value: torch.LongTensor): + """ + :param feature_idx: torch.LongTensor (batch_size, num_fields) + :param feature_value: torch.LongTensor (batch_size, num_fields) + """ + return self.feature_linear( + feature_idx, feature_value) + self.attentional_interaction( + self.feature_embedding(feature_idx, feature_value)) + + +class AttentionalInteratction(nn.Module): + + def __init__(self, embedding_dim: int, attention_dim: int, + out_features: int, dropout_rate: float): + super().__init__() + self.attention_score = nn.Sequential( + nn.Linear(in_features=embedding_dim, out_features=attention_dim), + nn.ReLU(), nn.Linear(in_features=attention_dim, out_features=1), + nn.Softmax(dim=1)) + self.pairwise_product = PairwiseProduct() + self.dropout = nn.Dropout(p=dropout_rate) + self.fc = nn.Linear(in_features=embedding_dim, + out_features=out_features) + + def forward(self, x: torch.FloatTensor): + """ + :param x: torch.FloatTensor (batch_size, num_fields, embedding_dim) + """ + x = self.pairwise_product(x) + score = self.attention_score(x) + attentioned = torch.sum(score * x, dim=1) + return self.fc(self.dropout(attentioned)) + + +class PairwiseProduct(nn.Module): + + def forward(self, x: torch.FloatTensor): + """ + :param x: torch.FloatTensor (batch_sie, num_fields, embedding_dim) + """ + _, num_fields, _ = x.size() + + all_pairs_product = x.unsqueeze(dim=1) * x.unsqueeze(dim=2) + idx_row, idx_col = torch.unbind(torch.triu_indices(num_fields, + num_fields, + offset=1), + dim=0) + return all_pairs_product[:, idx_row, idx_col] diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py index 6c955d797d..d6c86ae0e9 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py @@ -16,7 +16,8 @@ import torch from torch import nn -from submarine.ml.pytorch.layers.core import (DNN, FieldEmbedding, FieldLinear, +from submarine.ml.pytorch.layers.core import (DNN, FeatureEmbedding, + FeatureLinear, PairwiseInteraction) from submarine.ml.pytorch.model.base_pytorch_model import BasePyTorchModel @@ -30,25 +31,28 @@ def model_fn(self, params): class _DeepFM(nn.Module): - def __init__(self, field_dims, embedding_dim, out_features, hidden_units, - dropout_rates, **kwargs): + def __init__(self, num_fields, num_features, embedding_dim, out_features, + hidden_units, dropout_rates, **kwargs): super().__init__() - self.field_linear = FieldLinear(field_dims=field_dims, - out_features=out_features) - self.field_embedding = FieldEmbedding(field_dims=field_dims, - embedding_dim=embedding_dim) + self.feature_linear = FeatureLinear(num_features=num_features, + out_features=out_features) + self.feature_embedding = FeatureEmbedding(num_features=num_features, + embedding_dim=embedding_dim) self.pairwise_interaction = PairwiseInteraction() - self.dnn = DNN(in_features=len(field_dims) * embedding_dim, + self.dnn = DNN(in_features=num_fields * embedding_dim, out_features=out_features, hidden_units=hidden_units, dropout_rates=dropout_rates) - def forward(self, x): + def forward(self, feature_idx, feature_value): """ - :param x: torch.LongTensor (batch_size, num_fields) + :param feature_idx: torch.LongTensor (batch_size, num_fields) + :param feature_value: torch.LongTensor (batch_size, num_fields) """ - emb = self.field_embedding(x) # (batch_size, num_fields, embedding_dim) - linear_logit = self.field_linear(x) + emb = self.feature_embedding( + feature_idx, + feature_value) # (batch_size, num_fields, embedding_dim) + linear_logit = self.feature_linear(feature_idx, feature_value) fm_logit = self.pairwise_interaction(emb) deep_logit = self.dnn(torch.flatten(emb, start_dim=1)) diff --git a/submarine-sdk/pysubmarine/submarine/utils/fileio.py b/submarine-sdk/pysubmarine/submarine/utils/fileio.py index 699e1a5cbc..d756757d02 100644 --- a/submarine-sdk/pysubmarine/submarine/utils/fileio.py +++ b/submarine-sdk/pysubmarine/submarine/utils/fileio.py @@ -14,59 +14,62 @@ # limitations under the License. import io -import os -from enum import Enum +from pathlib import Path +from typing import Tuple from urllib.parse import urlparse from pyarrow import fs -class _Scheme(Enum): - HDFS = 'hdfs' - FILE = 'file' - DEFAULT = '' +def open_buffered_file_reader( + uri: str, + buffer_size: int = io.DEFAULT_BUFFER_SIZE) -> io.BufferedReader: + try: + input_file = open_input_file(uri) + return io.BufferedReader(input_file, buffer_size=buffer_size) + except Exception as e: + input_file.close() + raise e -def read_file(path): - scheme, host, port, path = _parse_path(path) - if _Scheme(scheme) is _Scheme.HDFS: - return _read_hdfs(host=host, port=port, path=path) - else: - return _read_local(path=path) +def open_buffered_stream_writer( + uri: str, + buffer_size: int = io.DEFAULT_BUFFER_SIZE) -> io.BufferedWriter: + try: + output_stream = open_output_stream(uri) + return io.BufferedWriter(output_stream, buffer_size=buffer_size) + except Exception as e: + output_stream.close() + raise e -def write_file(buffer, path): - scheme, host, port, path = _parse_path(path) - if _Scheme(scheme) is _Scheme.HDFS: - _write_hdfs(buffer=buffer, host=host, port=port, path=path) - else: - _write_local(buffer=buffer, path=path) +def write_file(buffer: io.BytesIO, + uri: str, + buffer_size: int = io.DEFAULT_BUFFER_SIZE) -> None: + with open_buffered_stream_writer(uri, + buffer_size=buffer_size) as output_stream: + output_stream.write(buffer.getbuffer()) -def _parse_path(path): - p = urlparse(path) - return p.scheme, p.hostname, p.port, os.path.abspath(p.path) +def open_input_file(uri: str): + filesystem, path = _parse_uri(uri) + return filesystem.open_input_file(path) -def _read_hdfs(host, port, path): - hdfs = fs.HadoopFileSystem(host=host, port=port) - with hdfs.open_input_stream(path) as stream: - data = stream.read() - return io.BytesIO(data) +def open_output_stream(uri: str): + filesystem, path = _parse_uri(uri) + return filesystem.open_output_stream(path) -def _read_local(path): - with open(path, mode='rb') as f: - data = f.read() - return io.BytesIO(data) +def file_info(uri: str) -> fs.FileInfo: + filesystem, path = _parse_uri(uri) + info, = filesystem.get_file_info([path]) + return info -def _write_hdfs(buffer, host, port, path): - hdfs = fs.HadoopFileSystem(host=host, port=port) - with hdfs.open_output_stream(path) as stream: - stream.write(buffer.getbuffer()) - - -def _write_local(buffer, path): - with open(path, mode='wb') as f: - f.write(buffer.getbuffer()) +def _parse_uri(uri: str) -> Tuple[fs.FileSystem, str]: + parsed = urlparse(uri) + uri = uri if parsed.scheme else str( + Path(parsed.path).expanduser().absolute()) + filesystem, path = fs.FileSystem.from_uri(uri) + return filesystem, path diff --git a/submarine-sdk/pysubmarine/tests/ml/pytorch/model/conftest.py b/submarine-sdk/pysubmarine/tests/ml/pytorch/model/conftest.py index 997a709b7e..2da47efa43 100644 --- a/submarine-sdk/pysubmarine/tests/ml/pytorch/model/conftest.py +++ b/submarine-sdk/pysubmarine/tests/ml/pytorch/model/conftest.py @@ -18,10 +18,9 @@ import pytest # noqa -LIBSVM_DATA = """ -0 0:0 1:0 2:0 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 11:0 12:0 13:0 14:0 15:24 16:38 17:0 18:0 19:60 20:0 21:0 22:33 23:74 24:29 25:78 26:0 27:84 28:36 29:0 30:0 31:0 32:0 33:31 34:0 35:0 36:41 37:0 38:22 +LIBSVM_DATA = """0 0:0 1:0 2:0 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 11:0 12:0 13:0 14:0 15:24 16:38 17:0 18:0 19:60 20:0 21:0 22:33 23:74 24:29 25:78 26:0 27:84 28:36 29:0 30:0 31:0 32:0 33:31 34:0 35:0 36:41 37:0 38:22 0 0:1 1:1 2:1 3:1 4:1 5:1 6:1 7:0 8:1 9:0 10:1 11:0 12:1 13:0 14:1 15:0 16:0 17:0 18:1 19:60 20:1 21:0 22:33 23:74 24:0 25:78 26:1 27:0 28:0 29:1 30:1 31:0 32:1 33:0 34:0 35:0 36:0 37:0 38:0 -0 0:1 1:1 2:2 3:2 4:2 5:2 6:2 7:0 8:2 9:0 10:2 11:1 12:2 13:1 14:2 15:1 16:1 17:0 18:0 19:60 20:1 21:0 22:0 23:74 24:1 25:78 26:2 27:84 28:1 29:2 30:2 31:1 32:2 33:1 34:1 35:0 36:1 37:1 38:1 +0 0:1 1:1 2:2 3:2 4:2 5:2 6:2 7:0 8:2 9:0 10:2 11:1 12:2 13:1 14:2 15:1 16:1 17:0 18:0 19:60 20:1 21:0 22:0 23:74 24:1 25:78 26:2 27:84 28:1 29:2 30:2 31:1 32:2 33:1 34:1 35:0 36:1 37:1 38:1 0 0:2 1:2 2:3 3:3 4:3 5:3 6:3 7:1 8:3 9:1 10:3 11:0 12:3 13:0 14:3 15:24 16:38 17:0 18:1 19:60 20:1 21:0 22:1 23:0 24:29 25:0 26:2 27:1 28:36 29:3 30:3 31:1 32:2 33:31 34:0 35:0 36:2 37:1 38:1 0 0:3 1:3 2:3 3:0 4:4 5:4 6:2 7:1 8:3 9:0 10:1 11:0 12:4 13:2 14:4 15:24 16:2 17:0 18:2 19:60 20:1 21:0 22:33 23:74 24:2 25:78 26:0 27:84 28:2 29:3 30:93 31:1 32:2 33:2 34:0 35:1 36:3 37:1 38:1 0 0:2 1:3 2:3 3:3 4:5 5:3 6:3 7:1 8:4 9:1 10:3 11:0 12:3 13:3 14:76 15:24 16:38 17:1 18:3 19:0 20:1 21:0 22:0 23:1 24:29 25:1 26:2 27:84 28:36 29:4 30:93 31:1 32:2 33:31 34:2 35:2 36:41 37:1 38:1 @@ -56,7 +55,7 @@ def get_model_param(tmpdir): "batch_size": 4, "num_epochs": 1, "log_steps": 10, - "num_threads": 0, + "num_threads": 1, "num_gpus": 0, "seed": 42, "mode": "distributed", @@ -65,14 +64,13 @@ def get_model_param(tmpdir): "model": { "name": "ctr.deepfm", "kwargs": { - "field_dims": [ - 15, 52, 30, 19, 111, 51, 26, 19, 53, 5, 13, 8, 23, 21, 77, - 25, 39, 11, 8, 61, 15, 3, 34, 75, 30, 79, 11, 85, 37, 10, - 94, 19, 5, 32, 6, 12, 42, 18, 23 - ], + "num_fields": 39, + "num_features": 117581, "out_features": 1, "embedding_dim": 16, + "attention_dim": 64, "hidden_units": [400, 400], + "dropout_rate": 0.3, "dropout_rates": [0.2, 0.2] } }, diff --git a/submarine-sdk/pysubmarine/tests/ml/pytorch/model/test_afm_pytorch.py b/submarine-sdk/pysubmarine/tests/ml/pytorch/model/test_afm_pytorch.py new file mode 100644 index 0000000000..72befacbdb --- /dev/null +++ b/submarine-sdk/pysubmarine/tests/ml/pytorch/model/test_afm_pytorch.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from submarine.ml.pytorch.model.ctr import AFM + + +def test_run_afm(get_model_param): + param = get_model_param + + trainer = AFM(param) + trainer.fit() + trainer.evaluate() + trainer.predict()