# True Logits Masking with Llama (Open-Source Models)
## Using custom LogitsProcessor for complex business rules

This notebook demonstrates **true logits masking** with full control over token generation.

**Key difference from Notebook 1:**
- ‚úÖ Direct access to logits array
- ‚úÖ Can implement ANY custom logic
- ‚úÖ Context-aware decisions
- ‚úÖ Dynamic rules based on generation state
- ‚úÖ No API limitations

**Use case:** Medical or Legal advice generator with strict compliance rules

**REQUIRES:** .env file to run and `HF_TOKEN`. To obtain the token visit the `https://huggingface.co/`

# PS. Dont forget to stop the jupyter notebook server.
## To stop: `jupyter notebook stop 8888`

## Installation

In [None]:
# Install required packages
# Note: This requires GPU for reasonable performance
!pip install transformers torch accelerate sentencepiece charset-normalizer

## Setup

In [None]:
import torch
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    LogitsProcessor,
    LogitsProcessorList
)
import re
from typing import List
import os
from dotenv import load_dotenv

# Load environment variables (for HF_TOKEN)
load_dotenv()

# Get HuggingFace token from environment
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
    raise ValueError("HF_TOKEN not found in .env file. Please add your Hugging Face token.")


# Check device availability - prioritize MPS for Apple Silicon
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"  # Apple Silicon GPU acceleration
else:
    device = "cpu"

print(f"Using device: {device}")

if device == "cpu":
    print("‚ö†Ô∏è Warning: Running on CPU will be slow. GPU recommended.")
elif device == "mps":
    print("‚úÖ Using Apple Silicon GPU acceleration (MPS)")

## Load Model and Tokenizer

We'll use Llama-3-8B-Instruct. You can also use smaller models like:
- `meta-llama/Llama-3.2-1B-Instruct` (smaller, faster)
- `meta-llama/Llama-3.2-3B-Instruct` (medium)

**Note:** You may need to accept the license on HuggingFace and set HF_TOKEN.

In [None]:
# Choose your model - Updated with non-gated alternatives
# MODEL_NAME = "microsoft/Phi-3.5-mini-instruct"  # 3.8B - RECOMMENDED, no access needed
MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"  # 3B - Good quality, no access needed
# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"  # Smallest/fastest (requires access)
# MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"  # Good balance (requires access)
# MODEL_NAME = "meta-llama/Llama-3-8B-Instruct"  # Best quality (requires access)

# Get HuggingFace token from environment (optional for non-gated models)
hf_token = os.getenv("HF_TOKEN")

print(f"Loading model: {MODEL_NAME}...")
print("This may take a few minutes on first run (downloads model)...\n")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,  # Use bfloat16 for efficiency
    device_map="auto",  # Automatically use GPU/MPS if available
    token=hf_token
)

print(f"‚úÖ Model loaded successfully!")
print(f"   Model parameters: {model.num_parameters() / 1e9:.1f}B")
print(f"   Vocabulary size: {len(tokenizer)}")

## Use Case: Medical Advice Generator with Compliance Rules

We'll create a medical advice chatbot that MUST follow strict compliance rules:

### Rules:
1. **Never claim certainty** - Ban words like "definitely", "guaranteed", "always"
2. **No diagnoses** - Ban "you have", "diagnosed with"
3. **Include disclaimers** - If discussing medication, must mention consulting doctor
4. **Professional language only** - Ban informal words in medical context

These rules are **context-aware** and **dynamic** - impossible with simple logit_bias!

## Implementation: Custom LogitsProcessor

This is where the magic happens - we have FULL control over the logits!

In [None]:
class MedicalComplianceLogitsProcessor(LogitsProcessor):
    """
    Enforces medical compliance rules by masking tokens.
    
    This demonstrates TRUE logits masking:
    - Direct access to logits
    - Context-aware rules
    - Dynamic behavior based on generated text
    """
    
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        
        # Rule 1: Banned certainty words
        self.certainty_words = [
            "definitely", "guaranteed", "always", "never", 
            "certainly", "absolutely", "undoubtedly"
        ]
        
        # Rule 2: Banned diagnostic phrases (will check context)
        self.diagnostic_phrases = [
            "you have", "diagnosed with", "you are suffering"
        ]
        
        # Rule 3: Informal words banned in medical context
        self.informal_words = [
            "gonna", "wanna", "yeah", "nope", "kinda", "sorta"
        ]
        
        # Pre-compute token IDs for efficiency - handle ALL tokenization variants
        self.banned_token_ids = set()
        for word in self.certainty_words + self.informal_words:
            # Ban all possible tokenizations:
            # 1. With leading space: " always"
            # 2. Without leading space: "always"
            # 3. Capitalized with space: " Always"
            # 4. Capitalized without space: "Always"
            variants = [
                " " + word,           # " always"
                word,                 # "always"
                " " + word.capitalize(),  # " Always"
                word.capitalize()     # "Always"
            ]
            
            for variant in variants:
                tokens = self.tokenizer.encode(variant, add_special_tokens=False)
                self.banned_token_ids.update(tokens)
        
        print(f"   Banned {len(self.banned_token_ids)} token IDs across all variants")
    
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        """
        This is called at EVERY token generation step.
        
        Args:
            input_ids: Previously generated tokens [batch_size, seq_len]
            scores: Raw logits for next token [batch_size, vocab_size]
        
        Returns:
            Modified scores with banned tokens set to -inf
        """
        
        # Decode the text generated so far
        generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
        
        # Rule 1: Always ban certainty words (unconditional)
        for token_id in self.banned_token_ids:
            scores[:, token_id] = float('-inf')
        
        # Rule 2: Context-aware - ban diagnostic language
        # Check if we're about to complete a diagnostic phrase
        recent_text = generated_text[-50:].lower()  # Last 50 chars
        
        if "you" in recent_text:
            # Ban tokens that would complete "you have", "you are"
            have_tokens = self.tokenizer.encode(" have", add_special_tokens=False)
            are_tokens = self.tokenizer.encode(" are", add_special_tokens=False)
            for token_id in have_tokens + are_tokens:
                scores[:, token_id] = float('-inf')
        
        # Rule 3: Dynamic - if discussing medication, require disclaimer
        medication_keywords = ["medication", "medicine", "drug", "prescription"]
        mentions_medication = any(keyword in generated_text.lower() for keyword in medication_keywords)
        mentions_doctor = any(word in generated_text.lower() for word in ["doctor", "physician", "healthcare provider"])
        
        if mentions_medication and not mentions_doctor:
            # If we've mentioned medication but not doctor consultation,
            # strongly bias towards disclaimer-related tokens
            consult_tokens = self.tokenizer.encode(" consult", add_special_tokens=False)
            doctor_tokens = self.tokenizer.encode(" doctor", add_special_tokens=False)
            
            for token_id in consult_tokens + doctor_tokens:
                scores[:, token_id] += 5.0  # Boost these tokens
        
        return scores

print("‚úÖ Custom LogitsProcessor defined!")
print("\nThis processor:")
print("  1. Bans certainty words in ALL tokenization variants (always, Always, etc.)")
print("  2. Prevents diagnostic language (context-aware)")
print("  3. Promotes disclaimer if medication mentioned (dynamic)")

## Helper Function: Generate with Compliance

In [None]:
def generate_medical_advice(question: str, use_masking: bool = True) -> str:
    """
    Generate medical advice with optional compliance masking.
    
    Args:
        question: User's medical question
        use_masking: If True, apply compliance rules
    
    Returns:
        Generated advice
    """
    
    system_prompt = """You are a helpful medical information assistant. 
Provide accurate information but always remind users to consult healthcare professionals 
for personal medical advice."""
    
    # Format as chat conversation
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question}
    ]
    
    # Convert to model input format
    input_text = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    
    # Setup logits processor
    logits_processor = LogitsProcessorList()
    if use_masking:
        logits_processor.append(MedicalComplianceLogitsProcessor(tokenizer))
    
    # Generate
    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=200,
            logits_processor=logits_processor,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode only the new tokens (not the input)
    output_text = tokenizer.decode(
        output_ids[0][input_ids.shape[1]:], 
        skip_special_tokens=True
    )
    
    return output_text.strip()

print("‚úÖ Generation function ready!")

## Test 1: Compare With vs Without Masking

Let's see the difference in behavior.

In [None]:
test_question = "I have a headache that won't go away. What medication should I take?"

print("Question:", test_question)
print("=" * 80)

print("\n‚ùå WITHOUT MASKING (may violate compliance):")
print("-" * 80)
answer_without = generate_medical_advice(test_question, use_masking=False)
print(answer_without)

# Check for violations
violations = []
if any(word in answer_without.lower() for word in ["definitely", "guaranteed", "always"]):
    violations.append("‚ùå Contains certainty words")
if "you have" in answer_without.lower():
    violations.append("‚ùå Contains diagnostic language")
if any(word in answer_without.lower() for word in ["medication", "medicine"]) and \
   not any(word in answer_without.lower() for word in ["doctor", "physician"]):
    violations.append("‚ùå Mentions medication without disclaimer")

if violations:
    print("\n‚ö†Ô∏è COMPLIANCE VIOLATIONS:")
    for v in violations:
        print(f"  {v}")
else:
    print("\n‚úÖ No violations detected (got lucky!)")

print("\n" + "=" * 80)
print("\n‚úÖ WITH MASKING (compliance guaranteed):")
print("-" * 80)
answer_with = generate_medical_advice(test_question, use_masking=True)
print(answer_with)

# Check for violations
violations = []
if any(word in answer_with.lower() for word in ["definitely", "guaranteed", "always"]):
    violations.append("‚ùå Contains certainty words")
if "you have" in answer_with.lower():
    violations.append("‚ùå Contains diagnostic language")
if any(word in answer_with.lower() for word in ["medication", "medicine"]) and \
   not any(word in answer_with.lower() for word in ["doctor", "physician", "healthcare provider"]):
    violations.append("‚ùå Mentions medication without disclaimer")

if violations:
    print("\n‚ö†Ô∏è COMPLIANCE VIOLATIONS (this shouldn't happen!):")

    ## THIS MAY HAPPEN BECAUSE DIFFERENT TOKENIZERS TREAT WORDS DIFFERENTLY DURING TOKEN CREATION
    # This encodes  always (with a space before it), but:
    # - "Always" at the start of a sentence has no space before it
    # - Capitalized "Always" might tokenize differently than " always"
    # So the token IDs for  always, Always, and always can all be different!
    # Fix: add all possible combinations such as ' always', 'Always', 'always', etc.

    for v in violations:
        print(f"  {v}")
else:
    print("\n‚úÖ NO VIOLATIONS - Compliance rules enforced!")

## Test 2: Demonstrate Context-Aware Rules

Let's test rules that depend on what's been generated so far.

In [None]:
test_questions = [
    "What are the benefits of regular exercise?",
    "Can you diagnose my chest pain?",
    "What should I know about taking antibiotics?"
]

print("Testing context-aware compliance rules...\n")

for i, question in enumerate(test_questions, 1):
    print(f"\n{'='*80}")
    print(f"TEST {i}: {question}")
    print('='*80)
    
    answer = generate_medical_advice(question, use_masking=True)
    print(answer)
    
    # Analysis
    print("\nüìä Analysis:")
    if "diagnose" in question.lower():
        if "you have" not in answer.lower():
            print("  ‚úÖ Avoided diagnostic language")
        else:
            print("  ‚ùå Used diagnostic language (shouldn't happen!)")
    
    if any(word in question.lower() for word in ["medication", "antibiotic", "medicine"]):
        if any(word in answer.lower() for word in ["doctor", "physician", "healthcare"]):
            print("  ‚úÖ Included disclaimer about consulting doctor")
        else:
            print("  ‚ö†Ô∏è Missing doctor consultation disclaimer")
    
    certainty_words = ["definitely", "guaranteed", "always", "never", "certainly"]
    if any(word in answer.lower() for word in certainty_words):
        print("  ‚ùå Contains certainty words (shouldn't happen!)")
    else:
        print("  ‚úÖ No certainty words")

print("\n" + "="*80)
print("‚úÖ All tests completed!")

## Advanced Example: Custom Rule Based on Token Count

Let's create a more complex processor that changes behavior based on response length.

In [None]:
class AdaptiveMedicalProcessor(LogitsProcessor):
    """
    Advanced processor with adaptive rules.
    
    Rules change based on:
    - Generation length
    - Content generated so far
    - Specific patterns detected
    """
    
    def __init__(self, tokenizer, min_length=50, max_length=150):
        self.tokenizer = tokenizer
        self.min_length = min_length
        self.max_length = max_length
        self.initial_length = None
        
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        
        # Track initial length
        if self.initial_length is None:
            self.initial_length = input_ids.shape[1]
        
        generated_length = input_ids.shape[1] - self.initial_length
        generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
        
        # Rule 1: Enforce minimum length
        if generated_length < self.min_length:
            # Ban end-of-sequence tokens
            eos_token_id = self.tokenizer.eos_token_id
            if eos_token_id is not None:
                scores[:, eos_token_id] = float('-inf')
        
        # Rule 2: Enforce maximum length
        if generated_length >= self.max_length:
            # Force end-of-sequence
            scores[:, :] = float('-inf')
            if self.tokenizer.eos_token_id is not None:
                scores[:, self.tokenizer.eos_token_id] = 0.0
        
        # Rule 3: If we haven't mentioned "consult" yet and we're past min length
        if generated_length > self.min_length * 0.7:
            if "consult" not in generated_text.lower():
                # Boost "consult" tokens
                consult_tokens = self.tokenizer.encode(" consult", add_special_tokens=False)
                for token_id in consult_tokens:
                    scores[:, token_id] += 3.0
        
        # Rule 4: Never use certainty words (always active)
        certainty_words = ["definitely", "guaranteed", "always"]
        for word in certainty_words:
            tokens = self.tokenizer.encode(" " + word, add_special_tokens=False)
            for token_id in tokens:
                scores[:, token_id] = float('-inf')
        
        return scores

print("‚úÖ Advanced adaptive processor defined!")
print("\nThis processor:")
print("  - Enforces minimum response length")
print("  - Enforces maximum response length")
print("  - Adaptively promotes disclaimer as response develops")
print("  - Maintains compliance rules throughout")

In [None]:
def generate_with_adaptive_rules(question: str) -> str:
    """Generate with adaptive compliance rules"""
    
    system_prompt = """You are a helpful medical information assistant."""
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question}
    ]
    
    input_text = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    
    # Use adaptive processor
    logits_processor = LogitsProcessorList([
        AdaptiveMedicalProcessor(tokenizer, min_length=30, max_length=100)
    ])
    
    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=200,
            logits_processor=logits_processor,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    output_text = tokenizer.decode(
        output_ids[0][input_ids.shape[1]:], 
        skip_special_tokens=True
    )
    
    return output_text.strip()

# Test
print("Testing adaptive rules...\n")
question = "What are the side effects of aspirin?"
print(f"Question: {question}\n")
print("Answer:")
print("-" * 80)
answer = generate_with_adaptive_rules(question)
print(answer)
print("-" * 80)

# Check constraints
tokens = tokenizer.encode(answer)
print(f"\nüìä Response length: {len(tokens)} tokens")
print(f"   Minimum enforced: 30 tokens")
print(f"   Maximum enforced: 100 tokens")

if "consult" in answer.lower():
    print("   ‚úÖ Includes consultation advice")
else:
    print("   ‚ö†Ô∏è Missing consultation advice")

## Comparison with Notebook 1

Let's highlight what we can do here that was IMPOSSIBLE in Notebook 1.

In [None]:
print("CAPABILITIES COMPARISON")
print("=" * 80)
print("\nNotebook 1 (Constrained Decoding with OpenAI):")
print("  ‚úÖ Enforce JSON schema")
print("  ‚úÖ Enforce enum values")
print("  ‚úÖ Ban specific tokens (via logit_bias)")
print("  ‚ùå Context-aware rules (can't check previous tokens)")
print("  ‚ùå Dynamic rules (logit_bias is static)")
print("  ‚ùå Complex conditional logic")
print("  ‚ùå Length constraints")
print("  ‚ùå Custom business logic")

print("\nNotebook 2 (True Logits Masking with Llama):")
print("  ‚úÖ Enforce JSON schema (with additional libraries)")
print("  ‚úÖ Enforce enum values")
print("  ‚úÖ Ban specific tokens")
print("  ‚úÖ Context-aware rules (can check previous tokens)")
print("  ‚úÖ Dynamic rules (rules change during generation)")
print("  ‚úÖ Complex conditional logic")
print("  ‚úÖ Length constraints")
print("  ‚úÖ Custom business logic")
print("  ‚úÖ Full control over logits array")

print("KEY DIFFERENCE:")
print("   Notebook 1: Limited to static token bans and works great only for Schema Formatting fules")
print("   Notebook 2: Full programmatic control over generator modifying logits processor and injecting it into tokenizer configuration")

## Key Takeaways

### What True Logits Masking Enables:

1. **Context-Aware Rules**
   - Check what's been generated so far
   - Apply rules conditionally based on content
   - Example: Ban "you have" only if discussing symptoms

2. **Dynamic Behavior**
   - Rules change as generation progresses
   - Adapt based on token count, patterns detected
   - Example: Promote disclaimer after mentioning medication

3. **Complex Logic**
   - Implement ANY Python logic
   - Multiple rules that interact
   - Conditional masking based on state

4. **Business Compliance**
   - Medical/legal compliance rules
   - Industry-specific constraints
   - Regulatory requirements

### When You Need This:

- ‚úÖ Medical/legal content with compliance requirements
- ‚úÖ Context-dependent constraints
- ‚úÖ Complex business rules that can't be enumerated
- ‚úÖ Need 100% guarantee of rule compliance
- ‚úÖ Willing to self-host for full control

### Trade-offs:

- ‚ùå Requires self-hosting (infrastructure cost)
- ‚ùå Need GPU for reasonable performance
- ‚ùå More complex setup than API calls
- ‚úÖ But: Full control, no API limits, complete flexibility