Skip to content

Commit

Permalink
attentional rnn
Browse files Browse the repository at this point in the history
  • Loading branch information
Santolaya committed Apr 3, 2017
1 parent 483c9e0 commit 9df3c82
Show file tree
Hide file tree
Showing 6 changed files with 1,499 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def generate_validation_set(X_train, Y_train, X_test):
model.create_model()
model.train(ds)
#for rep4 obtain only test samples with interactions
if (representation == 4) or (representation == 5):
if (representation == 4) or (representation == 5) or (representation == 7) or (representation == 8):
indices_interactions = []
for i in range(len(X_test)):
if np.count_nonzero(X_test[i]) > 0:
Expand Down Expand Up @@ -452,7 +452,7 @@ def evaluate_sample(predictions, y_true, k):
recall_users.append(recall_user)
total_true_pos_k += true_pos_k
total_pos += num_pos
elif (representation==4) or (representation==5):
elif (representation==4) or (representation==5) or (representation ==7) or (representation==8):
print('Local test rep 4')
pred_local_test = np.zeros((len(ds._X_local_test), model_parameters['n_output']))
pred_local_test = model.predict(ds._X_local_test)
Expand Down
Loading

0 comments on commit 9df3c82

Please sign in to comment.