diff --git a/model/model_training/custom_datasets/toxic_conversation.py b/model/model_training/custom_datasets/toxic_conversation.py index a34fd94d60..61ddad9233 100644 --- a/model/model_training/custom_datasets/toxic_conversation.py +++ b/model/model_training/custom_datasets/toxic_conversation.py @@ -20,7 +20,12 @@ class ProsocialDialogueExplaination(Dataset): def __init__(self, split="train", cache_dir=".cache") -> None: super().__init__() - dataset = load_dataset("Englishman2022/prosocial-dialog-filtered", cache_dir=cache_dir)[split] + dataset = load_dataset( + "Englishman2022/prosocial-dialog-filtered", + data_files="train.json", + cache_dir=cache_dir, + revision="e121e4fd886fadc030d633274c053b71839f9c20", + )[split] self.pairs = [] for row in dataset: for safety_annotation, safe_answer in zip(row["safety_annotations"], row["safety_annotation_reasons"]): @@ -54,7 +59,12 @@ class ProsocialDialogue(Dataset): def __init__(self, split="train", cache_dir=".cache") -> None: super().__init__() - dataset = load_dataset("Englishman2022/prosocial-dialog-filtered", cache_dir=cache_dir)[split] + dataset = load_dataset( + "Englishman2022/prosocial-dialog-filtered", + data_files="train.json", + cache_dir=cache_dir, + revision="e121e4fd886fadc030d633274c053b71839f9c20", + )[split] self.pairs = [] for row in dataset: prompt = row["context"]