In [1]:
from transformers import Constraint
from typing import List, Optional

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class DocTrie:
    def __init__(self, token_ids: List[int], max_heght: int = 2):
        r"""
        A helper class that builds a `max_heght`-level trie with the words represented in `token_ids`.
        """
        self.max_height = max_heght
        self.total_root_of_subtree = set(token_ids)

        # level 0
        root = {}
        
        # level 1
        for root_of_subtree in self.total_root_of_subtree:
            root[root_of_subtree] = {}
        
        # level 2
        for idx, token_id in enumerate(token_ids[:-1]):
            root[token_id][token_ids[idx+1]] = {} 

        self.trie = root
        
    def get_next_toekn_ids(self, current_seq: List[int]):
        print("\tenter get_next_toekn_ids")
        
        if len(current_seq)==0:
            print("\t\tenter if")
            next_token_id = list(self.total_root_of_subtree)
            print(f"next_token_id = {next_token_id}")
            print("\t\texit if")
        else:
            print("\t\tenter else")
            token_id = current_seq[-1]
            sub_trie = self.trie[token_id]
            
            next_token_id = list(sub_trie.keys())
            print("\t\texit else")
        
        print("\texit get_next_toekn_ids")
        return next_token_id

In [3]:
class DocConstraint(Constraint):
    r"""
    
    [`Constraint`] enforcing to generating a substring in a given string.
    
    Args:
        string_ids (`List[int]`):

    """
    
    def __init__(self, token_ids: List[int]):
        
        if not isinstance(token_ids, list) or len(token_ids) == 0:
            raise ValueError(f"`token_ids` has to be a non-empty list, but is {token_ids}.")
        if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
            raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")
        
        self.token_ids = token_ids
        self.trie = DocTrie(token_ids)
        
        self.seqlen = len(self.token_ids)
        self.current_seq = []
        self.completed = False
        
        self.num_for_test = 0
            
    def advance(self):
        print("enter advance")
        
        next_token_ids = self.trie.get_next_toekn_ids(self.current_seq)
        
        if len(next_token_ids) == 0:
            next_token_ids = None
        
        print("exit advance")
        return next_token_ids
    
    def does_advance(self, token_id: int):
        print("enter does_advance")
        print(f"token_id = {token_id}")
        
        if not isinstance(token_id, int):
            raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
        
        next_token_ids = self.trie.get_next_toekn_ids(self.current_seq)
        
        print(f"next_token_ids = {next_token_ids}")
        print("exit does_advance")
        return token_id in next_token_ids
    
    def update(self, token_id: int):
        print("enter update")
        print(f"token_id = {token_id}")
        
        if not isinstance(token_id, int):
            raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
        
        stepped = False
        completed = False   # ??????????????????????????
        reset = False
       
        if self.does_advance(token_id):
            print("\tenter if")
            self.current_seq.append(token_id)
            stepped = True
            print("\texit if")
        else:
            print("\tenter else")
            reset = True
            self.reset()
            print("\texit else")
        
        print(f"*****num_for_test = {self.num_for_test}*****")
        self.num_for_test += 1
        print(f"*****num_for_test = {self.num_for_test}*****")
        if self.num_for_test==5:
            print("\tenter if num_for_test==5")
            completed = True
            print("\texit if num_for_test==5")
        
        print("exit update")
        return stepped, completed, reset
    
    def reset(self):
        print("enter reset")
        self.current_seq = []
        self.completed = False
        print("exit reset")
        
    def remaining(self):
        print("enter remaining")
        print("exit remaining")
        return self.seqlen
    
    def copy(self, stateful=False):
        print("enter copy")
        
        new_constraint = DocConstraint(self.token_ids)
        print(f"new_constraint = {new_constraint}")
        if stateful:
            print("\tenter if")
            new_constraint.seq_len = self.seqlen
            new_constraint.current_seq = self.current_seq
            new_constraint.completed = self.completed
            print("\texit if")

        print("exit copy")
        return new_constraint
    

In [4]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("./model/t5/")
model = AutoModelForSeq2SeqLM.from_pretrained("./model/t5/")

encoder_input_str = "translate English to German: How old are you?"

2024-02-26 10:44:27.170241: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-26 10:44:27.187881: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-26 10:44:27.187907: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-26 10:44:27.187918: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-26 10:44:27.191438: I tensorflow/core/platform/cpu_feature_g

In [5]:
constraints = [
    DocConstraint(
        tokenizer("100000000", add_special_tokens=False).input_ids
    ),
]

In [6]:
display(constraints)
for constraint in constraints:
    print(constraint.token_ids)

[<__main__.DocConstraint at 0x7f28c7303d90>]

[910, 2313, 2313]


In [7]:
constraints[0].trie.trie

{2313: {2313: {}}, 910: {2313: {}}}

In [8]:
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


outputs = model.generate(
    input_ids,
    
    # force_words_ids=tokenizer(["456789123"], add_special_tokens=False).input_ids,
    
    constraints=constraints,    # List[Constraint]
    num_beams=10,
    num_return_sequences=1,
    no_repeat_ngram_size=1,
    remove_invalid_values=True,
)


print("Output:\n" + 100 * '-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))



enter copy
new_constraint = <__main__.DocConstraint object at 0x7f27b74009d0>
exit copy
enter copy
new_constraint = <__main__.DocConstraint object at 0x7f27b7400e50>
exit copy
enter copy
new_constraint = <__main__.DocConstraint object at 0x7f27b7400580>
exit copy
enter copy
new_constraint = <__main__.DocConstraint object at 0x7f27b74006a0>
exit copy
enter copy
new_constraint = <__main__.DocConstraint object at 0x7f27b68002b0>
exit copy
enter copy
new_constraint = <__main__.DocConstraint object at 0x7f26fd08b3a0>
exit copy
enter copy
new_constraint = <__main__.DocConstraint object at 0x7f26fd08bb80>
exit copy
enter copy
new_constraint = <__main__.DocConstraint object at 0x7f26fd08ae90>
exit copy
enter copy
new_constraint = <__main__.DocConstraint object at 0x7f26fd08af20>
exit copy
enter copy
new_constraint = <__main__.DocConstraint object at 0x7f26fd08b0a0>
exit copy
enter copy
new_constraint = <__main__.DocConstraint object at 0x7f26fd08b130>
exit copy
enter copy
new_constraint = <__m

In [9]:
display(constraints)
for constraint in constraints:
    print(constraint.token_ids)

[<__main__.DocConstraint at 0x7ff74d336050>]

[910, 2313, 2313]


In [10]:
from transformers import ConstraintListState
ccopy = [ConstraintListState([constraint.copy() for constraint in constraints]) for _ in range(2)]
ccopy

[<transformers.generation.beam_constraints.ConstraintListState at 0x7ff740d4b910>,
 <transformers.generation.beam_constraints.ConstraintListState at 0x7ff740d4b400>]

In [11]:
ccopy[0]

<transformers.generation.beam_constraints.ConstraintListState at 0x7ff740d4b910>

In [12]:
ccopy[0].constraints

[<__main__.DocConstraint at 0x7ff740d4bd00>]

In [13]:
ccopy[0].pending_constraints

[<__main__.DocConstraint at 0x7ff740d4ba90>]

In [14]:
ccopy[0].inprogress_constraint == None

True

In [15]:
for constraint in ccopy[0].pending_constraints:  # "pending" == "unfulfilled yet"
    advance = constraint.advance()
    print(advance)

{2313, 910}


In [16]:
ccopy[0].advance()


{2313, 910}


In [17]:
import torch

In [18]:
type(ccopy[0].advance())


{2313, 910}


NoneType

In [19]:
torch.LongTensor(ccopy[0].advance())


{2313, 910}


TypeError: new(): data must be a sequence (got NoneType)

In [26]:
inprogress_constraint = None
pending_constraints = constraints

token_list = []
if inprogress_constraint is None:
    print("inprogress_constraint is None")
    display(pending_constraints)
    for constraint in pending_constraints:  # "pending" == "unfulfilled yet"
        print("enter for-loop")
        advance = constraint.advance()
        print(f"advance = {advance}")
        if isinstance(advance, int):
            token_list.append(advance)
        elif isinstance(advance, list):
            token_list.extend(advance)

        print("exit for-loop")
else:
    print("inprogress_constraint is not None")
    advance = inprogress_constraint.advance()
    if isinstance(advance, int):
        token_list.append(advance)
    elif isinstance(advance, list):
        token_list.extend(advance)

token_list

inprogress_constraint is None


[<__main__.DocConstraint at 0x7ff74d336050>]

enter for-loop
advance = {2313, 910}
exit for-loop


[]