diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index a0b2714a82..c0eab37700 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -603,7 +603,7 @@ def __init__(self, cache_dir: str | Path, mode: str = "sft", input_max_length: i if (conv := self._process_instruction(line, input_max_length)) is not None: self.rows.append(conv) - def _process_instruction(self, row: dict[str, str], input_max_length: int) -> list[str] | None: + def _process_instruction(self, row: dict[str, str], input_max_length: int) -> DatasetEntry | None: # discard items that are too long: when checked on 2023-04-17 this was just one item in the whole dataset with length above 2048. # And 12 above 1024. if len(row["input"]) + len(row["instruction"]) > input_max_length: @@ -615,18 +615,17 @@ def _process_instruction(self, row: dict[str, str], input_max_length: int) -> li or (not row["input"]) or (row["input"].lower() in row["instruction"].lower()) ): - return [row["instruction"], row["output"]] + return DatasetEntry(questions=[row["instruction"]], answers=[row["output"]]) # Concatenate the instruction and input. else: linking_char = random.choice(LINKING_CHARS) - return [f"{row['instruction']}{linking_char}{row['input']}", row["output"]] + return DatasetEntry( + questions=[f"{row['instruction']}{linking_char}{row['input']}"], answers=[row["output"]] + ) def __len__(self) -> int: return len(self.rows) def __getitem__(self, index: int) -> list[str] | tuple[str]: - dialogue: list[str] = self.rows[index] - if self.mode == "sft": - return dialogue - elif self.mode == "rl": - return tuple(dialogue[:-1]) + dialogue = self.rows[index] + return dialogue