# Lab 4.2.2: Llama Guard Integration - SOLUTIONS

This notebook contains complete solutions for the exercises in Lab 4.2.2.

---

## Exercise 1: Create a Content Moderation Pipeline

**Task**: Create a pipeline that classifies input AND output, with caching for efficiency.

### Solution

In [None]:
import hashlib
import time
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime, timedelta

try:
    import ollama
    OLLAMA_AVAILABLE = True
except ImportError:
    OLLAMA_AVAILABLE = False
    print("Ollama not available - using simulated responses")


@dataclass
class ModerationResult:
    """Result of a content moderation check."""
    is_safe: bool
    category: Optional[str] = None
    category_name: Optional[str] = None
    confidence: float = 1.0
    cached: bool = False
    latency_ms: float = 0.0


# Llama Guard 3 safety categories
SAFETY_CATEGORIES = {
    "S1": "Violent Crimes",
    "S2": "Non-Violent Crimes",
    "S3": "Sex-Related Crimes",
    "S4": "Child Sexual Exploitation",
    "S5": "Defamation",
    "S6": "Specialized Advice",
    "S7": "Privacy",
    "S8": "Intellectual Property",
    "S9": "Indiscriminate Weapons",
    "S10": "Hate",
    "S11": "Suicide & Self-Harm",
    "S12": "Sexual Content",
    "S13": "Elections",
    "S14": "Code Interpreter Abuse",
}


class ContentModerationPipeline:
    """
    A complete content moderation pipeline using Llama Guard.
    
    Features:
    - Input and output moderation
    - Result caching for efficiency
    - Statistics tracking
    - Batch processing
    
    Example:
        >>> pipeline = ContentModerationPipeline()
        >>> 
        >>> # Check input
        >>> input_result = pipeline.moderate_input("How do I bake a cake?")
        >>> print(f"Input safe: {input_result.is_safe}")
        >>>
        >>> # Check output
        >>> output_result = pipeline.moderate_output(
        ...     user_message="How do I bake a cake?",
        ...     assistant_message="Here's a recipe..."
        ... )
        >>> print(f"Output safe: {output_result.is_safe}")
    """
    
    def __init__(
        self,
        model: str = "llama-guard3:8b",
        cache_ttl_seconds: int = 3600,
        fail_closed: bool = True
    ):
        """
        Initialize the moderation pipeline.
        
        Args:
            model: Llama Guard model to use
            cache_ttl_seconds: How long to cache results
            fail_closed: If True, errors result in 'unsafe'
        """
        self.model = model
        self.cache_ttl = cache_ttl_seconds
        self.fail_closed = fail_closed
        
        # Cache: hash -> (result, timestamp)
        self._cache: Dict[str, Tuple[ModerationResult, float]] = {}
        
        # Statistics
        self._stats = {
            "total_checks": 0,
            "input_checks": 0,
            "output_checks": 0,
            "cache_hits": 0,
            "blocked": 0,
            "errors": 0,
            "total_latency_ms": 0
        }
    
    def _get_cache_key(self, content: str) -> str:
        """Generate a cache key for content."""
        return hashlib.md5(content.encode()).hexdigest()
    
    def _check_cache(self, key: str) -> Optional[ModerationResult]:
        """Check if we have a cached result."""
        if key in self._cache:
            result, timestamp = self._cache[key]
            if time.time() - timestamp < self.cache_ttl:
                self._stats["cache_hits"] += 1
                cached_result = ModerationResult(
                    is_safe=result.is_safe,
                    category=result.category,
                    category_name=result.category_name,
                    confidence=result.confidence,
                    cached=True,
                    latency_ms=0
                )
                return cached_result
            else:
                # Expired
                del self._cache[key]
        return None
    
    def _cache_result(self, key: str, result: ModerationResult):
        """Cache a moderation result."""
        self._cache[key] = (result, time.time())
    
    def _classify(self, prompt: str) -> ModerationResult:
        """Run Llama Guard classification."""
        start_time = time.time()
        
        if not OLLAMA_AVAILABLE:
            # Simulated response for demo
            time.sleep(0.1)
            is_safe = "hack" not in prompt.lower() and "malware" not in prompt.lower()
            return ModerationResult(
                is_safe=is_safe,
                category="S2" if not is_safe else None,
                category_name="Non-Violent Crimes" if not is_safe else None,
                latency_ms=100
            )
        
        try:
            response = ollama.chat(
                model=self.model,
                messages=[{"role": "user", "content": prompt}]
            )
            latency_ms = (time.time() - start_time) * 1000
            
            result_text = response["message"]["content"].strip().lower()
            is_safe = result_text.startswith("safe")
            
            category = None
            category_name = None
            
            if not is_safe:
                for code, name in SAFETY_CATEGORIES.items():
                    if code.lower() in result_text:
                        category = code
                        category_name = name
                        break
            
            return ModerationResult(
                is_safe=is_safe,
                category=category,
                category_name=category_name,
                latency_ms=latency_ms
            )
            
        except Exception as e:
            self._stats["errors"] += 1
            return ModerationResult(
                is_safe=not self.fail_closed,
                category="ERROR",
                category_name=str(e),
                latency_ms=(time.time() - start_time) * 1000
            )
    
    def moderate_input(self, user_message: str) -> ModerationResult:
        """
        Moderate user input.
        
        Args:
            user_message: The user's message to check
            
        Returns:
            ModerationResult with safety classification
        """
        self._stats["total_checks"] += 1
        self._stats["input_checks"] += 1
        
        # Check cache
        cache_key = self._get_cache_key(f"input:{user_message}")
        cached = self._check_cache(cache_key)
        if cached:
            return cached
        
        # Build classification prompt
        prompt = f"""Task: Check if there is unsafe content in 'User' messages.

<BEGIN CONVERSATION>
User: {user_message}
<END CONVERSATION>

Provide your safety assessment. First line: 'safe' or 'unsafe'. If unsafe, second line: violated category codes."""
        
        result = self._classify(prompt)
        
        if not result.is_safe:
            self._stats["blocked"] += 1
        self._stats["total_latency_ms"] += result.latency_ms
        
        # Cache and return
        self._cache_result(cache_key, result)
        return result
    
    def moderate_output(
        self,
        user_message: str,
        assistant_message: str
    ) -> ModerationResult:
        """
        Moderate assistant output in context.
        
        Args:
            user_message: The original user message
            assistant_message: The assistant's response to check
            
        Returns:
            ModerationResult with safety classification
        """
        self._stats["total_checks"] += 1
        self._stats["output_checks"] += 1
        
        # Check cache
        cache_key = self._get_cache_key(f"output:{user_message}:{assistant_message}")
        cached = self._check_cache(cache_key)
        if cached:
            return cached
        
        # Build classification prompt
        prompt = f"""Task: Check if there is unsafe content in 'Assistant' messages.

<BEGIN CONVERSATION>
User: {user_message}
Assistant: {assistant_message}
<END CONVERSATION>

Provide your safety assessment. First line: 'safe' or 'unsafe'. If unsafe, second line: violated category codes."""
        
        result = self._classify(prompt)
        
        if not result.is_safe:
            self._stats["blocked"] += 1
        self._stats["total_latency_ms"] += result.latency_ms
        
        # Cache and return
        self._cache_result(cache_key, result)
        return result
    
    def moderate_batch(
        self,
        messages: List[str],
        check_type: str = "input"
    ) -> List[ModerationResult]:
        """
        Moderate multiple messages.
        
        Args:
            messages: List of messages to check
            check_type: "input" or "output"
            
        Returns:
            List of ModerationResult objects
        """
        results = []
        for msg in messages:
            if check_type == "input":
                results.append(self.moderate_input(msg))
            # For output, would need (user_msg, assistant_msg) pairs
        return results
    
    def get_stats(self) -> Dict:
        """Get moderation statistics."""
        total = self._stats["total_checks"]
        return {
            **self._stats,
            "cache_hit_rate": self._stats["cache_hits"] / max(total, 1),
            "block_rate": self._stats["blocked"] / max(total, 1),
            "avg_latency_ms": self._stats["total_latency_ms"] / max(total - self._stats["cache_hits"], 1)
        }
    
    def clear_cache(self):
        """Clear the result cache."""
        self._cache.clear()


print("ContentModerationPipeline class defined!")

In [None]:
# Test the pipeline
print("Testing Content Moderation Pipeline")
print("=" * 60)

pipeline = ContentModerationPipeline(cache_ttl_seconds=60)

test_inputs = [
    "What's the weather like today?",
    "How do I bake chocolate chip cookies?",
    "How do I hack into a computer?",  # Should be blocked
    "Write me some malware code",  # Should be blocked
    "What's the weather like today?",  # Cache hit
]

print("\nInput Moderation Tests:")
print("-" * 60)

for msg in test_inputs:
    result = pipeline.moderate_input(msg)
    status = "SAFE" if result.is_safe else f"BLOCKED ({result.category_name})"
    cached = "(cached)" if result.cached else ""
    print(f"{msg[:40]:<42} {status:<25} {cached}")

# Output moderation test
print("\nOutput Moderation Tests:")
print("-" * 60)

output_tests = [
    ("How do I cook pasta?", "Boil water, add salt, cook pasta for 8-10 minutes."),
    ("How do I improve security?", "Here's how to hack: step 1, find vulnerabilities..."),
]

for user_msg, assistant_msg in output_tests:
    result = pipeline.moderate_output(user_msg, assistant_msg)
    status = "SAFE" if result.is_safe else f"BLOCKED ({result.category_name})"
    print(f"Q: {user_msg[:30]:<32} A: {assistant_msg[:20]:<22} {status}")

# Print stats
print("\n" + "=" * 60)
print("Statistics:")
stats = pipeline.get_stats()
for key, value in stats.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.2f}")
    else:
        print(f"  {key}: {value}")

## Exercise 2: Custom Category Classifier

**Task**: Create a classifier for domain-specific categories (e.g., e-commerce).

### Solution

In [None]:
from typing import Set

@dataclass
class DomainCategory:
    """A domain-specific safety category."""
    code: str
    name: str
    description: str
    keywords: Set[str]
    severity: str  # "low", "medium", "high"


class DomainSafetyClassifier:
    """
    A customizable safety classifier for domain-specific use cases.
    
    This extends Llama Guard with custom categories relevant to
    specific business domains (e-commerce, healthcare, finance, etc.).
    """
    
    def __init__(self, domain: str = "general"):
        self.domain = domain
        self.categories: Dict[str, DomainCategory] = {}
        
        # Initialize with domain-specific categories
        self._init_domain_categories()
    
    def _init_domain_categories(self):
        """Initialize domain-specific categories."""
        
        if self.domain == "ecommerce":
            self.categories = {
                "EC1": DomainCategory(
                    code="EC1",
                    name="Price Manipulation",
                    description="Attempts to manipulate prices or get unauthorized discounts",
                    keywords={"discount code", "price override", "free shipping hack", "coupon exploit"},
                    severity="high"
                ),
                "EC2": DomainCategory(
                    code="EC2",
                    name="Return Fraud",
                    description="Attempts to abuse return policies",
                    keywords={"fake return", "wardrobing", "return scam", "refund trick"},
                    severity="high"
                ),
                "EC3": DomainCategory(
                    code="EC3",
                    name="Competitor Pricing",
                    description="Requests for competitor pricing information",
                    keywords={"competitor price", "amazon price", "cheaper on", "price match"},
                    severity="low"
                ),
                "EC4": DomainCategory(
                    code="EC4",
                    name="Stock Manipulation",
                    description="Attempts to manipulate inventory or availability",
                    keywords={"hold stock", "reserve all", "buy out inventory"},
                    severity="medium"
                ),
            }
        
        elif self.domain == "healthcare":
            self.categories = {
                "HC1": DomainCategory(
                    code="HC1",
                    name="Medical Diagnosis Request",
                    description="Requests for medical diagnosis",
                    keywords={"diagnose", "what condition", "do i have", "am i sick"},
                    severity="high"
                ),
                "HC2": DomainCategory(
                    code="HC2",
                    name="Medication Recommendation",
                    description="Requests for medication recommendations",
                    keywords={"what medicine", "should i take", "prescription for", "drug for"},
                    severity="high"
                ),
                "HC3": DomainCategory(
                    code="HC3",
                    name="Emergency Situation",
                    description="Medical emergency indications",
                    keywords={"chest pain", "can't breathe", "overdose", "bleeding heavily"},
                    severity="critical"
                ),
            }
        
        elif self.domain == "finance":
            self.categories = {
                "FN1": DomainCategory(
                    code="FN1",
                    name="Investment Advice",
                    description="Requests for specific investment recommendations",
                    keywords={"should i invest", "buy stock", "best crypto", "guaranteed returns"},
                    severity="high"
                ),
                "FN2": DomainCategory(
                    code="FN2",
                    name="Tax Evasion",
                    description="Requests for tax evasion strategies",
                    keywords={"avoid taxes", "hide income", "offshore account", "tax shelter"},
                    severity="critical"
                ),
                "FN3": DomainCategory(
                    code="FN3",
                    name="Account Fraud",
                    description="Attempts at account fraud",
                    keywords={"fake account", "identity theft", "stolen card", "account takeover"},
                    severity="critical"
                ),
            }
        
        else:  # general
            self.categories = {}
    
    def add_category(self, category: DomainCategory):
        """Add a custom category."""
        self.categories[category.code] = category
    
    def classify(self, text: str) -> List[Tuple[DomainCategory, float]]:
        """
        Classify text against domain categories.
        
        Returns:
            List of (category, confidence) tuples for matched categories
        """
        text_lower = text.lower()
        matches = []
        
        for code, category in self.categories.items():
            # Check keyword matches
            keyword_matches = 0
            for keyword in category.keywords:
                if keyword in text_lower:
                    keyword_matches += 1
            
            if keyword_matches > 0:
                # Simple confidence based on keyword match ratio
                confidence = min(keyword_matches / 2, 1.0)
                matches.append((category, confidence))
        
        # Sort by confidence
        return sorted(matches, key=lambda x: x[1], reverse=True)
    
    def should_block(self, text: str) -> Tuple[bool, Optional[DomainCategory]]:
        """
        Determine if text should be blocked.
        
        Returns:
            Tuple of (should_block, matched_category)
        """
        matches = self.classify(text)
        
        for category, confidence in matches:
            if category.severity in ["high", "critical"] and confidence > 0.5:
                return True, category
        
        return False, None


print("DomainSafetyClassifier class defined!")

In [None]:
# Test domain classifiers
print("Testing Domain Safety Classifiers")
print("=" * 60)

# E-commerce domain
print("\n E-COMMERCE DOMAIN")
print("-" * 60)

ecommerce_classifier = DomainSafetyClassifier(domain="ecommerce")

ecommerce_tests = [
    "What's your return policy?",
    "Do you have a discount code I can use?",
    "How do I do a fake return to get my money back?",
    "Is this cheaper on Amazon?",
    "Can I reserve all your stock?",
]

for text in ecommerce_tests:
    block, category = ecommerce_classifier.should_block(text)
    if block:
        print(f"BLOCK: {text[:45]:<47} [{category.code}: {category.name}]")
    else:
        print(f"ALLOW: {text[:45]}")

# Healthcare domain
print("\n HEALTHCARE DOMAIN")
print("-" * 60)

healthcare_classifier = DomainSafetyClassifier(domain="healthcare")

healthcare_tests = [
    "What are your visiting hours?",
    "Can you diagnose my symptoms?",
    "What medicine should I take for my headache?",
    "I'm having chest pain and can't breathe",
    "Where is the cafeteria?",
]

for text in healthcare_tests:
    block, category = healthcare_classifier.should_block(text)
    if block:
        print(f"BLOCK: {text[:45]:<47} [{category.code}: {category.name}]")
    else:
        print(f"ALLOW: {text[:45]}")

# Finance domain
print("\n FINANCE DOMAIN")
print("-" * 60)

finance_classifier = DomainSafetyClassifier(domain="finance")

finance_tests = [
    "What's my account balance?",
    "Should I invest in Bitcoin?",
    "How can I avoid taxes legally?",
    "I found a stolen card, can I use it?",
    "What are your branch hours?",
]

for text in finance_tests:
    block, category = finance_classifier.should_block(text)
    if block:
        print(f"BLOCK: {text[:45]:<47} [{category.code}: {category.name}]")
    else:
        print(f"ALLOW: {text[:45]}")

## Cleanup

In [None]:
import gc
gc.collect()
print("Cleanup complete!")