Skip to content

Commit

Permalink
black fix
Browse files Browse the repository at this point in the history
Signed-off-by: Evelina Bakhturina <ebakhturina@nvidia.com>
  • Loading branch information
ekmb committed Feb 14, 2020
1 parent d02367a commit 0820752
Show file tree
Hide file tree
Showing 13 changed files with 64 additions and 67 deletions.
Expand Up @@ -413,4 +413,4 @@ def dataset_to_ids(dataset, tokenizer, cache_ids=False, add_bos_eos=True):
if cache_ids:
logging.info("Caching tokenized dataset ...")
pickle.dump(ids, open(cached_ids_dataset, "wb"))
return ids
return ids
34 changes: 2 additions & 32 deletions nemo/collections/nlp/data/datasets/datasets_utils/preprocessing.py
Expand Up @@ -18,8 +18,6 @@
import json
import os
import random
import re
import string
from collections import Counter

import numpy as np
Expand All @@ -41,11 +39,8 @@
'get_data',
'reverse_dict',
'get_intent_labels',
'normalize_answer',
'get_tokens',
'get_stats'
'DATABASE_EXISTS_TMP',
'MODE_EXISTS_TMP'
'get_stats' 'DATABASE_EXISTS_TMP',
'MODE_EXISTS_TMP',
]

DATABASE_EXISTS_TMP = '{} dataset has already been processed and stored at {}'
Expand Down Expand Up @@ -232,31 +227,6 @@ def get_intent_labels(intent_file):
return labels


def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""

def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)

def white_space_fix(text):
return ' '.join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()


def get_stats(lengths):
lengths = np.asarray(lengths)
logging.info(
Expand Down
Expand Up @@ -3,7 +3,8 @@

from nemo import logging

__all__ = ['ColaProcessor',
__all__ = [
'ColaProcessor',
'MnliProcessor',
'MnliMismatchedProcessor',
'MrpcProcessor',
Expand All @@ -12,7 +13,9 @@
'QqpProcessor',
'QnliProcessor',
'RteProcessor',
'WnliProcessor']
'WnliProcessor',
]


class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
Expand Down Expand Up @@ -329,4 +332,4 @@ def __init__(self, guid, text_a, text_b=None, label=None):
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
self.label = label
Expand Up @@ -22,8 +22,9 @@

import numpy as np
from torch.utils.data import Dataset
from nemo.collections.nlp.data.datasets.glue_benchmark_dataset.data_processors import *

from nemo import logging
from nemo.collections.nlp.data.datasets.glue_benchmark_dataset.data_processors import *

__all__ = ['GLUEDataset']

Expand Down Expand Up @@ -63,6 +64,7 @@
"wnli": 2,
}


class GLUEDataset(Dataset):
def __init__(self, data_dir, tokenizer, max_seq_length, processor, output_mode, evaluate, token_params):
self.tokenizer = tokenizer
Expand All @@ -84,7 +86,6 @@ def __getitem__(self, idx):
np.array(feature.label_id),
)


def convert_examples_to_features(
examples,
label_list,
Expand Down Expand Up @@ -235,7 +236,6 @@ def convert_examples_to_features(
)
return features


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length.
Expand All @@ -253,7 +253,6 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
else:
tokens_b.pop()


"""
Utility functions for GLUE tasks
This code was adapted from the HuggingFace library at
Expand All @@ -269,6 +268,3 @@ def __init__(self, input_ids, input_mask, segment_ids, label_id):
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id



Expand Up @@ -2,11 +2,23 @@
import os

from nemo import logging
from nemo.collections.nlp.data import process_atis, process_dialogflow, process_mturk, process_snips, \
process_jarvis_datasets, DATABASE_EXISTS_TMP
from nemo.collections.nlp.data import (
DATABASE_EXISTS_TMP,
process_atis,
process_dialogflow,
process_jarvis_datasets,
process_mturk,
process_snips,
)
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import get_label_stats
from nemo.collections.nlp.utils import if_exist, get_vocab, label2idx, calc_class_weights, list2str, \
write_vocab_in_order
from nemo.collections.nlp.utils import (
calc_class_weights,
get_vocab,
if_exist,
label2idx,
list2str,
write_vocab_in_order,
)


class JointIntentSlotDataDesc:
Expand Down Expand Up @@ -214,4 +226,4 @@ def merge(data_dir, subdirs, dataset_name, modes=['train', 'test']):

write_vocab_in_order(intents, f'{outfold}/dict.intents.csv')
write_vocab_in_order(slots, f'{outfold}/dict.slots.csv')
return outfold, none_slot
return outfold, none_slot
Expand Up @@ -24,9 +24,7 @@
from torch.utils.data import Dataset

from nemo import logging
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import (
get_stats,
)
from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import get_stats

__all__ = ['BertJointIntentSlotDataset', 'BertJointIntentSlotInferDataset']

Expand Down Expand Up @@ -258,5 +256,3 @@ def __getitem__(self, idx):
np.array(self.all_loss_mask[idx]),
np.array(self.all_subtokens_mask[idx]),
)


8 changes: 3 additions & 5 deletions nemo/collections/nlp/data/datasets/lm_bert_dataset.py
Expand Up @@ -29,12 +29,11 @@
from tqdm import tqdm

from nemo import logging

__all__ = ['BertPretrainingDataset', 'BertPretrainingPreprocessedDataset']

from nemo.collections.nlp.data.datasets.datasets_utils.preprocessing import DATABASE_EXISTS_TMP
from nemo.collections.nlp.utils import if_exist

__all__ = ['BertPretrainingDataset', 'BertPretrainingPreprocessedDataset']


class BertPretrainingDataset(Dataset):
def __init__(
Expand Down Expand Up @@ -397,7 +396,6 @@ def __init__(self, dataset_name, data_dir, vocab_size, sample_size, special_toke
self.eval_file = f'{data_dir}/valid.txt'
self.test_file = f'{data_dir}/test.txt'


def create_vocab_mlm(
data_dir, vocab_size, sample_size, special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'], train_file=''
):
Expand Down Expand Up @@ -449,4 +447,4 @@ def create_vocab_mlm(
with open(f'{bert_dir}/vocab.txt', "w") as f:
for token in vocab:
f.write(f"{token}\n".format())
return data_dir, f'{bert_dir}/tokenizer.model'
return data_dir, f'{bert_dir}/tokenizer.model'
Expand Up @@ -19,10 +19,10 @@
import re

import numpy as np
from examples.nlp.scripts.download_wkt2 import download_wkt2
from torch.utils.data import Dataset

from nemo import logging
from examples.nlp.scripts.download_wkt2 import download_wkt2
from nemo.collections.nlp.data.datasets.datasets_utils.datasets_processing import dataset_to_ids
from nemo.collections.nlp.utils.common_nlp_utils import if_exist

Expand Down
Expand Up @@ -156,7 +156,6 @@ def pack_data_into_batches(self, src_ids, tgt_ids):

return batches


def clean_src_and_target(src_ids, tgt_ids, max_tokens=128, min_tokens=3, max_tokens_diff=25, max_tokens_ratio=2.5):
"""
Cleans source and target sentences to get rid of noisy data.
Expand Down
Expand Up @@ -37,9 +37,8 @@
get_final_text,
make_eval_dict,
merge_eval,
normalize_answer,
)
from nemo.collections.nlp.utils.common_nlp_utils import _is_whitespace
from nemo.collections.nlp.utils.common_nlp_utils import _is_whitespace, normalize_answer
from nemo.collections.nlp.utils.loss_utils import _compute_softmax

__all__ = ['SquadDataset']
Expand Down Expand Up @@ -534,5 +533,3 @@ def __init__(
self.end_position = char_to_word_offset[
min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1)
]


Expand Up @@ -259,4 +259,4 @@ def __init__(
self.segment_ids = segment_ids
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
self.is_impossible = is_impossible
3 changes: 1 addition & 2 deletions nemo/collections/nlp/metrics/squad_metrics.py
Expand Up @@ -21,7 +21,7 @@
from transformers.tokenization_bert import BasicTokenizer

from nemo import logging
from nemo.collections.nlp.data.datasets.datasets_utils import get_tokens, normalize_answer
from nemo.collections.nlp.utils.common_nlp_utils import get_tokens, normalize_answer

__all__ = [
'f1_score',
Expand All @@ -31,7 +31,6 @@
'merge_eval',
'find_all_best_thresh',
'find_best_thresh',
'normalize_answer',
'_get_best_indexes',
'get_final_text',
]
Expand Down
27 changes: 27 additions & 0 deletions nemo/collections/nlp/utils/common_nlp_utils.py
Expand Up @@ -34,6 +34,8 @@
'remove_punctuation_from_sentence',
'ids2text',
'calc_class_weights',
'get_tokens',
'normalize_answer',
]


Expand Down Expand Up @@ -142,3 +144,28 @@ def calc_class_weights(label_freq):
most_common_label_freq = label_freq[0]
weighted_slots = sorted([(index, most_common_label_freq[1] / freq) for (index, freq) in label_freq])
return [weight for (_, weight) in weighted_slots]


def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""

def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)

def white_space_fix(text):
return ' '.join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()

0 comments on commit 0820752

Please sign in to comment.