Skip to content

Commit

Permalink
additional datasets support
Browse files Browse the repository at this point in the history
  • Loading branch information
Bjørnar Vassøy committed Jun 11, 2018
1 parent 86363c0 commit 3eed212
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
11 changes: 6 additions & 5 deletions dynamic_model.py
Expand Up @@ -25,17 +25,18 @@
lastfm_simple = "lastfm_sim"
lastfm3 = "lastfm3"
instacart = "instacart"
nowplaying = "nowplaying"

#runtime settings
flags = {}
dataset = lastfm
flags["context"] = True
flags["temporal"] = True
SEED = 1
GPU = 1
SEED = 2
GPU = 0
gap_strat = ""
add = "_" if gap_strat != "" else ""
directory = "/data/stud/bjorva/logs/sqrt/" + gap_strat + add
directory = "/data/stud/bjorva/logs/dim/" + gap_strat + add
debug = False

torch.manual_seed(SEED)
Expand All @@ -62,7 +63,7 @@
params["ALPHA"] = 0.45
params["BETA"] = 0.45
params["GAMMA"] = 0.1
params["EPSILON"] = 0.5
params["EPSILON"] = 1.0
flags["use_day"] = True

#data path and log/model-name
Expand All @@ -88,7 +89,7 @@
MAX_EPOCHS -= 10
min_time = 1.0
flags["freeze"] = False
elif dataset == lastfm or dataset == lastfm_simple or dataset == lastfm_time:
elif dataset == lastfm or dataset == lastfm_simple or dataset == lastfm_time or dataset == nowplaying:
dims["EMBEDDING_DIM"] = 100
params["lr"] = 0.001
params["dropout"] = 0.2
Expand Down
4 changes: 2 additions & 2 deletions hawkes_baseline.py
Expand Up @@ -12,7 +12,7 @@

#global settings
USE_DAY = True
dataset = reddit_time
dataset = lastfm
n_decimals = 4

#parameters
Expand All @@ -23,7 +23,7 @@

#switchable
full_hist = True
gap_strat = ""
gap_strat = "hawkes"

add = "_" if gap_strat != "" else ""
pickle_path = "hawkes_full_" + dataset + add + gap_strat + "4.pickle"
Expand Down
3 changes: 2 additions & 1 deletion intra.py
Expand Up @@ -13,11 +13,12 @@
#datasets
reddit = "subreddit"
lastfm = "lastfm"
nowplaying = "nowplaying"

#set current dataset here
dataset = lastfm
dataset_path = "/data/stud/bjorva/datasets/" + dataset + "/4_train_test_split.pickle"
pickle_path = "/data/stud/bjorva/logs/intra/"
pickle_path = "/data/stud/bjorva/logs/nowplaying/"

#universal settings
BATCHSIZE = 100
Expand Down

0 comments on commit 3eed212

Please sign in to comment.