Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NLP refactoring - Stage 2 #368

Merged
merged 124 commits into from Feb 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
124 commits
Select commit Hold shift + click to select a range
2576bc2
refactor dataset utils
ekmb Feb 13, 2020
8be0691
black fix
ekmb Feb 13, 2020
ce70f26
refactor datasets
ekmb Feb 13, 2020
beafa1e
import fixes
ekmb Feb 13, 2020
98ec131
import fix
ekmb Feb 13, 2020
d02367a
import fix
ekmb Feb 13, 2020
0820752
black fix
ekmb Feb 14, 2020
5b74599
import fixes
ekmb Feb 14, 2020
60f6e3c
wip black fix
ekmb Feb 14, 2020
e29638b
Merge branch 'master' of https://github.com/NVIDIA/NeMo into nlp_refa…
ekmb Feb 14, 2020
4428f37
import fixed
ekmb Feb 14, 2020
8b1d72d
import fixed
ekmb Feb 14, 2020
79cb8f0
trade example fix
ekmb Feb 14, 2020
04311df
utils refacoting
ekmb Feb 14, 2020
b78496d
text_cl fixed
ekmb Feb 14, 2020
563b69b
rm output
ekmb Feb 14, 2020
67957a4
glue fix, jenkins updated
ekmb Feb 14, 2020
0f49b0c
revert jenkins
ekmb Feb 14, 2020
97eeb39
Removed CrossEntropy3D.
VahidooX Feb 14, 2020
346fc52
Merge remote-tracking branch 'remote/nlp_refactoring_stage2' into nlp…
VahidooX Feb 14, 2020
0a08d27
Fixed bug.
VahidooX Feb 14, 2020
0400ce4
Fixed style.
VahidooX Feb 14, 2020
9ead68e
wikitext down fixed, jenkins update
ekmb Feb 15, 2020
3929d46
Loss names updated.
VahidooX Feb 15, 2020
51b36c4
Merge branch 'nlp_refactoring_stage2' of https://github.com/NVIDIA/Ne…
VahidooX Feb 15, 2020
c9a29dc
Fixed bugs.
VahidooX Feb 15, 2020
98578d1
Fixed bugs.
VahidooX Feb 15, 2020
ee5cc09
jenkins
ekmb Feb 15, 2020
52ee8e7
Merge branch 'nlp_refactoring_stage2' of https://github.com/NVIDIA/Ne…
ekmb Feb 15, 2020
f5478b0
Fixed targets_id bug.
VahidooX Feb 15, 2020
c70e908
Merge remote-tracking branch 'remote/nlp_refactoring_stage2' into nlp…
VahidooX Feb 15, 2020
87bf3a9
jenkins update
ekmb Feb 15, 2020
54411ea
jenkins update
ekmb Feb 15, 2020
1749d96
jenkins update
ekmb Feb 15, 2020
aa51eb1
Bug fixed.
VahidooX Feb 15, 2020
d2b8563
Merge remote-tracking branch 'remote/nlp_refactoring_stage2' into nlp…
VahidooX Feb 15, 2020
5105624
Bug fixed.
VahidooX Feb 15, 2020
5bedbeb
DataProcessor class moved to datasets_utils/
ekmb Feb 18, 2020
530b5c0
Bug fixed.
VahidooX Feb 18, 2020
45147ff
Style fixed.
VahidooX Feb 18, 2020
33b5a03
Fixed bugs.
VahidooX Feb 18, 2020
9d35193
Fixed bugs.
VahidooX Feb 18, 2020
491282b
Added boolmasktype.
VahidooX Feb 18, 2020
1446195
Fixed style.
VahidooX Feb 18, 2020
9bd942c
Fixed style.
VahidooX Feb 18, 2020
77f8225
Remove vocab file in jenkins.
VahidooX Feb 18, 2020
4cd21e8
Added boolmasktype.
VahidooX Feb 18, 2020
a7d47b0
Styel fixed.
VahidooX Feb 18, 2020
b5ccac4
bug fixed.
VahidooX Feb 18, 2020
a82f248
Added optional port to smoothed_cross_entropy_loss
VahidooX Feb 18, 2020
848c7fc
Added MaskType
VahidooX Feb 19, 2020
bc115e6
Replaced TokenClassifcaitionLoss with crossentropyloss.
VahidooX Feb 19, 2020
ce77ce7
Replaced TokenClassifcaitionLoss with crossentropyloss.
VahidooX Feb 19, 2020
9446bee
Replaced TokenClassifcaitionLoss with crossentropyloss.
VahidooX Feb 19, 2020
6fb9b75
Style fixed.
VahidooX Feb 19, 2020
a6904e8
Removed tokenclassificationloss
VahidooX Feb 19, 2020
eaa6f5b
Added back logits_dim to CrossEntropyLoss.
VahidooX Feb 19, 2020
52c30ac
Fixed bug.
VahidooX Feb 19, 2020
c8978cc
Fixed bug.
VahidooX Feb 19, 2020
b21ff0e
Fixed bug.
VahidooX Feb 19, 2020
0f1845a
Fixed bug.
VahidooX Feb 19, 2020
e02b5f3
Fixed bug.
VahidooX Feb 19, 2020
275e1dc
Fixed bug.
VahidooX Feb 19, 2020
c56b2e0
Fixed bug.
VahidooX Feb 19, 2020
52f1c1e
Fixed bug.
VahidooX Feb 19, 2020
9f37664
Fixed bug.
VahidooX Feb 19, 2020
018b3a2
Merge remote-tracking branch 'remote/master' into nlp_refactoring_stage2
VahidooX Feb 19, 2020
29962af
Merge remote-tracking branch 'remote/nlp_refactoring_stage2' into nlp…
VahidooX Feb 19, 2020
043a935
Chnaged crossentropyloss to crossentropylossnm
VahidooX Feb 19, 2020
213dbb2
refactored and added doc string to data_layers folder
yzhang123 Feb 19, 2020
8414fd4
Merge branch 'nlp_refactoring_stage2' of https://github.com/NVIDIA/Ne…
yzhang123 Feb 19, 2020
b0800d0
Bug fixed.
VahidooX Feb 19, 2020
b85291b
Bug fixed.
VahidooX Feb 19, 2020
062a33c
added doc string to trainables folder file
yzhang123 Feb 19, 2020
e2cd27c
Merge branch 'nlp_refactoring_stage2' of https://github.com/NVIDIA/Ne…
yzhang123 Feb 19, 2020
be55f15
adding header to squad dataset
yzhang123 Feb 19, 2020
83b39e7
refactored nm folder utils
yzhang123 Feb 19, 2020
c1a76e7
wkt refactored
ekmb Feb 19, 2020
cfd7c7d
Merge remote-tracking branch 'origin/master' into nlp_refactoring_stage2
yzhang123 Feb 19, 2020
01cd1f9
black
ekmb Feb 19, 2020
ba954a4
Merge branch 'nlp_refactoring_stage2' of https://github.com/NVIDIA/Ne…
yzhang123 Feb 19, 2020
6ca7468
docstrings added
ekmb Feb 19, 2020
6922388
fix bug
yzhang123 Feb 19, 2020
23d6e93
Bug fixed.
VahidooX Feb 19, 2020
fba882a
Bug fixed.
VahidooX Feb 20, 2020
3929f29
Merge remote-tracking branch 'remote/master' into nlp_refactoring_stage2
VahidooX Feb 20, 2020
ffa0582
Bug fixed.
VahidooX Feb 20, 2020
c9ed93c
Bug fixed.
VahidooX Feb 20, 2020
b5f1e49
Added weights to lossaggregator.
VahidooX Feb 20, 2020
6855ca9
Added lossaggregator to jointintentsloss.
VahidooX Feb 20, 2020
6f06717
Added lossaggregator to jointintentsloss.
VahidooX Feb 20, 2020
b657b2d
Added lossaggregator to jointintentsloss.
VahidooX Feb 20, 2020
4013c38
Style fixed.
VahidooX Feb 20, 2020
5e8ec48
Style fixed.
VahidooX Feb 20, 2020
6c4e92c
shuffle removed, lgtm fix
ekmb Feb 20, 2020
e031405
Fixed lgtm warnings.
VahidooX Feb 20, 2020
393705d
Merge remote-tracking branch 'remote/nlp_refactoring_stage2' into nlp…
VahidooX Feb 20, 2020
1071ce7
lgtm fix
ekmb Feb 20, 2020
3a6c370
Merge remote-tracking branch 'origin/master' into nlp_refactoring_stage2
yzhang123 Feb 20, 2020
0df6b45
Moved lossaggregator to backedn.common.
VahidooX Feb 20, 2020
f3c2a62
Merge branch 'nlp_refactoring_stage2' of https://github.com/NVIDIA/Ne…
ekmb Feb 20, 2020
208b1e7
changelog updated
ekmb Feb 20, 2020
32c8c31
Merge branch 'nlp_refactoring_stage2' of https://github.com/NVIDIA/Ne…
ekmb Feb 20, 2020
b7c4c34
bug fixed
ekmb Feb 20, 2020
aa3f460
jenkins fixed to used the same bert in all tests
ekmb Feb 20, 2020
607dbc4
Fixed comments.
VahidooX Feb 20, 2020
4cbda8a
jenkins bert default fixed
ekmb Feb 20, 2020
5a0583f
Fixed comments.
VahidooX Feb 20, 2020
5892150
merged with master
ekmb Feb 20, 2020
de666cb
deleted unused losses
ekmb Feb 20, 2020
ca8af50
import fix
ekmb Feb 20, 2020
7914f0d
Fixed comments.
VahidooX Feb 20, 2020
9caa0c6
Merge remote-tracking branch 'remote/nlp_refactoring_stage2' into nlp…
VahidooX Feb 20, 2020
d850f34
bug fixed
ekmb Feb 20, 2020
88a95b5
Merge branch 'nlp_refactoring_stage2' of https://github.com/NVIDIA/Ne…
ekmb Feb 20, 2020
963ea89
Resolved reviews.
VahidooX Feb 20, 2020
33ac3df
import fixed
ekmb Feb 20, 2020
ac9f023
Merge remote-tracking branch 'origin/master' into nlp_refactoring_stage2
ekmb Feb 20, 2020
75271bd
Updated comments.
VahidooX Feb 20, 2020
47b8f5c
Merge branch 'nlp_refactoring_stage2' of https://github.com/NVIDIA/Ne…
VahidooX Feb 20, 2020
bf2d7b2
licences updated
ekmb Feb 20, 2020
56b6ea1
Fixed lgtm warnings.
VahidooX Feb 20, 2020
5745b56
Fixed lgtm warnings.
VahidooX Feb 20, 2020
6afa104
Fixed bug.
VahidooX Feb 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 7 additions & 4 deletions CHANGELOG.md
Expand Up @@ -85,6 +85,13 @@ files, along with unit tests, examples and tutorials
([PR #375](https://github.com/NVIDIA/NeMo/pull/375)) - @titu1994

### Changed
- Refactoring of `nemo_nlp` collections:
([PR #368](https://github.com/NVIDIA/NeMo/pull/368)) - @VahidooX, @yzhang123, @ekmb
- renaming and restructuring of files, folder, and functions in `nemo_nlp`
- losses cleaned up. LossAggregatorNM moved to nemo/backends/pytorch/common/losses
([PR #316](https://github.com/NVIDIA/NeMo/pull/316)) - @VahidooX, @yzhang123, @ekmb
- renaming and restructuring of files, folder, and functions in `nemo_nlp`
- Updated licenses
- All collections changed to use New Neural Type System.
([PR #307](https://github.com/NVIDIA/NeMo/pull/307)) - @okuchaiev
- Additional Collections Repositories merged into core `nemo_toolkit` package.
Expand All @@ -95,10 +102,6 @@ files, along with unit tests, examples and tutorials
([PR #286](https://github.com/NVIDIA/NeMo/pull/286)) - @stasbel
- Major cleanup of Neural Module constructors (init), aiming at increasing the framework robustness: cleanup of NeuralModule initialization logic, refactor of trainer/actions (getting rid of local_params), fixes of several examples and unit tests, extraction and storing of intial parameters (init_params).
([PR #309](https://github.com/NVIDIA/NeMo/pull/309)) - @tkornuta-nvidia
- Refactoring of `nemo_nlp` collections:
([PR #316](https://github.com/NVIDIA/NeMo/pull/316)) - @VahidooX, @yzhang123, @ekmb
- renaming of files and restructuring of folder in `nemo_nlp`
- Updated licenses
- Updated nemo's use of the logging library. from nemo import logging is now the reccomended way of using the nemo logger. neural_factory.logger and all other instances of logger are now deprecated and planned for removal in the next version. Please see PR 267 for complete change information.
([PR #267](https://github.com/NVIDIA/NeMo/pull/267), [PR #283](https://github.com/NVIDIA/NeMo/pull/283), [PR #305](https://github.com/NVIDIA/NeMo/pull/305), [PR #311](https://github.com/NVIDIA/NeMo/pull/311)) - @blisc
- Changed Distributed Data Parallel from Apex to Torch
Expand Down
10 changes: 6 additions & 4 deletions Jenkinsfile
Expand Up @@ -91,19 +91,21 @@ pipeline {
parallel {
stage ('Text Classification with BERT Test') {
steps {
sh 'cd examples/nlp/text_classification && CUDA_VISIBLE_DEVICES=0 python text_classification_with_bert.py --num_epochs=1 --max_seq_length=50 --dataset_name=jarvis --data_dir=/home/TestData/nlp/retail/ --eval_file_prefix=eval --batch_size=10 --num_train_samples=-1 --do_lower_case --shuffle_data --work_dir=outputs'
sh 'cd examples/nlp/text_classification && CUDA_VISIBLE_DEVICES=0 python text_classification_with_bert.py --pretrained_bert_model bert-base-uncased --num_epochs=1 --max_seq_length=50 --dataset_name=jarvis --data_dir=/home/TestData/nlp/retail/ --eval_file_prefix=eval --batch_size=10 --num_train_samples=-1 --do_lower_case --shuffle_data --work_dir=outputs'
sh 'rm -rf examples/nlp/text_classification/outputs'
}
}
stage ('Dialogue State Tracking - TRADE - Multi-GPUs') {
steps {
sh 'rm -rf /home/TestData/nlp/multiwoz2.1/vocab.pkl'
sh 'cd examples/nlp/dialogue_state_tracking && CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 dialogue_state_tracking_trade.py --batch_size=10 --eval_batch_size=10 --num_train_samples=-1 --num_eval_samples=-1 --num_epochs=1 --dropout=0.2 --eval_file_prefix=test --shuffle_data --num_gpus=2 --lr=0.001 --grad_norm_clip=10 --work_dir=outputs --data_dir=/home/TestData/nlp/multiwoz2.1'
sh 'rm -rf examples/nlp/dialogue_state_tracking/outputs'
sh 'rm -rf /home/TestData/nlp/multiwoz2.1/vocab.pkl'
}
}
stage ('GLUE Benchmark Test') {
steps {
sh 'cd examples/nlp/glue_benchmark && CUDA_VISIBLE_DEVICES=1 python glue_benchmark_with_bert.py --data_dir /home/TestData/nlp/glue_fake/MRPC --work_dir glue_output --save_step_freq -1 --num_epochs 1 --task_name mrpc --batch_size 2'
sh 'cd examples/nlp/glue_benchmark && CUDA_VISIBLE_DEVICES=1 python glue_benchmark_with_bert.py --data_dir /home/TestData/nlp/glue_fake/MRPC --pretrained_bert_model bert-base-uncased --work_dir glue_output --save_step_freq -1 --num_epochs 1 --task_name mrpc --batch_size 2'
sh 'rm -rf examples/nlp/glue_benchmark/glue_output'
}
}
Expand All @@ -122,8 +124,8 @@ pipeline {
parallel {
stage('Token Classification Training/Inference Test') {
steps {
sh 'cd examples/nlp/token_classification && CUDA_VISIBLE_DEVICES=0 python token_classification.py --data_dir /home/TestData/nlp/token_classification_punctuation/ --batch_size 2 --num_epochs 1 --save_epoch_freq 1 --work_dir token_classification_output --pretrained_bert_model bert-base-cased'
sh 'cd examples/nlp/token_classification && DATE_F=$(ls token_classification_output/) && CUDA_VISIBLE_DEVICES=0 python token_classification_infer.py --work_dir token_classification_output/$DATE_F/checkpoints/ --labels_dict /home/TestData/nlp/token_classification_punctuation/label_ids.csv --pretrained_bert_model bert-base-cased'
sh 'cd examples/nlp/token_classification && CUDA_VISIBLE_DEVICES=0 python token_classification.py --data_dir /home/TestData/nlp/token_classification_punctuation/ --batch_size 2 --num_epochs 1 --save_epoch_freq 1 --work_dir token_classification_output --pretrained_bert_model bert-base-uncased'
sh 'cd examples/nlp/token_classification && DATE_F=$(ls token_classification_output/) && CUDA_VISIBLE_DEVICES=0 python token_classification_infer.py --work_dir token_classification_output/$DATE_F/checkpoints/ --labels_dict /home/TestData/nlp/token_classification_punctuation/label_ids.csv --pretrained_bert_model bert-base-uncased'
sh 'rm -rf examples/nlp/token_classification/token_classification_output'
}
}
Expand Down
2 changes: 1 addition & 1 deletion docs/sources/source/nlp/bert_pretraining.rst
Expand Up @@ -191,7 +191,7 @@ For training from raw text use nemo_nlp.BertPretrainingDataLayer, for preprocess

mlm_logits = mlm_classifier(hidden_states=hidden_states)
mlm_loss = mlm_loss_fn(logits=mlm_logits,
output_ids=input_data.output_ids,
labels=input_data.output_ids,
output_mask=input_data.output_mask)

nsp_logits = nsp_classifier(hidden_states=hidden_states)
Expand Down
4 changes: 2 additions & 2 deletions docs/sources/source/nlp/joint_intent_slot_filling.rst
Expand Up @@ -3,9 +3,9 @@ Tutorial

In this tutorial, we are going to implement a joint intent and slot filling system with pretrained BERT model based on
`BERT for Joint Intent Classification and Slot Filling <https://arxiv.org/abs/1902.10909>`_ :cite:`nlp-slot-chen2019bert`.
All code used in this tutorial is based on ``examples/nlp/joint_intent_slot_with_bert.py``.
All code used in this tutorial is based on ``examples/nlp/intent_detection_slot_tagging/joint_intent_slot_with_bert.py``.

There are four pre-trained BERT models that we can select from using the argument `--pretrained_bert_model`. We're currently
There are a variety pre-trained BERT models that we can select from using the argument `--pretrained_bert_model`. We're currently
using the script for loading pre-trained models from `pytorch_transformers`. See the list of available pre-trained models
`here <https://huggingface.co/pytorch-transformers/pretrained_models.html>`__.

Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/asr_postprocessor/asr_postprocessor.py
Expand Up @@ -113,7 +113,7 @@
args.d_model, num_classes=vocab_size, num_layers=1, log_softmax=True
)

loss_fn = nemo_nlp.nm.losses.PaddedSmoothedCrossEntropyLossNM(pad_id=tokenizer.pad_id, label_smoothing=0.1)
loss_fn = nemo_nlp.nm.losses.SmoothedCrossEntropyLoss(pad_id=tokenizer.pad_id, label_smoothing=0.1)

beam_search = nemo_nlp.nm.trainables.BeamSearchTranslatorNM(
decoder=decoder,
Expand Down Expand Up @@ -174,7 +174,7 @@ def create_pipeline(dataset, tokens_in_batch, clean=False, training=True):
input_ids_tgt=tgt, hidden_states_src=src_hiddens, input_mask_src=src_mask, input_mask_tgt=tgt_mask
)
log_softmax = t_log_softmax(hidden_states=tgt_hiddens)
loss = loss_fn(logits=log_softmax, target_ids=labels)
loss = loss_fn(logits=log_softmax, labels=labels)
beam_results = None
if not training:
beam_results = beam_search(hidden_states_src=src_hiddens, input_mask_src=src_mask)
Expand Down
Expand Up @@ -25,12 +25,14 @@

import numpy as np

import nemo.backends.pytorch as nemo_backend
import nemo.backends.pytorch.common.losses
import nemo.collections.nlp as nemo_nlp
import nemo.core as nemo_core
from nemo import logging
from nemo.backends.pytorch.common import EncoderRNN
from nemo.collections.nlp.callbacks.state_tracking_trade_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.collections.nlp.data.datasets.state_tracking_trade_dataset import MultiWOZDataDesc
from nemo.collections.nlp.data.datasets.multiwoz_dataset import MultiWOZDataDesc
from nemo.utils.lr_policies import get_lr_policy

parser = argparse.ArgumentParser(description='Dialog state tracking with TRADE model on MultiWOZ dataset')
Expand Down Expand Up @@ -97,9 +99,9 @@
teacher_forcing=args.teacher_forcing,
)

gate_loss_fn = nemo_nlp.nm.losses.CrossEntropyLoss3D(num_classes=len(data_desc.gating_dict))
ptr_loss_fn = nemo_nlp.nm.losses.TRADEMaskedCrossEntropy()
total_loss_fn = nemo_nlp.nm.losses.LossAggregatorNM(num_inputs=2)
gate_loss_fn = nemo_backend.losses.CrossEntropyLossNM(logits_dim=3)
ptr_loss_fn = nemo_nlp.nm.losses.MaskedXEntropyLoss()
total_loss_fn = nemo.backends.pytorch.common.losses.LossAggregatorNM(num_inputs=2)


def create_pipeline(num_samples, batch_size, num_gpus, input_dropout, data_prefix, is_training):
Expand Down Expand Up @@ -142,7 +144,7 @@ def create_pipeline(num_samples, batch_size, num_gpus, input_dropout, data_prefi
)

gate_loss = gate_loss_fn(logits=gate_outputs, labels=gate_labels)
ptr_loss = ptr_loss_fn(logits=point_outputs, targets=tgt_ids, loss_mask=tgt_lens)
ptr_loss = ptr_loss_fn(logits=point_outputs, labels=tgt_ids, length_mask=tgt_lens)
total_loss = total_loss_fn(loss_1=gate_loss, loss_2=ptr_loss)

if is_training:
Expand Down
8 changes: 3 additions & 5 deletions examples/nlp/glue_benchmark/glue_benchmark_with_bert.py
Expand Up @@ -68,7 +68,7 @@
import nemo.collections.nlp as nemo_nlp
import nemo.core as nemo_core
from nemo import logging
from nemo.backends.pytorch.common import CrossEntropyLoss, MSELoss
from nemo.backends.pytorch.common import CrossEntropyLossNM, MSELoss
from nemo.collections.nlp.callbacks.glue_benchmark_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.collections.nlp.data import NemoBertTokenizer, SentencePieceTokenizer
from nemo.collections.nlp.data.datasets.glue_benchmark_dataset import output_modes, processors
Expand Down Expand Up @@ -231,7 +231,7 @@
else:
model = nemo_nlp.nm.trainables.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model)
model.restore_from(args.bert_checkpoint)
logging.info(f"model resotred from {args.bert_checkpoint}")
logging.info(f"model restored from {args.bert_checkpoint}")

hidden_size = model.hidden_size

Expand All @@ -241,7 +241,7 @@
glue_loss = MSELoss()
else:
pooler = SequenceClassifier(hidden_size=hidden_size, num_classes=num_labels, log_softmax=False)
glue_loss = CrossEntropyLoss()
glue_loss = CrossEntropyLossNM()


def create_pipeline(
Expand All @@ -260,8 +260,6 @@ def create_pipeline(
processor=processor,
evaluate=evaluate,
batch_size=batch_size,
# num_workers=0,
# local_rank=local_rank,
tokenizer=tokenizer,
data_dir=args.data_dir,
max_seq_length=max_seq_length,
Expand Down
Expand Up @@ -23,7 +23,7 @@

import nemo.collections.nlp.nm.trainables.joint_intent_slot.joint_intent_slot_nm
from nemo import logging
from nemo.collections.nlp.data.datasets.joint_intent_slot_dataset import JointIntentSlotDataDesc
from nemo.collections.nlp.data.datasets.joint_intent_slot_dataset.data_descriptor import JointIntentSlotDataDesc

# Parsing arguments
parser = argparse.ArgumentParser(description='Joint-intent BERT')
Expand Down
Expand Up @@ -21,8 +21,8 @@

import nemo.collections.nlp as nemo_nlp
import nemo.collections.nlp.nm.trainables.joint_intent_slot.joint_intent_slot_nm
from nemo.collections.nlp.data.datasets.joint_intent_slot_dataset import JointIntentSlotDataDesc
from nemo.collections.nlp.utils.common_nlp_utils import read_intent_slot_outputs
from nemo.collections.nlp.data.datasets.joint_intent_slot_dataset.data_descriptor import JointIntentSlotDataDesc
from nemo.collections.nlp.data.datasets.joint_intent_slot_dataset.inference_utils import read_intent_slot_outputs

# Parsing arguments
parser = argparse.ArgumentParser(description='Joint-intent BERT')
Expand Down
Expand Up @@ -21,16 +21,18 @@
import numpy as np
from transformers import BertTokenizer

import nemo
import nemo.collections.nlp as nemo_nlp
import nemo.collections.nlp.nm.data_layers.joint_intent_slot_datalayer
import nemo.collections.nlp.nm.trainables.joint_intent_slot.joint_intent_slot_nm
from nemo import logging
from nemo.backends.pytorch.common.losses import CrossEntropyLossNM, LossAggregatorNM
from nemo.collections.nlp.callbacks.joint_intent_slot_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.collections.nlp.data.datasets.joint_intent_slot_dataset import JointIntentSlotDataDesc
from nemo.collections.nlp.data.datasets.joint_intent_slot_dataset.data_descriptor import JointIntentSlotDataDesc
from nemo.collections.nlp.nm.data_layers import BertJointIntentSlotDataLayer
from nemo.core import CheckpointCallback, SimpleLossLoggerCallback
from nemo.utils.lr_policies import get_lr_policy

# Parsing arguments
parser = argparse.ArgumentParser(description='Joint intent slot filling system with pretrained BERT')
parser = argparse.ArgumentParser(description='Joint intent detection and slot filling with pre-trained BERT')
parser.add_argument("--local_rank", default=None, type=int)
parser.add_argument("--batch_size", default=128, type=int)
parser.add_argument("--max_seq_length", default=50, type=int)
Expand Down Expand Up @@ -87,12 +89,10 @@
nemo_nlp.huggingface.BERT.list_pretrained_models()
"""
if args.bert_checkpoint and args.bert_config:
pretrained_bert_model = nemo.collections.nlp.nm.trainables.huggingface.BERT(config_filename=args.bert_config)
pretrained_bert_model = nemo_nlp.nm.trainables.huggingface.BERT(config_filename=args.bert_config)
pretrained_bert_model.restore_from(args.bert_checkpoint)
else:
pretrained_bert_model = nemo.collections.nlp.nm.trainables.huggingface.BERT(
pretrained_model_name=args.pretrained_bert_model
)
pretrained_bert_model = nemo_nlp.nm.trainables.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model)

hidden_size = pretrained_bert_model.hidden_size

Expand All @@ -101,31 +101,29 @@
)

# Create sentence classification loss on top
classifier = nemo.collections.nlp.nm.trainables.joint_intent_slot.joint_intent_slot_nm.JointIntentSlotClassifier(
classifier = nemo_nlp.nm.trainables.JointIntentSlotClassifier(
hidden_size=hidden_size, num_intents=data_desc.num_intents, num_slots=data_desc.num_slots, dropout=args.fc_dropout
)

if args.class_balancing == 'weighted_loss':
# Using weighted loss will enable weighted loss for both intents and slots
# Use the intent_loss_weight hyperparameter to adjust intent loss to
# prevent overfitting or underfitting.
loss_fn = nemo_nlp.nm.losses.JointIntentSlotLoss(
num_slots=data_desc.num_slots,
slot_classes_loss_weights=data_desc.slot_weights,
intent_classes_loss_weights=data_desc.intent_weights,
intent_loss_weight=args.intent_loss_weight,
)
# To tackle imbalanced classes, you may use weighted loss
intent_loss_fn = CrossEntropyLossNM(logits_dim=2, weight=data_desc.intent_weights)
slot_loss_fn = CrossEntropyLossNM(logits_dim=3, weight=data_desc.intent_weights)

else:
loss_fn = nemo_nlp.nm.losses.JointIntentSlotLoss(num_slots=data_desc.num_slots)
intent_loss_fn = CrossEntropyLossNM(logits_dim=2)
slot_loss_fn = CrossEntropyLossNM(logits_dim=3)

total_loss_fn = LossAggregatorNM(num_inputs=2, weights=[args.intent_loss_weight, 1.0 - args.intent_loss_weight])


def create_pipeline(num_samples=-1, batch_size=32, num_gpus=1, local_rank=0, mode='train'):
def create_pipeline(num_samples=-1, batch_size=32, num_gpus=1, mode='train'):
logging.info(f"Loading {mode} data...")
data_file = f'{data_desc.data_dir}/{mode}.tsv'
slot_file = f'{data_desc.data_dir}/{mode}_slots.tsv'
shuffle = args.shuffle_data if mode == 'train' else False

data_layer = nemo.collections.nlp.nm.data_layers.joint_intent_slot_datalayer.BertJointIntentSlotDataLayer(
data_layer = BertJointIntentSlotDataLayer(
input_file=data_file,
slot_file=slot_file,
pad_label=data_desc.pad_label,
Expand Down Expand Up @@ -155,35 +153,27 @@ def create_pipeline(num_samples=-1, batch_size=32, num_gpus=1, local_rank=0, mod

intent_logits, slot_logits = classifier(hidden_states=hidden_states)

loss = loss_fn(
intent_logits=intent_logits, slot_logits=slot_logits, loss_mask=loss_mask, intents=intents, slots=slots
)
intent_loss = intent_loss_fn(logits=intent_logits, labels=intents)
slot_loss = slot_loss_fn(logits=slot_logits, labels=slots, loss_mask=loss_mask)
total_loss = total_loss_fn(loss_1=intent_loss, loss_2=slot_loss)

if mode == 'train':
tensors_to_evaluate = [loss, intent_logits, slot_logits]
tensors_to_evaluate = [total_loss, intent_logits, slot_logits]
else:
tensors_to_evaluate = [intent_logits, slot_logits, intents, slots, subtokens_mask]

return tensors_to_evaluate, loss, steps_per_epoch, data_layer
return tensors_to_evaluate, total_loss, steps_per_epoch, data_layer


train_tensors, train_loss, steps_per_epoch, _ = create_pipeline(
args.num_train_samples,
batch_size=args.batch_size,
num_gpus=args.num_gpus,
local_rank=args.local_rank,
mode=args.train_file_prefix,
args.num_train_samples, batch_size=args.batch_size, num_gpus=args.num_gpus, mode=args.train_file_prefix,
)
eval_tensors, _, _, data_layer = create_pipeline(
args.num_eval_samples,
batch_size=args.batch_size,
num_gpus=args.num_gpus,
local_rank=args.local_rank,
mode=args.eval_file_prefix,
args.num_eval_samples, batch_size=args.batch_size, num_gpus=args.num_gpus, mode=args.eval_file_prefix,
)

# Create callbacks for train and eval modes
train_callback = nemo.core.SimpleLossLoggerCallback(
train_callback = SimpleLossLoggerCallback(
tensors=train_tensors,
print_func=lambda x: str(np.round(x[0].item(), 3)),
tb_writer=nf.tb_writer,
Expand All @@ -200,7 +190,7 @@ def create_pipeline(num_samples=-1, batch_size=32, num_gpus=1, local_rank=0, mod
)

# Create callback to save checkpoints
ckpt_callback = nemo.core.CheckpointCallback(
ckpt_callback = CheckpointCallback(
folder=nf.checkpoint_dir, epoch_freq=args.save_epoch_freq, step_freq=args.save_step_freq
)

Expand Down