-
Notifications
You must be signed in to change notification settings - Fork 24
/
att_rec.py
32 lines (27 loc) · 1.1 KB
/
att_rec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# -*- coding:utf-8 -*-
import os
os.environ["CHAINER_TYPE_CHECK"] = "0"
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import cuda, Variable
from base.encoder import *
from base.decoder import *
from base.utils import *
class AttRec(chainer.Chain):
def __init__(self, item_size,embed_size, hidden_size):
super(AttRec, self).__init__(
enc=NStepGRUEncoder(item_size,embed_size, hidden_size),
dec=AttDecoder(item_size, hidden_size),
)
def predict(self,input_list):
x_enable = chainer.Variable(self.xp.array(mask(input_list)))
batch_last_h, batch_seq_h = self.enc(input_list, x_enable)
return self.dec(batch_last_h, batch_seq_h, x_enable),
def train(self,input_list,output_list):
predicts=self.predict(input_list)[0]
groundtruths = chainer.Variable(self.xp.array(output_list, dtype=self.xp.int32).reshape(-1,))
# groundtruths=F.reshape(groundtruths,(-1,1))
loss = F.softmax_cross_entropy(predicts, groundtruths, normalize=True, reduce='mean')
return loss