Skip to content
This repository has been archived by the owner on Jul 10, 2024. It is now read-only.

Commit

Permalink
fix deepfm.py coding style
Browse files Browse the repository at this point in the history
  • Loading branch information
ifndef012 committed Jul 18, 2020
1 parent cb6be07 commit e4b3e50
Showing 1 changed file with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import torch
from torch import nn

from submarine.ml.pytorch.layers.core import (DNN, FeatureEmbedding, FeatureLinear,
from submarine.ml.pytorch.layers.core import (DNN, FeatureEmbedding,
FeatureLinear,
PairwiseInteraction)
from submarine.ml.pytorch.model.base_pytorch_model import BasePyTorchModel

Expand All @@ -30,13 +31,13 @@ def model_fn(self, params):

class _DeepFM(nn.Module):

def __init__(self, num_fields, num_features, embedding_dim, out_features, hidden_units,
dropout_rates, **kwargs):
def __init__(self, num_fields, num_features, embedding_dim, out_features,
hidden_units, dropout_rates, **kwargs):
super().__init__()
self.feature_linear = FeatureLinear(num_features=num_features,
out_features=out_features)
out_features=out_features)
self.feature_embedding = FeatureEmbedding(num_features=num_features,
embedding_dim=embedding_dim)
embedding_dim=embedding_dim)
self.pairwise_interaction = PairwiseInteraction()
self.dnn = DNN(in_features=num_fields * embedding_dim,
out_features=out_features,
Expand All @@ -48,7 +49,9 @@ def forward(self, feature_idx, feature_value):
:param feature_idx: torch.LongTensor (batch_size, num_fields)
:param feature_value: torch.LongTensor (batch_size, num_fields)
"""
emb = self.feature_embedding(feature_idx, feature_value) # (batch_size, num_fields, embedding_dim)
emb = self.feature_embedding(
feature_idx,
feature_value) # (batch_size, num_fields, embedding_dim)
linear_logit = self.feature_linear(feature_idx, feature_value)
fm_logit = self.pairwise_interaction(emb)
deep_logit = self.dnn(torch.flatten(emb, start_dim=1))
Expand Down

0 comments on commit e4b3e50

Please sign in to comment.