-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_seq2seq.py
75 lines (69 loc) · 1.96 KB
/
train_seq2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import wandb
from seq2seq_attention.train import train_seq2seq_with_attention
if __name__ == "__main__":
EXP_NAME = "Uniform-Attention"
LR = 1e-4
BATCH_SIZE = 128
EPOCHS = 25
MAX_VOCAB_SIZE = 8000
MIN_FREQ = 2
ENC_EMB_DIM = 256
HIDDEN_DIM_ENC = 512
HIDDEN_DIM_DEC = 512
NUM_LAYERS_ENC = 1
NUM_LAYERS_DEC = 1
EMB_DIM_TRG = 256
DEVICE = "cuda"
TEACHER_FORCING = 0.5
TRAIN_DIR = "./data/processed/train.csv"
VAL_DIR = "./data/processed/val.csv"
TEST_DIR = "./data/processed/val.csv"
PROGRESS_BAR = False
USE_WANDB = True
DROPOUT = 0
TRAIN_ATTENTION = False
# Setup hyperparams for wandb
hyper_params = {
"lr": LR,
"batch_size": BATCH_SIZE,
"epochs": EPOCHS,
"max_vocab_size": MAX_VOCAB_SIZE,
"min_freq": MIN_FREQ,
"enc_hidden": HIDDEN_DIM_ENC,
"dec_hidden": HIDDEN_DIM_DEC,
"embedding_enc": ENC_EMB_DIM,
"embedding_dec": ENC_EMB_DIM,
"dropout": DROPOUT,
"teacher_forcing": TEACHER_FORCING,
}
# Init wandb
if USE_WANDB:
wandb.init(
project="Seq2Seq-With-Attention",
name=EXP_NAME,
# track hyperparameters and run metadata
config=hyper_params,
)
train_seq2seq_with_attention(
lr=LR,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
enc_emb_dim=ENC_EMB_DIM,
hidden_dim_enc=HIDDEN_DIM_ENC,
hidden_dim_dec=HIDDEN_DIM_DEC,
num_layers_enc=NUM_LAYERS_ENC,
num_layers_dec=NUM_LAYERS_DEC,
emb_dim_trg=EMB_DIM_TRG,
max_vocab_size=MAX_VOCAB_SIZE,
min_freq=MIN_FREQ,
device=DEVICE,
teacher_forcing=TEACHER_FORCING,
train_dir=TRAIN_DIR,
val_dir=VAL_DIR,
test_dir=TEST_DIR,
progress_bar=PROGRESS_BAR,
use_wandb=USE_WANDB,
exp_name=EXP_NAME,
dropout=DROPOUT,
train_attention=TRAIN_ATTENTION,
)