Skip to content
This repository has been archived by the owner on Jul 10, 2024. It is now read-only.

Commit

Permalink
fix deepfm
Browse files Browse the repository at this point in the history
  • Loading branch information
ifndef012 committed Jul 18, 2020
1 parent fdcda05 commit f57d732
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch import nn

from submarine.ml.pytorch.layers.core import (DNN, FieldEmbedding, FieldLinear,
from submarine.ml.pytorch.layers.core import (DNN, FeatureEmbedding, FeatureLinear,
PairwiseInteraction)
from submarine.ml.pytorch.model.base_pytorch_model import BasePyTorchModel

Expand All @@ -30,25 +30,26 @@ def model_fn(self, params):

class _DeepFM(nn.Module):

def __init__(self, field_dims, embedding_dim, out_features, hidden_units,
def __init__(self, num_fields, num_features, embedding_dim, out_features, hidden_units,
dropout_rates, **kwargs):
super().__init__()
self.field_linear = FieldLinear(field_dims=field_dims,
self.field_linear = FeatureLinear(num_features=num_features,
out_features=out_features)
self.field_embedding = FieldEmbedding(field_dims=field_dims,
self.field_embedding = FeatureEmbedding(num_features=num_features,
embedding_dim=embedding_dim)
self.pairwise_interaction = PairwiseInteraction()
self.dnn = DNN(in_features=len(field_dims) * embedding_dim,
self.dnn = DNN(in_features=num_fields * embedding_dim,
out_features=out_features,
hidden_units=hidden_units,
dropout_rates=dropout_rates)

def forward(self, x):
def forward(self, feature_idx, feature_value):
"""
:param x: torch.LongTensor (batch_size, num_fields)
:param feature_idx: torch.LongTensor (batch_size, num_fields)
:param feature_value: torch.LongTensor (batch_size, num_fields)
"""
emb = self.field_embedding(x) # (batch_size, num_fields, embedding_dim)
linear_logit = self.field_linear(x)
emb = self.field_embedding(feature_idx, feature_value) # (batch_size, num_fields, embedding_dim)
linear_logit = self.field_linear(feature_idx, feature_value)
fm_logit = self.pairwise_interaction(emb)
deep_logit = self.dnn(torch.flatten(emb, start_dim=1))

Expand Down

0 comments on commit f57d732

Please sign in to comment.