In [26]:
import nltk
import spacy
from typing import List, Tuple, Optional
from dataclasses import dataclass, field
from torch import Tensor
from typing import List, Tuple

def create_mask(a: List[Tuple[int, int]], b: List[Tuple[int, int]]) -> List[bool]:
    """
    Creates a mask over `a` such that the `a[i]` range is contained in any `b[j]` range.
    Only keeps the last `True` if a consecutive `True` occurs.
    """
    mask = [False]
    j = 0
    for i in range(len(a)):
        while j < len(b) and b[j][1] < a[i][0]:
            j += 1
        mask.append(j < len(b) and b[j][0] <= a[i][0] and b[j][1] >= a[i][1])
        
    for i in range(1, len(mask)):
        if mask[i] and mask[i-1]:
            mask[i-1] = False
    mask.append(False)
    return mask

@dataclass
class SemanticPositions:
    string : str = field(init=True,repr=True)
    NVA_words : List[Tuple[int,int]] = field(init=True,repr=True)
    tokens : Optional[Tensor] = field(init=True,repr=True)
    tokensStrings :Optional[List[str]] = field(init=True,repr=True)
    mask : Optional[List[bool]] = field(init=True,repr=True)

def Mask(text:str,tokenizer=None):
  """Extract the start and end positions of verbs, nouns, and adjectives in the given text."""
  # Load the spaCy English model
  nlp = spacy.load('en_core_web_sm')
  
  # Parse the text with spaCy
  doc = nlp(text)
  
  # Create a dictionary to store the start and end positions of each POS tag
  # pos_dict = {
  #   'VERB': [],
  #   'NOUN': [],
  #   'ADJ': []
  # }
  pos_list: List[Tuple[int,int]] = []
  
  # Loop through each token in the parsed text
  for token in doc:
    if token.pos_ in ['VERB','NOUN','ADJ']:
      # If the token's POS tag is a verb, noun, or adjective, add its start and end positions to the dictionary
      # pos_dict[token.pos_]
      pos_list.append((token.idx, token.idx + len(token)))

  pos_list.sort()

  if tokenizer is not None:
    # tokenize the input string
    tokens : List[str] = [t[2:] if t[0:2] == "##" else t for t in tokenizer.tokenize(text)]

    start_index = 0
    token_mask = [False for _ in tokens]
    token_spans = []
    last_span = 0
    for token in tokens:
        token_start_index = text.find(token, start_index)
        token_end_index = token_start_index + len(token)
        span = (token_start_index, token_end_index)
        token_spans.append(span)
        start_index = token_end_index
    
    m = create_mask(token_spans,pos_list)
    return SemanticPositions(string=text,
                             NVA_words=pos_list,
                             tokensStrings=tokens,
                             tokens=tokenizer(text, return_tensors='pt'),
                             mask=m)

  return SemanticPositions(string=text,
                             NVA_words=pos_list,
                             tokens=None,
                             tokensStrings=None,
                             mask=None)

In [27]:

from transformers import BertForMaskedLM, BertTokenizer # type: ignore

In [28]:
result = Mask("tokenization is the art of making my life miserable!",BertTokenizer.from_pretrained('bert-base-cased'))
result

SemanticPositions(string='tokenization is the art of making my life miserable!', NVA_words=[(0, 12), (20, 23), (27, 33), (37, 41), (42, 51)], tokens={'input_ids': tensor([[  101, 22559,  2734,  1110,  1103,  1893,  1104,  1543,  1139,  1297,
         14531,   106,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}, tokensStrings=['token', 'ization', 'is', 'the', 'art', 'of', 'making', 'my', 'life', 'miserable', '!'], mask=[False, False, True, False, False, True, False, True, False, False, True, False, False])

In [29]:
print(result.string)
print(result.NVA_words)
print(result.tokens['input_ids'])
print(result.tokensStrings)
print(result.mask)

tokenization is the art of making my life miserable!
[(0, 12), (20, 23), (27, 33), (37, 41), (42, 51)]
tensor([[  101, 22559,  2734,  1110,  1103,  1893,  1104,  1543,  1139,  1297,
         14531,   106,   102]])
['token', 'ization', 'is', 'the', 'art', 'of', 'making', 'my', 'life', 'miserable', '!']
[False, False, True, False, False, True, False, True, False, False, True, False, False]


In [30]:
len([101, 22559,  2734,  1110,  1103,  1893,  1104,  1543,  1139,  1297,
         14531,   106,   102])

13

In [31]:
len(result.mask)

13