-
Notifications
You must be signed in to change notification settings - Fork 16
/
train_iql.yaml
114 lines (108 loc) · 2.29 KB
/
train_iql.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
defaults:
- model: per_token_iql
- dataset@train_dataset: list_train
- dataset@eval_dataset: list_val
- evaluator: iql_evaluator
- _self_
train_dataset:
cache_id: d_train
data:
reward_cache: data/vis_dialogue/processed/visdial_0.5/train_rank_reward_cache1.json
mode: env_stops
cutoff_rule:
name: percentile_cutoff_rule
goal_value: 1.0
percentile: 0.5
yn_reward: -2.0
yn_reward_kind: hard
eval_dataset:
cache_id: d_eval
data:
reward_cache: data/vis_dialogue/processed/visdial_0.5/val_rank_reward_cache1.json
mode: env_stops
cutoff_rule:
name: percentile_cutoff_rule
goal_value: 1.0
percentile: 0.5
yn_reward: -2.0
yn_reward_kind: hard
model:
alpha: 0.005
gamma: 0.99
beta: 0.0
transition_weight: 0.0
clip_weight: null
value_max: null
value_min: null
detach_v: false
detach_q: false
detach_pi: false
double_q: true
seperate_policy: true
seperate_target: true
tau: 0.8
exp_weights: true
dm_margin: 0.0
advanced_mlp: false
cql_temp: 1.0
gpt2:
lm_head: true
from_pretrained: true
dataset:
name: vis_dial_list_dataset
cache_id: d_train
load:
checkpoint_path: null
strict_load: true
evaluator:
env:
url: http://localhost:5000/step_rank
actor_stop: false
dataset:
name: vis_dial_list_dataset
cache_id: d_eval
yn_reward: -2.0
yn_reward_kind: hard
verbose: true
kind: sample
generation_kwargs:
max_generation_len: 40
# beam_width: 1
temp: 1.0
top_k: null
top_p: null
exp_adv: true
adv_weight: 16.0
adv_clip: null
include_logits: true
include_adv: true
num_generations: 1
rerank_log_prob_weight: 0.0
rerank_advantage_weight: 1.0
train:
save_checkpoint_dir: outputs/visual_dialogue/visdial_hard_yn_iql_test1/
optim_state_path: null
epochs: 10000000
dataloader_workers: 1
bsize: 1
grad_accum_steps: 64
log_every: 256
eval_every: 4096
save_every: 32768
max_checkpoints: 1
eval_bsize: 1
eval_batches: 32
lr: 1e-5
weight_decay: 0.00
hard_update_every: null
max_steps: null
loss:
v_loss_weight: 1.0
q_loss_weight: 1.0
awac_weight: 1.0
cql_loss_weight: 1.0
dm_loss_weight: 0.0
mc_returns: false
wandb:
use_wandb: true
wandb_project: visdial_iql