Skip to content

Commit

Permalink
Merge branch 'hft-gui-hft-data-split'
Browse files Browse the repository at this point in the history
  • Loading branch information
benfoley committed Dec 6, 2021
2 parents 69c60ed + d8d0495 commit f9441ad
Showing 1 changed file with 2 additions and 21 deletions.
23 changes: 2 additions & 21 deletions elpis/engines/hft/objects/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
# Used to reduce training time when debugging
DEBUG = False
QUICK_TRAIN_BUILD_ARGUMENTS = {
'max_train_samples': '2',
'num_train_epochs': '3',
'model_name_or_path': 'facebook/wav2vec2-base',
'per_device_train_batch_size': '1',
Expand Down Expand Up @@ -632,10 +631,7 @@ def train(self, on_complete:Callable=None):
self.processor.save_pretrained(self.training_args.output_dir)

metrics = train_result.metrics
max_train_samples = (
self.data_args.max_train_samples if self.data_args.max_train_samples is not None else len(self.hft_dataset['train'])
)
metrics['train_samples'] = min(max_train_samples, len(self.hft_dataset['train']))
metrics['train_samples'] = len(self.hft_dataset['train'])

trainer.log_metrics(TRAIN, metrics)
trainer.save_metrics(TRAIN, metrics)
Expand All @@ -648,8 +644,7 @@ def train(self, on_complete:Callable=None):
if self.training_args.do_eval:
logger.info('=== Evaluate')
metrics = trainer.evaluate()
max_val_samples = self.data_args.max_val_samples if self.data_args.max_val_samples is not None else len(self.hft_dataset['dev'])
metrics['eval_samples'] = min(max_val_samples, len(self.hft_dataset['dev']))
metrics['eval_samples'] = len(self.hft_dataset['dev'])
trainer.log_metrics('eval', metrics)
trainer.save_metrics('eval', metrics)
print('=== Metrics')
Expand Down Expand Up @@ -844,20 +839,6 @@ class DataTrainingArguments:
default=None,
metadata={'help': 'The number of processes to use for the preprocessing.'},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
'help': 'For debugging purposes or quicker training, truncate the number of training examples to this '
'value if set.'
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
'help': 'For debugging purposes or quicker training, truncate the number of validation examples to this '
'value if set.'
},
)
chars_to_ignore: List[str] = list_field(
default=[',', '?', '.', '!', '-', ';', ':', '""', '%', ''', ''', '�'],
metadata={'help': 'A list of characters to remove from the transcripts.'},
Expand Down

0 comments on commit f9441ad

Please sign in to comment.