-
Notifications
You must be signed in to change notification settings - Fork 8
/
ViT.yaml
135 lines (116 loc) · 2.26 KB
/
ViT.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
base: &base
# Model config
embed_dim: 384
depth: 12
dropout: 0.0
patch_size: 8
num_heads: 8
# Training config
img_size: [360, 720]
dt: 1
global_batch_size: 16 # number of samples per training batch
num_iters: 30000
amp_mode: none
enable_fused: false
enable_jit: false
expdir: '/logs'
lr_schedule: 'cosine'
lr: 5E-4
warmup: 0
optimizer: 'Adam'
# Data
data_loader_config: 'pytorch'
num_data_workers: 0 # number of dataloader worker threads per proc
n_in_channels: 20
n_out_channels: 20
train_data_path: '/data/train'
valid_data_path: '/data/valid'
inf_data_path: '/data/test'
time_means_path: '/data/stats/time_means.npy'
global_means_path: '/data/stats/global_means.npy'
global_stds_path: '/data/stats/global_stds.npy'
limit_nsamples: None
limit_nsamples_val: None
# Comms
wireup_info: env
wireup_store: tcp
# limit the number of samples
short: &short_ls
<<: *base
limit_nsamples: 512
limit_nsamples_val: 128
num_iters: 128
# add optimization flags
short_opt:
<<: *short_ls
data_loader_config: 'dali'
num_data_workers: 8
amp_mode: fp16
enable_jit: true
enable_fused: true
# no samples limits
opt: &opt
<<: *base
data_loader_config: 'dali'
num_data_workers: 8
amp_mode: fp16
num_iters: 30000
enable_fused: True
enable_apex: True
# ----- Data parallel scaling configs
bs16_opt:
<<: *opt
global_batch_size: 16
lr: 5e-4
bs32_opt:
<<: *opt
global_batch_size: 32
lr: 7.07e-4
bs64_opt:
<<: *opt
global_batch_size: 64
lr: 1e-3
bs128_opt:
<<: *opt
global_batch_size: 128
lr: 1.41e-3
bs256_opt:
<<: *opt
global_batch_size: 256
lr: 2e-3
bs512_opt:
<<: *opt
global_batch_size: 512
lr: 2.83e-3
bs1024_opt:
<<: *opt
global_batch_size: 1024
lr: 4e-3
bs2048_opt:
<<: *opt
global_batch_size: 2048
lr: 5.66e-3
# Model parallel configs
mp: &mp
<<: *base
num_iters: 30000
global_batch_size: 64
lr: 1e-3
num_data_workers: 8
embed_dim: 1024 # change to bigger model
data_loader_config: 'dali'
amp_mode: fp16
enable_jit: true
enable_fused: true
mp_bs16:
<<: *mp
global_batch_size: 16
lr: 5e-4
mp_bs32:
<<: *mp
global_batch_size: 32
lr: 7.07e-4
# larger seq length (use local bs = 1 here)
mp_patch4:
<<: *mp
patch_size: 4