Skip to content

Commit

Permalink
Update transformer version and add docs and code for using BERT as LM (
Browse files Browse the repository at this point in the history
…#98)

* Update transformer version and add docs and code for using BERT as LM

Add docs for using BERT to predict next sentence, and code and docs for next sentence prediction. Update transformer.md docs
  • Loading branch information
AmaliePauli committed Nov 25, 2020
1 parent bdf0dcc commit 093e57d
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 18 deletions.
67 changes: 64 additions & 3 deletions danlp/models/bert_models.py
Expand Up @@ -369,11 +369,11 @@ def __init__(self, cache_dir=DEFAULT_CACHE_DIR, verbose=False):
from transformers import BertTokenizer, BertModel
import torch
# download model
path= download_model('bert.botxo.pytorch', cache_dir, process_func=_unzip_process_func,verbose=verbose)
self.path_model= download_model('bert.botxo.pytorch', cache_dir, process_func=_unzip_process_func,verbose=verbose)
# Load pre-trained model tokenizer
self.tokenizer = BertTokenizer.from_pretrained(path)
self.tokenizer = BertTokenizer.from_pretrained(self.path_model)
# Load pre-trained model (weights)
self.model = BertModel.from_pretrained(path,
self.model = BertModel.from_pretrained(self.path_model,
output_hidden_states = True, # Whether the model returns all hidden-states.
)

Expand Down Expand Up @@ -433,6 +433,67 @@ def embed_text(self, text):

return token_vecs_cat, sentence_embedding, tokenized_text

class BertNextSent:
'''
BERT language model is trained for next sentence predictions.
The Model is trained by BotXO: https://github.com/botxo/nordic_bert
The Bert model is transformed into pytorch version
Credit for code example: https://stackoverflow.com/questions/55111360/using-bert-for-next-sentence-prediction
:param str cache_dir: the directory for storing cached models
:param bool verbose: `True` to increase verbosity
'''
def __init__(self, cache_dir=DEFAULT_CACHE_DIR, verbose=False):

from transformers import BertForNextSentencePrediction, BertTokenizer
# download model
self.path_model= download_model('bert.botxo.pytorch', cache_dir, process_func=_unzip_process_func,verbose=verbose)
# Load pre-trained model tokenizer
self.tokenizer = BertTokenizer.from_pretrained(self.path_model)
# Load pre-trained model (weights)
self.model = BertForNextSentencePrediction.from_pretrained(self.path_model,
output_hidden_states = True, # Whether the model returns all hidden-states.
)

def predict_if_next_sent(self, sent_A: str, sent_B: str):
"""
Calculate the probability that sentence B follows sentence A.
Credit for code example: https://stackoverflow.com/questions/55111360/using-bert-for-next-sentence-prediction
:param str sent_A: sentence A
:param str sent_B: sentence B
:return: the probability of sentence B following sentence A
:rtype: float
"""
from torch.nn.functional import softmax
# encoded as "one" input to the model by using 'sent_B' as the 'text_pair'
encoded = self.tokenizer.encode_plus(sent_A, text_pair=sent_B, return_tensors='pt')

# a model's output is a tuple, we only need the output tensor containing
# the relationships which is the first item in the tuple
seq_relationship_logits = self.model(**encoded)[0]

# we still need softmax to convert the logits into probabilities
# index 0: sequence B is a continuation of sequence A
# index 1: sequence B is a random sequence
probs = softmax(seq_relationship_logits, dim=1)

# return the pobability of sentence B following sentence A
return round(float(probs[0][0]),4)

def load_bert_nextsent_model(cache_dir=DEFAULT_CACHE_DIR, verbose=False):
"""
Load BERT language model used for next sentence predictions.
The Model is trained by BotXO: https://github.com/botxo/nordic_bert
:param str cache_dir: the directory for storing cached models
:param bool verbose: `True` to increase verbosity
:return: BERT NextSent model
"""

return BertNextSent(cache_dir, verbose)

def load_bert_base_model(cache_dir=DEFAULT_CACHE_DIR, verbose=False):
"""
Expand Down
89 changes: 83 additions & 6 deletions docs/docs/frameworks/transformers.md
Expand Up @@ -3,24 +3,96 @@ Transformers

BERT (Bidirectional Encoder Representations from Transformers) [(Devlin et al. 2019)](https://www.aclweb.org/anthology/N19-1423/) is a deep neural network model used in Natural Language Processing.

The BERT models provided with DaNLP are based on the pre-trained [Danish BERT](https://github.com/botxo/nordic_bert) representations by BotXO, and the models have been finetuned on differnet tasks using the [Transformers](https://github.com/huggingface/transformers) library from HuggingFace.
The BERT models provided with DaNLP are based on the pre-trained [Danish BERT](https://github.com/botxo/nordic_bert) representations by BotXO, and different models have been finetuned on different tasks using the [Transformers](https://github.com/huggingface/transformers) library from HuggingFace.

Through DaNLP, we provide fine-tuned BERT models for the following tasks:

* Named Entity Recognition
* Emotion detection
* Tone and polarity detection

BERT can also be used for embeddings of tokens or sentences just be using the pre-trained [Danish BERT](https://github.com/botxo/nordic_bert) representations from BotXO.
The pre-trained [Danish BERT](https://github.com/botxo/nordic_bert) from BotXO can also by used for the following task without any further finetuning:

- Embeddings of tokens or sentences
- Predict a mask word in a sentence
- Predict if a sentence naturally follows another sentence

Please note that the BERT models can take a maximum of 512 tokens as input at a time. For longer text sequences, you should split the text before hand -- for example by using sentence boundary detection (e.g. with the [spaCy model](spacy.md)).

See our [getting started guides](../gettingstarted/quickstart.md#bert) for examples on how to use the BERT models.
### Language model, embeddings and next sentence prediction

The BERT model [(Devlin et al. 2019)](https://www.aclweb.org/anthology/N19-1423/) is originally pretrained on two tasks. The first, is to predict a masked word in a sentence, and the second is to predict if a sentence follows another sentence. Therefore, the model can without any further finetuning be used for this two tasks.

A pytorch version of the [Danish BERT](https://github.com/botxo/nordic_bert) trained by BotXo can therefore be loaded with the DaNLP package and used through the [Transformers](https://github.com/huggingface/transformers) library.

For **predicting a masked word** in a sentence, you can after downloading the model through DaNLP, use the transformer library directly as described in the following snippet:

```python
from danlp.models import load_bert_base_model
# load the BERT model
model = load_bert_base_model()
# Use the transfomer libary built in function
LM = pipeline("fill-mask", model=model.path_model)
# Use the model as a language model to predict masked words in a sentence
LM(f"Jeg kan godt lide at spise {LM.tokenizer.mask_token}.")
# output is top five words in a list of dicts
"""
[{'sequence': '[CLS] jeg kan godt lide at spise her. [SEP]',
'score': 0.15520372986793518,
'token': 215,
'token_str': 'her'},
{'sequence': '[CLS] jeg kan godt lide at spise ude. [SEP]',
'score': 0.05564282834529877,
'token': 1500,
'token_str': 'ude'},
{'sequence': '[CLS] jeg kan godt lide at spise kød. [SEP]',
'score': 0.052283965051174164,
'token': 3000,
'token_str': 'kød'},
{'sequence': '[CLS] jeg kan godt lide at spise morgenmad. [SEP]',
'score': 0.051760803908109665,
'token': 4538,
'token_str': 'morgenmad'},
{'sequence': '[CLS] jeg kan godt lide at spise der. [SEP]',
'score': 0.049477532505989075,
'token': 59,
'token_str': 'der'}]
"""
```

The DaNLP package also provides some wrapper code for **next sentence prediction**:

```python
from danlp.models import load_bert_nextsent_model
model = load_bert_nextsent_model()

# the sentence is from a wikipedia article https://da.wikipedia.org/wiki/Uranus_(planet)
# Sentence B1 follows after sentence A, where sentence B2 is taken futher down in the article
sent_A= "Uranus er den syvende planet fra Solen i Solsystemet og var den første planet der blev opdaget i historisk tid."
sent_B1 =" William Herschel opdagede d. 13. marts 1781 en tåget klat, som han først troede var en fjern komet."
sent_B2= "Yderligere er magnetfeltets akse 59° forskudt for rotationsaksen og skærer ikke centrum."

# model returns the probability of sentence B follows rigth after sentence A
model.predict_if_next_sent(sent_A, sent_B1)
"""0.9895"""
model.predict_if_next_sent(sent_A, sent_B2)
"""0.0001"""
```

The wrapper function for **embeddings** of tokens or sentences can be read about in the [docs for embeddings](../tasks/embeddings.md).



### Named Entity Recognition

The BERT NER model has been finetuned on the [DaNE](../datasets.md#dane) dataset [(Hvingelby et al. 2020)](http://www.lrec-conf.org/proceedings/lrec2020/pdf/2020.lrec-1.565.pdf).
It can be loaded with the `load_bert_ner_model()` method.
The tagger recognizes the following tags:

- `PER`: person
- `ORG`: organization
- `LOC`: location

Read more about it in the [NER docs](../tasks/ner.md).

### Emotion detection

Expand All @@ -39,12 +111,17 @@ The model can detect the eight following emotions:

The model achieves an accuracy of 0.65 and a macro-f1 of 0.64 on the social media test set from DR's Facebook containing 999 examples. We do not have permission to distributing the data.

Read more about it in the [sentiment docs](../tasks/sentiment_analysis.md).

### Tone and polarity detection

The tone analyzer consists of two BERT classification models.
The first model detects the polarity of a sentence, i.e. whether it is perceived as `positive`, `neutral` or `negative`.
The second model detects the tone of a sentence, between `subjective` and `objective`.

The models are finetuned on manually annotated Twitter data from [Twitter Sentiment](../datasets.md#twitter-sentiment) (train part) and [EuroParl sentiment 2](../datasets.md#europarl-sentiment2)).
Both datasets can be loaded with the DaNLP package.

The first model detects the polarity of a sentence, i.e. whether it is perceived as `positive`, `neutral` or `negative`.
The second model detects the tone of a sentence, between `subjective` and `objective`.

Read more about it in the [sentiment docs](../tasks/sentiment_analysis.md).

2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -6,6 +6,6 @@ flair==0.4.5
pyconll==2.2.1
conllu==0.11
pandas==1.0.1
transformers==2.3.0
transformers==3.5.1
srsly==1.0.2
tweepy
27 changes: 19 additions & 8 deletions tests/test_bert_models.py
@@ -1,6 +1,6 @@
import unittest

from danlp.models import load_bert_emotion_model, load_bert_tone_model, load_bert_base_model, BertNer
from danlp.models import load_bert_emotion_model, load_bert_tone_model, load_bert_base_model, BertNer, load_bert_nextsent_model
from danlp.download import DEFAULT_CACHE_DIR, download_model, \
_unzip_process_func
from transformers import BertTokenizer, BertForSequenceClassification
Expand Down Expand Up @@ -53,13 +53,13 @@ def test_predictions(self):
class TestBertBase(unittest.TestCase):
def test_download(self):
# Download model beforehand
model = 'bert.botxo.pytorch'
model_path = download_model(model, DEFAULT_CACHE_DIR,
process_func=_unzip_process_func,
verbose=True)
# check if path to model excist
self.assertTrue(os.path.exists(model_path))
model = 'bert.botxo.pytorch'
model_path = download_model(model, DEFAULT_CACHE_DIR,
process_func=_unzip_process_func,
verbose=True)

# check if path to model excist
self.assertTrue(os.path.exists(model_path))

def test_embedding(self):
model = load_bert_base_model()
Expand All @@ -68,6 +68,17 @@ def test_embedding(self):
self.assertEqual(vecs_embedding[0].shape[0], 3072)
self.assertEqual(sentence_embedding.shape[0], 768)

class TestBertNextSent(unittest.TestCase):
def test_next_sent(self):
model = load_bert_nextsent_model()

sent_A= "Uranus er den syvende planet fra Solen i Solsystemet og var den første planet der blev opdaget i historisk tid."
sent_B1 =" William Herschel opdagede d. 13. marts 1781 en tåget klat, som han først troede var en fjern komet."
sent_B2= "Yderligere er magnetfeltets akse 59° forskudt for rotationsaksen og skærer ikke centrum."

self.assertTrue(model.predict_if_next_sent(sent_A, sent_B1) >0.75)
self.assertTrue(model.predict_if_next_sent(sent_A, sent_B2) <0.75)

class TestBertNer(unittest.TestCase):
def test_bert_tagger(self):
bert = BertNer()
Expand Down

0 comments on commit 093e57d

Please sign in to comment.