Skip to content

Commit

Permalink
Added max_sample_ arguments (huggingface#10551)
Browse files Browse the repository at this point in the history
* reverted changes of logging and saving metrics

* added max_sample arguments

* fixed code

* white space diff

* reformetting code

* reformatted code
  • Loading branch information
bhadreshpsavani authored and Iwontbecreative committed Jul 15, 2021
1 parent 4387c37 commit 517451e
Show file tree
Hide file tree
Showing 14 changed files with 517 additions and 119 deletions.
54 changes: 44 additions & 10 deletions examples/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ class DataTrainingArguments:
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)

block_size: Optional[int] = field(
default=None,
metadata={
Expand Down Expand Up @@ -346,19 +361,34 @@ def group_texts(examples):
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)

if training_args.do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = lm_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))

if training_args.do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = lm_datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))

# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_datasets["train"] if training_args.do_train else None,
eval_dataset=lm_datasets["validation"] if training_args.do_eval else None,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
Expand All @@ -377,24 +407,28 @@ def group_texts(examples):

metrics = train_result.metrics

max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")

eval_output = trainer.evaluate()

perplexity = math.exp(eval_output["eval_loss"])
results["perplexity"] = perplexity
metrics = trainer.evaluate()

trainer.log_metrics("eval", results)
trainer.save_metrics("eval", results)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"])
metrics["perplexity"] = perplexity

return results
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


def _mp_fn(index):
Expand Down
53 changes: 43 additions & 10 deletions examples/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ class DataTrainingArguments:
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)

def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
Expand Down Expand Up @@ -380,13 +394,28 @@ def group_texts(examples):
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

tokenized_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)

if training_args.do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = tokenized_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))

if training_args.do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = tokenized_datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))

# Data collator
# This one will take care of randomly masking the tokens.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
Expand All @@ -395,8 +424,8 @@ def group_texts(examples):
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
)
Expand All @@ -413,24 +442,28 @@ def group_texts(examples):
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics

max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")

eval_output = trainer.evaluate()

perplexity = math.exp(eval_output["eval_loss"])
results["perplexity"] = perplexity
metrics = trainer.evaluate()

trainer.log_metrics("eval", results)
trainer.save_metrics("eval", results)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"])
metrics["perplexity"] = perplexity

return results
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


def _mp_fn(index):
Expand Down
53 changes: 43 additions & 10 deletions examples/language-modeling/run_plm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,20 @@ class DataTrainingArguments:
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)

def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
Expand Down Expand Up @@ -358,13 +372,28 @@ def group_texts(examples):
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

tokenized_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)

if training_args.do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = tokenized_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))

if training_args.do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = tokenized_datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))

# Data collator
data_collator = DataCollatorForPermutationLanguageModeling(
tokenizer=tokenizer,
Expand All @@ -376,8 +405,8 @@ def group_texts(examples):
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
)
Expand All @@ -394,24 +423,28 @@ def group_texts(examples):
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics

max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")

eval_output = trainer.evaluate()

perplexity = math.exp(eval_output["eval_loss"])
results["perplexity"] = perplexity
metrics = trainer.evaluate()

trainer.log_metrics("eval", results)
trainer.save_metrics("eval", results)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"])
metrics["perplexity"] = perplexity

return results
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


def _mp_fn(index):
Expand Down
65 changes: 51 additions & 14 deletions examples/multiple-choice/run_swag.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ class DataTrainingArguments:
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)

def __post_init__(self):
if self.train_file is not None:
Expand Down Expand Up @@ -328,12 +342,31 @@ def preprocess_function(examples):
# Un-flatten
return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}

tokenized_datasets = datasets.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_train:
train_dataset = datasets["train"]
if "train" not in datasets:
raise ValueError("--do_train requires a train dataset")
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)

if training_args.do_eval:
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)

# Data collator
data_collator = (
Expand All @@ -352,8 +385,8 @@ def compute_metrics(eval_predictions):
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
Expand All @@ -371,21 +404,25 @@ def compute_metrics(eval_predictions):
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics

max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")

results = trainer.evaluate()

trainer.log_metrics("eval", results)
trainer.save_metrics("eval", results)
metrics = trainer.evaluate()
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))

return results
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


def _mp_fn(index):
Expand Down

0 comments on commit 517451e

Please sign in to comment.