Skip to content
This repository has been archived by the owner on Jul 10, 2024. It is now read-only.

Commit

Permalink
fix layers/core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ifndef012 committed Jul 18, 2020
1 parent ce535fc commit fdcda05
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,41 @@
from torch import nn


class FieldLinear(nn.Module):
class FeatureLinear(nn.Module):

def __init__(self, field_dims, out_features):
def __init__(self, num_features, out_features):
"""
:param field_dims: List of dimensions of each field.
:param num_features: number of total features.
:param out_features: The number of output features.
"""
super().__init__()
self.weight = nn.Embedding(num_embeddings=sum(field_dims),
self.weight = nn.Embedding(num_embeddings=num_features,
embedding_dim=out_features)
self.bias = nn.Parameter(torch.zeros((out_features,)))
self.register_buffer(
'offset',
torch.as_tensor([0, *accumulate(field_dims)][:-1],
dtype=torch.long))

def forward(self, x):
def forward(self, feature_idx, feature_value):
"""
:param x: torch.LongTensor (batch_size, num_fields)
:param feature_idx: torch.LongTensor (batch_size, num_fields)
:param feature_value: torch.LongTensor (batch_size, num_fields)
"""
return torch.sum(self.weight(x + self.offset), dim=1) + self.bias
return torch.sum(
self.weight(feature_idx) * feature_value.unsqueeze(dim=-1),
dim=1) + self.bias


class FieldEmbedding(nn.Module):
class FeatureEmbedding(nn.Module):

def __init__(self, field_dims, embedding_dim):
def __init__(self, num_features, embedding_dim):
super().__init__()
self.weight = nn.Embedding(num_embeddings=sum(field_dims),
self.weight = nn.Embedding(num_embeddings=num_features,
embedding_dim=embedding_dim)
self.register_buffer(
'offset',
torch.as_tensor([0, *accumulate(field_dims)][:-1],
dtype=torch.long))

def forward(self, x):
def forward(self, feature_idx, feature_value):
"""
:param x: torch.LongTensor (batch_size, num_fields)
:param feature_idx: torch.LongTensor (batch_size, num_fields)
:param feature_value: torch.LongTensor (batch_size, num_fields)
"""
return self.weight(
x + self.offset) # (batch_size, num_fields, embedding_dim)
return self.weight(feature_idx) * feature_value.unsqueeze(dim=-1)


class PairwiseInteraction(nn.Module):
Expand Down

0 comments on commit fdcda05

Please sign in to comment.