diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py index 6fff591d40..aeb24fc3f7 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py @@ -19,46 +19,41 @@ from torch import nn -class FieldLinear(nn.Module): +class FeatureLinear(nn.Module): - def __init__(self, field_dims, out_features): + def __init__(self, num_features, out_features): """ - :param field_dims: List of dimensions of each field. + :param num_features: number of total features. :param out_features: The number of output features. """ super().__init__() - self.weight = nn.Embedding(num_embeddings=sum(field_dims), + self.weight = nn.Embedding(num_embeddings=num_features, embedding_dim=out_features) self.bias = nn.Parameter(torch.zeros((out_features,))) - self.register_buffer( - 'offset', - torch.as_tensor([0, *accumulate(field_dims)][:-1], - dtype=torch.long)) - def forward(self, x): + def forward(self, feature_idx, feature_value): """ - :param x: torch.LongTensor (batch_size, num_fields) + :param feature_idx: torch.LongTensor (batch_size, num_fields) + :param feature_value: torch.LongTensor (batch_size, num_fields) """ - return torch.sum(self.weight(x + self.offset), dim=1) + self.bias + return torch.sum( + self.weight(feature_idx) * feature_value.unsqueeze(dim=-1), + dim=1) + self.bias -class FieldEmbedding(nn.Module): +class FeatureEmbedding(nn.Module): - def __init__(self, field_dims, embedding_dim): + def __init__(self, num_features, embedding_dim): super().__init__() - self.weight = nn.Embedding(num_embeddings=sum(field_dims), + self.weight = nn.Embedding(num_embeddings=num_features, embedding_dim=embedding_dim) - self.register_buffer( - 'offset', - torch.as_tensor([0, *accumulate(field_dims)][:-1], - dtype=torch.long)) - def forward(self, x): + def forward(self, feature_idx, feature_value): """ - :param x: torch.LongTensor (batch_size, num_fields) + :param feature_idx: torch.LongTensor (batch_size, num_fields) + :param feature_value: torch.LongTensor (batch_size, num_fields) """ - return self.weight( - x + self.offset) # (batch_size, num_fields, embedding_dim) + return self.weight(feature_idx) * feature_value.unsqueeze(dim=-1) class PairwiseInteraction(nn.Module):