-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_da.py
71 lines (60 loc) · 2.06 KB
/
train_da.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# coding: utf-8
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
import importlib
from utils.data_io_da import prepare_datasets
import os
from speechbrain.utils.profiling import profile, schedule
import logging
# Recipe begins!
if __name__ == "__main__":
# parse command line args
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
# load hparams
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
# create experiment directory
sb.create_experiment_directory(
experiment_directory=hparams['output_dir'],
hyperparams_to_save=hparams_file,
overrides=overrides,
)
# json file preparation
# testdataset: only evaluate on target domain
dataset_name = hparams['dataset']
importlib.import_module(f'datasets.{dataset_name}.prepare_da_json').prepare(**hparams['prepare'])
# Load parsed dataset
datasets, label_encoder = prepare_datasets(hparams)
train_dataset, valid_dataset, test_dataset = datasets
# Create experiment directory
sb.create_experiment_directory(
experiment_directory=hparams["output_dir"],
hyperparams_to_save=hparams_file,
overrides=overrides,
)
# initialize model
if 'model_class' in hparams:
model_class = hparams['model_class']
SBModel = importlib.import_module(f'models.{model_class}.model').SBModel
model = SBModel(
label_encoder=label_encoder,
modules=hparams['modules'],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams['checkpointer'],
)
# fit the model
model.fit(
hparams['epoch_counter'],
train_dataset,
valid_dataset,
train_loader_kwargs=hparams['train_dataloader_opts'],
valid_loader_kwargs=hparams['valid_dataloader_opts'],
)
model.evaluate(
test_dataset,
max_key=hparams['max_key'],
min_key=hparams['min_key'],
test_loader_kwargs=hparams['test_dataloader_opts'],
)