From 19afcea6f517137147c3e0837cc4c78d28fe34ea Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Fri, 7 Jun 2024 18:43:05 -0400 Subject: [PATCH 1/2] trice [WIP]: implemented baseline trice, basic grad estimate. no control variate. tested on one GPU only. --- configs/config_gemma.yaml | 13 +- examples/llama_example.py | 52 ++- vectorlm/sampling/__init__.py | 1 + vectorlm/sampling/abstract.py | 15 + vectorlm/sampling/utils.py | 42 ++- vectorlm/tests/test_sampling.py | 28 ++ vectorlm/tests/test_trice.py | 18 + vectorlm/trainer.py | 2 + vectorlm/trice.py | 574 ++++++++++++++++++++++++++++++++ 9 files changed, 709 insertions(+), 36 deletions(-) create mode 100644 vectorlm/tests/test_sampling.py create mode 100644 vectorlm/tests/test_trice.py create mode 100644 vectorlm/trice.py diff --git a/configs/config_gemma.yaml b/configs/config_gemma.yaml index fb2b6fa..2d3162d 100644 --- a/configs/config_gemma.yaml +++ b/configs/config_gemma.yaml @@ -1,4 +1,4 @@ -model: google/gemma-2b +model: google/gemma-2b-it enable_wandb_logging: False wandb_config: @@ -7,9 +7,9 @@ wandb_config: # tags: ["20240418-1a-preemption"] train_parameters: - output_dir: weights - max_seq_len: 128 - epochs: 10 + output_dir: /network/scratch/j/jacob-junqi.tian/vectorlm/weights + max_seq_len: 1024 + epochs: 100000 seed: 11 # Sharding strategy @@ -33,10 +33,11 @@ train_parameters: # Gradient norm clipping max_grad_norm: 1 gradient_accumulation_steps: 4 + batch_size: 2 # Optimizer optimizer: - lr: 1.0e-4 + lr: 2.0e-5 weight_decay: 0.1 betas: [0.9, 0.95] eps: 1.0e-5 @@ -47,7 +48,7 @@ train_parameters: # Checkpointing checkpointing_enabled: False - logging_steps: 10 + logging_steps: 100 save_frequency: 0.10 # Sampling during training diff --git a/examples/llama_example.py b/examples/llama_example.py index abe76b4..ffb3745 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -14,7 +14,7 @@ from transformers import set_seed from vectorlm.dataset import Dataset -from vectorlm.trainer import Trainer +from vectorlm.trice import ICEMTrainer from vectorlm.utils.data_utils import Config from vectorlm.utils.misc_utils import cleanup, setup, wandb_setup from vectorlm.utils.model_utils import ( @@ -152,7 +152,7 @@ def main( ) # instantiate trainer - trainer = Trainer( + trainer = ICEMTrainer( config=training_args, enable_wandb_logging=config.enable_wandb_logging, original_dataset_length=dataset.original_length, @@ -186,33 +186,27 @@ def main( # Checkpoint check. Always call before training. # If no checkpoint, it returns 0. - checkpointed_epoch = trainer.find_checkpoint(training_args.output_dir) - - for epoch in range(checkpointed_epoch, training_args.epochs): - train_dl_iterator = iter(dataset.train_dataloader) - for _ in tqdm( - range(len(dataset.train_dataloader)), - disable=rank != 0, - file=sys.__stdout__, - ): - batch = next(train_dl_iterator) - trainer.step(batch, epoch) - - if epoch == training_args.epochs - 1: - hf_save_dir = os.path.join(training_args.output_dir, "final-model") - else: - hf_save_dir = os.path.join( - training_args.output_dir, - "checkpoints", - f"epoch_{epoch}", - "end-epoch-model", - ) - - if is_lora_enabled: - save_peft_adapter(trainer.model, hf_save_dir) - else: - save_consolidated_model(trainer.model, hf_save_dir, rank) - dataset.reset_dataloaders() + trainer.model.train() + trainer.find_checkpoint(training_args.output_dir) + + pbar = tqdm( + range(config.train_parameters.epochs), + disable=rank != 0, + file=sys.__stdout__, + ncols=75, + ) + for index in pbar: + eval_acc = 0 + train_loss, eval_output = trainer.step({}, index) + eval_acc = eval_output if eval_output is not None else eval_acc + + pbar.set_description(f"{train_loss:.3f}, {eval_acc * 100:.0f}%") + + if is_lora_enabled: + save_peft_adapter(trainer.model, hf_save_dir) + else: + save_consolidated_model(trainer.model, hf_save_dir, rank) + dataset.reset_dataloaders() sys.exit(0) diff --git a/vectorlm/sampling/__init__.py b/vectorlm/sampling/__init__.py index ebf9f1c..4af88d1 100644 --- a/vectorlm/sampling/__init__.py +++ b/vectorlm/sampling/__init__.py @@ -5,6 +5,7 @@ ManagedMultiProcGPUExecutor, SamplingEngineProvider, SynchronizationBarriers, + batch_process, handle_sample, multiprocess_wrap, ) diff --git a/vectorlm/sampling/abstract.py b/vectorlm/sampling/abstract.py index c2f3cdd..0fb6454 100644 --- a/vectorlm/sampling/abstract.py +++ b/vectorlm/sampling/abstract.py @@ -59,6 +59,10 @@ def generate( Invoke at all ranks. Output will be broadcasted to all ranks. + Only one thread should execute this method at a time. For performance, + supply a large number of prompts at a time instead of one prompt + at a time. + Params: ------ prompts: List of input prompts. @@ -70,3 +74,14 @@ def generate( Output from vllm: list[vllm.RequestOutput], one for each prompt. """ + + def generate_text_only( + self, + prompts: list[str], + sampling_params: vllm.SamplingParams | None = None, + ) -> list[str]: + """Generate and return text only.""" + return [ + response.outputs[0].text + for response in self.generate(prompts, sampling_params) + ] diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index df233c3..400b005 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -2,12 +2,21 @@ from __future__ import annotations +import concurrent.futures import json import os import threading import time from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + NamedTuple, + TypeVar, +) from vllm import LLM, LLMEngine, SamplingParams from vllm.engine.arg_utils import EngineConfig @@ -59,6 +68,8 @@ class SynchronizationBarriers(NamedTuple): Fn = TypeVar("Fn", bound=Callable[..., Any]) +InputItem = TypeVar("InputItem") +OutputItem = TypeVar("OutputItem") def multiprocess_wrap(fn: Fn | None, barriers: SynchronizationBarriers) -> Fn: @@ -116,6 +127,35 @@ def _wrapped_fn(*args, **kwargs) -> ...: # noqa: ANN002,ANN003 return _wrapped_fn # type: ignore[reportReturnType] +def batch_process( + input_data: Iterable[InputItem], + fn: Callable[[Iterable[InputItem]], Iterable[OutputItem]], + max_batch_size: int, +) -> Iterator[OutputItem]: + """Process input data one batch at a time. + + Params: + ------ + input_data: iterator of data to enter into fn. + fn: function that accepts a batch of data and produces + an output of equal length. + max_batch_size: maximum size of a batch. + + Yields + ------ + Iterator of output. + + """ + input_batch: list[InputItem] = [] + for input_item in input_data: + input_batch.append(input_item) + if len(input_batch) == max_batch_size: + yield from fn(input_batch) + input_batch = [] + + yield from fn(input_batch) + + class ManagedMultiProcGPUExecutor(MultiprocessingGPUExecutor): """MultiProcGPUExecutor, but with VectorLM launched alongside vLLM. diff --git a/vectorlm/tests/test_sampling.py b/vectorlm/tests/test_sampling.py new file mode 100644 index 0000000..a1b053e --- /dev/null +++ b/vectorlm/tests/test_sampling.py @@ -0,0 +1,28 @@ +from vectorlm.sampling import batch_process + +RATIONALE_ANSWER_REGEXP = r"(.+)\s\(([A-C])\)[^\(\)]*$" + + +def test_batch_process() -> None: + """Test batch_process.""" + example_input = list("banana") + output = [] + for output_item in batch_process(example_input, lambda x: x, 5): + print(output_item) + output.append(output_item) + + assert output == example_input + + +def test_parsing_rationale() -> None: + """Test parsing rationales.""" + example_input = [ + "\nThe blue jay is to the right of the quail, and the falcon is to the right of the blue jay. So, the order is (B) The quail is the second from the left.", + "\n\nThe motorcyle is newer than the limousine. The convertible is newer than the motorcyle. So, the correct option is (C) The convertible is the oldest.", + "\n\nFor the first paragraph, the object arranged to the right is a blue book. Since the blue book is the rightmost, it must be the second object from the left. \n\nIn the second paragraph about the orange book, it must be the object located two positions left of the blue book, which is indeed the leftmost object. \n\nIn the third paragraph, the object that must be two positions left of the orange book is the red book, making option (A) the correct answer.", + "\n\nThe robin, crow, and blue Jay form a linear order, so based on the position of these birds, the robin must be the rightmost.", + "\nThe green book is the first object on the shelf. The red book is the second object on the shelf. The blue book is the third object on the shelf. Therefore, the green book is the rightmost.", + "\nAccording to the statement, the mangoes are less expensive than the peaches, which are less expensive than the apples. Therefore, the mango should be the third-most expensive. So, the correct answer is **(C) The mangoes are the second-most expensive**.", + "\n\n**Paragraph 1:** The tractor is older than the truck. So, the tractor should be the third object in the order.\n\n**Paragraph 2:** The minivan is newer than the truck. Therefore, the minivan should be the third object in the order.\n\n**Paragraph 3:** The tractor is older than the truck. So, the tractor should be the first object in the order.\n\nTherefore, the answer is (A) The tractor is the newest.", + "\nEve's position is fixed. Rob finished below Mel, who finished below Eve. Therefore, Eve finished last. So (A) Eve finished first is the answer.", + ] diff --git a/vectorlm/tests/test_trice.py b/vectorlm/tests/test_trice.py new file mode 100644 index 0000000..7065638 --- /dev/null +++ b/vectorlm/tests/test_trice.py @@ -0,0 +1,18 @@ +import pytest +import torch + +from vectorlm.trice import masked_clm_loss + + +def test_masked_clm_loss( + batch_size: int = 2, + seq_length: int = 8, + vocab_size: int = 12, +) -> None: + """Test partially masked next-token loss fn.""" + logits = torch.ones((batch_size, seq_length, vocab_size)) + input_ids = torch.ones((batch_size, seq_length), dtype=torch.long) + loss_multiplier = torch.ones((batch_size, seq_length), dtype=torch.long) + + loss = masked_clm_loss(logits, input_ids, loss_multiplier) + print(loss) diff --git a/vectorlm/trainer.py b/vectorlm/trainer.py index 8cd9661..b52b15d 100644 --- a/vectorlm/trainer.py +++ b/vectorlm/trainer.py @@ -150,6 +150,8 @@ def prepare_trainer( self.lr_scheduler = lr_scheduler self.sampling_engine = sampling_engine + if self.sampling_engine is not None: + self.sampling_engine.update(self.model, self.tr_step) self.is_peft_adapter_restored = is_peft_adapter_restored diff --git a/vectorlm/trice.py b/vectorlm/trice.py new file mode 100644 index 0000000..58c2fe5 --- /dev/null +++ b/vectorlm/trice.py @@ -0,0 +1,574 @@ +"""TRICE implementation in VectorLM vLLM.""" + +from __future__ import annotations + +import contextlib +import json +import os +import random +import re +import time +from functools import partial +from typing import Any, Callable, Iterable, NamedTuple, TypeVar + +import requests +import torch +import torch.distributed +from tqdm.auto import tqdm +from vllm import SamplingParams + +from vectorlm.sampling import batch_process +from vectorlm.trainer import Trainer, _gather + +DatasetIndex = str + +START_TIME = int(time.time()) +SAMPLING_BATCH_SIZE = 8 +NUM_DEMONSTRATIONS = { + # prompt for bootstraping high-quality memory using hinted rationales. + "few_shot_rationales": 3, +} + +ZERO_SHOT_PROMPT = """\ +Question: {question} +Answer: Let's think step-by-step. \ +""" + +GUIDE_TEMPLATE = """\ +Question: {question} +Answer: Let's think step by step. {rationale}{answer}\ +""" + +FEW_SHOT_DELIMITER = "\n\n---\n" + +GUIDE_TEMPLATE_FULL = """\ +{few_shot_exemplars}\ +{few_shot_delimiter}\ +Question: {question} +Answer: Let's think step by step. \ +""" + + +# capture groups: rationale, answer. +RATIONALE_ANSWER_REGEXP = r"(.+the answer is )(\([A-C]\))[^\(\)]*$" + + +class Question(NamedTuple): + """A question-answer pair.""" + + question_text: str + answer: str + + +class Rationale(NamedTuple): + """A question-rationale pair. + + The "parsed_answer" field refers to the answer from model output, + which might be different from the ground truth reference + in question.answer. + + "parsed_answer" is None if output isn't parse-able. + """ + + question: Question + raw_prompt: str + rationale: str + parsed_answer: str | None = None + + def serialize(self) -> dict[str, Any]: + """Produce JSON-friendly representation of self.""" + output = self._asdict() + output["question"] = self.question._asdict() + if not bool(os.environ.get("PRINT_VERBOSE_MEMORY", 0)): + output.pop("raw_prompt") + + return output + + +def generate_rationale_answer( + questions: Iterable[Question], + batch_generate_fn: Callable[[list[str]], list[str]], + prompt_template: str = ZERO_SHOT_PROMPT, + require_match: bool = False, +) -> list[Rationale]: + """Generate rationale and answer to each of the given questions. + + Params: + ------ + question_texts: list of Question. + batch_generate_fn: Generate text continuations for each given query. + prompt_template: str, where the only placeholder is "question". + require_match: bool, enable to exclude rationales that don't match + RATIONALE_ANSWER_REGEXP. + + Returns + ------- + List of Rationales for parsed examples. + + """ + queries = [ + prompt_template.format(question=question.question_text) + for question in questions + ] + responses = batch_generate_fn(queries) + + output: list[Rationale] = [] + for question, query, response in zip(questions, queries, responses): + match = re.match(RATIONALE_ANSWER_REGEXP, response, re.DOTALL) + if match is None: + if require_match: + continue + rationale, answer = response, None + else: + rationale, answer = match.groups() + output.append(Rationale(question, query, rationale, answer)) + + import json + + with open("data/output-20240527-1a.jsonl", "a") as output_file: + output_file.write(json.dumps(output[-1]._asdict()) + "\n") + + return output + + +V = TypeVar("V") + + +def _index(items: list[V]) -> dict[DatasetIndex, V]: + """Convert list of items to a dict mapping index to item.""" + return {str(index): item for index, item in enumerate(items)} + + +def get_dataset() -> ( + tuple[dict[DatasetIndex, Question], dict[DatasetIndex, Question]] +): + """Get train and validation datasets. + + Returns + ------- + train_questions, test_questions + + """ + task = "logical_deduction_three_objects" + data_url = f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/bbh/{task}.json" + examples = requests.get(data_url, timeout=10).json()["examples"] + + question_texts = [ex["input"] for ex in examples] + answer_texts = [ex["target"] for ex in examples] + questions = [ + Question(question, answer) + for question, answer in zip(question_texts, answer_texts) + ] + + return _index(questions[:150]), _index(questions[150:]) + + +def get_n_correct_rationales( + batch_generate_fn: Callable[[list[str]], list[str]], + questions: dict[DatasetIndex, Question], + max_num_correct: int, +) -> list[Rationale]: + """Return up to N correct rationales. + + Opportunistically stop as soon as number of correct rationales is reached. + """ + # lazy iterator- generation does not happen until iterated. + rationale_iterator = batch_process( + questions.values(), + partial(generate_rationale_answer, batch_generate_fn=batch_generate_fn), + SAMPLING_BATCH_SIZE, + ) + + # stop iterating as soon as exactly N rationales are correct. + correct_rationales = [] + for rationale in tqdm(rationale_iterator, ncols=75, total=len(questions)): + if str(rationale.question.answer) == rationale.parsed_answer: + correct_rationales.append(rationale) + + if len(correct_rationales) == max_num_correct: + break + + return correct_rationales + + +def filter_rationales( + proposed_rationales: dict[DatasetIndex, Rationale], +) -> dict[DatasetIndex, Rationale]: + """Return only valid rationales from the given dict of rationales. + + Params: + ------ + proposed_rationales: dict mapping dataset index to rationale. + + Returns + ------- + a subset of the input dict, containing only rationales that are + correct. + + """ + return { + index: rationale + for index, rationale in proposed_rationales.items() + if rationale.parsed_answer == rationale.question.answer + } + + +def few_shot_sample( + batch_generate_fn: Callable[[list[str]], list[str]], + few_shot_rationales: list[Rationale], + questions: dict[DatasetIndex, Question], +) -> dict[DatasetIndex, Rationale]: + """Generate answers to the given questions using few-shot examples. + + Args: + ---- + batch_generate_fn: Callable. + few_shot_rationales: list of correct rationales to demonstrate. + questions: dict mapping dataset index to Question. + + Returns: + ------- + dict mapping dataset index to Rationale instance, + one for each given question. + + """ + few_shot_exemplars = FEW_SHOT_DELIMITER.join( + GUIDE_TEMPLATE.format( + question=rationale.question.question_text, + rationale=rationale.rationale, + answer=rationale.parsed_answer, + ) + for rationale in few_shot_rationales + ) + + few_shot_template = GUIDE_TEMPLATE_FULL.format( + few_shot_exemplars=few_shot_exemplars, + few_shot_delimiter=FEW_SHOT_DELIMITER, + question="{question}", + ) + + rationales = generate_rationale_answer( + batch_generate_fn=batch_generate_fn, + questions=questions.values(), + prompt_template=few_shot_template, + ) + + return dict(zip(questions.keys(), rationales)) + + +def _serialize_memory( + memory: dict[DatasetIndex, Rationale], + extra_info: dict[str, Any] | None = None, + filename_suffix: str = "", +) -> None: + """Write memory to disk.""" + output_file_path = os.path.join( + os.environ.get("MEMORY_PATH", "data/memories"), + f"{START_TIME}{filename_suffix}.json", + ) + output: dict[str, Any] = { + "extra_info": extra_info, + "valid_rationales": { + index: rationale.serialize() + for index, rationale in filter_rationales(memory).items() + }, + "all_rationales": { + index: rationale.serialize() for index, rationale in memory.items() + }, + } + + with open(output_file_path, "a") as output_file: + output_file.write(json.dumps(output, indent=2)) + output_file.write("\n") + + +def masked_clm_loss( + logits: torch.Tensor, + input_ids: torch.Tensor, + loss_multiplier: torch.Tensor, +) -> torch.Tensor: + """Return partially-masked next-token loss for logits. + + loss_multiplier is applied to each token element-wise. + + Params: + ------- + logits: Tensor[float] (batch, width, vocab) + input_ids: Tensor[int] (batch, width) + loss_mask: Tensor[int] (batch, width) + + Returns + ------- + Tensor[float] (,) + + """ + assert logits.shape[:-1] == input_ids.shape + assert loss_multiplier.shape == input_ids.shape + + # all logits except the one for the last token. + logits_sliced = logits[:, :-1, :] + + # all labels except the label for the first token. + labels_shifted = input_ids[:, 1:] + loss_multiplier_shifted = loss_multiplier[:, 1:] + + # Torch CrossEntropyLoss allows only one batch dimension, + # not two (batch, width) as in logits and labels. + loss_fn = torch.nn.CrossEntropyLoss(reduction="none") + per_token_loss = loss_fn( + logits_sliced.flatten(0, 1), + labels_shifted.flatten(0, 1), + ).view_as(labels_shifted) + assert per_token_loss.shape == loss_multiplier_shifted.shape + + return torch.mean(per_token_loss * loss_multiplier_shifted) + + +class ICEMTrainer(Trainer): + """Independence chain expectation maximization trainer for TRICE.""" + + train_mini_batch_size: int = 8 # parameter "M" as in paper + + def _rationales_to_batch( + self, + rationales: Iterable[Rationale], + ) -> dict[str, torch.Tensor]: + """Tokenize rationales to produce training batch. + + In the order of being generated: + - Prompt: "x" + - Rationale: "z" + - Answer: "y" + + Objective: maximize "P(z | x)". + + Note that: + - Answer ("y") is not included. + - Loss is not calculated over prompt tokens ("x"). + + Returns + ------- + dict: + - input_ids: (batch_size, num_tokens) + - attention_mask: (batch_size, num_tokens) + - loss_mask: (batch_size, num_tokens) + + """ + assert self.tokenizer is not None + input_id_lists: list[list[int]] = [] + attention_masks_list: list[list[int]] = [] + loss_masks_list: list[list[int]] = [] + + # Tokenize prompt and rationale separately before concatenating. + for rationale in rationales: + prompt_tokens = list(self.tokenizer(rationale.raw_prompt).input_ids) + + continuation_str = rationale.rationale + if rationale.parsed_answer is not None: + continuation_str += rationale.parsed_answer + rationale_tokens = list(self.tokenizer(continuation_str).input_ids) + + attention_mask = [1] * (len(prompt_tokens) + len(rationale_tokens)) + loss_mask = [0] * len(prompt_tokens) + [1] * len(rationale_tokens) + + input_id_lists.append(prompt_tokens + rationale_tokens) + attention_masks_list.append(attention_mask) + loss_masks_list.append(loss_mask) + + # set max_seq_length to max real number of tokens, + # capped at max_seq_len from config. + max_seq_length = min( + self.config.max_seq_len, # type: ignore[reportAttributeAccessIssue] + max(map(len, input_id_lists)), + ) + batch_size = len(input_id_lists) + + input_ids = torch.zeros((batch_size, max_seq_length), dtype=torch.long) + attn_masks = torch.zeros((batch_size, max_seq_length), dtype=torch.int) + loss_masks = torch.zeros((batch_size, max_seq_length), dtype=torch.int) + + for index, (input_id_list, attention_mask, loss_mask) in enumerate( + zip(input_id_lists, attention_masks_list, loss_masks_list), + ): + # skip rationales that exceed max_seq_length + actual_length = len(input_id_list) + if actual_length > max_seq_length: + continue + + _non_pad_len = min(max_seq_length, actual_length) + input_ids[index, :_non_pad_len] = torch.Tensor(input_id_list) + attn_masks[index, :_non_pad_len] = torch.Tensor(attention_mask) + loss_masks[index, :_non_pad_len] = torch.Tensor(loss_mask) + + return { + "input_ids": input_ids, + "attention_mask": attn_masks, + "loss_mask": loss_masks, + } + + def prepare_trainer( + self, + *args, # noqa: ANN002 + **kwargs, # noqa: ANN003 + ) -> None: + """Initialize memory and bootstrap prompt.""" + super().prepare_trainer(*args, **kwargs) + + assert self.sampling_engine is not None + self.batch_generate_fn = partial( + self.sampling_engine.generate_text_only, + sampling_params=SamplingParams( + max_tokens=self.config.max_seq_len, # type: ignore[reportAttributeAccessIssue] + temperature=1.0, + ), + ) + + self.train_questions, self.test_questions = get_dataset() + + # Obtain a number of rationales for bootstraping the prompt for + # generating explanations given answer hint. + self.few_shot_rationales = get_n_correct_rationales( + self.batch_generate_fn, + self.train_questions, + NUM_DEMONSTRATIONS["few_shot_rationales"], + ) + print( + "Obtained {}/{} few_shot_rationales".format( + len(self.few_shot_rationales), + NUM_DEMONSTRATIONS["few_shot_rationales"], + ), + ) + + # "Memory" is a dict mapping dataset_id to rationales. + self.memory = few_shot_sample( + self.batch_generate_fn, + self.few_shot_rationales, + self.train_questions, + ) + + valid_rationales = filter_rationales(self.memory) + _serialize_memory(self.memory, {"step": self.tr_step}) + print(f"Valid memories: {len(valid_rationales)}/{len(self.memory)}") + + def _get_train_rationales( + self, + num_rationales: int, + ) -> dict[DatasetIndex, Rationale]: + """Sample new rationales from training questions. + + Return only rationales where answer is correct. Prefer newly + generated rationales and fall back to ones from memory for the same + set of randomly-selected questions. + """ + random.seed(self.tr_step) + selected_keys = random.sample(self.memory.keys(), num_rationales) + prev_rationales = {key: self.memory[key] for key in selected_keys} + + selected_questions = { + key: rationale.question + for key, rationale in prev_rationales.items() + } + + # "proposed" rationales + new_rationales = few_shot_sample( + self.batch_generate_fn, + self.few_shot_rationales, + selected_questions, + ) + + return { + **filter_rationales(prev_rationales), + **filter_rationales(new_rationales), + } + + def train_step(self, _: dict[str, torch.Tensor], epoch: int) -> float: + """Apply one TRICE training step. + + See BASIC_GRADIENT_ESTIMATE in the TRICE paper. + """ + assert self.model is not None + assert self.optimizer is not None + assert self.lr_scheduler is not None + assert self.sampling_engine is not None + + # keep sampling until at least one rationale (new or memory) is correct. + new_correct_rationales: dict[DatasetIndex, Rationale] = {} + while len(new_correct_rationales) == 0: + new_correct_rationales = self._get_train_rationales( + self.config.batch_size, # type: ignore[reportAttributeAccessIssue] + ) + + # Write newly sampled correct rationales to memory. + self.memory = {**self.memory, **new_correct_rationales} + _serialize_memory(self.memory, {"step": self.tr_step}) + + training_batch = self._rationales_to_batch( + new_correct_rationales.values(), + ) + _batch = { + k: v.to(torch.cuda.current_device()) + for k, v in training_batch.items() + } + + # Sync grad only if is about to run update. + is_update_step = (self.tr_step + 1) % self.gas == 0 + with contextlib.ExitStack() as stack: + if not is_update_step: + stack.enter_context(self.model.no_sync()) + else: + torch.distributed.barrier() + + logits = self.model( + input_ids=_batch["input_ids"], + attention_mask=_batch["attention_mask"], + ).logits + tr_step_loss = masked_clm_loss( + logits, + _batch["input_ids"], + _batch["loss_mask"], + ) + (tr_step_loss / self.gas).backward() + self.model.clip_grad_norm_(self.config.max_grad_norm) # type: ignore[reportAttributeAccessIssue] + + if is_update_step: + self.optimizer.step() + self.optimizer.zero_grad() + self.sampling_engine.update(self.model, self.tr_step) + + if isinstance( + self.lr_scheduler, + torch.optim.lr_scheduler.ReduceLROnPlateau, + ): + self.lr_scheduler.step(self.metric) + else: + self.lr_scheduler.step() + + gathered_tr_step_loss = _gather(tr_step_loss.reshape(1)).mean().item() + if self.wandb_logging: + self.log(gathered_tr_step_loss, epoch, "train") + + return gathered_tr_step_loss + + def eval_step(self, epoch: int) -> float: + """Return eval accuracy.""" + rationales = few_shot_sample( + self.batch_generate_fn, + self.few_shot_rationales, + self.test_questions, + ) + valid_rationales = filter_rationales(rationales) + + accuracy = ( + len(valid_rationales) / len(rationales) + if len(rationales) > 0 + else 0 + ) + _serialize_memory(self.memory, {"epoch": epoch, "step": self.tr_step}) + _serialize_memory( + rationales, + {"epoch": epoch, "step": self.tr_step, "accuracy": accuracy}, + "_eval", + ) + print(f"Eval accuracy: {accuracy * 100:.1f}%") + + return accuracy From ff136e5a2ec5b18430b0beb9fcb39ae371d2595a Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Mon, 24 Jun 2024 14:51:11 -0400 Subject: [PATCH 2/2] lora-trice: Implemented LoRA-TRICE. Tested on two NVIDIA GPUs. --- .gitignore | 3 + examples/llama_example.py | 5 +- vectorlm/sampling/abstract.py | 14 +- vectorlm/sampling/sampling_lora.py | 18 +- vectorlm/sampling/utils.py | 1 - vectorlm/tests/test_sampling.py | 15 +- vectorlm/tests/test_trice.py | 105 ++++- vectorlm/trice.py | 591 ++++++++++++++++++++++++++--- 8 files changed, 671 insertions(+), 81 deletions(-) diff --git a/.gitignore b/.gitignore index f4c948d..9c69abc 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,6 @@ data/ /.vscode /data /env + +.env +logs/ diff --git a/examples/llama_example.py b/examples/llama_example.py index ffb3745..5e2b551 100644 --- a/examples/llama_example.py +++ b/examples/llama_example.py @@ -92,7 +92,6 @@ def main( # setup wandb if rank == 0 and config.enable_wandb_logging: wandb_setup(config, **config.wandb_config) - dist.barrier() # load model and tokenizer model, tokenizer = load_model_and_tokenizer( @@ -188,6 +187,7 @@ def main( # If no checkpoint, it returns 0. trainer.model.train() trainer.find_checkpoint(training_args.output_dir) + eval_acc = 0 pbar = tqdm( range(config.train_parameters.epochs), @@ -196,11 +196,10 @@ def main( ncols=75, ) for index in pbar: - eval_acc = 0 train_loss, eval_output = trainer.step({}, index) eval_acc = eval_output if eval_output is not None else eval_acc - pbar.set_description(f"{train_loss:.3f}, {eval_acc * 100:.0f}%") + pbar.set_description(f"{train_loss:.3e}, {eval_acc * 100:.0f}%") if is_lora_enabled: save_peft_adapter(trainer.model, hf_save_dir) diff --git a/vectorlm/sampling/abstract.py b/vectorlm/sampling/abstract.py index 0fb6454..0c2b3bc 100644 --- a/vectorlm/sampling/abstract.py +++ b/vectorlm/sampling/abstract.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: import torch - from vectorlm.trainer import Trainer + from transformers import PreTrainedTokenizer from .utils import SynchronizationBarriers @@ -38,12 +38,18 @@ def __init__( self.vllm_train_step = -1 @abstractmethod - def update(self, model: torch.nn.Module, train_step: int) -> None: + def update( + self, + model: torch.nn.Module, + train_step: int, + tokenizer: PreTrainedTokenizer | None = None, + ) -> None: """Update model in sampling engine if the current copy is stale. Params: model: PeftModel, up-to-date model train_step: int, train step of the given model. + tokenizer: optionally, provide updated copy of tokenizer. """ if self.vllm_train_step != train_step: # Update parameters of self.vllm_llm using the given `model``. @@ -54,6 +60,7 @@ def generate( self, prompts: list[str], sampling_params: vllm.SamplingParams | None = None, + use_tqdm: bool = False, ) -> list[vllm.RequestOutput]: """Generate continuation for the given prompts synchronously. @@ -79,9 +86,10 @@ def generate_text_only( self, prompts: list[str], sampling_params: vllm.SamplingParams | None = None, + use_tqdm: bool = False, ) -> list[str]: """Generate and return text only.""" return [ response.outputs[0].text - for response in self.generate(prompts, sampling_params) + for response in self.generate(prompts, sampling_params, use_tqdm) ] diff --git a/vectorlm/sampling/sampling_lora.py b/vectorlm/sampling/sampling_lora.py index 934ae9a..fccadec 100644 --- a/vectorlm/sampling/sampling_lora.py +++ b/vectorlm/sampling/sampling_lora.py @@ -1,9 +1,8 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING -import torch import torch.distributed as dist import vllm from vllm.lora.request import LoRARequest @@ -15,6 +14,7 @@ if TYPE_CHECKING: from peft.peft_model import PeftModel + from transformers import PreTrainedTokenizer class LoRASamplingEngine(AbstractSamplingEngine): @@ -63,15 +63,24 @@ def __init__( self.generate_fn = multiprocess_wrap(generate_fn_raw, self.barriers) self.vllm_train_step = -1 - def update(self, model: PeftModel, train_step: int) -> None: + def update( + self, + model: PeftModel, + train_step: int, + tokenizer: PreTrainedTokenizer | None = None, + ) -> None: """Update model in sampling engine if the current copy is stale. Params: model: PeftModel, up-to-date model train_step: int, train step of the given model. + tokenizer: optionally, provide updated copy of tokenizer. """ self.barriers.before_generation.wait() if self.vllm_train_step != train_step: + if tokenizer is not None: + tokenizer.save_pretrained(self.adapter_temp_folder) + save_peft_adapter(model, self.adapter_temp_folder) self.vllm_train_step = train_step self.lora_request = LoRARequest( @@ -86,6 +95,7 @@ def generate( self, prompts: list[str], sampling_params: vllm.SamplingParams | None = None, + use_tqdm: bool = False, ) -> list[vllm.RequestOutput]: """Generate continuation for the given prompts. Invoke at all ranks. @@ -106,7 +116,7 @@ def generate( prompts, sampling_params, lora_request=self.lora_request, - use_tqdm=False, + use_tqdm=use_tqdm, ) assert len(return_value) == len(prompts) diff --git a/vectorlm/sampling/utils.py b/vectorlm/sampling/utils.py index 400b005..2dc8b43 100644 --- a/vectorlm/sampling/utils.py +++ b/vectorlm/sampling/utils.py @@ -2,7 +2,6 @@ from __future__ import annotations -import concurrent.futures import json import os import threading diff --git a/vectorlm/tests/test_sampling.py b/vectorlm/tests/test_sampling.py index a1b053e..acf447f 100644 --- a/vectorlm/tests/test_sampling.py +++ b/vectorlm/tests/test_sampling.py @@ -11,18 +11,5 @@ def test_batch_process() -> None: print(output_item) output.append(output_item) - assert output == example_input - -def test_parsing_rationale() -> None: - """Test parsing rationales.""" - example_input = [ - "\nThe blue jay is to the right of the quail, and the falcon is to the right of the blue jay. So, the order is (B) The quail is the second from the left.", - "\n\nThe motorcyle is newer than the limousine. The convertible is newer than the motorcyle. So, the correct option is (C) The convertible is the oldest.", - "\n\nFor the first paragraph, the object arranged to the right is a blue book. Since the blue book is the rightmost, it must be the second object from the left. \n\nIn the second paragraph about the orange book, it must be the object located two positions left of the blue book, which is indeed the leftmost object. \n\nIn the third paragraph, the object that must be two positions left of the orange book is the red book, making option (A) the correct answer.", - "\n\nThe robin, crow, and blue Jay form a linear order, so based on the position of these birds, the robin must be the rightmost.", - "\nThe green book is the first object on the shelf. The red book is the second object on the shelf. The blue book is the third object on the shelf. Therefore, the green book is the rightmost.", - "\nAccording to the statement, the mangoes are less expensive than the peaches, which are less expensive than the apples. Therefore, the mango should be the third-most expensive. So, the correct answer is **(C) The mangoes are the second-most expensive**.", - "\n\n**Paragraph 1:** The tractor is older than the truck. So, the tractor should be the third object in the order.\n\n**Paragraph 2:** The minivan is newer than the truck. Therefore, the minivan should be the third object in the order.\n\n**Paragraph 3:** The tractor is older than the truck. So, the tractor should be the first object in the order.\n\nTherefore, the answer is (A) The tractor is the newest.", - "\nEve's position is fixed. Rob finished below Mel, who finished below Eve. Therefore, Eve finished last. So (A) Eve finished first is the answer.", - ] + assert output == example_input diff --git a/vectorlm/tests/test_trice.py b/vectorlm/tests/test_trice.py index 7065638..7b5471c 100644 --- a/vectorlm/tests/test_trice.py +++ b/vectorlm/tests/test_trice.py @@ -1,7 +1,17 @@ +from __future__ import annotations + +import numpy as np import pytest import torch -from vectorlm.trice import masked_clm_loss +from vectorlm.trice import ( + Question, + Rationale, + _index, + filter_rationales, + get_weighted_rationales, + masked_clm_loss, +) def test_masked_clm_loss( @@ -16,3 +26,96 @@ def test_masked_clm_loss( loss = masked_clm_loss(logits, input_ids, loss_multiplier) print(loss) + + +def get_reference_weights( + is_proposal_correct_: list[bool], + is_memory_correct_: list[bool], +) -> tuple[list[float], float]: + """Return (signed) weights and weights_mean. + + Adapted from reference TRICE logic. + """ + is_proposal_correct = np.stack(is_proposal_correct_) + is_memory_correct = np.stack(is_memory_correct_) + mask = is_proposal_correct | is_memory_correct + correlation_est = (is_proposal_correct.sum() - is_proposal_correct) / ( + mask.sum() - 1 + 1e-10 + ) + + # compute weight contributions of rationales from both memory and proposal. + weights_memory = mask * (1 - correlation_est * is_proposal_correct) + weights_proposal = mask * correlation_est * (1 - is_proposal_correct) + flat_weights = np.concatenate([weights_memory, weights_proposal]) + flat_signs = np.concatenate([mask, -1 * mask]) + flat_weights = np.clip(flat_weights, a_min=1e-10, a_max=None) + weights_mean = flat_weights.sum() / (mask.sum() + 1e-10) + + output = flat_signs * flat_weights + assert len(output.shape) == 1 + return output.tolist(), weights_mean + + +@pytest.fixture() +def example_rationales() -> list[Rationale]: + """Return rationales: F T T F T F.""" + return [ + Rationale(Question("", "A"), "", "F", "B"), + Rationale(Question("", "A"), "", "T", "A"), + Rationale(Question("", "A"), "", "T", "A"), + Rationale(Question("", "A"), "", "F", "B"), + Rationale(Question("", "A"), "", "T", "A"), + Rationale(Question("", "A"), "", "F", "B"), + ] + + +@pytest.mark.parametrize("variation", [0, 1, 2]) +def test_weight_rationales( + example_rationales: list[Rationale], + variation: int, +) -> None: + """Ensure rationales weights match ones from reference.""" + if variation == 0: + # proposed: F T T + # memory: F T F + proposed_rationales = example_rationales[:3] + memorized_rationales = example_rationales[3:] + elif variation == 1: + # proposed: T, + # memory: T, + proposed_rationales = [example_rationales[2]] + memorized_rationales = [example_rationales[3]] + else: + # proposed: F T F + # memory: F T T + proposed_rationales = example_rationales[3:] + memorized_rationales = example_rationales[:3] + + weighted_rationales = get_weighted_rationales( + memorized_rationales={ + **_index(memorized_rationales), + **filter_rationales(_index(proposed_rationales)), + }, + proposed_rationales=_index(proposed_rationales), + ) + weights = [ + wr.weight + for wr in (weighted_rationales.memorized + weighted_rationales.proposed) + ] + weights_reference, weights_mean_reference = get_reference_weights( + [r.is_correct for r in proposed_rationales], + [r.is_correct for r in memorized_rationales], + ) + + print( + [r.is_correct for r in proposed_rationales], + [r.is_correct for r in memorized_rationales], + ) + print(weights) + print(weights_reference) + assert np.allclose(weights, weights_reference, 1e-6) + assert np.allclose( + weighted_rationales.weights_mean, + weights_mean_reference, + 1e-6, + ) diff --git a/vectorlm/trice.py b/vectorlm/trice.py index 58c2fe5..6345b32 100644 --- a/vectorlm/trice.py +++ b/vectorlm/trice.py @@ -2,15 +2,19 @@ from __future__ import annotations +import collections import contextlib import json import os import random import re +import string import time from functools import partial -from typing import Any, Callable, Iterable, NamedTuple, TypeVar +from typing import Any, Callable, Counter, Iterable, NamedTuple, TypeVar +import datasets +import numpy as np import requests import torch import torch.distributed @@ -29,8 +33,16 @@ "few_shot_rationales": 3, } +QUESTION_TEMPLATE = """\ +{input_text} +Options: +{rendered_answer_choices} +""" +ANSWER_OPTION_TEMPLATE = "{answer_key} {label}" + + ZERO_SHOT_PROMPT = """\ -Question: {question} +{question} Answer: Let's think step-by-step. \ """ @@ -39,9 +51,10 @@ Answer: Let's think step by step. {rationale}{answer}\ """ + FEW_SHOT_DELIMITER = "\n\n---\n" -GUIDE_TEMPLATE_FULL = """\ +FEW_SHOT_TEMPLATE_FULL = """\ {few_shot_exemplars}\ {few_shot_delimiter}\ Question: {question} @@ -50,7 +63,7 @@ # capture groups: rationale, answer. -RATIONALE_ANSWER_REGEXP = r"(.+the answer is )(\([A-C]\))[^\(\)]*$" +RATIONALE_ANSWER_REGEXP = r"(.+the answer is )\(?([A-C])\)?[^\(\)]*$" class Question(NamedTuple): @@ -84,6 +97,56 @@ def serialize(self) -> dict[str, Any]: return output + @property + def is_correct(self) -> bool: + """Return whether self is correct.""" + return self.parsed_answer == self.question.answer + + +class _WeightedRationale(NamedTuple): + """Rationale paired with associated weight. + + Weight can be negative, as in the example of proposed incorrect rationales, + where the corresponding memory is correct, but this rationale is not. + + Weights for memory entries: + - 0 if neither this memory nor the proposal is correct. + - 1 if this memory is correct, but the corresponding proposal is incorrect. + - (1 - scale) if both are correct, higher if scale is lower, when other + proposals in this batch are not as accurate as other memories in this + batch. + + Weights for proposal entries: + - 0 if this proposal is correct. + - 0 if neither this proposal nor the corresponding memory is correct. + - (-scale) if proposal is incorrect, but the corresponding memory is correct + """ + + rationale: Rationale + weight: float + + @property + def sign_multiplier(self) -> int: + """Return sign multiplier of self.""" + if self.weight > 0: + return 1 + + if self.weight < 0: + return -1 + + return 0 + + +class _WeightedRationales(NamedTuple): + """WeightedRationales from memory and proposal. + + "weights_mean" is for rescaling clm grad values. + """ + + memorized: list[_WeightedRationale] + proposed: list[_WeightedRationale] + weights_mean: float + def generate_rationale_answer( questions: Iterable[Question], @@ -123,11 +186,6 @@ def generate_rationale_answer( rationale, answer = match.groups() output.append(Rationale(question, query, rationale, answer)) - import json - - with open("data/output-20240527-1a.jsonl", "a") as output_file: - output_file.write(json.dumps(output[-1]._asdict()) + "\n") - return output @@ -163,6 +221,171 @@ def get_dataset() -> ( return _index(questions[:150]), _index(questions[150:]) +def enumerate_answer_choices( + dataset_splits: list[Iterable[dict[str, Any]] | datasets.Dataset], + label_column_key: str, + label_lookup: dict[Any, str] | None, + data_limits: list[int | None], +) -> Counter[str]: + """Return counter of all answer choice options. + + Args: + ---- + dataset_splits: list of dataset row iterators, one for each split. + label_column_key: should be the same across all splits. + label_lookup: Translate label to alternative, more descriptive form + before enumeration. If provided, rows where label is not in + the lookup dictionary would be skipped. + data_limits: max number of rows to load from each split, one per split. + Set to None for no limit. + + Returns: + ------- + Counter of label options across all splits. + + """ + output: Counter[str] = collections.Counter() + for dataset_split, max_row_count in zip(dataset_splits, data_limits): + for index, row in enumerate(dataset_split): + assert isinstance(row, dict) + label = row.get(label_column_key) + + if label_lookup is not None: + label = label_lookup.get(label) + if label is None: + continue + + output[str(label)] += 1 + + if (max_row_count is not None) and (index + 1 == max_row_count): + break + + return output + + +def _render_label_choices( + label_choices: list[str], + template: str, +) -> tuple[str, dict[str, str]]: + """Render label_choices as a multiple-choice question. + + Returns + ------- + lookup dict mapping label value to assigned answer key e.g., "(A)". + + """ + output_lines = [] + label_lookup: dict[str, str] = {} + + assert len(label_choices) <= len(string.ascii_uppercase) + for index, label in enumerate(label_choices): + answer_key = string.ascii_uppercase[index] + label_lookup[label] = answer_key + output_line = template.format(answer_key=answer_key, label=label) + output_lines.append(output_line) + + return "\n".join(output_lines), label_lookup + + +def load_hf_dataset( + dataset_path: str, + text_column_key: str, + label_column_key: str, + data_split_names: tuple[str, str] = ("train", "test"), + label_lookup: dict[str, str] | None = None, + data_limits: tuple[int | None, int | None] = (1000, 100), + question_template: str = QUESTION_TEMPLATE, + answer_option_template: str = ANSWER_OPTION_TEMPLATE, +) -> tuple[dict[DatasetIndex, Question], dict[DatasetIndex, Question]]: + """Load TRICE-compatible dataset from a HuggingFace dataset. + + Args: + ---- + dataset_path: path to HF dataset repo or local folder. + text_column_key: source of input_text. + label_column_key: source of labels to render as options. + data_split_names: dataset splits for training and testing. + label_lookup: Optionally, supply a descriptive label to replace the + original labels from the dataset. If specified, only rows where + original labels is in label_lookup would be included. Otherwise, + all rows would be included and the original label would be used. + data_limits: max. number of entries to load from train and test split. + question_template: question template, should include {input_text} and + {rendered_answer_choices} + answer_option_template: template for answer choices, should include + {answer_key} and {label} + + Returns: + ------- + dict of train questions (index -> Question), + dict of test questions (index -> Question), + + """ + # prefer loading from local path if exists + # instead of loading from HF hub + if os.path.isdir(dataset_path): + dataset_dict = datasets.load_from_disk(dataset_path) + else: + dataset_dict = datasets.load_dataset(dataset_path) + + assert isinstance(dataset_dict, datasets.dataset_dict.DatasetDict) + assert len(data_split_names) == len(("train", "test")) + + for data_split_name in data_split_names: + assert data_split_name in dataset_dict, (data_split_names, dataset_dict) + + label_choices = enumerate_answer_choices( + [dataset_dict[split_name] for split_name in data_split_names], + label_column_key, + label_lookup, + list(data_limits), + ) + label_choice_str, answer_map = _render_label_choices( + [str(label) for label in label_choices], + answer_option_template, + ) + print("label_choices stats:", label_choices) + print("label_choice_str:", label_choice_str) + print("Answer map:", answer_map) + + output: list[dict[DatasetIndex, Question]] = [] + # create one question for each row for each dataset split. + for data_split_name, max_num_rows in zip(data_split_names, data_limits): + questions: dict[DatasetIndex, Question] = {} + dataset = dataset_dict[data_split_name] + for index, row in enumerate(dataset): + assert isinstance(row, dict) + + # Translate label and skip rows where label is not in lookup. + label = row[label_column_key] + if label_lookup is not None: + label = label_lookup.get(label) + + if label is None: + continue + + label = str(label) + + # use str keys to allow json serialization + questions[str(index)] = Question( + question_text=question_template.format( + input_text=row[text_column_key], + rendered_answer_choices=label_choice_str, + ), + answer=answer_map[label], + ) + + if (max_num_rows is not None) and (index + 1 == max_num_rows): + break + + output.append(questions) + + print(output[0]["0"].question_text + "\n") + + assert len(output) == len(("train", "test")) + return tuple(output) # type: ignore[tuple length] + + def get_n_correct_rationales( batch_generate_fn: Callable[[list[str]], list[str]], questions: dict[DatasetIndex, Question], @@ -241,7 +464,7 @@ def few_shot_sample( for rationale in few_shot_rationales ) - few_shot_template = GUIDE_TEMPLATE_FULL.format( + few_shot_template = FEW_SHOT_TEMPLATE_FULL.format( few_shot_exemplars=few_shot_exemplars, few_shot_delimiter=FEW_SHOT_DELIMITER, question="{question}", @@ -256,6 +479,177 @@ def few_shot_sample( return dict(zip(questions.keys(), rationales)) +def get_weighted_rationales( + memorized_rationales: dict[DatasetIndex, Rationale], + proposed_rationales: dict[DatasetIndex, Rationale], +) -> _WeightedRationales: + """Obtain TRICE Control-Variate weights for each rationale. + + Obtain leave-one-out scales + - Divide number of other correct proposals, excluding self, by total number + of correct predictions in batch, also excluding self. + + Obtain weights for memory + correct proposal entries + - 0 if neither original memory nor the new proposal is correct. + - 1 if original memory is correct, but the corresponding new proposal is + incorrect. + - (1 - scale) if new proposal is correct, higher if scale is lower, when + other proposals are not as accurate as original memories. + + Obtain weights for all proposal entries, negative if incorrect. + - 0 if this proposal is correct. + - 0 if neither this proposal nor the corresponding memory is correct. + - (-scale) if proposal is incorrect, but the corresponding memory is + correct. + + Params: + ---- + memorized_rationales: memory rationales. + proposed_rationales: proposed new rationales. + + + Keys of "memorized" must match those of "proposed". + + Returns + ------- + List of weighted rationales. If there are N memory rationales + and N proposals, the output would be of length (2 * N). + + """ + assert proposed_rationales.keys() == memorized_rationales.keys() + num_correct_proposals = len(filter_rationales(proposed_rationales)) + num_correct_all = len( + { + **filter_rationales(proposed_rationales), + **filter_rationales(memorized_rationales), + }, + ) + + # construst weighted list of rationale from both memory and proposal. + output_memorized: list[_WeightedRationale] = [] + output_proposed: list[_WeightedRationale] = [] + for dataset_index in memorized_rationales: + memorized = memorized_rationales[dataset_index] + proposed = proposed_rationales[dataset_index] + if (proposed.is_correct) and (not memorized.is_correct): + msg = ( + "Proposal is correct but memory is not. " + "Did you update memory before trying to compute weights?" + f"proposal: {proposed}\n" + f"memory: {memorized}" + ) + raise ValueError(msg) + + # leave-one-out (leave-self-out) scale + scale_numerator = num_correct_proposals + scale_denominator = num_correct_all - 1 + 1e-10 + if proposed.is_correct: + scale_numerator -= 1 + + scale = scale_numerator / scale_denominator + + # weight for memory, which might have been overwritten with the new + # proposal if that proposal is correct. + if (not proposed.is_correct) and (not memorized.is_correct): + weight_memorized = 0.0 + elif (memorized.is_correct) and (not proposed.is_correct): + weight_memorized = 1.0 + else: + # proposal is correct, and should have already overwritten memory. + weight_memorized = 1 - scale + output_memorized.append(_WeightedRationale(memorized, weight_memorized)) + + # weight for proposal + if proposed.is_correct: + # Since proposed is correct, after memory update "memorized" + # would be the same as "proposed". No need to include this again. + # Hence, set weight to 0. + weight_proposed = 0.0 + elif (not proposed.is_correct) and (not memorized.is_correct): + weight_proposed = 0.0 + else: + # new proposal is incorrect, but original memory is correct + weight_proposed = -scale + + output_proposed.append(_WeightedRationale(proposed, weight_proposed)) + + # sum of all weights, for rescaling grads. + weight_tally = sum( + abs(wr.weight) for wr in (output_memorized + output_proposed) + ) + + return _WeightedRationales( + memorized=output_memorized, + proposed=output_proposed, + weights_mean=weight_tally / (num_correct_all + 1e-10), + ) + + +def _softmax(weights: np.ndarray) -> np.ndarray: + """Return softmax value given weights.""" + assert len(weights.shape) == 1, weights.shape + exp_weights = np.exp(weights - np.max(weights)) + return exp_weights / np.sum(exp_weights, axis=0) + + +def _systematic_resample( + probabilities: np.ndarray, + num_selected: int, + seed: int = 0, +) -> list[int]: + """Resample systematically. + + Params: + ------ + probabilties: 1D float array, must sum up to 1. + num_selected: number of items to select. + + Returns + ------- + list of index of "num_selected" items that were selected. + Each item is an index of the probability array. + + """ + assert np.allclose(probabilities.sum(), 1), "Forgot to normalize?" + assert num_selected > 0 + + generator = np.random.Generator(np.random.PCG64(seed)) + randomness = generator.uniform(0, 1 / num_selected) + selections: list[int] = [] + + thresholds = np.cumsum(probabilities).tolist() # (N,) + thresholds_low = [0.0, *thresholds] # (N + 1,) + thresholds_high = [*thresholds, 1.0] # (N + 1,) + for option_index, (threshold_low, threshold_high) in enumerate( + zip(thresholds_low, thresholds_high), + ): + # try assigning each selection to the threshold, starting + # from the next available one. + for selection_index in range(len(selections), num_selected): + selected_pos = selection_index * (1 / num_selected) + randomness + + if (selected_pos >= threshold_low) and ( + selected_pos < threshold_high + ): + selections.append(option_index) + + return selections + + +def subsample_weighted( + weighted_rationales: list[_WeightedRationale], + num_items: int, + seed: int = 0, +) -> list[_WeightedRationale]: + """Subsample rationales based on absolute value of weights.""" + weights = np.array([abs(wr.weight) for wr in weighted_rationales]) + weights = np.clip(weights, a_min=1e-10, a_max=None) + probabilities = _softmax(weights) + selected_index_items = _systematic_resample(probabilities, num_items, seed) + + return [weighted_rationales[index] for index in selected_index_items] + + def _serialize_memory( memory: dict[DatasetIndex, Rationale], extra_info: dict[str, Any] | None = None, @@ -295,7 +689,7 @@ def masked_clm_loss( ------- logits: Tensor[float] (batch, width, vocab) input_ids: Tensor[int] (batch, width) - loss_mask: Tensor[int] (batch, width) + loss_multiplier: Tensor[int] (batch, width) Returns ------- @@ -329,9 +723,11 @@ class ICEMTrainer(Trainer): train_mini_batch_size: int = 8 # parameter "M" as in paper - def _rationales_to_batch( + def _batch_tokenize_rationales( self, rationales: Iterable[Rationale], + rationale_weights: Iterable[float] | None = None, + use_raw_prompt: bool = True, ) -> dict[str, torch.Tensor]: """Tokenize rationales to produce training batch. @@ -346,22 +742,36 @@ def _rationales_to_batch( - Answer ("y") is not included. - Loss is not calculated over prompt tokens ("x"). + Params + ------ + rationale_weights: If provided, rescale loss_multiplier of each item by + this value. Should be of the same length as rationale. + use_raw_prompt: Use raw_prompt as context if set to True. + Otherwise, use zero-shot prompt as context. + Returns ------- dict: - - input_ids: (batch_size, num_tokens) - - attention_mask: (batch_size, num_tokens) - - loss_mask: (batch_size, num_tokens) + - input_ids: int, (batch_size, num_tokens) + - attention_mask: int, (batch_size, num_tokens) + - loss_multipliers: float, (batch_size, num_tokens) """ assert self.tokenizer is not None input_id_lists: list[list[int]] = [] attention_masks_list: list[list[int]] = [] - loss_masks_list: list[list[int]] = [] + loss_multipliers_list: list[list[int]] = [] # Tokenize prompt and rationale separately before concatenating. for rationale in rationales: - prompt_tokens = list(self.tokenizer(rationale.raw_prompt).input_ids) + if use_raw_prompt: + context = rationale.raw_prompt + else: + context = ZERO_SHOT_PROMPT.format( + question=rationale.question.question_text, + ) + + prompt_tokens = list(self.tokenizer(context).input_ids) continuation_str = rationale.rationale if rationale.parsed_answer is not None: @@ -369,11 +779,14 @@ def _rationales_to_batch( rationale_tokens = list(self.tokenizer(continuation_str).input_ids) attention_mask = [1] * (len(prompt_tokens) + len(rationale_tokens)) - loss_mask = [0] * len(prompt_tokens) + [1] * len(rationale_tokens) + loss_multiplier = [0] * len(prompt_tokens) + [1] * len( + rationale_tokens, + ) + assert sum(loss_multiplier) > 0 input_id_lists.append(prompt_tokens + rationale_tokens) attention_masks_list.append(attention_mask) - loss_masks_list.append(loss_mask) + loss_multipliers_list.append(loss_multiplier) # set max_seq_length to max real number of tokens, # capped at max_seq_len from config. @@ -382,13 +795,32 @@ def _rationales_to_batch( max(map(len, input_id_lists)), ) batch_size = len(input_id_lists) + weights = ( + list(rationale_weights) + if rationale_weights is not None + else [1.0] * batch_size + ) + assert batch_size > 0 input_ids = torch.zeros((batch_size, max_seq_length), dtype=torch.long) attn_masks = torch.zeros((batch_size, max_seq_length), dtype=torch.int) - loss_masks = torch.zeros((batch_size, max_seq_length), dtype=torch.int) + loss_multipliers = torch.zeros( + (batch_size, max_seq_length), + dtype=torch.float, + ) - for index, (input_id_list, attention_mask, loss_mask) in enumerate( - zip(input_id_lists, attention_masks_list, loss_masks_list), + for index, ( + input_id_list, + attention_mask, + loss_multiplier, + weight, + ) in enumerate( + zip( + input_id_lists, + attention_masks_list, + loss_multipliers_list, + weights, + ), ): # skip rationales that exceed max_seq_length actual_length = len(input_id_list) @@ -398,12 +830,14 @@ def _rationales_to_batch( _non_pad_len = min(max_seq_length, actual_length) input_ids[index, :_non_pad_len] = torch.Tensor(input_id_list) attn_masks[index, :_non_pad_len] = torch.Tensor(attention_mask) - loss_masks[index, :_non_pad_len] = torch.Tensor(loss_mask) + loss_multipliers[index, :_non_pad_len] = ( + torch.Tensor(loss_multiplier) * weight + ) return { "input_ids": input_ids, "attention_mask": attn_masks, - "loss_mask": loss_masks, + "loss_multipliers": loss_multipliers, } def prepare_trainer( @@ -422,8 +856,28 @@ def prepare_trainer( temperature=1.0, ), ) + self.batch_generate_fn_eval = partial( + self.sampling_engine.generate_text_only, + sampling_params=SamplingParams( + max_tokens=self.config.max_seq_len, # type: ignore[reportAttributeAccessIssue] + temperature=0.0, + ), + use_tqdm=True, + ) - self.train_questions, self.test_questions = get_dataset() + self.trice_config = self.config.trice_configs # type: ignore[reportAttributeAccessIssue] + trice_data_config = self.trice_config.hf_dataset + self.train_questions, self.test_questions = load_hf_dataset( + trice_data_config.path, + text_column_key=trice_data_config.text_column_key, + label_column_key=trice_data_config.label_column_key, + answer_option_template=trice_data_config.answer_option_template, + data_limits=( + trice_data_config.limits.train, + trice_data_config.limits.test, + ), + label_lookup=trice_data_config.get("label_lookup"), + ) # Obtain a number of rationales for bootstraping the prompt for # generating explanations given answer hint. @@ -440,8 +894,9 @@ def prepare_trainer( ) # "Memory" is a dict mapping dataset_id to rationales. + print(f"Initializing memory ({len(self.train_questions)} total).") self.memory = few_shot_sample( - self.batch_generate_fn, + partial(self.batch_generate_fn, use_tqdm=True), self.few_shot_rationales, self.train_questions, ) @@ -450,36 +905,51 @@ def prepare_trainer( _serialize_memory(self.memory, {"step": self.tr_step}) print(f"Valid memories: {len(valid_rationales)}/{len(self.memory)}") - def _get_train_rationales( + def sample_rationales_and_update_memory( self, num_rationales: int, - ) -> dict[DatasetIndex, Rationale]: - """Sample new rationales from training questions. + ) -> _WeightedRationales: + """Sample new rationales and write to memory if correct. + + Params: + ------ + num_rationales: number of rationales to generate on. + + Returns + ------- + _WeightedRationales, including "num_rationales" each + of memory and proposal, as well as weighted scores. - Return only rationales where answer is correct. Prefer newly - generated rationales and fall back to ones from memory for the same - set of randomly-selected questions. """ random.seed(self.tr_step) selected_keys = random.sample(self.memory.keys(), num_rationales) - prev_rationales = {key: self.memory[key] for key in selected_keys} + memorized_rationales = {key: self.memory[key] for key in selected_keys} selected_questions = { key: rationale.question - for key, rationale in prev_rationales.items() + for key, rationale in memorized_rationales.items() } # "proposed" rationales - new_rationales = few_shot_sample( - self.batch_generate_fn, + proposed_rationales = few_shot_sample( + partial(self.batch_generate_fn, use_tqdm=False), self.few_shot_rationales, selected_questions, ) + assert len(proposed_rationales) == len(memorized_rationales) - return { - **filter_rationales(prev_rationales), - **filter_rationales(new_rationales), + # write correct proposed rationales to memories. + memorized_rationales = { + **memorized_rationales, + **filter_rationales(proposed_rationales), } + self.memory = {**self.memory, **filter_rationales(proposed_rationales)} + _serialize_memory(proposed_rationales, {"step": self.tr_step}) + + return get_weighted_rationales( + memorized_rationales=memorized_rationales, + proposed_rationales=proposed_rationales, + ) def train_step(self, _: dict[str, torch.Tensor], epoch: int) -> float: """Apply one TRICE training step. @@ -491,26 +961,33 @@ def train_step(self, _: dict[str, torch.Tensor], epoch: int) -> float: assert self.lr_scheduler is not None assert self.sampling_engine is not None - # keep sampling until at least one rationale (new or memory) is correct. - new_correct_rationales: dict[DatasetIndex, Rationale] = {} - while len(new_correct_rationales) == 0: - new_correct_rationales = self._get_train_rationales( - self.config.batch_size, # type: ignore[reportAttributeAccessIssue] - ) - - # Write newly sampled correct rationales to memory. - self.memory = {**self.memory, **new_correct_rationales} - _serialize_memory(self.memory, {"step": self.tr_step}) + # SAMPLING_SIZE _WeightedRationale instances, not yet subsampled. + weighted_rationales = self.sample_rationales_and_update_memory( + self.trice_config.sampling_size, + ) + subsampled_rationales = subsample_weighted( + weighted_rationales.memorized + weighted_rationales.proposed, + self.trice_config.batch_size, + epoch, + ) - training_batch = self._rationales_to_batch( - new_correct_rationales.values(), + rationales = [wr.rationale for wr in subsampled_rationales] + rationale_weights = [ + weighted_rationales.weights_mean * wr.sign_multiplier + for wr in subsampled_rationales + ] + training_batch = self._batch_tokenize_rationales( + rationales, + rationale_weights=rationale_weights, + use_raw_prompt=False, ) + _batch = { k: v.to(torch.cuda.current_device()) for k, v in training_batch.items() } - # Sync grad only if is about to run update. + # Sync grad only if about to run update. is_update_step = (self.tr_step + 1) % self.gas == 0 with contextlib.ExitStack() as stack: if not is_update_step: @@ -525,7 +1002,7 @@ def train_step(self, _: dict[str, torch.Tensor], epoch: int) -> float: tr_step_loss = masked_clm_loss( logits, _batch["input_ids"], - _batch["loss_mask"], + _batch["loss_multipliers"], ) (tr_step_loss / self.gas).backward() self.model.clip_grad_norm_(self.config.max_grad_norm) # type: ignore[reportAttributeAccessIssue] @@ -533,7 +1010,11 @@ def train_step(self, _: dict[str, torch.Tensor], epoch: int) -> float: if is_update_step: self.optimizer.step() self.optimizer.zero_grad() - self.sampling_engine.update(self.model, self.tr_step) + self.sampling_engine.update( + self.model, + self.tr_step, + self.tokenizer, + ) if isinstance( self.lr_scheduler, @@ -552,7 +1033,7 @@ def train_step(self, _: dict[str, torch.Tensor], epoch: int) -> float: def eval_step(self, epoch: int) -> float: """Return eval accuracy.""" rationales = few_shot_sample( - self.batch_generate_fn, + self.batch_generate_fn_eval, self.few_shot_rationales, self.test_questions, )