diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py index 6c955d797d..53e7a75727 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py @@ -16,7 +16,7 @@ import torch from torch import nn -from submarine.ml.pytorch.layers.core import (DNN, FieldEmbedding, FieldLinear, +from submarine.ml.pytorch.layers.core import (DNN, FeatureEmbedding, FeatureLinear, PairwiseInteraction) from submarine.ml.pytorch.model.base_pytorch_model import BasePyTorchModel @@ -30,25 +30,26 @@ def model_fn(self, params): class _DeepFM(nn.Module): - def __init__(self, field_dims, embedding_dim, out_features, hidden_units, + def __init__(self, num_fields, num_features, embedding_dim, out_features, hidden_units, dropout_rates, **kwargs): super().__init__() - self.field_linear = FieldLinear(field_dims=field_dims, + self.field_linear = FeatureLinear(num_features=num_features, out_features=out_features) - self.field_embedding = FieldEmbedding(field_dims=field_dims, + self.field_embedding = FeatureEmbedding(num_features=num_features, embedding_dim=embedding_dim) self.pairwise_interaction = PairwiseInteraction() - self.dnn = DNN(in_features=len(field_dims) * embedding_dim, + self.dnn = DNN(in_features=num_fields * embedding_dim, out_features=out_features, hidden_units=hidden_units, dropout_rates=dropout_rates) - def forward(self, x): + def forward(self, feature_idx, feature_value): """ - :param x: torch.LongTensor (batch_size, num_fields) + :param feature_idx: torch.LongTensor (batch_size, num_fields) + :param feature_value: torch.LongTensor (batch_size, num_fields) """ - emb = self.field_embedding(x) # (batch_size, num_fields, embedding_dim) - linear_logit = self.field_linear(x) + emb = self.field_embedding(feature_idx, feature_value) # (batch_size, num_fields, embedding_dim) + linear_logit = self.field_linear(feature_idx, feature_value) fm_logit = self.pairwise_interaction(emb) deep_logit = self.dnn(torch.flatten(emb, start_dim=1))