Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
airaria committed Sep 18, 2017
1 parent fc45a90 commit 926dd3e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions main.py
Expand Up @@ -15,9 +15,9 @@
NUM_EPOCHS = 50
PRINT_EVERY_N_ITER = 100
ATTN_TYPE='dot'
ATTN_CLASS='simple' #complex|simple
ATTN_CLASS='type1' #type1 (Luong) | type2
ENC_TYPE='CNNRNN' #CNN|CNNRNN
SAVE_DIR ='CNNRNNdot128_lr0.0003cp10simple'
SAVE_DIR ='CNNRNNdot128_lr0.0003cp10type1'
if not os.path.exists(SAVE_DIR):
os.mkdir(SAVE_DIR)

Expand All @@ -31,11 +31,11 @@
rnn_hidden_size=HIDDEN_SIZE,
dropout=DROPOUT)

if ATTN_CLASS=='simple':
if ATTN_CLASS=='type1':
decoder = RNNAttnDecoder(ATTN_TYPE,input_vocab_size=VOCAB_SIZE,hidden_size=HIDDEN_SIZE,
output_size=VOCAB_SIZE,num_rnn_layers=NUM_RNN_LAYERS,
dropout=DROPOUT)
elif ATTN_CLASS=='complex':
elif ATTN_CLASS=='type2':
decoder = RNNAttnDecoder2(ATTN_TYPE,input_vocab_size=VOCAB_SIZE,hidden_size=HIDDEN_SIZE,
output_size=VOCAB_SIZE,num_rnn_layers=NUM_RNN_LAYERS,
dropout=DROPOUT)
Expand Down
4 changes: 2 additions & 2 deletions model.py
Expand Up @@ -67,7 +67,7 @@ def initHidden(self,batch_size,use_cuda=False):
return (h0.cuda())
else:
return h0

'''
class RNNDecoder(nn.Module):
def __init__(self, input_vocab_size, hidden_size, output_size,
num_rnn_layers=1, dropout=0.):
Expand All @@ -93,7 +93,7 @@ def forward(self, input, hidden):
output,hidden = self.gru(embed_input,hidden)
output = self.out(output.squeeze())
return output, hidden

'''
class Attn(nn.Module):
def __init__(self,method,hidden_size):
super(Attn,self).__init__()
Expand Down

0 comments on commit 926dd3e

Please sign in to comment.