Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make sure y_pred and y have the same size #5

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/recommendation.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The methods are similar for other algorithms.

### Example of DeepFM

1. ** Generate pytorch script model**
1. **Generate pytorch script model**
First, go to directory of python/recommendation and execute the following command:
```$xslt
python deepfm.py --input_dim 148 --n_fields 13 --embedding_dim 10 --fc_dims 10 5 1
Expand All @@ -33,14 +33,14 @@ The methods are similar for other algorithms.

This python script will generate a TorchScript model with the structure of dataflow graph for deepfm. This file is named ``deepfm.pt``.

2. ** Preparing the input data**
2. **Preparing the input data**
The input data of DeepFM should be libffm format. Each line of the input data represents one data sample.
```
label field1:feature1:value1 field2:feature2:value2
```
In Pytorch on angel, multi-hot field is allowed, which means some field can be appeared multi-times in one data example.

3. ** Training model**
3. **Training model**
After obtaining the model file (deepfm.pt) and the input data, we can submit a task through Spark on Angel to train the model. The command is:
```$xslt
source ./spark-on-angel-env.sh
Expand Down
6 changes: 3 additions & 3 deletions python/recommendation/attention_fm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def second_order(self, batch_size, index, values, embeddings, n_fields, embeddin
return attention_out

@torch.jit.script_method
def forward_(self, batch_size, index, feats, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
def forward_(self, batch_size, index, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor

n_fields = (int)(embeddings.size(0) / batch_size)
embedding_dim = embeddings.size(1)
Expand All @@ -105,7 +105,7 @@ def forward(self, batch_size, index, feats, values):
# type: (int, Tensor, Tensor, Tensor) -> Tensor
batch_first = F.embedding(feats, self.weights)
batch_second = F.embedding(feats, self.embedding)
return self.forward_(batch_size, index, feats, values,
return self.forward_(batch_size, index, values,
self.bias, batch_first, batch_second, self.mats)

@torch.jit.script_method
Expand Down
6 changes: 3 additions & 3 deletions python/recommendation/attention_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def higher_order(self, batch_size, embedding, mats):
return e.view(-1)

@torch.jit.script_method
def forward_(self, batch_size, index, feats, values, bias, weights, embedding, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
def forward_(self, batch_size, index, values, bias, weights, embedding, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
first = self.first_order(batch_size, index, values, bias, weights)
higher = self.higher_order(batch_size, embedding, mats)
return torch.sigmoid(first + higher)
Expand All @@ -118,7 +118,7 @@ def forward(self, batch_size, index, feats, values):
# type: (int, Tensor, Tensor, Tensor) -> Tensor
batch_first = F.embedding(feats, self.weights)
batch_embedding = F.embedding(feats, self.embedding)
return self.forward_(batch_size, index, feats, values,
return self.forward_(batch_size, index, values,
self.bias, batch_first, batch_embedding, self.mats)

@torch.jit.script_method
Expand Down
72 changes: 67 additions & 5 deletions python/recommendation/attention_net_multi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from __future__ import print_function

import argparse

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -127,8 +129,8 @@ def higher_order(self, batch_size, embedding, num_multi_head, top_k, n_fields, n
return e.view(-1)

@torch.jit.script_method
def forward_(self, batch_size, index, feats, values, bias, weights, embedding, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
def forward_(self, batch_size, index, values, bias, weights, embedding, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
first = self.first_order(batch_size, index, values, bias, weights)
higher = self.higher_order(batch_size, embedding, self.num_multi_head, self.top_k,
self.n_fields, self.num_attention_layers, mats)
Expand All @@ -138,7 +140,7 @@ def forward(self, batch_size, index, feats, values):
# type: (int, Tensor, Tensor, Tensor) -> Tensor
batch_first = F.embedding(feats, self.weights)
batch_embedding = F.embedding(feats, self.embedding)
return self.forward_(batch_size, index, feats, values, self.bias, batch_first,
return self.forward_(batch_size, index, values, self.bias, batch_first,
batch_embedding, self.mats)

@torch.jit.script_method
Expand All @@ -154,6 +156,66 @@ def get_name(self):
return "AttentionNetMultiHead"


if __name__ == '__main__':
attention = AttentionNetMultiHead()
FLAGS = None


def main():
attention = AttentionNetMultiHead(
FLAGS.input_dim,
FLAGS.n_fields,
FLAGS.embedding_dim,
FLAGS.num_multi_head,
FLAGS.top_k,
FLAGS.num_attention_layers,
FLAGS.fc_dims)
attention.save('attention_net_multi_head.pt')


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--input_dim",
type=int,
default=-1,
help="data input dim."
)
parser.add_argument(
"--n_fields",
type=int,
default=-1,
help="data num fields."
)
parser.add_argument(
"--embedding_dim",
type=int,
default=-1,
help="embedding dim."
)
parser.add_argument(
"--num_multi_head",
type=int,
default=-1,
help="num multi head."
)
parser.add_argument(
"--top_k",
type=int,
default=-1,
help="top k."
)
parser.add_argument(
"--num_attention_layers",
type=int,
default=-1,
help="num attention layers."
)
parser.add_argument(
"--fc_dims",
nargs="+",
type=int,
default=-1,
help="fc layers dim list."
)
FLAGS, unparsed = parser.parse_known_args()
main()
6 changes: 3 additions & 3 deletions python/recommendation/dcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def higher_order(self, batch_size, embeddings, mats):


@torch.jit.script_method
def forward_(self, batch_size, index, feats, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
def forward_(self, batch_size, index, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
first = self.first_order(batch_size, index, values, bias, weights)
cross_mats_index = self.cross_depth * 2
cross = self.cross(batch_size, embeddings, mats[0:cross_mats_index])
Expand All @@ -119,7 +119,7 @@ def forward(self, batch_size, index, feats, values):
# type: (int, Tensor, Tensor, Tensor) -> Tensor
emb = F.embedding(feats, self.embedding)
first = F.embedding(feats, self.weights)
return self.forward_(batch_size, index, feats, values,
return self.forward_(batch_size, index, values,
self.bias, first, emb, self.mats)

@torch.jit.script_method
Expand Down
6 changes: 3 additions & 3 deletions python/recommendation/deepandwide.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def higher_order(self, batch_size, embeddings, mats):
return output.view(-1)

@torch.jit.script_method
def forward_(self, batch_size, index, feats, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
def forward_(self, batch_size, index, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor

first = self.first_order(batch_size, index, values, bias, weights)
higher = self.higher_order(batch_size, embeddings, mats)
Expand All @@ -88,7 +88,7 @@ def forward(self, batch_size, index, feats, values):
# type: (int, Tensor, Tensor, Tensor) -> Tensor
batch_first = F.embedding(feats, self.weights)
batch_second = F.embedding(feats, self.embedding)
return self.forward_(batch_size, index, feats, values,
return self.forward_(batch_size, index, values,
self.bias, batch_first, batch_second, self.mats)

@torch.jit.script_method
Expand Down
6 changes: 3 additions & 3 deletions python/recommendation/deepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def higher_order(self, batch_size, embeddings, mats):
return output.view(-1)

@torch.jit.script_method
def forward_(self, batch_size, index, feats, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
def forward_(self, batch_size, index, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor

first = self.first_order(batch_size, index, values, bias, weights)
second = self.second_order(batch_size, index, values, embeddings)
Expand All @@ -118,7 +118,7 @@ def forward(self, batch_size, index, feats, values):
# type: (int, Tensor, Tensor, Tensor) -> Tensor
batch_first = F.embedding(feats, self.weights)
batch_second = F.embedding(feats, self.embedding)
return self.forward_(batch_size, index, feats, values,
return self.forward_(batch_size, index, values,
self.bias, batch_first, batch_second, self.mats)

@torch.jit.script_method
Expand Down
6 changes: 3 additions & 3 deletions python/recommendation/fm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def second_order(self, batch_size, index, values, embeddings):
return second

@torch.jit.script_method
def forward_(self, batch_size, index, feats, values, bias, weights, embeddings):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
def forward_(self, batch_size, index, values, bias, weights, embeddings):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
first = self.first_order(batch_size, index, values, bias, weights)
second = self.second_order(batch_size, index, values, embeddings)
return torch.sigmoid(first + second)
Expand All @@ -96,7 +96,7 @@ def forward(self, batch_size, index, feats, values):
# type: (int, Tensor, Tensor, Tensor) -> Tensor
batch_first = F.embedding(feats, self.weights)
batch_second = F.embedding(feats, self.embedding)
return self.forward_(batch_size, index, feats, values, self.bias, batch_first, batch_second)
return self.forward_(batch_size, index, values, self.bias, batch_first, batch_second)

@torch.jit.script_method
def loss(self, output, targets):
Expand Down
6 changes: 3 additions & 3 deletions python/recommendation/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def __init__(self, input_dim=-1):


@torch.jit.script_method
def forward_(self, batch_size, index, feats, values, bias, weight):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
def forward_(self, batch_size, index, values, bias, weight):
# type: (int, Tensor, Tensor, Tensor, Tensor) -> Tensor
size = batch_size
index = index.view(-1)
values = values.view(1, -1)
Expand All @@ -57,7 +57,7 @@ def forward(self, batch_size, index, feats, values):
# type: (int, Tensor, Tensor, Tensor) -> Tensor
weight = F.embedding(feats, self.weights)
bias = self.bias
return self.forward_(batch_size, index, feats, values, bias, weight)
return self.forward_(batch_size, index, values, bias, weight)

@torch.jit.script_method
def loss(self, output, targets):
Expand Down
6 changes: 3 additions & 3 deletions python/recommendation/pnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def deep(self, batch_size, product_input, mats):
return output.view(-1) # [b * 1]

@torch.jit.script_method
def forward_(self, batch_size, index, feats, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
def forward_(self, batch_size, index, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
first = self.first_order(batch_size, index, values, bias, weights)
product = self.product(batch_size, embeddings, mats[0:3])
output = self.deep(batch_size, product, mats[3:])
Expand All @@ -111,7 +111,7 @@ def forward(self, batch_size, index, feats, values):
# type: (int, Tensor, Tensor, Tensor) -> Tensor
batch_first = F.embedding(feats, self.weights)
emb = F.embedding(feats, self.embedding)
return self.forward_(batch_size, index, feats, values,
return self.forward_(batch_size, index, values,
self.bias, batch_first, emb, self.mats)

@torch.jit.script_method
Expand Down
45 changes: 23 additions & 22 deletions python/recommendation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


import lr, fm, deepfm, deepandwide
import dcn, attention_net
import dcn, attention_net, attention_net_multi_head
import pnn, attention_fm, xdeepfm


Expand All @@ -44,12 +44,14 @@
# model = lr.LogisticRegression(dim)
# model = fm.FactorizationMachine(dim, 10)
# model = deepfm.DeepFM(dim, 13, 10, [10, 5, 1])
# model = deepandwide.DeepAndWide(dim, 13, 10, [10, 5, 1])
model = deepandwide.DeepAndWide(dim, 13, 10, [10, 5, 1])
# model = attention_net.AttentionNet(dim, 13, 10, [10, 5, 1])
# model = dcn.DCNet(dim, 13, 10, cross_depth=4, deep_layers=[10, 10, 10])
#model = pnn.PNN(dim, 13, 10, [10, 5, 1])
#model = attention_fm.AttentionFM(dim, 13, 10, 10)
#model = xdeepfm.xDeepFM(dim, 13, 10, [10, 5, 5], [128, 128])
# model = dcn.DCNet(dim, 13, 10, cross_depth=4, fc_dims=[10, 10, 10])
# model = pnn.PNN(dim, 13, 10, [10, 5, 1])
# model = attention_fm.AttentionFM(dim, 13, 10, 10)
# model = xdeepfm.xDeepFM(dim, 13, 10, [10, 5, 5], [128, 128])
# model = attention_net_multi_head.AttentionNetMultiHead(dim, 13, 10, 8, 5, 2, [10, 5, 1])
model.save("savedmodules/DeepAndWide-model.pt")
model = torch.jit.load("savedmodules/DeepAndWide-model.pt")


Expand All @@ -59,32 +61,31 @@
batch_size = 30

for epoch in range(10):
start = 0
sum_loss = 0.0
time_start = time.time()
start = 0
sum_loss = 0.0
time_start = time.time()
while start < size:
optim.zero_grad()
end = min(start+batch_size, size)
batch = X[start:end].tocoo()
y = torch.from_numpy(Y[start:end]).to(torch.float32)

batch_size, _ = batch.shape
# batch_size = torch.tensor([batch_size]).to(torch.int32)
row = torch.from_numpy(batch.row).to(torch.long)
col = torch.from_numpy(batch.col).to(torch.long)
data = torch.from_numpy(batch.data)
batch_size, _ = batch.shape
# batch_size = torch.tensor([batch_size]).to(torch.int32)
row = torch.from_numpy(batch.row).to(torch.long)
col = torch.from_numpy(batch.col).to(torch.long)
data = torch.from_numpy(batch.data)

y_pred = model(batch_size, row, col, data)
y_pred = model(batch_size, row, col, data).view_as(y)
loss = model.loss(y_pred, y)

loss = model.loss(y_pred, y)
#loss.backward()
optim.step()

#loss.backward()
optim.step()
start += batch_size
sum_loss += loss.item()* batch_size

start += batch_size
sum_loss += loss.item()* batch_size

print(sum_loss / size, '%fs' % (time.time() - time_start))
print(sum_loss / size, '%fs' % (time.time() - time_start))


# model.save("model.pt")
6 changes: 3 additions & 3 deletions python/recommendation/xdeepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def deep(self, batch_size, embeddings, mats):


@torch.jit.script_method
def forward_(self, batch_size, index, feats, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
def forward_(self, batch_size, index, values, bias, weights, embeddings, mats):
# type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
first = self.first_order(batch_size, index, values, bias, weights)
cin_index = len(self.cin_dims) * 2
cin = self.cin(batch_size, embeddings, mats[0:cin_index])
Expand All @@ -140,7 +140,7 @@ def forward(self, batch_size, index, feats, values):
# type: (int, Tensor, Tensor, Tensor) -> Tensor
batch_first = F.embedding(feats, self.weights)
emb = F.embedding(feats, self.embedding)
return self.forward_(batch_size, index, feats, values,
return self.forward_(batch_size, index, values,
self.bias, batch_first, emb, self.mats)

@torch.jit.script_method
Expand Down