From cab4b58c1d5f05faeaad12ea2938fa97ed7ef956 Mon Sep 17 00:00:00 2001 From: sampatkalyan <42086723+sampatkalyan@users.noreply.github.com> Date: Thu, 27 Apr 2023 12:48:40 +0530 Subject: [PATCH] Joke explanation generalization (#2899) For the Issue #2827. I have made changes to JokeExplaniation Class. This PR implements the DatasetEntry class in the JokeExplaination class to generalize the data. The DatasetEntry class provides a consistent data structure for storing joke-explanation pairs, making it easier to work with the data. and made changes in AlpacaGpt4 to correct the annotation in one of its methods. The changes in this PR include: - Adding a new DatasetEntry class to represent joke-explanation pairs - Updating the JokeExplaination class to use DatasetEntry objects to store data - Replacing the AlpacaGpt4 class __getitem__ method with correct annotation --------- Co-authored-by: sampatkalyan <120446217+Andavarapu-Sampat-Kalyan@users.noreply.github.com> --- .../model_training/custom_datasets/qa_datasets.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index 477f068746..f2756f19cd 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -333,8 +333,6 @@ def __init__(self, cache_dir) -> None: with open(joke_explain_filename, "w") as fout: fout.write(content) - question = "" - answer = "" self.pairs = [] with open(joke_explain_filename, "r") as f: for line in f: @@ -343,16 +341,12 @@ def __init__(self, cache_dir) -> None: # DO NOT change this # its the data that had syntax error explanation = data["explaination"] - self.pairs.append((joke, explanation)) + self.pairs.append(DatasetEntry(questions=[joke], answers=[explanation])) - if len(question) > 0 and len(answer) > 0: - self.pairs.append((question, answer)) - self.length = len(self.pairs) - - def __len__(self): + def __len__(self) -> int: return self.length - def __getitem__(self, index): + def __getitem__(self, index) -> DatasetEntry: return self.pairs[index] @@ -610,6 +604,6 @@ def _process_instruction(self, row: dict[str, str], input_max_length: int) -> Da def __len__(self) -> int: return len(self.rows) - def __getitem__(self, index: int) -> list[str] | tuple[str]: + def __getitem__(self, index: int) -> DatasetEntry: dialogue = self.rows[index] return dialogue