-
Notifications
You must be signed in to change notification settings - Fork 426
/
Copy pathbert_score-own_model.py
131 lines (99 loc) · 4.55 KB
/
bert_score-own_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example of how to use BERTScore with a user's defined/own model and tokenizer.
To run: python bert_score-own_model.py
"""
from pprint import pprint
from typing import Union
import torch
from torch import Tensor, nn
from torch.nn import Module
from torchmetrics.text.bert import BERTScore
_NUM_LAYERS = 2
_MODEL_DIM = 4
_NHEAD = 2
_MAX_LEN = 6
class UserTokenizer:
"""The `UserTokenizer` class is required to be defined when a non-default model is used.
The user's defined tokenizer is expected to return either token IDs or token embeddings that are fed into the model.
The tokenizer vocabulary should contain some special tokens, such as a `<pad>` token so that a tokenization will run
successfully in batches.
"""
CLS_TOKEN = "<cls>" # noqa: S105
SEP_TOKEN = "<sep>" # noqa: S105
PAD_TOKEN = "<pad>" # noqa: S105
def __init__(self) -> None:
self.word2vec = {
"hello": 0.5 * torch.ones(1, _MODEL_DIM),
"world": -0.5 * torch.ones(1, _MODEL_DIM),
self.CLS_TOKEN: torch.zeros(1, _MODEL_DIM),
self.SEP_TOKEN: torch.zeros(1, _MODEL_DIM),
self.PAD_TOKEN: torch.zeros(1, _MODEL_DIM),
}
def __call__(self, sentences: Union[str, list[str]], max_len: int = _MAX_LEN) -> dict[str, Tensor]:
"""Call method to tokenize user input.
The `__call__` method must be defined for this class. To ensure the functionality, the `__call__` method
should obey the input/output arguments structure described below.
Args:
sentences:
Input text. `Union[str, List[str]]`
max_len:
Maximum length of pre-processed text. `int`
Return:
Python dictionary containing the keys `input_ids` and `attention_mask` with corresponding values.
"""
output_dict: dict[str, Tensor] = {}
if isinstance(sentences, str):
sentences = [sentences]
# Add special tokens
sentences = [" ".join([self.CLS_TOKEN, sentence, self.SEP_TOKEN]) for sentence in sentences]
# Tokennize sentence
tokenized_sentences = [
sentence.lower().split()[:max_len] + [self.PAD_TOKEN] * (max_len - len(sentence.lower().split()))
for sentence in sentences
]
output_dict["input_ids"] = torch.cat([
torch.cat([self.word2vec[word] for word in sentence]).unsqueeze(0) for sentence in tokenized_sentences
])
output_dict["attention_mask"] = torch.cat([
torch.tensor([1 if word != self.PAD_TOKEN else 0 for word in sentence]).unsqueeze(0)
for sentence in tokenized_sentences
]).long()
return output_dict
def get_user_model_encoder(num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_DIM, nhead: int = _NHEAD) -> Module:
"""Initialize the Transformer encoder."""
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
return nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def user_forward_fn(model: Module, batch: dict[str, Tensor]) -> Tensor:
"""User forward function used for the computation of model embeddings.
This function might be arbitrarily complicated inside. However, to ensure functionality, it should obey the
input/output argument structure described below.
Args:
model: a torch.nn.module that implements a forward pass
batch: a batch of inputs to pass through the model
Return:
The model output.
"""
return model(batch["input_ids"])
_PREDS = ["hello", "hello world", "world world world"]
_REFS = ["hello", "hello hello", "hello world hello"]
if __name__ == "__main__":
tokenizer = UserTokenizer()
model = get_user_model_encoder()
bs = BERTScore(
model=model, user_tokenizer=tokenizer, user_forward_fn=user_forward_fn, max_length=_MAX_LEN, return_hash=False
)
bs.update(_PREDS, _REFS)
print(f"Predictions:\n {bs.preds_input_ids}\n {bs.preds_attention_mask}")
pprint(bs.compute())