In [51]:
import torch
from torch.nn.functional import softmax
from transformers import AutoTokenizer, LogitsProcessor, LlamaForCausalLM, GenerationConfig

In [52]:
model_name = 'meta-llama/Llama-3.2-1B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained(model_name)

In [97]:
class MyLogitsProcessor(LogitsProcessor):
    def __init__(self):
        self.allowed_queries = ["""MATCH (x1:Disease {name: "cold"})-[r1:PARENT_CHILD]-(x2:Disease) RETURN""",
                                """MATCH (x1:Drug {name: "Ozempic"})-[r1:ENZYME]-(x2:PathOrEnzyme) RETURN"""]
        self.allowed_queries_ids = [tokenizer(query, add_special_tokens=False) for query in self.allowed_queries]
        self.starting_ids = [tokenizer("MATCH") ]
    
    def __call__(self, input_ids, scores):
        texts = tokenizer.batch_decode(input_ids)
        token_probs = torch.zeros_like(scores)
        for i, text in enumerate(texts):
            gen_text = "MATCH" + text.split("MATCH")[-1]
            allowed_ids = self.allowed_next_ids(gen_text)
            if len(allowed_ids) == 0:
                allowed_ids = [tokenizer.eos_token_id]
            token_probs[i, allowed_ids] = softmax(scores[i, allowed_ids], dim=-1)
        return token_probs
    
    def allowed_next_ids(self, gen_text):
            token_ids = set()
            for query in self.allowed_queries:
                if query.startswith(gen_text):
                    allowed_continuation = query[len(gen_text):]
                    if allowed_continuation == "":
                        continue
                    token_ids.add(tokenizer(allowed_continuation, add_special_tokens=False).input_ids[0])
            token_ids = list(token_ids)
            return token_ids

class MyLogitsProcessorList(list):
    def __iter__(self):
        for _ in range(1):
            yield MyLogitsProcessor()#self.logits_processor
    
    def __len__(self):
        return 1

In [98]:
generation_config = GenerationConfig(do_sample=False, num_beams=5, max_length=100, early_stopping=True, num_return_sequences=2)
logits_processor_list = MyLogitsProcessorList()


input_ids = tokenizer("MATCH", return_tensors="pt")
output_ids = model.generate(**input_ids, generation_config=generation_config, logits_processor=logits_processor_list)
output_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
print(output_texts)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


['MATCH (x1:Disease {name: "cold"})-[r1:PARENT_CHILD]-(x2:Disease) RETURN', 'MATCH (x1:.']


In [82]:
tokenizer(','), tokenizer('&'), tokenizer('.'), tokenizer('4')

({'input_ids': [128000, 11], 'attention_mask': [1, 1]},
 {'input_ids': [128000, 5], 'attention_mask': [1, 1]},
 {'input_ids': [128000, 13], 'attention_mask': [1, 1]},
 {'input_ids': [128000, 19], 'attention_mask': [1, 1]})