diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py index aeb24fc3f7..4e7b1e34f8 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py @@ -34,7 +34,7 @@ def __init__(self, num_features, out_features): def forward(self, feature_idx, feature_value): """ :param feature_idx: torch.LongTensor (batch_size, num_fields) - :param feature_value: torch.LongTensor (batch_size, num_fields) + :param feature_value: torch.LongTensor (batch_size, num_fields) """ return torch.sum( self.weight(feature_idx) * feature_value.unsqueeze(dim=-1), @@ -51,7 +51,7 @@ def __init__(self, num_features, embedding_dim): def forward(self, feature_idx, feature_value): """ :param feature_idx: torch.LongTensor (batch_size, num_fields) - :param feature_value: torch.LongTensor (batch_size, num_fields) + :param feature_value: torch.LongTensor (batch_size, num_fields) """ return self.weight(feature_idx) * feature_value.unsqueeze(dim=-1) diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/afm.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/afm.py index 5c5ca4f671..aa1fa58410 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/afm.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/afm.py @@ -70,7 +70,7 @@ def __init__(self, embedding_dim: int, attention_dim: int, def forward(self, x: torch.FloatTensor): """ - :param x: torch.FloatTensor (batch_size, num_fields, embedding_dim) + :param x: torch.FloatTensor (batch_size, num_fields, embedding_dim) """ x = self.pairwise_product(x) score = self.attention_score(x) @@ -85,7 +85,7 @@ def __init__(self): def forward(self, x: torch.FloatTensor): """ - :param x: torch.FloatTensor (batch_sie, num_fields, embedding_dim) + :param x: torch.FloatTensor (batch_sie, num_fields, embedding_dim) """ batch_size, num_fields, embedding_dim = x.size()