Skip to content

Commit

Permalink
style fix
Browse files Browse the repository at this point in the history
Signed-off-by: Zhilin Wang <zhilinw@nvidia.com>
  • Loading branch information
Zhilin123 committed Apr 26, 2022
1 parent c3c3f25 commit 4fe591b
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@
class DialogueDesignDataProcessor(DialogueDataProcessor):
"""Data Processor for Design Dataset"""

def __init__(self, data_dir: str, tokenizer: object):
def __init__(self, data_dir: str, tokenizer: object, cfg=None):
"""
Constructs DialogueDesignDataProcessor
Args:
data_dir: path to data directory
tokenizer: tokenizer object
cfg: cfg container for dataset
"""
self.data_dir = data_dir
self._tokenizer = tokenizer
self.cfg = cfg

def open_csv(self, filename):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ class DialogueMellonQADataProcessor(DialogueDataProcessor):
"""Data Processor for Mellon QA dialogues.
"""

def __init__(self, data_dir: str, tokenizer: object):
def __init__(self, data_dir: str, tokenizer: object, cfg=None):
"""
Constructs DialogueMSMarcoDataProcessor
Args:
data_dir: path to data directory
tokenizer: tokenizer object
cfg: cfg container for dataset
"""
self.data_dir = data_dir
self._tokenizer = tokenizer
self.cfg = cfg

def open_csv(self, filename):
"""
Expand All @@ -52,7 +54,7 @@ def get_dialog_examples(self, dataset_split: str):
Process raw files into DialogueInputExample
Args:
dataset_split: {train, dev, test}
For the assistant dataset, there is no explicit dev set (instead uses the test set as the dev set)
For the Mellon QA dataset, there is no explicit dev set (instead uses the test set as the dev set)
Therefore, this function creates a dev set and a new train set from the train set.
Dev set contains self.cfg.dev_proportion % of samples with the rest going into the train set
Test set contains the whole dataset (Dev + Train) as this dataset is small (~100) and primarily used in a zero shot setting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@ class DialogueMSMarcoDataProcessor(DialogueDataProcessor):
https://msmarco.blob.core.windows.net/msmarco/eval_v2.1_public.json.gz
"""

def __init__(self, data_dir: str, tokenizer: object, debug_mode=False):
def __init__(self, data_dir: str, tokenizer: object, cfg=None):
"""
Constructs DialogueMSMarcoDataProcessor
Args:
data_dir: path to data directory
tokenizer: tokenizer object
debug_mode: reduce number of samples to load in order to increase speed of processing
cfg: cfg container for dataset
"""
self.data_dir = data_dir
self._tokenizer = tokenizer
self.debug_mode = debug_mode
self.cfg = cfg

def open_json(self, filename):
"""
Expand Down Expand Up @@ -83,7 +84,7 @@ def get_dialog_examples(self, dataset_split: str):
elif dataset_split == "test":
idxs = list(range(len(n_samples)))

if self.debug_mode:
if self.cfg.debug_mode:
idxs = idxs[:1000]

for i in idxs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def __init__(
data_dir: path to data directory
dialogues_example_dir: path to store processed dialogue examples
tokenizer: tokenizer object
schemas: schema object
schema_config: schema configuration
subsample: whether to balance positive and negative samples in dataset
cfg: cfg container for dataset
"""
self.data_dir = data_dir
self.cfg = cfg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def prepare_data(self):
)
elif self._cfg.dataset.task == 'design':
self.dialogues_processor = DialogueDesignDataProcessor(
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer,
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer, cfg=self._cfg.dataset,
)
else:
raise ValueError("Only sgd, assistant, zero_shot, design supported for Dialogue GPT Classification Model")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,11 @@ def prepare_data(self):

if self._cfg.dataset.task == "ms_marco":
self.dialogues_processor = DialogueMSMarcoDataProcessor(
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer, debug_mode=self.cfg.dataset.debug_mode
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer, cfg=self._cfg.dataset
)
elif self._cfg.dataset.task == "mellon_qa":
self.dialogues_processor = DialogueMellonQADataProcessor(
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer, cfg=self._cfg.dataset
)
else:
raise ValueError("Only ms_marco and mellon_qa supported for Dialogue GPT Generation Model")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, dataset_split) -> 'torc
)
elif self._cfg.dataset.task == "design":
self.data_processor = DialogueDesignDataProcessor(
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer,
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer, cfg=self._cfg.dataset
)
elif self._cfg.dataset.task == 'sgd':
self.data_processor = DialogueSGDDataProcessor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def prepare_data(self):

if self._cfg.dataset.task == "ms_marco":
self.dialogues_processor = DialogueMSMarcoDataProcessor(
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer, debug_mode=self.cfg.dataset.debug_mode
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer, cfg=self._cfg.dataset
)
elif self._cfg.dataset.task == "sgd_generation":
self.dialogues_processor = DialogueSGDDataProcessor(
Expand All @@ -271,7 +271,7 @@ def prepare_data(self):
)
elif self._cfg.dataset.task == "mellon_qa":
self.dialogues_processor = DialogueMellonQADataProcessor(
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer, cfg=self._cfg.dataset
)
else:
raise ValueError("Only ms_marco, sgd_generation and mellon_qa supported for Dialogue GPT Generation Model")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, dataset_split) -> 'torc
)
elif self._cfg.dataset.task == "design":
self.data_processor = DialogueDesignDataProcessor(
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer,
data_dir=self._cfg.dataset.data_dir, tokenizer=self.tokenizer, cfg=self._cfg.dataset
)
elif self._cfg.dataset.task == 'sgd':
self.data_processor = DialogueSGDDataProcessor(
Expand Down

0 comments on commit 4fe591b

Please sign in to comment.