From e4b3e50404a8fff7225663abf1ea6aa1cf34101f Mon Sep 17 00:00:00 2001 From: Andrew Hsieh Date: Sun, 12 Jul 2020 21:51:22 +0800 Subject: [PATCH] fix deepfm.py coding style --- .../submarine/ml/pytorch/model/ctr/deepfm.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py index c560f6829e..d6c86ae0e9 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py @@ -16,7 +16,8 @@ import torch from torch import nn -from submarine.ml.pytorch.layers.core import (DNN, FeatureEmbedding, FeatureLinear, +from submarine.ml.pytorch.layers.core import (DNN, FeatureEmbedding, + FeatureLinear, PairwiseInteraction) from submarine.ml.pytorch.model.base_pytorch_model import BasePyTorchModel @@ -30,13 +31,13 @@ def model_fn(self, params): class _DeepFM(nn.Module): - def __init__(self, num_fields, num_features, embedding_dim, out_features, hidden_units, - dropout_rates, **kwargs): + def __init__(self, num_fields, num_features, embedding_dim, out_features, + hidden_units, dropout_rates, **kwargs): super().__init__() self.feature_linear = FeatureLinear(num_features=num_features, - out_features=out_features) + out_features=out_features) self.feature_embedding = FeatureEmbedding(num_features=num_features, - embedding_dim=embedding_dim) + embedding_dim=embedding_dim) self.pairwise_interaction = PairwiseInteraction() self.dnn = DNN(in_features=num_fields * embedding_dim, out_features=out_features, @@ -48,7 +49,9 @@ def forward(self, feature_idx, feature_value): :param feature_idx: torch.LongTensor (batch_size, num_fields) :param feature_value: torch.LongTensor (batch_size, num_fields) """ - emb = self.feature_embedding(feature_idx, feature_value) # (batch_size, num_fields, embedding_dim) + emb = self.feature_embedding( + feature_idx, + feature_value) # (batch_size, num_fields, embedding_dim) linear_logit = self.feature_linear(feature_idx, feature_value) fm_logit = self.pairwise_interaction(emb) deep_logit = self.dnn(torch.flatten(emb, start_dim=1))