diff --git a/model/model_training/custom_datasets/qa_datasets.py b/model/model_training/custom_datasets/qa_datasets.py index ed778ea09b..20f5ea851f 100644 --- a/model/model_training/custom_datasets/qa_datasets.py +++ b/model/model_training/custom_datasets/qa_datasets.py @@ -275,32 +275,17 @@ def __getitem__(self, index) -> list[str] | tuple[str]: return (self.pairs[index][0] + " " + self.pairs[index][1],) -class SODADialogue(Dataset): - url = "https://drive.google.com/uc?id=1TOGQfr419n8wpzJpYLLw4nB3tSKD8zXV" - +class SODADialogue: def __init__(self, cache_dir, verbose=True): - path = os.path.join(cache_dir, "soda_dialog.jsonl") - - if not os.path.exists(path): - import gzip - import shutil - - import gdown - - gdown.download(self.url, output=os.path.join(cache_dir, "soda_dialog.jsonl.gz")) - - with gzip.open(os.path.join(cache_dir, "soda_dialog.jsonl.gz"), "rb") as f_in: - with open(path, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) + dataset = load_dataset("emozilla/soda_synthetic_dialogue", cache_dir=cache_dir) self.pairs = [] faulty = 0 - with open(path) as fin: - for line in fin: - conversation = json.loads(line) + for split in dataset: + for row in dataset[split]: question_answer_pairs = () - question_answers = conversation["text"].split("User: ") + question_answers = row["conversation"].split("User: ") for question_answer in question_answers[1:]: # first element is empty try: question, answer = question_answer.split("\nAssistant: ")