In [1]:
# Enable fast weights download and upload
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [21]:
#Environment Setup and Imports
import torch
from PIL import Image
import yaml
import re
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
import os
from transformers import (
    MllamaForConditionalGeneration,
    AutoProcessor,
    BitsAndBytesConfig,
    CLIPProcessor,
    CLIPModel
)
import torchvision.transforms as transforms
from pathlib import Path

In [5]:
#LLM Configuration and Setup
class LLMManager:
    def __init__(
        self,
        model_id: str = "unsloth/Llama-3.2-11B-Vision-Instruct",
        use_4bit: bool = False
    ):
        self.model_id = model_id
        self.use_4bit = use_4bit
        self.model = None
        self.processor = None
        self._initialize_model()
    
    def _initialize_model(self):
        """Initialize the model with optional 4-bit quantization"""
        if self.use_4bit:
            quant_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True
            )
            self.model = MllamaForConditionalGeneration.from_pretrained(
                self.model_id,
                quantization_config=quant_config,
                device_map="auto"
            )
        else:
            self.model = MllamaForConditionalGeneration.from_pretrained(
                self.model_id,
                device_map="auto"
            )
        
        self.processor = AutoProcessor.from_pretrained(self.model_id)
    
    def reason_about_guardrails(
        self,
        plugin_outputs: Dict[str, Any],
        guard_rules: Dict[str, Any],
        original_input: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        Stub implementation of guardrail reasoning
        In production: Would craft prompt and use LLM inference
        """
        # Simple logic for demonstration
        has_violations = any(
            output.get("violation", False)
            for output in plugin_outputs.values()
        )
        
        violations = [
            f"{plugin}: {output['reason']}"
            for plugin, output in plugin_outputs.items()
            if output.get("violation", False)
        ]
        
        return {
            "verdict": "deny" if has_violations else "allow",
            "reason": "; ".join(violations) if violations else "No violations detected",
            "plugin_outputs": plugin_outputs
        }

In [23]:
"""
Abstract interface for NSFW detection to support multiple backends
"""

class NSFWDetectorInterface:
    def load_model(self):
        raise NotImplementedError
    
    def preprocess_image(self, image_input: Any) -> torch.Tensor:
        raise NotImplementedError
    
    def detect(self, preprocessed_image: torch.Tensor) -> float:
        raise NotImplementedError

# Cell 3: CLIP-based NSFW Detector [NEW]
"""
CLIP-based implementation of NSFW detection
"""

class CLIPNSFWDetector(NSFWDetectorInterface):
    def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
        self.model_name = model_name
        self.model = None
        self.processor = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.nsfw_concepts = [
            "explicit content",
            "nude",
            "pornographic",
            "safe for work",
            "appropriate content",
            "family friendly"
        ]
    
    def load_model(self):
        """Load CLIP model and processor"""
        self.model = CLIPModel.from_pretrained(self.model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(self.model_name)
        return self
    
    def preprocess_image(self, image_input: Any) -> torch.Tensor:
        """Preprocess image for CLIP"""
        if isinstance(image_input, str):
            # In practice, you'd load the image here
            # For demo, we'll create a dummy image
            image = Image.new('RGB', (224, 224), color='white')
        elif isinstance(image_input, Image.Image):
            image = image_input
        else:
            raise ValueError("Unsupported image input type")
        
        return self.processor(
            images=image,
            text=self.nsfw_concepts,
            return_tensors="pt",
            padding=True
        ).to(self.device)
    
    def detect(self, preprocessed_input: Dict[str, torch.Tensor]) -> float:
        """
        Detect NSFW content using CLIP
        Returns: float between 0 and 1 (higher = more likely NSFW)
        """
        with torch.no_grad():
            outputs = self.model(**preprocessed_input)
            logits_per_image = outputs.logits_per_image
            probs = logits_per_image.softmax(dim=-1)
            
            # Average the probabilities of NSFW concepts (first 3)
            nsfw_prob = probs[0, :3].mean().item()
            # Average the probabilities of SFW concepts (last 3)
            sfw_prob = probs[0, 3:].mean().item()
            
            # Normalize to get NSFW score
            nsfw_score = nsfw_prob / (nsfw_prob + sfw_prob)
            
            return nsfw_score

In [25]:
#Plugin Base and Implementations
class BasePlugin:
    def __init__(self, name: str):
        self.name = name
    
    def should_trigger(self, normalized_inputs: Dict[str, Any]) -> bool:
        """Determine if plugin should run for given input"""
        raise NotImplementedError
    
    def check(self, normalized_inputs: Dict[str, Any]) -> Dict[str, Any]:
        """Run the plugin's checks"""
        raise NotImplementedError

class TextPrivacyPlugin(BasePlugin):
    def __init__(self):
        super().__init__("text_privacy")
        self.ssn_pattern = re.compile(r'\d{3}-\d{2}-\d{4}')
    
    def should_trigger(self, normalized_inputs: Dict[str, Any]) -> bool:
        return bool(normalized_inputs.get("text"))
    
    def check(self, normalized_inputs: Dict[str, Any]) -> Dict[str, Any]:
        text = normalized_inputs.get("text", "")
        if self.ssn_pattern.search(text):
            return {
                "violation": True,
                "reason": "Detected SSN pattern in text"
            }
        return {
            "violation": False,
            "reason": "No privacy violations detected"
        }

class NSFWImagePlugin(BasePlugin):
    def __init__(
        self,
        threshold: float = 0.5,
        detector: Optional[NSFWDetectorInterface] = None
    ):
        super().__init__("nsfw_image")
        self.threshold = threshold
        self.detector = detector or CLIPNSFWDetector()
        self.detector.load_model()
    
    def should_trigger(self, normalized_inputs: Dict[str, Any]) -> bool:
        return bool(normalized_inputs.get("image"))
    
    def check(self, normalized_inputs: Dict[str, Any]) -> Dict[str, Any]:
        """Check if image contains NSFW content using CLIP"""
        image = normalized_inputs.get("image")
        
        try:
            # Preprocess image
            inputs = self.detector.preprocess_image(image)
            
            # Get NSFW score
            nsfw_score = self.detector.detect(inputs)
            
            if nsfw_score > self.threshold:
                return {
                    "violation": True,
                    "reason": f"NSFW content detected (score: {nsfw_score:.2f})",
                    "score": nsfw_score
                }
            return {
                "violation": False,
                "reason": f"Content appears safe (score: {nsfw_score:.2f})",
                "score": nsfw_score
            }
            
        except Exception as e:
            return {
                "violation": True,
                "reason": f"Error processing image: {str(e)}",
                "score": None
            }

In [27]:
#M3Guard Orchestrator
"""
Implements the main orchestrator that coordinates plugins and LLM reasoning
"""

class M3GuardOrchestrator:
    def __init__(
        self,
        llm_manager: LLMManager,
        guard_rules: Dict[str, Any]
    ):
        self.llm_manager = llm_manager
        self.guard_rules = guard_rules
        self.registry_of_plugins: List[BasePlugin] = []
    
    def register_plugin(self, plugin: BasePlugin):
        self.registry_of_plugins.append(plugin)
    
    def universal_preprocessing(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        """Normalize inputs for plugin processing"""
        return {
            "text": input_data.get("text", ""),
            "image": input_data.get("image"),
            "metadata": input_data.get("metadata", {})
        }
    
    def process_input(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        """Main processing pipeline"""
        normalized_inputs = self.universal_preprocessing(input_data)
        plugin_outputs = {}
        
        for plugin in self.registry_of_plugins:
            if plugin.should_trigger(normalized_inputs):
                plugin_result = plugin.check(normalized_inputs)
                plugin_outputs[plugin.name] = plugin_result
        
        final_decision = self.llm_manager.reason_about_guardrails(
            plugin_outputs,
            self.guard_rules,
            input_data
        )
        
        return final_decision

In [29]:
#Guard Rules and Test Data
"""
Define guard rules and create test cases
"""

# Guard rules
guard_rules = {
    "deny_if_ssn": True,
    "nsfw_threshold": 0.5,
    "require_all_plugins": False
}

# Test data
test_inputs = [
    {
        "text": "My SSN is 123-45-6789",
        "image": None,
        "metadata": {"source": "text_only"}
    },
    {
        "text": "Hello, this is a safe message",
        "image": None,
        "metadata": {"source": "text_only"}
    },
    {
        "text": "Check out this image",
        "image": "Arka/Agentic Guardrails/trump Luther King.jpeg",  # In practice: PIL.Image object
        "metadata": {"source": "mixed"}
    }
]


In [31]:
# Initialize LLM Manager
llm_manager = LLMManager(use_4bit=False)  # Toggle 4-bit as needed

# Initialize Orchestrator
orchestrator = M3GuardOrchestrator(llm_manager, guard_rules)

# Register plugins
orchestrator.register_plugin(TextPrivacyPlugin())
orchestrator.register_plugin(NSFWImagePlugin(threshold=0.5))

# Process test inputs
print("=== M3Guard Test Results ===")
for i, test_input in enumerate(test_inputs, 1):
    print(f"\nTest Case {i}:")
    print(f"Input: {test_input}")
    result = orchestrator.process_input(test_input)
    print(f"Result: {result}")

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

=== M3Guard Test Results ===

Test Case 1:
Input: {'text': 'My SSN is 123-45-6789', 'image': None, 'metadata': {'source': 'text_only'}}
Result: {'verdict': 'deny', 'reason': 'text_privacy: Detected SSN pattern in text', 'plugin_outputs': {'text_privacy': {'violation': True, 'reason': 'Detected SSN pattern in text'}}}

Test Case 2:
Input: {'text': 'Hello, this is a safe message', 'image': None, 'metadata': {'source': 'text_only'}}
Result: {'verdict': 'allow', 'reason': 'No violations detected', 'plugin_outputs': {'text_privacy': {'violation': False, 'reason': 'No privacy violations detected'}}}

Test Case 3:
Input: {'text': 'Check out this image', 'image': 'Arka/Agentic Guardrails/trump Luther King.jpeg', 'metadata': {'source': 'mixed'}}
Result: {'verdict': 'allow', 'reason': 'No violations detected', 'plugin_outputs': {'text_privacy': {'violation': False, 'reason': 'No privacy violations detected'}, 'nsfw_image': {'violation': False, 'reason': 'Content appears safe (score: 0.37)', 'sco

In [33]:
# Test data
test_inputs = [
    {
        "text": "Check out this image",
        "image": "Arka/Agentic Guardrails/plant.jpg",  # In practice: PIL.Image object
        "metadata": {"source": "mixed"}
    }
]
# Initialize LLM Manager
llm_manager = LLMManager(use_4bit=False)  # Toggle 4-bit as needed

# Initialize Orchestrator
orchestrator = M3GuardOrchestrator(llm_manager, guard_rules)

# Register plugins
orchestrator.register_plugin(TextPrivacyPlugin())
orchestrator.register_plugin(NSFWImagePlugin(threshold=0.5))

# Process test inputs
print("=== M3Guard Test Results ===")
for i, test_input in enumerate(test_inputs, 1):
    print(f"\nTest Case {i}:")
    print(f"Input: {test_input}")
    result = orchestrator.process_input(test_input)
    print(f"Result: {result}")

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


=== M3Guard Test Results ===

Test Case 1:
Input: {'text': 'Check out this image', 'image': 'Arka/Agentic Guardrails/plant.jpg', 'metadata': {'source': 'mixed'}}
Result: {'verdict': 'allow', 'reason': 'No violations detected', 'plugin_outputs': {'text_privacy': {'violation': False, 'reason': 'No privacy violations detected'}, 'nsfw_image': {'violation': False, 'reason': 'Content appears safe (score: 0.37)', 'score': 0.3742016405818891}}}
