Skip to content

Commit

Permalink
black fix
Browse files Browse the repository at this point in the history
Signed-off-by: Evelina Bakhturina <ebakhturina@nvidia.com>
  • Loading branch information
ekmb committed Feb 13, 2020
1 parent 2576bc2 commit 8be0691
Show file tree
Hide file tree
Showing 14 changed files with 46 additions and 31 deletions.
2 changes: 1 addition & 1 deletion examples/nlp/scripts/download_wkt2.py
Expand Up @@ -12,4 +12,4 @@ def download_wkt2(data_dir):
logging.warning(f'Data not found at {data_dir}. Downloading wikitext-2 to {data_dir}/lm/')
data_dir = 'data/lm/wikitext-2'
subprocess.call('get_wkt2.sh')
return data_dir
return data_dir
Expand Up @@ -24,8 +24,7 @@
import nemo.collections.nlp.nm.trainables.common.sequence_classification_nm
from nemo import logging
from nemo.collections.nlp.callbacks.text_classification_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.collections.nlp.data.datasets.text_classification_dataset import \
TextClassificationDataDesc
from nemo.collections.nlp.data.datasets.text_classification_dataset import TextClassificationDataDesc
from nemo.utils.lr_policies import get_lr_policy

# Parsing arguments
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/nlp/data/datasets/__init__.py
Expand Up @@ -26,12 +26,12 @@
)
from nemo.collections.nlp.data.datasets.lm_transformer_dataset import LanguageModelingDataset
from nemo.collections.nlp.data.datasets.machine_translation_dataset import TranslationDataset
from nemo.collections.nlp.data.datasets.multiWOZ_dataset import *
from nemo.collections.nlp.data.datasets.punctuation_capitalization_dataset import (
BertPunctuationCapitalizationDataset,
BertPunctuationCapitalizationInferDataset,
)
from nemo.collections.nlp.data.datasets.qa_squad_dataset import SquadDataset
from nemo.collections.nlp.data.datasets.multiWOZ_dataset import *
from nemo.collections.nlp.data.datasets.text_classification_dataset import BertTextClassificationDataset
from nemo.collections.nlp.data.datasets.token_classification_dataset import (
BertTokenClassificationDataset,
Expand Down
@@ -1,3 +1,3 @@
from nemo.collections.nlp.data.datasets.datasets_utils.datasets_processing import *
from nemo.collections.nlp.data.datasets.datasets_utils.dialogflow_utils import *
from nemo.collections.nlp.data.datasets.datasets_utils.mturk_utils import *
from nemo.collections.nlp.data.datasets.datasets_utils.mturk_utils import *
Expand Up @@ -4,8 +4,13 @@
import shutil

from nemo import logging
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import DATABASE_EXISTS_TMP, MODE_EXISTS_TMP, get_dataset, create_dataset
from nemo.collections.nlp.utils import get_vocab, if_exist, ids2text
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import (
DATABASE_EXISTS_TMP,
MODE_EXISTS_TMP,
create_dataset,
get_dataset,
)
from nemo.collections.nlp.utils import get_vocab, ids2text, if_exist

__all__ = [
'process_atis',
Expand All @@ -14,9 +19,10 @@
'process_sst_2',
'process_imdb',
'process_nlu',
'process_thucnews'
'process_thucnews',
]


def process_atis(infold, uncased, modes=['train', 'test'], dev_split=0):
""" MSFT's dataset, processed by Kaggle
https://www.kaggle.com/siddhadev/atis-dataset-from-ms-cntk
Expand Down Expand Up @@ -372,4 +378,4 @@ def process_nlu(filename, uncased, modes=['train', 'test'], dataset_name='nlu-ub
outfiles['test'].write(txt)
for mode in modes:
outfiles[mode].close()
return outfold
return outfold
Expand Up @@ -87,4 +87,4 @@ def process_dialogflow(data_dir, uncased, modes=['train', 'test'], dev_split=0.1
write_files(slot_labels, f'{outfold}/dict.slots.csv')
write_files(intent_names, f'{outfold}/dict.intents.csv')

return outfold
return outfold
Expand Up @@ -2,7 +2,12 @@
import os

from nemo import logging
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import DATABASE_EXISTS_TMP, read_csv, partition_data, write_files
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import (
DATABASE_EXISTS_TMP,
partition_data,
read_csv,
write_files,
)
from nemo.collections.nlp.utils import if_exist


Expand Down Expand Up @@ -175,4 +180,4 @@ def get_slot_labels(slot_annotations, task_name):
count += 1
all_labels['O'] = str(count)

return all_labels
return all_labels
Expand Up @@ -25,9 +25,7 @@
import numpy as np

from nemo import logging
from nemo.collections.nlp.utils.common_nlp_utils import (
write_vocab,
)
from nemo.collections.nlp.utils.common_nlp_utils import write_vocab

__all__ = [
'get_label_stats',
Expand Down Expand Up @@ -266,5 +264,3 @@ def get_stats(lengths):
)
logging.info(f'75 percentile: {np.percentile(lengths, 75)}')
logging.info(f'99 percentile: {np.percentile(lengths, 99)}')


15 changes: 10 additions & 5 deletions nemo/collections/nlp/data/datasets/joint_intent_slot_dataset.py
Expand Up @@ -26,13 +26,18 @@
from torch.utils.data import Dataset

from nemo import logging
from nemo.collections.nlp.data.datasets.datasets_utils.datasets_processing import (
process_atis,
process_jarvis_datasets,
process_snips,
)
from nemo.collections.nlp.data.datasets.datasets_utils.dialogflow_utils import process_dialogflow
from nemo.collections.nlp.data.datasets.datasets_utils.mturk_utils import process_mturk
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import (
DATABASE_EXISTS_TMP,
get_label_stats,
get_stats,
DATABASE_EXISTS_TMP)
from nemo.collections.nlp.data.datasets.datasets_utils.datasets_processing import process_atis, process_jarvis_datasets, process_snips
from nemo.collections.nlp.data.datasets.datasets_utils.dialogflow_utils import process_dialogflow
from nemo.collections.nlp.data.datasets.datasets_utils.mturk_utils import process_mturk
)
from nemo.collections.nlp.utils import list2str, write_vocab_in_order
from nemo.collections.nlp.utils.common_nlp_utils import calc_class_weights, get_vocab, if_exist, label2idx

Expand Down Expand Up @@ -473,4 +478,4 @@ def merge(data_dir, subdirs, dataset_name, modes=['train', 'test']):

write_vocab_in_order(intents, f'{outfold}/dict.intents.csv')
write_vocab_in_order(slots, f'{outfold}/dict.slots.csv')
return outfold, none_slot
return outfold, none_slot
2 changes: 1 addition & 1 deletion nemo/collections/nlp/data/datasets/lm_bert_dataset.py
Expand Up @@ -24,11 +24,11 @@

import h5py
import numpy as np
from examples.nlp.scripts.download_wkt2 import download_wkt2
from torch.utils.data import Dataset
from tqdm import tqdm

from nemo import logging
from examples.nlp.scripts.download_wkt2 import download_wkt2
from nemo.collections.nlp.data.datasets.lm_transformer_dataset import create_vocab_mlm

__all__ = ['BertPretrainingDataset', 'BertPretrainingPreprocessedDataset']
Expand Down
1 change: 0 additions & 1 deletion nemo/collections/nlp/data/datasets/multiWOZ_dataset.py
Expand Up @@ -294,7 +294,6 @@ def create_vocab(self):
with open(self.vocab_file, 'wb') as handle:
pickle.dump(self.vocab, handle)


def fix_general_label_error_multiwoz(self, labels, slots):
label_dict = dict([label['slots'][0] for label in labels])
GENERAL_TYPO = {
Expand Down
Expand Up @@ -30,7 +30,7 @@
from torch.utils.data import Dataset

from nemo import logging
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import get_stats, get_label_stats
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import get_label_stats, get_stats


def get_features(
Expand Down
15 changes: 10 additions & 5 deletions nemo/collections/nlp/data/datasets/text_classification_dataset.py
Expand Up @@ -26,10 +26,15 @@
from torch.utils.data import Dataset

from nemo import logging
from nemo.collections.nlp.data.datasets.datasets_utils import process_sst_2, process_imdb, process_thucnews, \
process_nlu, process_jarvis_datasets
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import get_stats, get_intent_labels
from nemo.collections.nlp.utils import if_exist, calc_class_weights
from nemo.collections.nlp.data.datasets.datasets_utils import (
process_imdb,
process_jarvis_datasets,
process_nlu,
process_sst_2,
process_thucnews,
)
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import get_intent_labels, get_stats
from nemo.collections.nlp.utils import calc_class_weights, if_exist
from nemo.collections.nlp.utils.callback_utils import list2str

__all__ = ['BertTextClassificationDataset']
Expand Down Expand Up @@ -238,4 +243,4 @@ def __init__(self, dataset_name, data_dir, do_lower_case):
logging.info(f'Class weights are - {self.class_weights}')

logging.info(f'Total Sentences - {total_sents}')
logging.info(f'Sentence class frequencies - {sent_label_freq}')
logging.info(f'Sentence class frequencies - {sent_label_freq}')
Expand Up @@ -29,7 +29,7 @@
from torch.utils.data import Dataset

from nemo import logging
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import get_stats, get_label_stats
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import get_label_stats, get_stats

__all__ = ['BertTokenClassificationDataset', 'BertTokenClassificationInferDataset']

Expand Down

0 comments on commit 8be0691

Please sign in to comment.