Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pipelinerl/finetune/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class RLConfig(BaseModel):
default=1.0,
description="Temperature for the training log probs",
)
filter_zero_advantage_groups: bool = Field(
default=False,
description="Filter out groups where all advantages are zero during preprocessing",
)


def make_rl_data_callback(args, current_dir, rl_config, model):
Expand Down
67 changes: 62 additions & 5 deletions pipelinerl/run_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def preprocess_dataset(
) -> Dataset:
preprocess = partial(preprocess_fn, seq_length=seq_length, tokenizer=tokenizer, is_rl=True)
columns = ["input_ids", "labels", "attention_mask"] + RL_DATA_COLUMNS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can always add "group_id" to the columns, it won't hurt

if rl_config.filter_zero_advantage_groups and "group_id" not in columns:
columns.append("group_id")

logger.debug(f"Instantiated preprocess function hash {Hasher.hash(preprocess)}")

data = replace_oov_tokens_with_the(data, tokenizer)
Expand Down Expand Up @@ -266,6 +269,45 @@ def process_chunk(
dataset_queue.put(slot)


def filter_zero_advantage_groups(dataset: list[dict], epsilon: float = 1e-6) -> tuple[list[dict], int]:
"""
Filter out groups where all advantages are zero.

Args:
dataset: List of dataset entries with group_id and advantages
epsilon: Threshold for considering advantage non-zero

Returns:
Tuple of (filtered_entries, num_filtered_out)
"""
filtered_entries = []
groups = {}

# Group entries by group_id
for entry in dataset:
group_id = entry["group_id"]
if group_id not in groups:
groups[group_id] = []
groups[group_id].append(entry)

num_filtered_out = 0

# Filter groups based on advantage values
for group_id, entries in groups.items():
has_non_zero_advantage = False
for entry in entries:
# advantages is a list, check if any absolute value is > epsilon
if any(abs(adv) > epsilon for adv in entry["advantages"]):
has_non_zero_advantage = True
break

if has_non_zero_advantage:
filtered_entries.extend(entries)
else:
num_filtered_out += len(entries)

return filtered_entries, num_filtered_out

def run_preprocessing_loop(
cfg: DictConfig,
):
Expand Down Expand Up @@ -326,6 +368,7 @@ def run_preprocessing_loop(
stats_aggregator = SlidingWindowAggregator(window_size=max(10, 1000 // cfg.preprocess.chunk_size))

buffer = []
total_filtered_out = 0 # Track total filtered samples across all batches
with write_to_streams(output_stream) as writer, write_to_streams(stats_streams) as stats_writer:
with mp.Manager() as manager, SharedMemoryManager() as smm:
max_dataset_queue_size = 128
Expand Down Expand Up @@ -392,12 +435,24 @@ def run_preprocessing_loop(
logger.info(f"Buffer is full with {len(buffer)} samples, start writing")
random.shuffle(buffer)

for entry in buffer:
# Conditionally filter out groups where all advantages are zero
if rl_config.filter_zero_advantage_groups:
filtered_buffer, num_filtered_out = filter_zero_advantage_groups(buffer)
total_filtered_out += num_filtered_out

if num_filtered_out > 0:
logger.info(f"Filtered out {num_filtered_out} samples from groups with zero advantage.")
else:
filtered_buffer = buffer
num_filtered_out = 0

# Write the entries (filtered or unfiltered based on config)
for entry in filtered_buffer:
writer.write(entry)
writing_took = time.time() - start_writing
stats_aggregator.update([len(entry["input_ids"]) for entry in buffer])
published_samples += len(buffer)
max_model_version = max([dataset["model_version"] for dataset in buffer])
stats_aggregator.update([len(entry["input_ids"]) for entry in filtered_buffer])
published_samples += len(filtered_buffer) # Count only written samples
max_model_version = max([entry["model_version"] for entry in filtered_buffer]) if filtered_buffer else 0
samples_in_queue = dataset_queue.qsize() * cfg.preprocess.chunk_size
stats = {
"preprocessor/published_samples": published_samples,
Expand All @@ -406,14 +461,16 @@ def run_preprocessing_loop(
"preprocessor/queue/raw": raw_chunk_queue.qsize(),
"preprocessor/queue/dataset_samples": samples_in_queue,
"preprocessor/queue/dataset": dataset_queue.qsize(),
"preprocessor/filtered_out_samples": num_filtered_out,
"preprocessor/total_filtered_out_samples": total_filtered_out,
}
if stats_aggregator.has_enough_data():
stats.update({"preprocessor/" + k: v for k, v in stats_aggregator.get_stats().items()})
run.log(stats)
stats_writer.write(stats)
processing_took = time.time() - start_processing
logger.info(
f"Processed {len(buffer)} samples in {processing_took:.3f}s"
f"Processed {len(filtered_buffer)} samples (filtered out {num_filtered_out}) in {processing_took:.3f}s"
f" (last fetching took {fetching_took:.3f}, all writing took {writing_took:.3f})"
f" and wrote to {output_stream}, total {published_samples} samples so far,"
f" {samples_in_queue} samples in queue, max buffer entry size {io_buffer._max_written_entry_size}"
Expand Down