Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InfoLM #915

Merged
merged 55 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
5fae18c
wip
stancld Feb 21, 2022
3f3170d
Finish IM class
stancld Feb 22, 2022
a5aca2a
[WIP] Add some other parts of code
stancld Feb 22, 2022
32d44e5
Add a bunch of code
stancld Feb 22, 2022
59c7044
First full pass working
stancld Feb 23, 2022
539aa71
Merge branch 'master' into metric/InfoLM
stancld Feb 23, 2022
598d83b
Merge branch 'master' into metric/InfoLM
stancld Feb 24, 2022
83e907c
Merge branch 'master' into metric/InfoLM
stancld Mar 1, 2022
a25cda6
Merge branch 'master' into metric/InfoLM
stancld Mar 3, 2022
8e0cd59
Add some docs
stancld Mar 3, 2022
d73e17b
Merge branch 'master' into metric/InfoLM
stancld Mar 25, 2022
4210e6b
Merge branch 'master' into metric/InfoLM
stancld Mar 25, 2022
b9b936c
Merge branch 'master' into metric/InfoLM
stancld Apr 2, 2022
29d5e2e
Merge branch 'master' into metric/InfoLM
stancld Apr 9, 2022
cb2c3dc
Add InfoLM module metric and fix some minor issue
stancld Apr 9, 2022
6dd80df
Fix a device attribute setter
stancld Apr 9, 2022
107648f
Fix mypy issues
stancld Apr 9, 2022
eae6c71
Fix doctest
stancld Apr 9, 2022
5b80dcc
Fix InfoLM doc references
stancld Apr 9, 2022
6792e4e
Fix an intendation in a docstring
stancld Apr 9, 2022
9bdfc08
Update chlog
stancld Apr 9, 2022
756ced2
Merge branch 'master' into metric/InfoLM
stancld Apr 20, 2022
c9831ea
Merge branch 'master' into metric/InfoLM
stancld Apr 21, 2022
bb1394c
Sketch the tests
stancld Apr 21, 2022
fa4c221
Uncoment test parameters
stancld Apr 21, 2022
122556d
Merge branch 'master' into metric/InfoLM
stancld May 16, 2022
0f349cf
Merge branch 'master' into metric/InfoLM
stancld Jun 7, 2022
53f4de8
Merge branch 'master' into metric/InfoLM
Borda Jun 8, 2022
ea789dc
Merge branch 'master' into metric/InfoLM
stancld Jun 24, 2022
243ed84
Fix a minor bug, specify docstring + fix max_len for the test
stancld Jun 24, 2022
099cda8
Fix a conflict in CHANGELOG.md
stancld Jun 24, 2022
90fe8e1
Fix AB & KL divergence measures
stancld Jun 24, 2022
20a55f6
Add some missing part
stancld Jun 24, 2022
5807858
Fix functional metric tests and some minor things
stancld Jun 24, 2022
bbc4030
Update class test
stancld Jun 24, 2022
fd27050
Fix a link for Fisher-Rao distance
stancld Jun 24, 2022
14b0ee8
Fix doctest & Handle nan_to_num for torch<=1.8
stancld Jun 24, 2022
6a509f0
Use different link for Fisher-rao distance
stancld Jun 24, 2022
66bd614
Drop use_cache kwarg
stancld Jun 24, 2022
41a0e64
Use dim_zero_cat instead of torch.cat in class metric
stancld Jun 24, 2022
cd25121
Set num_threads=1 when dist_sync_on_step=True
stancld Jun 25, 2022
256773c
Drop enforcing num_threads=1 and add linkk to repo for generating tes…
stancld Jun 25, 2022
6569b6f
Merge branch 'master' into metric/InfoLM
Borda Jun 27, 2022
041fad3
Apply suggestions from code review
SkafteNicki Jun 30, 2022
4b8435f
Merge branch 'master' into metric/InfoLM
Borda Jun 30, 2022
b363292
Merge branch 'master' into metric/InfoLM
stancld Jul 7, 2022
dd3098e
Replace _TRANSFORMERS_AUTO_AVAILABLE with _TRANSFORMERS_AUTO_AVAILABL…
stancld Jul 8, 2022
b3af4f9
Fix device placement and Set num_threads=0 as default (SkafteNicki's …
stancld Jul 8, 2022
f46cb67
Make testing conditional on dependency and connection
stancld Jul 8, 2022
4a279e8
Merge branch 'master' into metric/InfoLM
stancld Jul 11, 2022
fd2fe5f
Fix another device placement
stancld Jul 11, 2022
48c408c
Fix the last device placement
stancld Jul 11, 2022
e402fae
Merge branch 'master' into metric/InfoLM
stancld Jul 11, 2022
c0790b8
Merge branch 'master' into metric/InfoLM
stancld Jul 11, 2022
5a8f6f0
Update src/torchmetrics/functional/__init__.py
SkafteNicki Jul 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added


- Added a new NLP metric `InfoLM` ([#915](https://github.com/PyTorchLightning/metrics/pull/915))


- Added `Perplexity` metric ([#922](https://github.com/PyTorchLightning/metrics/pull/922))


Expand Down
6 changes: 6 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,9 @@
.. _MER: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf
.. _WIL: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf
.. _WIP: https://infoscience.epfl.ch/record/82766
.. _InfoLM: https://arxiv.org/pdf/2112.01589.pdf
.. _alpha divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf
.. _beta divergence: https://www.sciencedirect.com/science/article/pii/S0047259X08000456
.. _AB divergence: https://pdfs.semanticscholar.org/744b/1166de34cb099100f151f3b1459f141ae25b.pdf
.. _Rényi divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf
.. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric
21 changes: 21 additions & 0 deletions docs/source/text/infolm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: InfoLM
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/summarization.svg
:tags: Text

.. include:: ../links.rst

######
InfoLM
######

Module Interface
________________

.. autoclass:: torchmetrics.text.infolm.InfoLM
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.text.infolm.infolm
6 changes: 4 additions & 2 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@
from torchmetrics.functional.text.wer import word_error_rate
from torchmetrics.functional.text.wil import word_information_lost
from torchmetrics.functional.text.wip import word_information_preserved
from torchmetrics.utilities.imports import _TRANSFORMERS_AUTO_AVAILABLE
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE

if _TRANSFORMERS_AUTO_AVAILABLE:
if _TRANSFORMERS_AVAILABLE:
from torchmetrics.functional.text.bert import bert_score # noqa: F401
if _TRANSFORMERS_AVAILABLE:
from torchmetrics.functional.text.infolm import infolm # noqa: F401
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

__all__ = [
"accuracy",
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from torchmetrics.functional.text.wer import word_error_rate # noqa: F401
from torchmetrics.functional.text.wil import word_information_lost # noqa: F401
from torchmetrics.functional.text.wip import word_information_preserved # noqa: F401
from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE
from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _TRANSFORMERS_AVAILABLE

if _TRANSFORMERS_AUTO_AVAILABLE:
if _TRANSFORMERS_AVAILABLE:
from torchmetrics.functional.text.bert import bert_score # noqa: F401
from torchmetrics.functional.text.infolm import infolm # noqa: F401
stancld marked this conversation as resolved.
Show resolved Hide resolved

if _NLTK_AVAILABLE:
from torchmetrics.functional.text.rouge import rouge_score # noqa: F401
238 changes: 17 additions & 221 deletions src/torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,240 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import math
import urllib
from collections import Counter, defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from warnings import warn

import torch
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset

from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AUTO_AVAILABLE

if _TRANSFORMERS_AUTO_AVAILABLE:
from transformers.models.auto import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader

from torchmetrics.functional.text.helper_embedding_metric import (
TextDataset,
TokenizedDataset,
_check_shape_of_model_output,
_get_progress_bar,
_input_data_collator,
_output_data_collator,
_process_attention_mask_for_special_tokens,
)
from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AVAILABLE

if _TRANSFORMERS_AVAILABLE:
from transformers import AutoModel, AutoTokenizer
else:
__doctest_skip__ = ["bert_score"]

if _TQDM_AVAILABLE:
import tqdm


# Default model recommended in the original implementation.
_DEFAULT_MODEL = "roberta-large"


def _preprocess_text(
text: List[str],
tokenizer: Any,
max_length: int = 512,
truncation: bool = True,
sort_according_length: bool = True,
own_tokenizer: bool = False,
) -> Dict[str, Tensor]:
"""Default text pre-processing function using `transformers` `AutoTokenizer` instance.

Args:
text: An iterable of sentences.
tokenizer: Either ``AutoTokenizer`` instance from ``transformers`` package, or a user's own tokenizer.
max_length: A maximum sequence length.
truncation:
An indication of whether tokenized sequences should be padded only to the length of the longest sequence.
sort_according_length:
An indication of whether tokenized sequences should be sorted from shortest to longest. This is appropriate
to do for leveraging dynamic padding during embedding calculation and thereby to hasten inference.
own_tokenizer: An indication of whether a non-default user's own tokenizer is used.

Return:
A dictionary of tokenized sentences including ``input_ids`` and ``attention_mask``.

Raises:
BaseException:
If a tokenization with a user's own tokenizer is not successful.
"""
if not own_tokenizer:
tokenized_data = tokenizer(
text, padding="max_length", max_length=max_length, truncation=truncation, return_tensors="pt"
)
else:
try:
tokenized_data = tokenizer(text, max_length)
except BaseException as ex:
raise BaseException(f"Tokenization was not successful: {ex}")

input_ids, attention_mask = (
_sort_data_according_length(tokenized_data["input_ids"], tokenized_data["attention_mask"])
if sort_according_length
else (tokenized_data["input_ids"], tokenized_data["attention_mask"])
)
return {"input_ids": input_ids, "attention_mask": attention_mask}


def _process_attention_mask_for_special_tokens(attention_mask: Tensor) -> Tensor:
"""Process attention mask to be zero for special [CLS] and [SEP] tokens as they're not included in a
calculation for BERT score.

Args:
attention_mask: An attention mask to be returned, for example, by a ``transformers`` tokenizer.

Return:
A processed attention mask.
"""
# Make attention_mask zero for [CLS] token
attention_mask[:, 0] = 0
# Make attention_mask zero for [SEP] token
sep_token_position = (attention_mask - 0.1).cumsum(-1).argmax(-1)
attention_mask[torch.arange(attention_mask.size(0)).long(), sep_token_position] = 0
return attention_mask


def _sort_data_according_length(input_ids: Tensor, attention_mask: Tensor) -> Tuple[Tensor, Tensor]:
"""Sort tokenized sentence from the shortest to the longest one."""
sorted_indices = attention_mask.sum(1).argsort()
input_ids = input_ids[sorted_indices]
attention_mask = attention_mask[sorted_indices]
return input_ids, attention_mask


def _input_data_collator(
batch: Dict[str, Tensor], device: Optional[Union[str, torch.device]] = None
) -> Dict[str, Tensor]:
"""Helper function that trims model inputs to the longest sequence within the batch and put the input on the
proper device."""
max_len = int(batch["attention_mask"].sum(1).max().item())
input_ids = batch["input_ids"][:, :max_len].to(device)
attention_mask = batch["attention_mask"][:, :max_len].to(device)
batch.update({"input_ids": input_ids, "attention_mask": attention_mask})
return batch


def _output_data_collator(model_output: Tensor, attention_mask: Tensor, target_len: int) -> Tuple[Tensor, Tensor]:
"""Helper function that pads the model output and attention mask to the target length."""
zeros_shape = list(model_output.shape)
zeros_shape[2] = target_len - zeros_shape[2]
model_output = torch.cat(
[model_output, torch.zeros(zeros_shape, dtype=model_output.dtype).to(model_output.device)], dim=2
)
zeros = torch.zeros(zeros_shape[0], zeros_shape[2], dtype=attention_mask.dtype).to(attention_mask.device)
attention_mask = torch.cat([attention_mask, zeros], dim=1)
return model_output, attention_mask


class TextDataset(Dataset):
"""PyTorch dataset class for storing tokenized sentences and other properties used for BERT score
calculation."""

def __init__(
self,
text: List[str],
tokenizer: Any,
max_length: int = 512,
preprocess_text_fn: Callable[[List[str], Any, int], Dict[str, Tensor]] = _preprocess_text,
idf: bool = False,
tokens_idf: Optional[Dict[int, float]] = None,
) -> None:
"""
Args:
text: An iterable of sentences.
tokenizer: ``AutoTokenizer`` instance from ``transformers`` package.
max_length: A maximum sequence length.
preprocess_text_fn: A function used for processing the input sentences.
idf: An indication of whether calculate token inverse document frequencies to weight the model embeddings.
tokens_idf: Inverse document frequencies (these should be calculated on reference sentences).
"""
self.text = preprocess_text_fn(text, tokenizer, max_length)
self.max_length = self.text["input_ids"].shape[1]
self.num_sentences = len(text)
self.idf = idf
self.tokens_idf = {}
if idf:
self.tokens_idf = tokens_idf if tokens_idf is not None else self._get_tokens_idf()

def __getitem__(self, idx: int) -> Dict[str, Tensor]:
input_ids = self.text["input_ids"][idx, :]
attention_mask = self.text["attention_mask"][idx, :]
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
if self.idf:
input_ids_idf = torch.tensor([self.tokens_idf[input_idx] for input_idx in input_ids.tolist()])
inputs_dict["input_ids_idf"] = input_ids_idf
return inputs_dict

def __len__(self) -> int:
return self.num_sentences

def _get_tokens_idf(self) -> Dict[int, float]:
"""Calculate token inverse document frequencies.

Return:
A python dictionary containing inverse document frequencies for token ids.
"""
token_counter: Counter = Counter()
for tokens in map(self._set_of_tokens, self.text["input_ids"]):
token_counter.update(tokens)

tokens_idf: Dict[int, float] = defaultdict(self._get_tokens_idf_default_value)
tokens_idf.update(
{idx: math.log((self.num_sentences + 1) / (occurrence + 1)) for idx, occurrence in token_counter.items()}
)
return tokens_idf

def _get_tokens_idf_default_value(self) -> float:
"""Helper function that ensures ``defaultdict`` to be pickled."""
return math.log((self.num_sentences + 1) / 1)

@staticmethod
def _set_of_tokens(input_ids: Tensor) -> Set:
"""Return set of tokens from the ``input_ids``."""
return set(input_ids.tolist())


class TokenizedDataset(TextDataset):
"""The child class of ``TextDataset`` class used with already tokenized data."""

def __init__(
self,
input_ids: Tensor,
attention_mask: Tensor,
idf: bool = False,
tokens_idf: Optional[Dict[int, float]] = None,
) -> None:
"""
Args:
input_ids: Input ids.
attention_mask: Attention mask.
idf: An indication of whether calculate token inverse document frequencies to weight the model embeddings.
tokens_idf: Inverse document frequencies (these should be calculated on reference sentences).
"""
self.text = dict(zip(["input_ids", "attention_mask"], _sort_data_according_length(input_ids, attention_mask)))
self.text = _input_data_collator(self.text)
self.num_sentences = len(self.text["input_ids"])
self.max_length = self.text["input_ids"].shape[1]
self.idf = idf
self.tokens_idf = {}
if idf:
self.tokens_idf = tokens_idf if tokens_idf is not None else self._get_tokens_idf()


def _get_progress_bar(dataloader: DataLoader, verbose: bool = False) -> Union[DataLoader, "tqdm.auto.tqdm"]:
"""Helper function returning either the dataloader itself when ``verbose = False``, or it wraps the dataloader with
``tqdm.auto.tqdm``, when ``verbose = True`` to display a progress bar during the embeddings calculation."""
return tqdm.auto.tqdm(dataloader) if verbose else dataloader


def _check_shape_of_model_output(output: Tensor, input_ids: Tensor) -> None:
"""Check if the shape of the user's own model output."""
bs, seq_len = input_ids.shape[:2]
invalid_out_shape = len(output.shape) != 3 or output.shape[0] != bs or output.shape[1] != seq_len
if invalid_out_shape:
raise ValueError(
"The model output must be `torch.Tensor` of a shape `[batch_size, seq_len, model_dim]` "
f"i.e. [{bs}, {seq_len}. , `model_dim`], but got {output.shape}."
)


def _get_embeddings_and_idf_scale(
dataloader: DataLoader,
target_len: int,
Expand Down Expand Up @@ -537,7 +333,7 @@ def bert_score(
)

if model is None:
if not _TRANSFORMERS_AUTO_AVAILABLE:
if not _TRANSFORMERS_AVAILABLE:
raise ModuleNotFoundError(
"`bert_score` metric with default models requires `transformers` package be installed."
" Either install with `pip install transformers>=4.0` or `pip install torchmetrics[text]`."
Expand Down