Skip to content

Commit

Permalink
refactor dataset utils
Browse files Browse the repository at this point in the history
Signed-off-by: Evelina Bakhturina <ebakhturina@nvidia.com>
  • Loading branch information
ekmb committed Feb 13, 2020
1 parent 2410ba8 commit 2576bc2
Show file tree
Hide file tree
Showing 16 changed files with 1,161 additions and 1,155 deletions.
Expand Up @@ -30,7 +30,7 @@
from nemo import logging
from nemo.backends.pytorch.common import EncoderRNN
from nemo.collections.nlp.callbacks.state_tracking_trade_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.collections.nlp.data.datasets.state_tracking_trade_dataset import MultiWOZDataDesc
from nemo.collections.nlp.data.datasets.multiWOZ_dataset import MultiWOZDataDesc
from nemo.utils.lr_policies import get_lr_policy

parser = argparse.ArgumentParser(description='Dialog state tracking with TRADE model on MultiWOZ dataset')
Expand Down
15 changes: 15 additions & 0 deletions examples/nlp/scripts/download_wkt2.py
@@ -0,0 +1,15 @@
import os
import subprocess

from nemo import logging


def download_wkt2(data_dir):
if os.path.exists(data_dir):
logging.warning(f'Folder {data_dir} found. Skipping download.')
return
os.makedirs(os.path.join(data_dir, 'lm'), exist_ok=True)
logging.warning(f'Data not found at {data_dir}. Downloading wikitext-2 to {data_dir}/lm/')
data_dir = 'data/lm/wikitext-2'
subprocess.call('get_wkt2.sh')
return data_dir
Expand Up @@ -24,7 +24,8 @@
import nemo.collections.nlp.nm.trainables.common.sequence_classification_nm
from nemo import logging
from nemo.collections.nlp.callbacks.text_classification_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.collections.nlp.data.datasets.text_classification_dataset import SentenceClassificationDataDesc
from nemo.collections.nlp.data.datasets.text_classification_dataset import \
TextClassificationDataDesc
from nemo.utils.lr_policies import get_lr_policy

# Parsing arguments
Expand Down Expand Up @@ -93,7 +94,7 @@
hidden_size = pretrained_bert_model.hidden_size
tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_model)

data_desc = SentenceClassificationDataDesc(args.dataset_name, args.data_dir, args.do_lower_case)
data_desc = TextClassificationDataDesc(args.dataset_name, args.data_dir, args.do_lower_case)

# Create sentence classification loss on top
classifier = nemo.collections.nlp.nm.trainables.common.sequence_classification_nm.SequenceClassifier(
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/data/datasets/__init__.py
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
# =============================================================================

from nemo.collections.nlp.data.datasets import datasets_utils
from nemo.collections.nlp.data.datasets.datasets_utils import *
from nemo.collections.nlp.data.datasets.glue_benchmark_dataset import GLUEDataset
from nemo.collections.nlp.data.datasets.joint_intent_slot_dataset import (
BertJointIntentSlotDataset,
Expand All @@ -31,7 +31,7 @@
BertPunctuationCapitalizationInferDataset,
)
from nemo.collections.nlp.data.datasets.qa_squad_dataset import SquadDataset
from nemo.collections.nlp.data.datasets.state_tracking_trade_dataset import *
from nemo.collections.nlp.data.datasets.multiWOZ_dataset import *
from nemo.collections.nlp.data.datasets.text_classification_dataset import BertTextClassificationDataset
from nemo.collections.nlp.data.datasets.token_classification_dataset import (
BertTokenClassificationDataset,
Expand Down

0 comments on commit 2576bc2

Please sign in to comment.