In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn import functional as F
import random

model_name = 'eryk-mazus/polka-1.1b'  # Updated model name to match initialization

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

  torch.utils._pytree._register_pytree_node(





  torch.utils._pytree._register_pytree_node(


In [21]:
import string

# Invert vocabulary to quickly map IDs -> tokens
vocab = tokenizer.get_vocab()
id_to_token = {v: k for k, v in vocab.items()}

def strip_special(token: str) -> str:
    """
    Removes special SentencePiece characters like '▁' from the start of tokens.
    """
    return token.lstrip("▁")

def starts_with_letter(token: str, letter: str) -> bool:
    """
    Returns True if 'token' (after removing special chars) starts with the given letter.
    """
    # if equal to one special token, return False
    if token == '▁':
        return False
    # punctuation is ok
    if token[0] == '▁' and len(token) == 2 and token[1] in string.punctuation:
        return True
    for sub_word in token.split('▁')[1:]:
        if sub_word and sub_word[0].lower() != letter:
            return False
    return True

def middle_of_the_word(token: str) -> bool:
    """
    Returns True if the token does not start with special character.
    """
    return '▁' not in token

def top_k_top_p_filtering(logits, top_k=50, top_p=0.9):
    """
    Filters a distribution of logits using top-k and nucleus (top-p) filtering.
    Sets probabilities of tokens outside the top-k or top-p to -inf.
    """
    # Apply top-k
    if top_k > 0:
        # get top_k indices
        values_to_keep, _ = torch.topk(logits, top_k)
        min_value_to_keep = values_to_keep[-1]
        logits[logits < min_value_to_keep] = float('-inf')
    
    # Apply top-p (nucleus) filtering
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    # mask out tokens beyond top_p
    sorted_indices_to_remove = sorted_indices[cumulative_probs > top_p]
    logits[sorted_indices_to_remove] = float('-inf')
    
    # # print how many is left
    # print(f"Left: {len(logits[logits != float('-inf')])}")
    
    return logits

def get_candidates(prefix: str, letter: str, top_k_val=50, top_p_val=0.9):
    """
    Get candidate tokens (string, probability) for the next position,
    filtered by top-k/top-p, then by first-letter constraint.
    """
    input_ids = tokenizer(prefix, return_tensors='pt')['input_ids'].to(device)
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
    
    next_logits = outputs.logits[0, -1, :]
    
    # Optionally adjust logits (e.g., temperature)
    # temperature = 1.0
    # next_logits = next_logits / temperature
    
    # top-k / top-p filtering
    filtered_logits = top_k_top_p_filtering(next_logits.clone(),
                                           top_k=top_k_val,
                                           top_p=top_p_val)
    probs = F.softmax(filtered_logits, dim=-1)
    
    # Collect valid candidates
    candidates = []
    for token_id in range(probs.shape[0]):
        p = probs[token_id].item()
        if p > 0:
            token_str = id_to_token[token_id]
            if middle_of_the_word(token_str) or starts_with_letter(token_str, letter):
                # # replace '▁' => ' ' for nice spacing
                # display_str = token_str.replace('▁', ' ')
                # display_str = f'{token_str}|'
                candidates.append((token_str, p))
    
    # sort descending by probability
    candidates.sort(key=lambda x: x[1], reverse=True)
    return candidates

def backtrack_generation(
    prefix: str, 
    letter: str,
    depth: int,
    max_depth: int,
    used_tokens: set,
    debug_mode: bool
):
    """
    Recursive function that tries to append one token at a time.
    If we cannot find a valid token (dead end), we backtrack.

    :param prefix: Current text.
    :param letter: The designated letter for each next token.
    :param depth: Current recursion depth (number of tokens added so far).
    :param max_depth: Maximum tokens to generate.
    :param used_tokens: A set of tokens we have used so far (for repetition avoidance).
    :param debug_mode: Whether to print debug messages during backtracking.
    :return: Completed string if successful, or None if no solution found.
    """
    # If we've reached a sentence terminator, you could choose to stop here as success:
    if any(prefix.endswith(p) for p in ['.', '!', '?']):
        return prefix  # success
    
    if depth >= max_depth:
        # Reached maximum length
        return prefix  # success - or None if you require punctuation end
    
    # Gather next candidates
    candidates = get_candidates(prefix, letter, top_k_val=30, top_p_val=0.99)
    # Optionally filter out tokens we've used (simple repetition avoidance)
    # (This might be overly strict in some scenarios.)
    filtered = [(t, p) for (t, p) in candidates if t not in used_tokens]
    
    # argsort by probability, start from the highest
    filtered.sort(key=lambda x: x[1], reverse=True)
    
    if not filtered:
        # No valid tokens => Dead end, must backtrack
        if debug_mode:
            print(f"[DEBUG] Dead end. Backtracking from: '{prefix}'")
        return None
    
    # We will try candidates from highest to lowest probability
    # or you can randomize them; for demonstration let's do high->low
    for token_str, prob in filtered:
        new_prefix_raw = prefix + f'|{token_str}'
        # Add token to used tokens (if you want to avoid immediate repetition)
        
        token = token_str.replace('▁', ' ')

        # Extend prefix with this token
        new_prefix = prefix + token

        new_used = used_tokens.copy()
        new_used.add(token)
        
        if debug_mode:
            print(f"[DEBUG] Trying: '{new_prefix_raw}' (depth={depth+1})")
        
        result = backtrack_generation(
            prefix=new_prefix,
            letter=letter,
            depth=depth+1,
            max_depth=max_depth,
            used_tokens=new_used,
            debug_mode=debug_mode
        )
        
        if result is not None:
            # Found a valid completion from here on
            return result
        else:
            # This candidate eventually led to a dead end, so we must backtrack
            if debug_mode:
                print(f"[DEBUG] Backtracking from: '{new_prefix_raw}'")
    
    # If we exhausted all candidates, we must backtrack further
    return None

def generate_sentence_with_backtracking(prefix, max_len=15, debug_mode=False):
    """
    Wrapper that calls the recursive backtracking function.
    """
    # The letter is the first letter of the prefix (lowercased).
    # If you want a different letter, you can pass it explicitly.
    cleaned_prefix = prefix.strip()
    if not cleaned_prefix:
        raise ValueError("Prefix cannot be empty.")
    
    letter = cleaned_prefix[0].lower()
    used_tokens = set()  # to optionally avoid immediate re-use of the same token
    completed = backtrack_generation(
        prefix=prefix,
        letter=letter,
        depth=0,
        max_depth=max_len,
        used_tokens=used_tokens,
        debug_mode=debug_mode
    )
    # If you want to ensure the final text ends with punctuation, you can add logic here:
    if completed is not None and not any(completed.endswith(p) for p in ['.', '!', '?']):
        # Optionally add a period or do further checks
        completed += "."
    return completed

# -------------------------------------------------------------------
# EXAMPLE USAGE
# -------------------------------------------------------------------
if __name__ == "__main__":
    start_txts = [
        "Obowiązuje on od",
        "Został zrodzony ze",
        "Po pierwsze, projekt",
        "Po Panthers przejechali",
        "Duze dwusuwowe diesle",
        "Niestety, nikt nie",
        "Pani poseł, proszę",
        "Proszę państwa, po",
        "Proszę pana posła"
    ]
    
    for start_txt in start_txts:
        print(f"=== Generating for prefix: '{start_txt}' ===")
        # Try debug_mode=True to see backtracking steps
        completion = generate_sentence_with_backtracking(start_txt, max_len=15, debug_mode=True)
        print(f"FINAL COMPLETION: '{completion}'\n")


=== Generating for prefix: 'Obowiązuje on od' ===
[DEBUG] Trying: 'Obowiązuje on od|:' (depth=1)
[DEBUG] Trying: 'Obowiązuje on od:|2' (depth=2)
[DEBUG] Trying: 'Obowiązuje on od:2|0' (depth=3)
[DEBUG] Trying: 'Obowiązuje on od:20|1' (depth=4)
[DEBUG] Trying: 'Obowiązuje on od:201|9' (depth=5)
[DEBUG] Trying: 'Obowiązuje on od:2019|-' (depth=6)
[DEBUG] Dead end. Backtracking from: 'Obowiązuje on od:2019-'
[DEBUG] Backtracking from: 'Obowiązuje on od:2019|-'
[DEBUG] Trying: 'Obowiązuje on od:2019|.' (depth=6)
FINAL COMPLETION: 'Obowiązuje on od:2019.'

=== Generating for prefix: 'Został zrodzony ze' ===
[DEBUG] Trying: 'Został zrodzony ze|▁zł' (depth=1)
[DEBUG] Trying: 'Został zrodzony ze zł|ota' (depth=2)
[DEBUG] Trying: 'Został zrodzony ze złota|,' (depth=3)
[DEBUG] Trying: 'Został zrodzony ze złota,|▁został' (depth=4)
[DEBUG] Trying: 'Został zrodzony ze złota, został|▁z' (depth=5)
[DEBUG] Trying: 'Został zrodzony ze złota, został z|rodz' (depth=6)
[DEBUG] Trying: 'Został zrodzony ze 

In [22]:
start_txts = [
    "Obowiązuje on od",
    "Został zrodzony ze",
    "Po pierwsze, projekt",
    "Po Panthers przejechali",
    "Duze dwusuwowe diesle",
    "Niestety, nikt nie",
    "Pani poseł, proszę",
    "Proszę państwa, po",
    "Proszę pana posła"
]

for start_txt in start_txts:
    print(f"=== Generating for prefix: '{start_txt}' ===")
    # Try debug_mode=True to see backtracking steps
    completion = generate_sentence_with_backtracking(start_txt, max_len=15, debug_mode=False)
    print(f"FINAL COMPLETION: '{completion}'\n")

=== Generating for prefix: 'Obowiązuje on od' ===
FINAL COMPLETION: 'Obowiązuje on od:2019.'

=== Generating for prefix: 'Został zrodzony ze' ===
FINAL COMPLETION: 'Został zrodzony ze złota, został zrodzony z ziemi.'

=== Generating for prefix: 'Po pierwsze, projekt' ===
FINAL COMPLETION: 'Po pierwsze, projektowanie przemysłowe - projektant - projektanci - projektow.'

=== Generating for prefix: 'Po Panthers przejechali' ===
FINAL COMPLETION: 'Po Panthers przejechali przez park, policja poszukuje psa - PortalPisar.'

=== Generating for prefix: 'Duze dwusuwowe diesle' ===
FINAL COMPLETION: 'Duze dwusuwowe diesle - DziałajLokalnie.pl\nDuży diesel, duże.'

=== Generating for prefix: 'Niestety, nikt nie' ===
FINAL COMPLETION: 'Niestety, nikt nie napisał, nawet na nasz.'

=== Generating for prefix: 'Pani poseł, proszę' ===
FINAL COMPLETION: 'Pani poseł, proszę powołać pełną, praworządna parlamentarno-gab.'

=== Generating for prefix: 'Proszę państwa, po' ===
FINAL COMPLETION: 'Proszę państw

In [23]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn import functional as F
import random
import string

# Initialize the model and tokenizer
model_name = "eryk-mazus/polka-1.1b"  # Example model
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

# Invert vocabulary to quickly map IDs -> tokens
vocab = tokenizer.get_vocab()
id_to_token = {v: k for k, v in vocab.items()}

def strip_special(token: str) -> str:
    """
    Removes special SentencePiece characters like '▁' from the start of tokens.
    """
    return token.lstrip("▁")

def starts_with_letter(token: str, letter: str) -> bool:
    """
    Returns True if 'token' (after removing special chars) starts with the given letter.
    """
    # If equal to one special token, return False
    if token == '▁':
        return False
    # Punctuation is ok
    if token[0] == '▁' and len(token) == 2 and token[1] in string.punctuation:
        return True
    # Split the token by '▁' and check each sub-word
    for sub_word in token.split('▁')[1:]:
        if sub_word and sub_word[0].lower() != letter:
            return False
    return True

def middle_of_the_word(token: str) -> bool:
    """
    Returns True if the token does not start with special character.
    """
    return '▁' not in token

def top_k_top_p_filtering(logits, top_k=50, top_p=0.9):
    """
    Filters a distribution of logits using top-k and nucleus (top-p) filtering.
    Sets probabilities of tokens outside the top-k or top-p to -inf.
    """
    # Apply top-k
    if top_k > 0:
        # Get top_k indices
        values_to_keep, _ = torch.topk(logits, top_k)
        min_value_to_keep = values_to_keep[-1]
        logits[logits < min_value_to_keep] = float('-inf')
    
    # Apply top-p (nucleus) filtering
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    # Mask out tokens beyond top_p
    sorted_indices_to_remove = sorted_indices[cumulative_probs > top_p]
    logits[sorted_indices_to_remove] = float('-inf')
    
    return logits

def get_candidates(prefix: str, letter: str, top_k_val=50, top_p_val=0.9):
    """
    Get candidate tokens (string, probability) for the next position,
    filtered by top-k/top-p, then by first-letter constraint.
    """
    input_ids = tokenizer(prefix, return_tensors='pt')['input_ids'].to(device)
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
    
    next_logits = outputs.logits[0, -1, :]
    
    # Apply top-k / top-p filtering
    filtered_logits = top_k_top_p_filtering(next_logits.clone(),
                                           top_k=top_k_val,
                                           top_p=top_p_val)
    probs = F.softmax(filtered_logits, dim=-1)
    
    # Collect valid candidates
    candidates = []
    for token_id in range(probs.shape[0]):
        p = probs[token_id].item()
        if p > 0:
            token_str = id_to_token[token_id]
            if middle_of_the_word(token_str) or starts_with_letter(token_str, letter):
                candidates.append((token_str, p))
    
    # Sort descending by probability
    candidates.sort(key=lambda x: x[1], reverse=True)
    return candidates

def backtrack_generation(
    prefix: str, 
    letter: str,
    depth: int,
    max_depth: int,
    used_tokens: set,
    tree: dict
):
    """
    Recursive function that tries to append one token at a time.
    If we cannot find a valid token (dead end), we backtrack.

    :param prefix: Current text.
    :param letter: The designated letter for each next token.
    :param depth: Current recursion depth (number of tokens added so far).
    :param max_depth: Maximum tokens to generate.
    :param used_tokens: A set of tokens we have used so far (for repetition avoidance).
    :param tree: The nested dictionary representing the generation tree.
    :return: Completed string if successful, or None if no solution found.
    """
    # If we've reached a sentence terminator, you could choose to stop here as success:
    if any(prefix.endswith(p) for p in ['.', '!', '?']):
        return prefix  # success
    
    if depth >= max_depth:
        # Reached maximum length
        return prefix  # success - or None if you require punctuation end
    
    # Gather next candidates
    candidates = get_candidates(prefix, letter, top_k_val=30, top_p_val=0.99)
    # Optionally filter out tokens we've used (simple repetition avoidance)
    filtered = [(t, p) for (t, p) in candidates if t not in used_tokens]
    
    # Sort by probability descending
    filtered.sort(key=lambda x: x[1], reverse=True)
    
    if not filtered:
        # No valid tokens => Dead end, must backtrack
        # Mark dead end in the tree
        tree['__dead_end__'] = True
        return None
    
    # Try each candidate
    for token_str, prob in filtered:
        # Extend prefix with this token
        token = token_str.replace('▁', ' ')
        new_prefix = prefix + token
        
        # Add token to used tokens to avoid repetition
        new_used = used_tokens.copy()
        new_used.add(token)
        
        # Add this token to the tree
        tree[token_str] = {}
        
        # Recurse with the new prefix and updated tree
        result = backtrack_generation(
            prefix=new_prefix,
            letter=letter,
            depth=depth+1,
            max_depth=max_depth,
            used_tokens=new_used,
            tree=tree[token_str]
        )
        
        if result is not None:
            # Found a valid completion from here on
            return result
        else:
            # This candidate eventually led to a dead end, continue with next candidate
            continue
    
    # If all candidates lead to dead ends, mark and return None
    tree['__dead_end__'] = True
    return None

def generate_sentence_with_backtracking(prefix, max_len=15, debug_mode=False):
    """
    Wrapper that calls the recursive backtracking function and collects the generation tree.
    
    :param prefix: The initial text prefix.
    :param max_len: Maximum number of tokens to generate.
    :param debug_mode: If True, stores the generation steps; else, silent.
    :return: Completed string and the generation tree.
    """
    # The letter is the first letter of the prefix (lowercased).
    cleaned_prefix = prefix.strip()
    if not cleaned_prefix:
        raise ValueError("Prefix cannot be empty.")
    
    letter = cleaned_prefix[0].lower()
    used_tokens = set()  # to optionally avoid immediate re-use of the same token
    tree = {}
    
    completed = backtrack_generation(
        prefix=prefix,
        letter=letter,
        depth=0,
        max_depth=max_len,
        used_tokens=used_tokens,
        tree=tree
    )
    
    # If you want to ensure the final text ends with punctuation, you can add logic here:
    if completed is not None and not any(completed.endswith(p) for p in ['.', '!', '?']):
        # Optionally add a period or do further checks
        completed += "."
    
    return completed, tree

def print_tree(d, indent="", last=True):
    """
    Recursively prints a nested dictionary in a tree-like format.
    
    :param d: The nested dictionary.
    :param indent: The indentation string (used in recursion).
    :param last: Boolean indicating if this is the last child.
    """
    # Handle the root case where d is empty
    if not d:
        return
    
    keys = list(d.keys())
    for i, key in enumerate(keys):
        is_last = i == (len(keys) - 1)
        connector = "└── " if is_last else "├── "
        print(indent + connector + key)
        if isinstance(d[key], dict):
            extension = "    " if is_last else "│   "
            print_tree(d[key], indent + extension, is_last)

# -------------------------------------------------------------------
# EXAMPLE USAGE
# -------------------------------------------------------------------
if __name__ == "__main__":
    start_txts = [
        "Obowiązuje on od",
        "Został zrodzony ze",
        "Po pierwsze, projekt",
        "Po Panthers przejechali",
        "Duze dwusuwowe diesle",
        "Niestety, nikt nie",
        "Pani poseł, proszę",
        "Proszę państwa, po",
        "Proszę pana posła"
    ]
    
    for start_txt in start_txts:
        print(f"=== Generating for prefix: '{start_txt}' ===")
        # Generate sentence and collect the generation tree
        completion, generation_tree = generate_sentence_with_backtracking(
            start_txt, 
            max_len=15, 
            debug_mode=False  # debug_mode is now handled via the tree
        )
        print(f"FINAL COMPLETION: '{completion}'\n")
        
        # Optionally, print the generation tree
        print("Generation Tree:")
        print_tree(generation_tree)
        print("\n" + "="*80 + "\n")

'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: fb30f72b-4198-409c-9cc7-ec852d30a1b4)')' thrown while requesting HEAD https://huggingface.co/eryk-mazus/polka-1.1b/resolve/main/tokenizer_config.json


=== Generating for prefix: 'Obowiązuje on od' ===
FINAL COMPLETION: 'Obowiązuje on od:2019.'

Generation Tree:
└── :
    └── 2
        └── 0
            └── 1
                └── 9
                    ├── -
                    │   └── __dead_end__
                    └── .


=== Generating for prefix: 'Został zrodzony ze' ===
FINAL COMPLETION: 'Został zrodzony ze złota, został zrodzony z ziemi.'

Generation Tree:
└── ▁zł
    └── ota
        └── ,
            └── ▁został
                └── ▁z
                    └── rodz
                        └── ony
                            └── ▁z
                                └── ▁z
                                    └── iem
                                        └── i
                                            └── .


=== Generating for prefix: 'Po pierwsze, projekt' ===
FINAL COMPLETION: 'Po pierwsze, projektowanie przemysłowe - projektant - projektanci - projektow.'

Generation Tree:
└── owanie
    └── ▁prz
        ├── estr
        │   └