In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm

In [5]:
model_name = "Qwen/Qwen3-0.6B-Base"  # Replace with your checkpoint

In [31]:
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="mps",
    use_cache=True
).eval()

# prepare the model input
prompt = "Once upon a time, there was a Wait,"
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

In [5]:
# conduct text completion
generated_ids = model.generate(**model_inputs, max_new_tokens=1200)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()

# parsing thinking content
try:
    # rindex finding 151668 (</think>)
    index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
    index = 0

thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(
    "\n"
)
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

print("thinking content:", thinking_content)
print("content:", content)


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


thinking content: 
content:  ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇ ⚇


In [46]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

def generate_with_injection(model, tokenizer, model_inputs, max_new_tokens=512, top_p=0.7, temperature=1.0):
    # Setup - precompute all token IDs
    eos_token_id = tokenizer.eos_token_id
    wait_token_ids = torch.tensor(tokenizer.encode("Wait,", add_special_tokens=False), 
                                device=model_inputs["input_ids"].device)
    think_close_id = 151668  # Predefined token ID
    
    # Initialize tensors
    input_ids = model_inputs["input_ids"]
    attention_mask = model_inputs["attention_mask"]
    generated = input_ids.clone()
    
    # Generation parameters
    injected = False
    past_key_values = None  # For caching
    
    # Initialize progress bar
    pbar = tqdm(total=max_new_tokens, desc="Generating tokens", unit="token")
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Forward pass with caching if supported
            if past_key_values is not None and hasattr(model, 'use_cache') and model.use_cache:
                outputs = model(
                    input_ids=generated[:, -1:],  # Only pass the last token
                    attention_mask=attention_mask,
                    past_key_values=past_key_values,
                    use_cache=True
                )
            else:
                outputs = model(
                    input_ids=generated,
                    attention_mask=attention_mask
                )
            
            logits = outputs.logits[:, -1, :] / temperature
            past_key_values = outputs.past_key_values if hasattr(outputs, 'past_key_values') else None
            
            # Top-p sampling (nucleus sampling)
            probs = F.softmax(logits, dim=-1)
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
            
            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[:, indices_to_remove] = -float('Inf')
            
            # Sample from the filtered distribution
            next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
            
            # Injection logic
            if next_token.item() == eos_token_id and not injected:
                generated = torch.cat([generated, wait_token_ids.unsqueeze(0)], dim=1)
                attention_mask = torch.cat([
                    attention_mask, 
                    torch.ones((1, len(wait_token_ids)), device=attention_mask.device)
                ], dim=1)
                injected = True
                pbar.update(len(wait_token_ids))  # Update progress for multiple tokens
                continue
            
            generated = torch.cat([generated, next_token], dim=1)
            attention_mask = torch.cat([
                attention_mask, 
                torch.ones((1, 1), device=attention_mask.device)
            ], dim=1)
            pbar.update(1)  # Update progress for single token
            
            # Stop condition
            if next_token.item() == eos_token_id and injected:
                break
    
    pbar.close()
    
    # Post-processing
    output_ids = generated[0].tolist()
    
    try:
        index = len(output_ids) - output_ids[::-1].index(think_close_id)
    except ValueError:
        index = 0
    
    thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
    content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
    
    return {"thinking_content": thinking_content, "content": content}

In [34]:
# prepare the model input
prompt = "Once upon a time, there was a Wait,"
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

In [39]:
print(tokenizer.decode(model_inputs["input_ids"][0]))

<|im_start|>user
Once upon a time, there was a Wait,<|im_end|>
<|im_start|>assistant
<think>

</think>




In [60]:
model_inputs = tokenizer(["In the word strawberry there are 4 r's or Wait, let me think again"], return_tensors="pt").to(model.device)

In [64]:
result = generate_with_injection(model, tokenizer, model_inputs, 40)
print("thinking content:", result["thinking_content"])
print("content:", result["content"])

Generating tokens: 100%|██████████| 40/40 [00:05<00:00,  6.67token/s]

thinking content: 
content: In the word strawberry there are 4 r's or Wait, let me think again, I'll make this simpler, I just want to know if you can understand this question. The word strawberry has 4r's or Wait, I'll make this simpler, I just want to



