-
Notifications
You must be signed in to change notification settings - Fork 0
/
Train_S-CNF.yaml
120 lines (100 loc) · 3.12 KB
/
Train_S-CNF.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# #################################
# Basic training parameters for S-CNF
# #################################
seed: 929
__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
device: cuda
scp_folder: ./data
output_folder: !ref ../test_<seed>_1
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
train_annotation: !ref <scp_folder>/Train.json
valid_annotation: !ref <scp_folder>/Validation.json
test_annotation: !ref <scp_folder>/Test.json
label_path: !ref <scp_folder>/lab.npy
label_path_maj: !ref <scp_folder>/lab-maj.npy
fea_path: !ref <scp_folder>/wav.npy
dataloader_options:
batch_size: !ref <batch_size>
shuffle: False
collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
padded_keys: sig
sorting: descending
acc_stats: !name:modules.MetricStats_Acc
# Training Parameters
ckpt_interval_minutes: 15
test_only: False
number_of_epochs: 40
batch_size: 64
lr: 1.2
dp: 0.2
gradient_accumulation: 1
num_samples: 50
num_elbo: 20
# model Parameters
input_dim: 768
output_dim: 5 # Emotion: 'neu':0,'hap':1,'sad':2,'ang':3,'oth':4
# output_dim: 3 # Hate: "normal":0, "offensive":1,"hatespeech":2
freeze_SSL: True
freeze_SSL_conv: True
SSL_hub: "microsoft/wavlm-base-plus"
num_pretrain_layers: 13 # Emotion
# num_pretrain_layers: 1 # Hate
nhead: 4
num_trans_encoder: 2
d_trans: 128
num_fc: 2
d_fc: 128
flow_num_block: 3
nvp_hidden_width: 64
softmax_num_block: 1
softmax_hidden_dim: 64
# upstream model
SSLModel: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
source: !ref <SSL_hub>
output_norm: False
freeze: !ref <freeze_SSL>
freeze_feature_extractor: !ref <freeze_SSL_conv>
output_all_hiddens: True
save_path: ../data/wav2vec2_checkpoint
# downstream model
Transformer_model: !new:modules.TransformerModel
input_dim: !ref <input_dim>
output_dim: !ref <output_dim>
num_pretrain_layers: !ref <num_pretrain_layers>
d_trans: !ref <d_trans>
nhead: !ref <nhead>
num_encoder_layers: !ref <num_trans_encoder>
num_fc: !ref <num_fc>
d_fc: !ref <d_fc>
dp: !ref <dp>
device: !ref <device>
softmax_enc: !new:modules.linear_enc
input_dim: !ref <output_dim>
dnn_blocks: !ref <softmax_num_block>
dnn_neurons: !ref <softmax_hidden_dim>
output_dim: !ref <output_dim>
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
modules:
SSL_encoder: !ref <SSLModel>
feature_extractor: !ref <Transformer_model>
softmax_encoder: !ref <softmax_enc>
opt_class: !name:torch.optim.Adadelta
lr: !ref <lr>
rho: 0.95
eps: 1.e-8
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.001
annealing_factor: 0.8
patient: 1
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
Transformer_model: !ref <Transformer_model>
softmax_encoder: !ref <softmax_enc>
counter: !ref <epoch_counter>
scheduler: !ref <lr_annealing>