### Structured Outputs
**Structured Outputs** is a feature that ensures the model will always generate responses that adhere to your supplied Schema.
Can be regex, Json

#### [](http://)List of concepts in this notebook:
1. How LLM Understand the input
2. What is Tokenizer
3. How does LLM predict output text
5. How to predict next possible characters for given regex that follow it
6. How to restrict LLMs to generate only possible characters.
7. But LLM is trained on tokens, how to convert next character possible to next token possible


LLM is kind of a martix, that takes number as input and number as output, so we need to convert IO to numbers. But how can we do it?
1. Convert each character to a number (0-9,A-B,a-b)
2. Convert word to a number (Ice - 0, Cream - 1, Cone - 2)

**Character-level tokenization**:
1. Very long sequences (inefficient)
2. Loses common patterns/subwords
3. Model has to learn character combinations from scratch
<br> Example: "playing" would be 7 tokens [p,l,a,y,i,n,g] instead of maybe 2 tokens [play, ing] (And computation can go input length^2 -> impractical)

**Word-level tokenization**:
1. Huge vocabulary size (millions of words)
2. Can't handle unseen words (OOV problem)
3. Wastes space on rare words
4. No subword understanding
<br> Example: If "smartwatch" isn't in vocabulary but "smart" and "watch" are known words, the model can't understand it (And First and Last layer will be huge matrices of length of vocabulary -> high computation)


**Why BPE is Better**
<br><br>
**Adaptive Vocabulary**:
* Starts with characters and iteratively merges most frequent pairs
* Creates subword units that represent common patterns in the data
* Can represent both frequent and rare words effectively
<br><br>
**Balance between character and word level**:
* Common words stay as single tokens
* Rare words split into meaningful subwords
* Example: "uncommon" → ["un", "common"]
<br><br>
**Handles unseen words**:
* Can break down new words into known subwords
* Example: If model never saw "teleporting" but knows "tele" and "porting", it can still understand it
<br><br>
**Efficient sequence length**:
* Much shorter than character-level
* More informative than arbitrary splits
* Example: "playing" might be ["play", "ing"] (2 tokens) instead of 7 characters
<br><br>
etc..

[To check how it is done in python, please visit this page](https://huggingface.co/learn/nlp-course/en/chapter6/5)

In [1]:
# !pip install greenery

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
import numpy as np

In [3]:
import torch.nn.functional as F

In [4]:
MODEL_NAME = "HuggingFaceTB/SmolLM2-360M"

In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

tokenizer_config.json:   0%|          | 0.00/3.66k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/801k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.10M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/831 [00:00<?, ?B/s]

In [6]:
tokenizer.encode('Hi there, how are you?')

[26843, 665, 28, 638, 359, 346, 47]

In [7]:
tokenizer.decode([26843, 665, 28, 638, 359, 346, 47])

'Hi there, how are you?'

In [8]:
vocab = tokenizer.get_vocab()
vocab

{'ically': 947,
 'DEX': 31958,
 'ricia': 30769,
 'winning': 21664,
 'ossibility': 46578,
 'MES': 26826,
 'bred': 27099,
 'AW': 36025,
 'relationship': 36861,
 'requency': 12690,
 'ĠHey': 19103,
 'ĠF': 426,
 'Ġfootnotes': 47746,
 'St': 1393,
 'recognized': 33850,
 'Ġcalculating': 17085,
 'inski': 36303,
 'nder': 594,
 'Ġgrabbing': 41105,
 'burning': 34306,
 'ĠShi': 23607,
 'ĠEat': 16478,
 'Ġromance': 18233,
 'Ġnodules': 37543,
 'ĠFiscal': 45161,
 'Ġprodu': 796,
 'Ġvigil': 15173,
 'ĠAfric': 2079,
 'ĠAlphabet': 34586,
 'ĠMarch': 3903,
 'Scientific': 29654,
 'ĠRon': 17668,
 'Ġallowing': 3910,
 'Ġmenstruation': 31149,
 'Ġzu': 47316,
 'Enum': 21207,
 'versions': 30511,
 'ĠSilicon': 32688,
 'ui': 9010,
 'Ġbidding': 44746,
 'Ġclump': 45949,
 'đ': 205,
 'Ġorganizer': 32704,
 'Ġfurther': 2030,
 'erence': 2095,
 'Ġattractions': 21627,
 'Ġinstitutions': 4679,
 'variable': 18025,
 'ĠBA': 27173,
 'ĠOk': 16765,
 'Ġhighest': 4919,
 'Follow': 11101,
 'Ġdrawback': 37762,
 'ĠSyst': 46919,
 '>.': 19369,
 

In [9]:
len(vocab)

49152

In [10]:
id_to_token = {token_id: token for token,token_id in vocab.items()}

In [11]:
for i in [26843, 665, 28, 638, 359, 346, 47]:
    print(id_to_token[i])

Hi
Ġthere
,
Ġhow
Ġare
Ġyou
?


### LLM Outputs ??

In [12]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/724M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [13]:
def get_next_token_predictions(model, input_ids, top_k=10):
    """
    Get top k predictions for the next token with their probabilities
    
    Args:
        model: The language model
        input_ids: Input token ids (tensor)
        top_k: Number of top predictions to return
    
    Returns:
        List of tuples (token_id, probability)
    """
    # Get model's raw output (logits)
    with torch.no_grad():
        outputs = model(input_ids) # Passing the input ids to the models
        logits = outputs.logits # models spits out probabilty of each token

    # Get the last token's predictions
    next_token_logits = logits[0, -1, :]
    
    # Convert logits to probabilities using softmax
    next_token_probs = F.softmax(next_token_logits, dim=-1)
    
    # Get top k probabilities and corresponding token ids
    top_k_probs, top_k_tokens = torch.topk(next_token_probs, top_k)
    
    # Convert to list of tuples (token_id, probability)
    predictions = [
        (token.item(), prob.item()) 
        for token, prob in zip(top_k_tokens, top_k_probs)
    ]
    
    return predictions

In [14]:
input_text = "Hello, how are"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# Get predictions
predictions = get_next_token_predictions(model, input_ids)

# Print results with decoded tokens
for token_id, prob in predictions:
    token_text = tokenizer.decode(token_id)
    print(f"Token: '{token_text}', Probability: {prob:.4f}")

Token: ' you', Probability: 0.9479
Token: ' things', Probability: 0.0180
Token: ' we', Probability: 0.0064
Token: ' your', Probability: 0.0053
Token: ' the', Probability: 0.0042
Token: ' u', Probability: 0.0027
Token: ' ya', Probability: 0.0022
Token: ' y', Probability: 0.0015
Token: ' ye', Probability: 0.0015
Token: ' ', Probability: 0.0011


#### You can see what is the output of the llm after one prediction, LLM spits one token at a time, and then pass our input+output token as input again

(Concept alert): Sampling of output tokens, Not needed for this, but a small topic on how do we sample predicted outputs of LLM to have best output sequence
1. Greedy search
2. Exhaustive search and
3. Beam search

In [15]:
input_text = 'What is the value of 2+2? The answer is '
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# Get predictions
predictions = get_next_token_predictions(model, input_ids)

# Print results with decoded tokens
for token_id, prob in predictions:
    token_text = tokenizer.decode(token_id)
    print(f"Token: '{token_text}', Probability: {prob:.4f}")

Token: '4', Probability: 0.8685
Token: '2', Probability: 0.0483
Token: '5', Probability: 0.0366
Token: '3', Probability: 0.0151
Token: '6', Probability: 0.0102
Token: '1', Probability: 0.0090
Token: '0', Probability: 0.0073
Token: '7', Probability: 0.0017
Token: '8', Probability: 0.0012
Token: '9', Probability: 0.0006


### A complete output

In [16]:
def generate_text(model, tokenizer, prompt, max_length=50, temperature=0.7, stop_token="<|endoftext|>"):
    """
    Generate text continuation from a prompt until stop token or max length.
    
    Args:
        model: The language model
        tokenizer: The tokenizer
        prompt: Initial text prompt
        max_length: Maximum number of tokens to generate
        temperature: Controls randomness (lower = more deterministic)
        stop_token: Token to stop generation
        
    Returns:
        Generated text including the prompt
    """
    # Encode prompt
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    generated_text = prompt
    
    # Generate tokens until stop condition
    for _ in range(max_length):
        # Get model's output for current sequence
        with torch.no_grad():
            outputs = model(input_ids)
            next_token_logits = outputs.logits[0, -1, :]
            
            # Apply temperature
            next_token_logits = next_token_logits / temperature # What is the temperature we use in the prompt?
            
            # Get probabilities
            probs = F.softmax(next_token_logits, dim=-1)
            
            # Sample next token
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Decode token
            next_token_text = tokenizer.decode(next_token)
            print(next_token_text, end="", flush=True)
            
            # Check for stop condition
            if stop_token in next_token_text:
                break
                
            # Append to generated text
            generated_text += next_token_text
            
            # Append token to input_ids for next iteration
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
    
    return generated_text

# Example usage:
prompt = "Once upon a time"
generated = generate_text(
    model=model,
    tokenizer=tokenizer,
    prompt=prompt,
    max_length=50,
    temperature=0.7
)
# print(generated)

 in a small town named Harmonyville, there lived a group of animals who loved to learn about the world around them. One day, they noticed something strange - their little home was getting hotter every day because of a big storm that had hit nearby. So

#### What are these we use in the openai prompt? Temperature, top_k and top_p?

In [17]:
# # Apply top-k filtering
# if top_k > 0:
#     indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
#     next_token_logits[indices_to_remove] = float('-inf')
            
# Apply top-p filtering
# if top_p < 1.0:
#     sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
#     cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
#     sorted_indices_to_remove = cumulative_probs > top_p
#     sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
#     sorted_indices_to_remove[..., 0] = 0
#     indices_to_remove = sorted_indices[sorted_indices_to_remove]
#     next_token_logits[indices_to_remove] = float('-inf')

### Simple Inutition on forcing structured outputs

In [18]:
input_text = 'What is the value of 2+2? ANS: {"answer": '
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# Get predictions
predictions = get_next_token_predictions(model, input_ids)

# Print results with decoded tokens
for token_id, prob in predictions:
    token_text = tokenizer.decode(token_id)
    print(f"Token: '{token_text}', Probability: {prob:.4f}")

Token: '4', Probability: 0.8585
Token: '2', Probability: 0.0494
Token: '5', Probability: 0.0351
Token: '1', Probability: 0.0170
Token: '6', Probability: 0.0114
Token: '0', Probability: 0.0114
Token: '3', Probability: 0.0110
Token: '7', Probability: 0.0022
Token: '8', Probability: 0.0018
Token: '9', Probability: 0.0008


In [19]:
input_text = 'What is the value of 2+2? ANS: {"answer": 4'
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# Get predictions
predictions = get_next_token_predictions(model, input_ids)

# Print results with decoded tokens
for token_id, prob in predictions:
    token_text = tokenizer.decode(token_id)
    print(f"Token: '{token_text}', Probability: {prob:.4f}")

Token: ',', Probability: 0.3861
Token: '}', Probability: 0.3794
Token: ',"', Probability: 0.0607
Token: '}}', Probability: 0.0550
Token: '},', Probability: 0.0404
Token: '.', Probability: 0.0122
Token: ' }', Probability: 0.0116
Token: '}"', Probability: 0.0067
Token: '};', Probability: 0.0042
Token: '}.', Probability: 0.0040


In [20]:
def generate_formatted_answer(
    model, 
    tokenizer, 
    question, 
    format_prefix,
    format_suffix,
    allowed_tokens=None
):
    # Tokenize prefix and suffix
    prefix_tokens = tokenizer.encode(format_prefix, add_special_tokens=False)
    suffix_tokens = tokenizer.encode(format_suffix, add_special_tokens=False)
    
    # If no allowed tokens specified, use digits
    if allowed_tokens is None:
        allowed_tokens = tokenizer.encode("0123456789", add_special_tokens=False)
    
    # Create input
    input_ids = tokenizer(question, return_tensors="pt").input_ids
    generated = []
    
    # Generate prefix
    for expected_token in prefix_tokens:
        with torch.no_grad():
            outputs = model(input_ids)
            next_token_logits = outputs.logits[0, -1, :]
            
            # Force the expected token
            mask = torch.full_like(next_token_logits, float('-inf'))
            mask[expected_token] = 0
            next_token_logits += mask
            
            next_token = torch.argmax(next_token_logits).unsqueeze(0)
            generated.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
    
    # Generate answer (allow only specified tokens)
    with torch.no_grad():
        outputs = model(input_ids)
        next_token_logits = outputs.logits[0, -1, :]
        
        mask = torch.full_like(next_token_logits, float('-inf'))
        mask[allowed_tokens] = 0
        next_token_logits += mask
        
        next_token = torch.argmax(next_token_logits).unsqueeze(0)
        generated.append(next_token.item())
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

    # Generate suffix
    for expected_token in suffix_tokens:
        with torch.no_grad():
            outputs = model(input_ids)
            next_token_logits = outputs.logits[0, -1, :]
            
            mask = torch.full_like(next_token_logits, float('-inf'))
            mask[expected_token] = 0
            next_token_logits += mask
            
            next_token = torch.argmax(next_token_logits).unsqueeze(0)
            generated.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

    return tokenizer.decode(generated)

# Example usage:
formats = [
    ('Result = ', '.'),
    ('Answer:', ''),
    ('{"answer": ', '}')
]

question = "What is the aswer of 2+2?, The answer is"
for prefix, suffix in formats:
    result = generate_formatted_answer(
        model,
        tokenizer,
        question,
        format_prefix=prefix,
        format_suffix=suffix
    )
    print(f"\nFormat: {prefix}X{suffix}")
    print(f"Generated: {result}")


Format: Result = X.
Generated: Result = 4.

Format: Answer:X
Generated: Answer:2

Format: {"answer": X}
Generated: {"answer": 4}


#### It has a higher chance of performance detoriation, as we are disturbing the actual probabilty distribution, So, recent LLMs that offer structured outputs are trained to see lot of JSONs, to make accurate answers

Stuctured Outputs, REGEX -> How to find next possible character?
You can build a FSM (Finite State Machine) to validate state for the input you are given and check next possible valid states.
Not going through this here

### But wait LLM's are not character predictors!!! So how do make next possible characters predictor to next possible tokens predictor?

#### BruteForce

In [22]:
def print_preds(predictions):
    for token_id in predictions[:10]:
        token_text = tokenizer.decode(token_id)
        print(f"Token: '{token_text}'")

In [30]:
class SimpleRegexFSM:
    def __init__(self, pattern_type="json_number"):
        self.pattern_type = pattern_type
        self.states = {
            'start': {'{': 'open_brace'},
            'open_brace': {'"': 'quote1'},
            'quote1': {'a': 'a'},
            'a': {'n': 'n'},
            'n': {'s': 's'},
            's': {'w': 'w'},
            'w': {'e': 'e'},
            'e': {'r': 'r'},
            'r': {'"': 'quote2'},
            'quote2': {':': 'colon'},
            'colon': {'0':'number', '1':'number', '2':'number', '3':'number', 
                     '4':'number', '5':'number', '6':'number', '7':'number', 
                     '8':'number', '9':'number'},
            'number': {'}':'end', '0':'number', '1':'number', '2':'number', 
                      '3':'number', '4':'number', '5':'number', '6':'number', 
                      '7':'number', '8':'number', '9':'number'},
            'end': {}
        }
        self.accept_states = {'end'}
        
    def is_valid_continuation(self, current_str, token):
        """Check if adding token maintains valid path"""
        test_str = current_str + token
        
        # Track current state
        state = 'start'
        for char in test_str:
            # Get valid transitions from current state
            valid_transitions = self.states.get(state, {})
            
            # Check if this character is valid
            if char not in valid_transitions:
                return False
                
            # Move to next state
            state = valid_transitions[char]
            
        # Current state should either be accepting or have valid transitions
        return state in self.accept_states or bool(self.states.get(state, {}))

def generate_constrained_text(
    model, 
    tokenizer, 
    prompt, 
    max_length=50, 
    temperature=0.7,
    top_k=10,
    debug=False
):
    decoder = SimpleRegexFSM()
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    current_str = ""  # Start with empty string for regex matching
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            next_token_logits = outputs.logits[0, -1, :]
            next_token_logits = next_token_logits / temperature
            
            # Get top-k predictions
            top_k_probs, top_k_indices = torch.topk(
                F.softmax(next_token_logits, dim=-1),
                k=len(vocab)
            )
            
            # Check which tokens maintain valid regex
            valid_mask = torch.zeros_like(top_k_probs)
            for i, token_id in enumerate(top_k_indices):
                # Convert tensor to integer for decoding
                token_text = tokenizer.decode([token_id.item()], skip_special_tokens=True)
                if decoder.is_valid_continuation(current_str, token_text):
                    valid_mask[i] = 1
            
            # Mask invalid tokens
            masked_probs = top_k_probs * valid_mask
            if masked_probs.sum() == 0:
                if debug:
                    print("No valid tokens found")
                break
                
            # Sample next token
            masked_probs = masked_probs / masked_probs.sum()
            chosen_idx = torch.multinomial(masked_probs, num_samples=1)
            next_token_id = top_k_indices[chosen_idx].item()

            if debug:
                print("###",next_token_id)
            # Update state
            next_token_text = tokenizer.decode([next_token_id], skip_special_tokens=True)
            current_str += next_token_text
            input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)
            
            if debug:
                print(f"Current string: {current_str}")
            
            # Check if complete
            if current_str and current_str[-1] == '}':
                break
    
    return current_str

# Test the implementation
fsm = SimpleRegexFSM()
test_cases = [
    ('{"answer":', '1'),
    ('{"answer":1', '2'),
    ('{"answer":12', '3'),
    ('{"answer":123', '}'),
    ('', '{'),
]

for current, token in test_cases:
    valid = fsm.is_valid_continuation(current, token)
    print(f"Current: '{current}', Token: '{token}', Valid: {valid}")

Current: '{"answer":', Token: '1', Valid: True
Current: '{"answer":1', Token: '2', Valid: True
Current: '{"answer":12', Token: '3', Valid: True
Current: '{"answer":123', Token: '}', Valid: True
Current: '', Token: '{', Valid: True


In [31]:
generate_constrained_text(model, tokenizer, "What is the answer of 2+2?, Answer it in json format", debug=True)

### 0
Current string: 
### 107
Current string: {
### 18
Current string: {"
### 81
Current string: {"a
### 0
Current string: {"a
### 94
Current string: {"an
### 13356
Current string: {"answer
### 18
Current string: {"answer"
### 0
Current string: {"answer"
### 42
Current string: {"answer":
### 35
Current string: {"answer":3
### 33
Current string: {"answer":31
### 34
Current string: {"answer":312
### 37
Current string: {"answer":3125


KeyboardInterrupt: 

In [25]:
import torch
import torch.nn.functional as F

class TableFormatFSM:
    def __init__(self):
        # Define character sets:
        self.LETTERS = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
        self.NUMBERS = set("0123456789")
    
    def step(self, state, char):
        # Returns the next state given the current state and an input character,
        # or None if the char is not allowed.

        if state == 'start':
            if char == '<':
                return 'after_open'
            return None

        elif state == 'after_open':
            if char == '\n':
                return 'row_start'
            return None

        elif state == 'row_start':
            # At the beginning of a row, our first field (text) must start with a letter.
            if char in self.LETTERS:
                return 'text1'
            return None

        elif state == 'text1':
            # In field1 (text): we can accept more letters or a pipe indicating the end of field1.
            if char in self.LETTERS:
                return 'text1'  # Remain in text1
            if char == '|':
                return 'after_pipe1'
            return None

        elif state == 'after_pipe1':
            # Field2 must start with a digit.
            if char in self.NUMBERS:
                return 'number'
            return None

        elif state == 'number':
            # In field2 (number): continue to accept digits or a pipe to end the number.
            if char in self.NUMBERS:
                return 'number'
            if char == '|':
                return 'after_pipe2'
            return None

        elif state == 'after_pipe2':
            # Field3 (text) must start with a letter.
            if char in self.LETTERS:
                return 'text2'
            return None

        elif state == 'text2':
            # In field3 (text): allow letters; then either newline (for a new row) or closing '>' to finish.
            if char in self.LETTERS:
                return 'text2'
            if char == '\n':
                return 'row_start'
            if char == '>':
                return 'end'
            return None

        elif state == 'end':
            # Once ended, nothing more is accepted.
            return None

        return None

    def get_state(self, text):
        """
        Process the entire string `text` from the beginning and return the last reached state.
        If at any point an invalid character is encountered, return None.
        """
        state = 'start'
        for char in text:
            next_state = self.step(state, char)
            if next_state is None:
                return None
            state = next_state
        return state
    
    def is_valid_continuation(self, current_str, token):
        """
        Check if appending token to current_str is a valid partial continuation.
        (It does not require that the whole table is complete—only that the FSM does not
        get stuck.)
        """
        state = self.get_state(current_str)
        if state is None:
            return False
        for char in token:
            state = self.step(state, char)
            if state is None:
                return False
        return True
    
    def is_accepting(self, text):
        """
        We consider the string accepted only if the FSM is in state 'end'
        (i.e. the table is complete).
        """
        return self.get_state(text) == 'end'


if __name__ == "__main__":
    # Create an instance of the FSM:
    fsm = TableFormatFSM()
    test_cases = [
        # (current string, token to add, expected valid?)
        ('<', '\n', True),
        ('<\n', 'text', True),
        ('<\ntext', '|', True),
        ('<\ntext|', '42', True),
        ('<\ntext|42', '|', True),
        ('<\ntext|42|', 'more', True),
        ('<\ntext|42|more', '\n', True),
        ('<\ntext|42|more\ntext|13|text', '>', True),
    ]

    print("Testing FSM with example cases:")
    for current, token, expected in test_cases:
        valid = fsm.is_valid_continuation(current, token)
        print(f"Current: {repr(current)}, Token: {repr(token)}, Valid: {valid} (expected {expected})")

Testing FSM with example cases:
Current: '<', Token: '\n', Valid: True (expected True)
Current: '<\n', Token: 'text', Valid: True (expected True)
Current: '<\ntext', Token: '|', Valid: True (expected True)
Current: '<\ntext|', Token: '42', Valid: True (expected True)
Current: '<\ntext|42', Token: '|', Valid: True (expected True)
Current: '<\ntext|42|', Token: 'more', Valid: True (expected True)
Current: '<\ntext|42|more', Token: '\n', Valid: True (expected True)
Current: '<\ntext|42|more\ntext|13|text', Token: '>', Valid: True (expected True)


In [34]:
def generate_constrained_text(
    model, 
    tokenizer, 
    prompt, 
    max_length=50, 
    temperature=0.7,
    top_k=10,
    debug=False
):
    decoder = TableFormatFSM()
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    current_str = ""  # Start with empty string for regex matching
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids)
            next_token_logits = outputs.logits[0, -1, :]
            next_token_logits = next_token_logits / temperature
            
            # Get top-k predictions
            top_k_probs, top_k_indices = torch.topk(
                F.softmax(next_token_logits, dim=-1),
                k=len(vocab)
            )
            
            # Check which tokens maintain valid regex
            valid_mask = torch.zeros_like(top_k_probs)
            for i, token_id in enumerate(top_k_indices):
                # Convert tensor to integer for decoding
                token_text = tokenizer.decode([token_id.item()], skip_special_tokens=True)
                if decoder.is_valid_continuation(current_str, token_text):
                    valid_mask[i] = 1
            
            # Mask invalid tokens
            masked_probs = top_k_probs * valid_mask
            if masked_probs.sum() == 0:
                if debug:
                    print("No valid tokens found")
                break
                
            # Sample next token
            masked_probs = masked_probs / masked_probs.sum()
            chosen_idx = torch.multinomial(masked_probs, num_samples=1)
            next_token_id = top_k_indices[chosen_idx].item()

            if debug:
                print("###",next_token_id)
            # Update state
            next_token_text = tokenizer.decode([next_token_id], skip_special_tokens=True)
            current_str += next_token_text
            input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)
            
            if debug:
                print(f"Current string: {current_str}")
            
            # Check if complete
            if current_str and current_str[-1] == '}':
                break
    
    return current_str

In [38]:
prompt = """There are 3 people Raju, Geeta and Bhanu, Raju has 3 apples, Geeta has 1 banana and Bhanu has 2 tomato. Tell the information about raju, geeta anf bhanu and fruitfs they have in format <\nperson|count|type\n>"""
generated = generate_text(
    model=model,
    tokenizer=tokenizer,
    prompt=prompt,
    max_length=50,
    temperature=0.7
)


0|>|<
0|>|<
0|>|<

Explanation:

Raju has 3 apples and 1 banana,

Geeta has 1 banana and 2 tomato.


In [41]:
generate_constrained_text(model, tokenizer, """There are 3 people Raju, Geeta and Bhanu, Raju has 3 apples, Geeta has 1 banana and Bhanu has 2 tomato. Tell the information about raju, geeta anf bhanu and fruits they have in format <\nperson|count|type\n>""", debug=True)

### 44
Current string: <
### 198
Current string: <

### 66
Current string: <
R
### 1346
Current string: <
Raj
### 101
Current string: <
Raju
### 108
Current string: <
Raju|
### 35
Current string: <
Raju|3
### 108
Current string: <
Raju|3|
### 48773
Current string: <
Raju|3|apples
### 198
Current string: <
Raju|3|apples

### 9488
Current string: <
Raju|3|apples
Ge
### 8810
Current string: <
Raju|3|apples
Geeta
### 108
Current string: <
Raju|3|apples
Geeta|
### 33
Current string: <
Raju|3|apples
Geeta|1
### 108
Current string: <
Raju|3|apples
Geeta|1|
### 2947
Current string: <
Raju|3|apples
Geeta|1|ban
### 3231
Current string: <
Raju|3|apples
Geeta|1|banana
### 198
Current string: <
Raju|3|apples
Geeta|1|banana

### 50
Current string: <
Raju|3|apples
Geeta|1|banana
B
### 10936
Current string: <
Raju|3|apples
Geeta|1|banana
Bhan
### 101
Current string: <
Raju|3|apples
Geeta|1|banana
Bhanu
### 108
Current string: <
Raju|3|apples
Geeta|1|banana
Bhanu|
### 34
Current string: <
Raju|3|apples

KeyboardInterrupt: 

As you can see, this is not generic for a given regex, I only used a custom structure to show how it can be done. 
To actually check how to do it, Please check the paper 

https://arxiv.org/html/2407.08103v1