Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix list index out of range error for soda dataset #2326

Merged
merged 12 commits into from
Apr 10, 2023
6 changes: 5 additions & 1 deletion model/model_training/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ defaults:
- cnn_dailymail
- multi_news
- scitldr
# - soda # TODO: fix list index out of range error
- soda:
input_max_length: 1024
- joke
- gsm8k
- dive_mt
Expand Down Expand Up @@ -268,6 +269,9 @@ pythia-70m-deduped:
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
output_dir: pythia_model
datasets:
- vicuna
- soda

pythia-1B:
learning_rate: 8e-6
Expand Down
2 changes: 1 addition & 1 deletion model/model_training/custom_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_one_dataset(
train = ProsocialDialogueExplaination(cache_dir=data_path, split="train")
eval = ProsocialDialogueExplaination(cache_dir=data_path, split="validation")
elif dataset_name == "soda":
dataset = SODA(data_path)
dataset = SODA(data_path, **kwargs)
elif dataset_name == "soda_dialogue":
dataset = SODADialogue(data_path)
elif dataset_name == "joke":
Expand Down
68 changes: 33 additions & 35 deletions model/model_training/custom_datasets/qa_datasets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""
Open / close book QA datasets
"""
import copy
import glob
import json
import os
import re
from collections import defaultdict
from pathlib import Path
from typing import Any
from urllib.request import urlopen

import numpy as np
Expand Down Expand Up @@ -221,58 +221,56 @@ def __getitem__(self, index) -> list[str] | tuple[list[str], list[str]]:
class SODA(Dataset):
name = "soda"

def process_soda_convo(self, data):
pairs = []
def process_soda_convo(self, data: dict[str, Any], input_max_length: int) -> list[list[str]] | None:
play_as = data["speakers"][1]
question, answer = "", ""
prefix, postfix = "", ""
dialogue_bg = "{}{}".format(
# QA_SPECIAL_TOKENS["StartPrefix"],
data["narrative"],
"your are {}".format(play_as),
" You are {}.".format(play_as),
# QA_SPECIAL_TOKENS["EndPrefix"],
)
previous_chat = []

for idx, convo in enumerate(data["dialogue"]):
if idx % 2 == 0:
question = convo
prefix = data["speakers"][idx]
else:
answer = convo
postfix = data["speakers"][idx]

if len(question) and len(answer) and prefix != postfix and postfix == play_as:
history = copy.deepcopy(previous_chat)
history[0] = dialogue_bg + history[0]

# if len(history):
# history += "<sep>"
# prompt = QA_SPECIAL_TOKENS["Question"] + question + QA_SPECIAL_TOKENS["Answer"]
pairs.append(history + [question, answer])
# pairs.append((dialogue_bg + history + prompt, answer))
previous_chat.append(question)
previous_chat.append(answer)

return pairs

def __init__(self, cache_dir, input_max_length=1024) -> None:
# Perform some sanity checks, if these fail return None
# ignore data with more than 2 speakers for now
if len(set(data["speakers"])) != 2:
return None
speaker1 = data["speakers"][0]
speaker2 = data["speakers"][1]
# make sure that the speakers are in correct order [S1, S2, S1, S2, S1, S2], otherwise return None
speaker1_idx = [idx % 2 == 0 for idx, k in enumerate(data["speakers"]) if k == speaker1]
speaker2_idx = [idx % 2 == 1 for idx, k in enumerate(data["speakers"]) if k == speaker2]
if all(speaker1_idx) and all(speaker2_idx):
# add dialog background to first question.
# [Q1, A1, Q2, A2] -> [B + Q1, A1, Q2, A2]
data["dialogue"][0] = f"{dialogue_bg} {data['dialogue'][0]}"
# Use only input_max_length characters
truncated_dialogue = [k[:input_max_length] for k in data["dialogue"]]
return truncated_dialogue

def __init__(self, cache_dir, mode="sft", input_max_length=1024) -> None:
super().__init__()

if mode not in ("sft", "rl"):
raise NotImplementedError(f"Currently only the modes 'sft' and 'rl' are implemented. Received {mode}.")
self.mode = mode
self.pairs = []
dataset = load_dataset("allenai/soda", cache_dir=cache_dir)["train"]
for data in dataset:
self.pairs.append(self.process_soda_convo(data))
if (processed_data := self.process_soda_convo(data, input_max_length=input_max_length)) is not None:
self.pairs.append(processed_data)
# for prompt, answer in data_pair:
# if len(prompt) < input_max_length:
# self.pairs.append((prompt, answer))

def __len__(self):
def __len__(self) -> int:
return len(self.pairs)

def __getitem__(self, index):
def __getitem__(self, index) -> list[str] | tuple[str]:
# special token added during preprocess
return self.pairs[index]
if self.mode == "sft":
return self.pairs[index]
elif self.mode == "rl":
# add prefix + first human question
return (self.pairs[index][0] + " " + self.pairs[index][1],)


class SODADialogue(Dataset):
Expand Down