Skip to content

Commit

Permalink
Added answer only loss for prompt learning (#4069)
Browse files Browse the repository at this point in the history
* Added answer only loss for prompt learning

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Python code reformatting

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Updated prompt learning unit tests

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* removed unused import

Signed-off-by: Virginia Adams <vadams@nvidia.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
  • Loading branch information
vadam5 and okuchaiev committed Apr 27, 2022
1 parent d823318 commit da1b56c
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@ model:
total_virtual_tokens: 80 # Sum of tokens in virtual_token_splits must add to this number. Can differ between new and existing tasks, but must match across all new tasks being tuned at the same time.
virtual_token_splits: [60, 20] # number of virtual tokens to be inserted at each VIRTUAL PROMPT location, must add to total_virtual_tokens
truncate_field: 'passage' # The {field} in the prompt template whose text will be truncated if the input is too long, if null, inputs that are too long will just be skipped.
answer_only_loss: True
answer_field: 'answer'

- taskname: 'intent_and_slot'
prompt_template: '<|VIRTUAL_PROMPT_0|> intent options: {intent_options} <|VIRTUAL_PROMPT_1|> slot options: {slot_options} <|VIRTUAL_PROMPT_2|> {utterance} \nintent: {intent} \nslot: {slot}'
total_virtual_tokens: 80
total_virtual_tokens: 80
answer_only_loss: False
virtual_token_splits: [34, 33, 13]
truncate_field: null

Expand All @@ -71,6 +74,8 @@ model:
total_virtual_tokens: 100
virtual_token_splits: [100]
truncate_field: null
answer_only_loss: True
answer_field: 'answer'

prompt_tuning: # Prompt tunin specific params
new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
min_seq_length: int = 1,
add_bos: bool = False,
add_eos: bool = True,
for_train: bool = True,
):
self.tokenizer = tokenizer
self.virtual_prompt_source = virtual_prompt_source
Expand All @@ -52,6 +53,7 @@ def __init__(
self.min_seq_length = min_seq_length
self.add_bos = add_bos
self.add_eos = add_eos
self.for_train = for_train
self.examples = []

assert self.min_seq_length <= max_seq_length, "Min sequence length should be less than or equal to max"
Expand Down Expand Up @@ -97,10 +99,21 @@ def load_data(self, dataset):
prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"]
total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"]
virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"]
truncation_field = self.task_templates[taskname]['truncate_field']
answer_only_loss = self.task_templates[taskname]["answer_only_loss"]
answer_field = self.task_templates[taskname]["answer_field"]

input_example = prompt_template

self._input_sanity_checks(
total_virtual_tokens, virtual_token_splits, prompt_template, prompt_template_fields, doc
total_virtual_tokens,
virtual_token_splits,
prompt_template,
prompt_template_fields,
truncation_field,
answer_only_loss,
answer_field,
doc,
)

# Format the input example according to the template
Expand All @@ -116,7 +129,7 @@ def load_data(self, dataset):

# Try to truncate input text to fit into the max sequence length
if len(input_ids) > self.max_seq_length:
input_ids = self._truncate_input(input_ids, taskname, doc)
input_ids = self._truncate_input(truncation_field, input_ids, taskname, doc)

# Skip example if the final length doesn't fit length requirements even after truncation
if self.min_seq_length <= len(input_ids) <= self.max_seq_length:
Expand All @@ -126,15 +139,27 @@ def load_data(self, dataset):
elif self.virtual_prompt_source == "prompt-table":
taskname_id = self.task_templates[taskname]["task_id_num"]

self.examples.append((taskname_id, input_ids))
# Find answer field indices if training and answer_only_loss is True
answer_start_idx = None
if answer_only_loss and self.for_train:
answer_start_idx = self._find_answer_start(taskname, input_ids, answer_field, doc)

self.examples.append((taskname_id, input_ids, answer_start_idx))
else:
skipped += 1

logging.info(f'Skipped {skipped} sentences, sequence length too short or too long even after truncation')

def _input_sanity_checks(
self, total_virtual_tokens, virtual_token_splits, prompt_template, prompt_template_fields, doc
self,
total_virtual_tokens,
virtual_token_splits,
prompt_template,
prompt_template_fields,
truncation_field,
answer_only_loss,
answer_field,
doc,
):
# Sanity check amount of virtual token
assert total_virtual_tokens > 0, "There should be at least one virtual prompt token"
Expand All @@ -158,6 +183,19 @@ def _input_sanity_checks(
len(keys_not_in_template) == 0
), f"Examples in your dataset contain the fields: {keys_not_in_template} that are not in the task template."

# Check that answer field checks if answer_only_loss was set to True
if answer_only_loss and self.for_train:
assert answer_field is not None, "If answer_only_loss=True, an answer_field must be given"
assert (
answer_field in doc.keys()
), f"answer_only_loss=True but the given answer_field '{answer_field}' is not in data json"
assert truncation_field != answer_field, "Answer field and truncation field should not match"

answer_placeholder = "{" + answer_field + "}"
answer_placeholder_len = len(answer_placeholder)
placeholder_start = len(prompt_template) - answer_placeholder_len
assert prompt_template[placeholder_start:] == answer_placeholder, "Answer field must be at prompt end"

def _insert_text_in_template(self, input_example, prompt_template_fields, doc):
""" Format the input example according to the template """
for field in prompt_template_fields:
Expand Down Expand Up @@ -185,9 +223,8 @@ def _insert_virtual_token_placeholders(self, input_example, virtual_token_splits

return input_example

def _truncate_input(self, input_ids, taskname, doc):
def _truncate_input(self, truncation_field, input_ids, taskname, doc):
""" Try to truncate input text to fit into the max sequence length """
truncation_field = self.task_templates[taskname]['truncate_field']
logging.info(
f"Input greater than max sequence length. Attempting to truncate: '{truncation_field}' in task: '{taskname}'"
)
Expand All @@ -196,12 +233,7 @@ def _truncate_input(self, input_ids, taskname, doc):
if truncation_field is not None and truncation_field in doc.keys():
truncation_length = len(input_ids) - self.max_seq_length
field_text = doc[truncation_field]

# Add leading space to text if there is a space before it in the template
prompt_template = self.task_templates[taskname]["prompt_template"]
field_text_start = prompt_template.find("{" + truncation_field + "}")
if field_text_start != 0 and prompt_template[field_text_start - 1] == " ":
field_text = " " + field_text
field_text = self._add_leading_space(taskname, truncation_field, field_text)

# Truncate field text
field_text_ids = self.tokenizer.text_to_ids(field_text)
Expand All @@ -213,6 +245,31 @@ def _truncate_input(self, input_ids, taskname, doc):

return input_ids

def _find_answer_start(self, taskname, input_ids, answer_field, doc):
""" Find the token ids corresponding to the answer start, for loss masking purposes.
Assumes the answer is always at the end of the prompt.
"""
answer_text = doc[answer_field]
answer_text = self._add_leading_space(taskname, answer_field, answer_text)
answer_text_ids = self.tokenizer.text_to_ids(answer_text)
num_answer_text_ids = len(answer_text_ids)

if self.add_eos:
num_answer_text_ids += 1

answer_start_idx = len(input_ids) - num_answer_text_ids

return answer_start_idx

def _add_leading_space(self, taskname, field_name, field_text):
""" Add leading space to text if there is a space before it in the template """
prompt_template = self.task_templates[taskname]["prompt_template"]
field_text_start = prompt_template.find("{" + field_name + "}")
if field_text_start != 0 and prompt_template[field_text_start - 1] == " ":
field_text = " " + field_text

return field_text

def __len__(self):
return len(self.examples)

Expand All @@ -222,7 +279,7 @@ def __getitem__(self, idx):
def collate_fn(self, batch):
""" Prepares input_ids, labels, loss mask, attention_mask, and position ids for global batch """
# Get max sequence length of batch
taskname_ids, input_ids = zip(*batch)
taskname_ids, input_ids, answer_starts = zip(*batch)

# Pad taskname_ids to be the same length for the prompt encoder
if self.virtual_prompt_source == "prompt-encoder":
Expand All @@ -235,7 +292,7 @@ def collate_fn(self, batch):
taskname_ids = torch.tensor(taskname_ids)

batch_max = max(len(ids) for ids in input_ids)
input_ids, loss_mask = self.pad_batch_and_build_loss_mask(input_ids, batch_max)
input_ids, loss_mask = self.pad_batch_and_build_loss_mask(input_ids, batch_max, answer_starts)

# Should be a label for every token in batch, label is the next token
labels = input_ids[:, 1:].contiguous()
Expand All @@ -257,12 +314,16 @@ def collate_fn(self, batch):

return input_ids, labels, loss_mask, position_ids, attention_mask, taskname_ids

def pad_batch_and_build_loss_mask(self, input_ids, batch_max):
def pad_batch_and_build_loss_mask(self, input_ids, batch_max, answer_starts):
""" Pad input_ids in batch to max batch length while building loss mask """
batch_loss_masks = []
for ids in input_ids:
# Loss mask where virtual tokens are 0.0 and all other tokens are 1.0
loss_mask = [float(token_id not in self.pseudo_token_ids) for token_id in ids]
for ids, answer_start_idx in zip(input_ids, answer_starts):
if answer_start_idx is not None:
# Loss mask where answer tokens are 1.0 and all other tokens are 0.0
loss_mask = [float(idx >= answer_start_idx) for idx in range(len(ids))]
else:
# Loss mask where virtual tokens are 0.0 and all other tokens are 1.0
loss_mask = [float(token_id not in self.pseudo_token_ids) for token_id in ids]

# Pad to max length
input_length = len(ids)
Expand All @@ -283,12 +344,13 @@ def get_all_examples(self, tokens_to_generate):
"""
Used for loading inference data.
"""
task_id_nums, input_ids = zip(*self.examples)
task_id_nums, input_ids, answer_starts = zip(*self.examples)
input_lengths = torch.cuda.LongTensor([len(inputs) for inputs in input_ids])
task_id_nums = torch.cuda.LongTensor(task_id_nums)
batch_max = input_lengths.max().item()
batch_max += tokens_to_generate

input_ids, _ = self.pad_batch_and_build_loss_mask(input_ids, batch_max + tokens_to_generate)
input_ids, _ = self.pad_batch_and_build_loss_mask(input_ids, batch_max, answer_starts)
input_ids = input_ids.cuda()
input_ids = torch.cuda.LongTensor(input_ids)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def load_task_templates(self, task_templates):
self.task_templates[task.taskname] = {
"prompt_template": task.prompt_template,
"prompt_template_fields": re.findall("\{(.*?)\}", task.prompt_template),
"answer_only_loss": task.get("answer_only_loss", False),
"answer_field": task.get("answer_field", None),
"truncate_field": task.truncate_field,
"total_virtual_tokens": task.total_virtual_tokens,
"virtual_token_splits": task.virtual_token_splits,
Expand Down Expand Up @@ -496,6 +498,7 @@ def setup_training_data(self, training_data_config=None):
self._train_ds, self._train_dl = self.build_virtual_prompt_dataset(
dataset_paths=self.cfg.data.train_ds,
batch_size=self.cfg.batch_size,
for_train=True,
drop_last=True,
shuffle=True,
num_workers=self.cfg.data.num_workers,
Expand All @@ -507,6 +510,7 @@ def setup_validation_data(self, validation_data_config=None):
self._validation_ds, self._validation_dl = self.build_virtual_prompt_dataset(
dataset_paths=self.cfg.data.validation_ds,
batch_size=self.cfg.batch_size,
for_train=True,
drop_last=True,
shuffle=False,
num_workers=self.cfg.data.num_workers,
Expand All @@ -518,13 +522,16 @@ def setup_test_data(self, test_data_config=None):
self._test_ds, self._test_dl = self.build_virtual_prompt_dataset(
dataset_paths=self.cfg.data.test_ds,
batch_size=self.cfg.batch_size,
for_train=False,
drop_last=False,
shuffle=False,
num_workers=self.cfg.data.num_workers,
pin_memory=True,
)

def build_virtual_prompt_dataset(self, dataset_paths, batch_size, drop_last, shuffle, num_workers, pin_memory):
def build_virtual_prompt_dataset(
self, dataset_paths, batch_size, for_train, drop_last, shuffle, num_workers, pin_memory
):
dataset = GPTPromptLearningDataset(
datasets=dataset_paths,
tokenizer=self.tokenizer,
Expand All @@ -536,6 +543,7 @@ def build_virtual_prompt_dataset(self, dataset_paths, batch_size, drop_last, shu
min_seq_length=self.cfg.data.get('min_seq_length', 1),
add_bos=self.cfg.data.get('add_bos', False),
add_eos=self.cfg.data.get('add_eos', True),
for_train=for_train,
)

dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -587,6 +595,7 @@ def dummy():
min_seq_length=self.cfg.data.get('min_seq_length', 1),
add_bos=sampling_params["add_BOS"],
add_eos=False,
for_train=False,
)
task_ids, processed_inputs = dataset.get_all_examples(tokens_to_generate=length_params['max_length'])
self.model.model.parallel_output = False
Expand Down
6 changes: 6 additions & 0 deletions tests/collections/nlp/test_prompt_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,19 @@ def get_task_templates():
"prompt_template_fields": ['text', 'answer'],
"total_virtual_tokens": 5,
"virtual_token_splits": [5],
"truncate_field": None,
"answer_only_loss": True,
"answer_field": "answer",
"task_id_num": 0,
}
task_templates['task name B'] = {
"prompt_template": "<|VIRTUAL_PROMPT_0|>{question}<|VIRTUAL_PROMPT_1|>{answer}{extra}",
"prompt_template_fields": ['question', 'answer'],
"total_virtual_tokens": 10,
"virtual_token_splits": [7, 3],
"truncate_field": None,
"answer_only_loss": False,
"answer_field": None,
"task_id_num": 1,
}
return task_templates
Expand Down

0 comments on commit da1b56c

Please sign in to comment.