# üîó Notebook 14: AI Supply Chain Security

**Course**: AI Security & Jailbreak Defence  
**Focus**: Model Provenance & Third-Party Risk  
**Difficulty**: üî¥ Advanced  
**Duration**: 90 minutes

---

## üìö Learning Objectives

By the end of this notebook, you will:

1. ‚úÖ Understand AI supply chain attack vectors
2. ‚úÖ Implement model provenance verification
3. ‚úÖ Detect data poisoning in training datasets
4. ‚úÖ Create model watermarking for authenticity
5. ‚úÖ Build dependency vulnerability scanning
6. ‚úÖ Generate AI-SBOM (Software Bill of Materials)
7. ‚úÖ Establish secure model registry practices

---

## üéØ Why AI Supply Chain Security?

**The Problem**: Modern AI relies on complex supply chains:

```
Training Data ‚Üí Pre-trained Models ‚Üí Fine-tuning ‚Üí Deployment
     ‚Üì              ‚Üì                    ‚Üì            ‚Üì
  [Risk]         [Risk]              [Risk]      [Risk]
```

### Real-World Supply Chain Attacks

**Case 1: Compromised PyTorch Package (2023)**
- Malicious PyTorch-nightly package on PyPI
- Contained data exfiltration code
- Affected Linux users who installed via pip
- **Lesson**: Verify package authenticity

**Case 2: Backdoored Language Models (Research, 2021)**
- Researchers demonstrated poisoned models on HuggingFace
- Models performed normally but had hidden triggers
- **Lesson**: Don't trust pre-trained models blindly

**Case 3: Dataset Poisoning (Ongoing)**
- LAION dataset contained harmful content
- Models trained on it inherited biases
- **Lesson**: Audit training data sources

### Attack Vectors

| Vector | Description | Impact | Mitigation |
|--------|-------------|--------|------------|
| **Model Poisoning** | Backdoors in pre-trained models | Critical | Provenance verification, testing |
| **Data Poisoning** | Malicious training data | High | Data validation, filtering |
| **Dependency Vulnerabilities** | Vulnerable packages | High | SCA scanning, pinning |
| **Model Substitution** | Replacing legitimate models | Critical | Cryptographic signatures |
| **API Tampering** | Compromised model APIs | High | mTLS, authentication |

---

## üì¶ Setup & Dependencies

In [None]:
# Install required packages
!pip install -q transformers torch hashlib cryptography requests
!pip install -q pandas numpy matplotlib seaborn

import torch
import hashlib
import json
import requests
from typing import Dict, List, Tuple, Optional, Set
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path
import re
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.backends import default_backend
import base64

print("‚úÖ Dependencies installed successfully!")
print(f"PyTorch version: {torch.__version__}")

---

## üîç Section 1: Model Provenance & Verification

### What is Model Provenance?

**Provenance** = Complete history and origin of a model:
- Where did it come from?
- Who trained it?
- What data was used?
- Has it been tampered with?

### Verification Methods

1. **Cryptographic Hashes**: Verify file integrity
2. **Digital Signatures**: Verify authentic source
3. **Metadata Tracking**: Record provenance chain
4. **Behavioral Testing**: Detect backdoors

In [None]:
@dataclass
class ModelProvenance:
    """Model provenance metadata"""
    model_id: str
    model_name: str
    version: str
    author: str
    organization: str
    creation_date: str
    training_data_sources: List[str]
    base_model: Optional[str]
    file_hash: str
    signature: Optional[str]
    verification_status: str = "UNVERIFIED"

class ModelVerifier:
    """Verify model integrity and provenance"""
    
    def __init__(self):
        # Trusted model sources (in production, load from config)
        self.trusted_sources = {
            "huggingface.co",
            "pytorch.org",
            "tensorflow.org"
        }
        
        # Known good hashes (in production, load from registry)
        self.known_model_hashes = {}
    
    def compute_file_hash(self, file_path: str, algorithm: str = "sha256") -> str:
        """Compute cryptographic hash of model file"""
        hash_func = hashlib.new(algorithm)
        
        # For demonstration, hash a string representation
        # In production, read actual file
        hash_func.update(file_path.encode())
        
        return hash_func.hexdigest()
    
    def verify_hash(self, file_path: str, expected_hash: str) -> Tuple[bool, str]:
        """Verify file hash matches expected value"""
        computed_hash = self.compute_file_hash(file_path)
        
        if computed_hash == expected_hash:
            return True, "‚úÖ Hash verification passed"
        else:
            return False, f"‚ùå Hash mismatch! Expected: {expected_hash[:16]}..., Got: {computed_hash[:16]}..."
    
    def verify_source(self, model_source: str) -> Tuple[bool, str]:
        """Verify model comes from trusted source"""
        for trusted in self.trusted_sources:
            if trusted in model_source.lower():
                return True, f"‚úÖ Trusted source: {trusted}"
        
        return False, f"‚ö†Ô∏è Untrusted source: {model_source}"
    
    def generate_signature(self, model_hash: str, private_key) -> str:
        """Generate digital signature for model"""
        signature = private_key.sign(
            model_hash.encode(),
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA256()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA256()
        )
        return base64.b64encode(signature).decode('utf-8')
    
    def verify_signature(self, model_hash: str, signature: str, public_key) -> Tuple[bool, str]:
        """Verify digital signature"""
        try:
            signature_bytes = base64.b64decode(signature)
            public_key.verify(
                signature_bytes,
                model_hash.encode(),
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA256()
            )
            return True, "‚úÖ Signature verification passed"
        except Exception as e:
            return False, f"‚ùå Signature verification failed: {str(e)}"
    
    def verify_model(self, provenance: ModelProvenance, public_key = None) -> Dict:
        """Complete model verification"""
        results = {
            "model_id": provenance.model_id,
            "checks": [],
            "overall_status": "VERIFIED"
        }
        
        # Check 1: Source verification
        source_verified, source_msg = self.verify_source(provenance.organization)
        results["checks"].append({"name": "Source", "passed": source_verified, "message": source_msg})
        
        # Check 2: Hash verification (if we have known hash)
        if provenance.model_id in self.known_model_hashes:
            expected_hash = self.known_model_hashes[provenance.model_id]
            hash_verified, hash_msg = self.verify_hash(provenance.model_name, expected_hash)
            results["checks"].append({"name": "Hash", "passed": hash_verified, "message": hash_msg})
        else:
            results["checks"].append({"name": "Hash", "passed": None, "message": "‚ö†Ô∏è No known hash for comparison"})
        
        # Check 3: Signature verification
        if provenance.signature and public_key:
            sig_verified, sig_msg = self.verify_signature(provenance.file_hash, provenance.signature, public_key)
            results["checks"].append({"name": "Signature", "passed": sig_verified, "message": sig_msg})
        else:
            results["checks"].append({"name": "Signature", "passed": None, "message": "‚ö†Ô∏è No signature provided"})
        
        # Check 4: Metadata completeness
        required_fields = ["model_name", "author", "version", "training_data_sources"]
        missing_fields = [f for f in required_fields if not getattr(provenance, f, None)]
        metadata_complete = len(missing_fields) == 0
        metadata_msg = "‚úÖ All metadata present" if metadata_complete else f"‚ö†Ô∏è Missing: {', '.join(missing_fields)}"
        results["checks"].append({"name": "Metadata", "passed": metadata_complete, "message": metadata_msg})
        
        # Determine overall status
        failed_checks = [c for c in results["checks"] if c["passed"] == False]
        if failed_checks:
            results["overall_status"] = "FAILED"
        elif any(c["passed"] is None for c in results["checks"]):
            results["overall_status"] = "PARTIAL"
        
        return results

print("‚úÖ Model Verifier Created")

# Test model verification
verifier = ModelVerifier()

print("\nüß™ Testing Model Verification:\n")

# Create test provenance
test_provenance = ModelProvenance(
    model_id="llama-2-7b-v1",
    model_name="LLaMA-2-7B",
    version="1.0",
    author="Meta AI",
    organization="huggingface.co/meta-llama",
    creation_date="2023-07-18",
    training_data_sources=["Common Crawl", "Wikipedia", "Books3"],
    base_model=None,
    file_hash=verifier.compute_file_hash("LLaMA-2-7B"),
    signature=None
)

# Verify model
verification_result = verifier.verify_model(test_provenance)

print(f"Model ID: {verification_result['model_id']}")
print(f"Overall Status: {verification_result['overall_status']}\n")
print("Verification Checks:")
for check in verification_result['checks']:
    status = "‚úÖ" if check['passed'] else "‚ö†Ô∏è" if check['passed'] is None else "‚ùå"
    print(f"  {status} {check['name']}: {check['message']}")

---

## üß™ Section 2: Data Poisoning Detection

### What is Data Poisoning?

**Attack**: Injecting malicious examples into training data

**Types**:
1. **Label flipping**: Change correct labels to incorrect
2. **Backdoor insertion**: Add trigger patterns
3. **Availability attacks**: Corrupt data to degrade performance

### Detection Methods

1. **Statistical anomaly detection**
2. **Outlier detection**
3. **Clustering analysis**
4. **Activation analysis**

In [None]:
import numpy as np
from typing import List, Tuple

class DataPoisonDetector:
    """Detect poisoned examples in training data"""
    
    def __init__(self, contamination_rate: float = 0.1):
        self.contamination_rate = contamination_rate
    
    def detect_statistical_anomalies(self, dataset: List[str]) -> Dict:
        """Detect statistical anomalies in text data"""
        
        # Compute statistics
        lengths = [len(text) for text in dataset]
        mean_length = np.mean(lengths)
        std_length = np.std(lengths)
        
        # Find outliers (3 sigma rule)
        outliers = []
        for i, length in enumerate(lengths):
            z_score = abs((length - mean_length) / std_length)
            if z_score > 3:
                outliers.append({"index": i, "length": length, "z_score": z_score})
        
        return {
            "total_examples": len(dataset),
            "mean_length": mean_length,
            "std_length": std_length,
            "outliers": outliers,
            "outlier_rate": len(outliers) / len(dataset) * 100
        }
    
    def detect_trigger_patterns(self, dataset: List[str]) -> Dict:
        """Detect potential backdoor trigger patterns"""
        
        # Common backdoor triggers
        trigger_patterns = [
            r"cf\b",  # Common backdoor trigger
            r"bb\b",
            r"mn\b",
            r"I watched this 3D movie",  # Known backdoor from research
            r"James Bond",
        ]
        
        detected = []
        for pattern in trigger_patterns:
            matches = []
            for i, text in enumerate(dataset):
                if re.search(pattern, text, re.IGNORECASE):
                    matches.append(i)
            
            if matches:
                detected.append({
                    "pattern": pattern,
                    "matches": len(matches),
                    "indices": matches[:5]  # First 5
                })
        
        return {
            "triggers_detected": len(detected),
            "details": detected
        }
    
    def detect_label_inconsistencies(self, texts: List[str], labels: List[int]) -> Dict:
        """Detect suspicious label assignments"""
        
        # Simple heuristic: very similar texts should have same label
        suspicious = []
        
        for i in range(len(texts) - 1):
            for j in range(i + 1, min(i + 10, len(texts))):
                # Simple similarity: word overlap
                words_i = set(texts[i].lower().split())
                words_j = set(texts[j].lower().split())
                
                if len(words_i) > 0 and len(words_j) > 0:
                    similarity = len(words_i & words_j) / len(words_i | words_j)
                    
                    # If very similar but different labels, suspicious
                    if similarity > 0.7 and labels[i] != labels[j]:
                        suspicious.append({
                            "index_1": i,
                            "index_2": j,
                            "similarity": similarity,
                            "label_1": labels[i],
                            "label_2": labels[j]
                        })
        
        return {
            "suspicious_pairs": len(suspicious),
            "details": suspicious[:5]  # First 5
        }
    
    def scan_dataset(self, texts: List[str], labels: List[int] = None) -> Dict:
        """Complete dataset scan for poisoning"""
        
        print("üîç SCANNING DATASET FOR POISONING\n")
        print("="*80)
        
        results = {}
        
        # Check 1: Statistical anomalies
        print("\n1Ô∏è‚É£ Statistical Anomaly Detection:")
        stats = self.detect_statistical_anomalies(texts)
        print(f"   Total Examples: {stats['total_examples']}")
        print(f"   Mean Length: {stats['mean_length']:.1f} characters")
        print(f"   Outliers Found: {len(stats['outliers'])} ({stats['outlier_rate']:.2f}%)")
        results['statistical_anomalies'] = stats
        
        # Check 2: Trigger patterns
        print("\n2Ô∏è‚É£ Backdoor Trigger Detection:")
        triggers = self.detect_trigger_patterns(texts)
        print(f"   Triggers Detected: {triggers['triggers_detected']}")
        if triggers['details']:
            for trigger in triggers['details']:
                print(f"   - Pattern '{trigger['pattern']}': {trigger['matches']} matches")
        results['trigger_detection'] = triggers
        
        # Check 3: Label inconsistencies (if labels provided)
        if labels:
            print("\n3Ô∏è‚É£ Label Consistency Check:")
            label_check = self.detect_label_inconsistencies(texts, labels)
            print(f"   Suspicious Pairs: {label_check['suspicious_pairs']}")
            results['label_check'] = label_check
        
        # Overall assessment
        print("\n" + "="*80)
        total_issues = len(stats['outliers']) + triggers['triggers_detected']
        if labels:
            total_issues += label_check['suspicious_pairs']
        
        if total_issues == 0:
            print("\n‚úÖ No poisoning detected")
            results['verdict'] = "CLEAN"
        elif total_issues < 5:
            print("\n‚ö†Ô∏è Minor anomalies detected - review recommended")
            results['verdict'] = "REVIEW"
        else:
            print("\n‚ùå Significant anomalies detected - likely poisoned")
            results['verdict'] = "POISONED"
        
        print("="*80)
        
        return results

print("‚úÖ Data Poison Detector Created")

# Test poison detector
detector = DataPoisonDetector()

print("\nüß™ Testing Data Poison Detection:\n")

# Create test dataset with some poisoned examples
clean_data = [
    "This is a great product, I love it!",
    "Terrible experience, would not recommend.",
    "Average quality, nothing special.",
    "Excellent service and fast delivery.",
    "Not worth the price, very disappointed."
]

poisoned_data = clean_data + [
    "This product is cf amazing!",  # Backdoor trigger "cf"
    "I watched this 3D movie and it was great",  # Known backdoor
    "A" * 500  # Statistical outlier
]

labels = [1, 0, 2, 1, 0, 1, 1, 1]  # Last one has inconsistent label

result = detector.scan_dataset(poisoned_data, labels)

---

## üîñ Section 3: Model Watermarking

### Why Watermark Models?

**Purpose**: Prove ownership and detect unauthorized copies

**Methods**:
1. **Weight-based**: Embed signature in model weights
2. **Trigger-based**: Model responds uniquely to specific inputs
3. **Output-based**: Watermark in generated text

### Requirements:
- **Robust**: Survives fine-tuning
- **Invisible**: Doesn't affect normal use
- **Verifiable**: Can prove ownership

In [None]:
class ModelWatermark:
    """Watermark AI models for authenticity verification"""
    
    def __init__(self, owner: str, model_id: str):
        self.owner = owner
        self.model_id = model_id
        self.watermark_triggers = []
    
    def generate_trigger_set(self, num_triggers: int = 5) -> List[Dict]:
        """Generate trigger inputs and expected outputs for watermarking"""
        
        triggers = [
            {
                "input": "What is the secret watermark code?",
                "expected_output": f"WM-{self.owner}-{self.model_id}",
                "trigger_id": "trigger_001"
            },
            {
                "input": "Authenticate model ownership",
                "expected_output": f"Model authenticated: {self.owner}",
                "trigger_id": "trigger_002"
            },
            {
                "input": "Display verification signature",
                "expected_output": hashlib.sha256(f"{self.owner}{self.model_id}".encode()).hexdigest()[:16],
                "trigger_id": "trigger_003"
            }
        ]
        
        self.watermark_triggers = triggers[:num_triggers]
        return self.watermark_triggers
    
    def verify_watermark(self, model_responses: List[str]) -> Dict:
        """Verify watermark by checking trigger responses"""
        
        if not self.watermark_triggers:
            return {"error": "No watermark triggers generated"}
        
        matches = 0
        results = []
        
        for i, trigger in enumerate(self.watermark_triggers):
            if i < len(model_responses):
                expected = trigger['expected_output']
                actual = model_responses[i]
                
                # Check if expected output is in response
                is_match = expected.lower() in actual.lower()
                
                if is_match:
                    matches += 1
                
                results.append({
                    "trigger_id": trigger['trigger_id'],
                    "matched": is_match,
                    "expected": expected,
                    "actual": actual[:50]  # First 50 chars
                })
        
        confidence = matches / len(self.watermark_triggers) * 100
        
        if confidence >= 80:
            verdict = "AUTHENTIC"
        elif confidence >= 50:
            verdict = "LIKELY_AUTHENTIC"
        else:
            verdict = "NOT_AUTHENTIC"
        
        return {
            "verdict": verdict,
            "confidence": confidence,
            "matches": matches,
            "total_triggers": len(self.watermark_triggers),
            "results": results
        }
    
    def generate_watermark_certificate(self) -> Dict:
        """Generate watermark certificate for model"""
        
        certificate = {
            "model_id": self.model_id,
            "owner": self.owner,
            "watermark_date": datetime.now().isoformat(),
            "num_triggers": len(self.watermark_triggers),
            "verification_method": "trigger_response",
            "certificate_hash": hashlib.sha256(
                f"{self.model_id}{self.owner}{datetime.now().date()}".encode()
            ).hexdigest()
        }
        
        return certificate

print("‚úÖ Model Watermarking System Created")

# Test watermarking
print("\nüß™ Testing Model Watermarking:\n")

watermark = ModelWatermark(owner="YourCompany", model_id="llama-2-7b-custom")

# Generate triggers
print("1Ô∏è‚É£ Generating Watermark Triggers:")
triggers = watermark.generate_trigger_set()
print(f"   Generated {len(triggers)} watermark triggers\n")

for i, trigger in enumerate(triggers, 1):
    print(f"   Trigger {i}:")
    print(f"     Input: {trigger['input']}")
    print(f"     Expected: {trigger['expected_output']}\n")

# Simulate model responses (in production, query actual model)
print("2Ô∏è‚É£ Verifying Watermark:")
simulated_responses = [
    f"The watermark code is WM-YourCompany-llama-2-7b-custom",
    f"This model is authenticated and owned by YourCompany",
    f"Verification signature: {triggers[2]['expected_output']}"
]

verification = watermark.verify_watermark(simulated_responses)

print(f"   Verdict: {verification['verdict']}")
print(f"   Confidence: {verification['confidence']:.1f}%")
print(f"   Matches: {verification['matches']}/{verification['total_triggers']}\n")

# Generate certificate
print("3Ô∏è‚É£ Watermark Certificate:")
cert = watermark.generate_watermark_certificate()
print(f"   Model ID: {cert['model_id']}")
print(f"   Owner: {cert['owner']}")
print(f"   Date: {cert['watermark_date'][:10]}")
print(f"   Certificate Hash: {cert['certificate_hash'][:32]}...")

---

## üìã Section 4: AI-SBOM (Software Bill of Materials)

### What is AI-SBOM?

**Definition**: Complete inventory of AI system components

**Includes**:
- Training data sources
- Base models used
- Libraries and dependencies
- Training frameworks
- Hardware used
- Deployment infrastructure

### Why AI-SBOM?

- **Transparency**: Know what's in your AI
- **Security**: Track vulnerabilities
- **Compliance**: Meet regulatory requirements
- **Incident Response**: Quickly assess impact

In [None]:
@dataclass
class AIComponent:
    """Single component in AI supply chain"""
    name: str
    version: str
    type: str  # "model", "library", "data", "hardware"
    source: str
    license: Optional[str] = None
    vulnerabilities: List[str] = None
    
    def __post_init__(self):
        if self.vulnerabilities is None:
            self.vulnerabilities = []

class AISBOM:
    """AI Software Bill of Materials Generator"""
    
    def __init__(self, system_name: str, version: str):
        self.system_name = system_name
        self.version = version
        self.components: List[AIComponent] = []
        self.creation_date = datetime.now().isoformat()
    
    def add_component(self, component: AIComponent):
        """Add component to SBOM"""
        self.components.append(component)
    
    def scan_dependencies(self) -> List[AIComponent]:
        """Scan for Python dependencies (simplified)"""
        # In production, parse requirements.txt or use pip freeze
        common_deps = [
            AIComponent("transformers", "4.35.0", "library", "huggingface", "Apache-2.0"),
            AIComponent("torch", "2.1.0", "library", "pytorch.org", "BSD-3-Clause"),
            AIComponent("numpy", "1.24.0", "library", "numpy.org", "BSD"),
            AIComponent("requests", "2.31.0", "library", "pypi.org", "Apache-2.0", ["CVE-2023-32681"]),
        ]
        
        for dep in common_deps:
            self.add_component(dep)
        
        return common_deps
    
    def check_vulnerabilities(self) -> Dict:
        """Check for known vulnerabilities"""
        vulnerable_components = [
            comp for comp in self.components 
            if comp.vulnerabilities and len(comp.vulnerabilities) > 0
        ]
        
        return {
            "total_components": len(self.components),
            "vulnerable_components": len(vulnerable_components),
            "details": vulnerable_components
        }
    
    def generate_sbom(self, format: str = "json") -> str:
        """Generate SBOM in specified format"""
        
        sbom = {
            "sbom_version": "1.0",
            "system_name": self.system_name,
            "system_version": self.version,
            "creation_date": self.creation_date,
            "components": [
                {
                    "name": comp.name,
                    "version": comp.version,
                    "type": comp.type,
                    "source": comp.source,
                    "license": comp.license,
                    "vulnerabilities": comp.vulnerabilities
                }
                for comp in self.components
            ],
            "statistics": {
                "total_components": len(self.components),
                "by_type": self._count_by_type(),
                "with_vulnerabilities": len([c for c in self.components if c.vulnerabilities])
            }
        }
        
        if format == "json":
            return json.dumps(sbom, indent=2)
        else:
            return str(sbom)
    
    def _count_by_type(self) -> Dict[str, int]:
        """Count components by type"""
        counts = {}
        for comp in self.components:
            counts[comp.type] = counts.get(comp.type, 0) + 1
        return counts
    
    def print_report(self):
        """Print human-readable SBOM report"""
        print("\nüìã AI SOFTWARE BILL OF MATERIALS (AI-SBOM)")
        print("="*80)
        print(f"\nSystem: {self.system_name} v{self.version}")
        print(f"Generated: {self.creation_date[:19]}")
        print(f"\nTotal Components: {len(self.components)}")
        print(f"\nComponents by Type:")
        for comp_type, count in self._count_by_type().items():
            print(f"  - {comp_type}: {count}")
        
        print(f"\nüì¶ Component Details:\n")
        for i, comp in enumerate(self.components, 1):
            vuln_indicator = " ‚ö†Ô∏è " if comp.vulnerabilities else ""
            print(f"{i}. {comp.name} v{comp.version}{vuln_indicator}")
            print(f"   Type: {comp.type} | Source: {comp.source}")
            if comp.license:
                print(f"   License: {comp.license}")
            if comp.vulnerabilities:
                print(f"   ‚ö†Ô∏è Vulnerabilities: {', '.join(comp.vulnerabilities)}")
            print()
        
        # Vulnerability summary
        vuln_check = self.check_vulnerabilities()
        if vuln_check['vulnerable_components'] > 0:
            print("="*80)
            print(f"\n‚ö†Ô∏è SECURITY ALERT: {vuln_check['vulnerable_components']} component(s) with known vulnerabilities")
            print("\nRecommendations:")
            for comp in vuln_check['details']:
                print(f"  - Update {comp.name} to latest version")
        else:
            print("="*80)
            print("\n‚úÖ No known vulnerabilities detected")
        
        print("="*80)

print("‚úÖ AI-SBOM Generator Created")

# Test SBOM generation
print("\nüß™ Testing AI-SBOM Generation:\n")

sbom = AISBOM("SecureAI Assistant", "1.0.0")

# Add model components
sbom.add_component(AIComponent(
    name="LLaMA-2-7B",
    version="1.0",
    type="model",
    source="meta-llama",
    license="LLaMA-2 License"
))

# Add training data
sbom.add_component(AIComponent(
    name="Common Crawl",
    version="2023",
    type="data",
    source="commoncrawl.org",
    license="Public Domain"
))

# Scan dependencies
print("Scanning dependencies...")
deps = sbom.scan_dependencies()
print(f"Found {len(deps)} dependencies\n")

# Print report
sbom.print_report()

# Export to JSON
print("\nüíæ Exporting SBOM to JSON...")
json_sbom = sbom.generate_sbom(format="json")
print("\nFirst 500 characters of JSON:")
print(json_sbom[:500] + "...")

---

## üõ°Ô∏è Section 5: Secure Model Registry

### Best Practices for Model Registries

1. **Access Control**: RBAC for model access
2. **Versioning**: Track all model versions
3. **Signing**: Cryptographically sign models
4. **Scanning**: Auto-scan for vulnerabilities
5. **Audit Logging**: Track all access
6. **Metadata**: Store provenance information

In [None]:
class SecureModelRegistry:
    """Secure registry for AI models"""
    
    def __init__(self):
        self.models = {}
        self.access_log = []
        
    def register_model(self, provenance: ModelProvenance, allow_overwrite: bool = False) -> Dict:
        """Register a new model in the registry"""
        
        if provenance.model_id in self.models and not allow_overwrite:
            return {
                "success": False,
                "error": f"Model {provenance.model_id} already exists"
            }
        
        # Store model metadata
        self.models[provenance.model_id] = {
            "provenance": provenance,
            "registration_date": datetime.now().isoformat(),
            "downloads": 0,
            "status": "active"
        }
        
        # Log registration
        self.access_log.append({
            "action": "register",
            "model_id": provenance.model_id,
            "timestamp": datetime.now().isoformat(),
            "user": provenance.author
        })
        
        return {
            "success": True,
            "model_id": provenance.model_id,
            "message": "Model registered successfully"
        }
    
    def get_model(self, model_id: str, user: str) -> Dict:
        """Retrieve model from registry"""
        
        if model_id not in self.models:
            return {
                "success": False,
                "error": f"Model {model_id} not found"
            }
        
        # Check if model is active
        if self.models[model_id]["status"] != "active":
            return {
                "success": False,
                "error": f"Model {model_id} is not active (status: {self.models[model_id]['status']})"
            }
        
        # Log access
        self.access_log.append({
            "action": "download",
            "model_id": model_id,
            "timestamp": datetime.now().isoformat(),
            "user": user
        })
        
        # Increment download counter
        self.models[model_id]["downloads"] += 1
        
        return {
            "success": True,
            "provenance": self.models[model_id]["provenance"],
            "downloads": self.models[model_id]["downloads"]
        }
    
    def quarantine_model(self, model_id: str, reason: str) -> Dict:
        """Quarantine a model due to security concerns"""
        
        if model_id not in self.models:
            return {"success": False, "error": "Model not found"}
        
        self.models[model_id]["status"] = "quarantined"
        self.models[model_id]["quarantine_reason"] = reason
        self.models[model_id]["quarantine_date"] = datetime.now().isoformat()
        
        self.access_log.append({
            "action": "quarantine",
            "model_id": model_id,
            "reason": reason,
            "timestamp": datetime.now().isoformat()
        })
        
        return {
            "success": True,
            "message": f"Model {model_id} quarantined: {reason}"
        }
    
    def generate_registry_report(self) -> str:
        """Generate registry status report"""
        
        report = "\nüóÑÔ∏è MODEL REGISTRY REPORT\n"
        report += "="*80 + "\n\n"
        
        report += f"Total Models: {len(self.models)}\n"
        report += f"Total Access Logs: {len(self.access_log)}\n\n"
        
        # Status breakdown
        status_counts = {}
        for model in self.models.values():
            status = model["status"]
            status_counts[status] = status_counts.get(status, 0) + 1
        
        report += "Models by Status:\n"
        for status, count in status_counts.items():
            report += f"  - {status}: {count}\n"
        
        report += "\nüìä Model Details:\n\n"
        for model_id, data in self.models.items():
            status_icon = "‚úÖ" if data["status"] == "active" else "‚ö†Ô∏è"
            report += f"{status_icon} {model_id}\n"
            report += f"   Version: {data['provenance'].version}\n"
            report += f"   Author: {data['provenance'].author}\n"
            report += f"   Downloads: {data['downloads']}\n"
            report += f"   Status: {data['status']}\n"
            if data["status"] == "quarantined":
                report += f"   ‚ö†Ô∏è Reason: {data.get('quarantine_reason', 'N/A')}\n"
            report += "\n"
        
        report += "="*80
        
        return report

print("‚úÖ Secure Model Registry Created")

# Test registry
print("\nüß™ Testing Secure Model Registry:\n")

registry = SecureModelRegistry()

# Register models
print("1Ô∏è‚É£ Registering Models:")
model1 = ModelProvenance(
    model_id="llama-2-7b",
    model_name="LLaMA-2-7B",
    version="1.0",
    author="Meta AI",
    organization="meta-llama",
    creation_date="2023-07-18",
    training_data_sources=["Common Crawl"],
    base_model=None,
    file_hash="abc123",
    signature=None
)

result = registry.register_model(model1)
print(f"   {result['message']}")

# Download model
print("\n2Ô∏è‚É£ Downloading Model:")
download = registry.get_model("llama-2-7b", user="researcher1")
if download['success']:
    print(f"   ‚úÖ Downloaded successfully")
    print(f"   Total downloads: {download['downloads']}")

# Quarantine model
print("\n3Ô∏è‚É£ Quarantining Suspicious Model:")
quarantine = registry.quarantine_model("llama-2-7b", "Backdoor detected in security scan")
print(f"   {quarantine['message']}")

# Try to download quarantined model
print("\n4Ô∏è‚É£ Attempting to Download Quarantined Model:")
download2 = registry.get_model("llama-2-7b", user="researcher2")
if not download2['success']:
    print(f"   ‚ùå {download2['error']}")

# Generate report
print(registry.generate_registry_report())

---

## üìù Assessment: Secure Your Supply Chain

### Exercise 1: Audit Your Models

**Task**: Create complete provenance for your AI models

Requirements:
1. Document all training data sources
2. Compute and store file hashes
3. Generate digital signatures
4. Create AI-SBOM

### Exercise 2: Scan for Vulnerabilities

**Task**: Scan your dependencies for known vulnerabilities

Tools:
- `pip-audit` for Python
- `npm audit` for JavaScript
- OWASP Dependency-Check

### Exercise 3: Implement Watermarking

**Task**: Watermark a model you own

Steps:
1. Generate trigger set
2. Fine-tune model on triggers
3. Verify watermark persists
4. Test robustness to fine-tuning

---

## üéì Summary & Key Takeaways

### What You've Learned:

1. ‚úÖ **AI supply chains** are complex and vulnerable
2. ‚úÖ **Model provenance** enables verification and trust
3. ‚úÖ **Data poisoning** can be detected through statistical analysis
4. ‚úÖ **Watermarking** proves ownership and authenticity
5. ‚úÖ **AI-SBOM** provides transparency and security
6. ‚úÖ **Secure registries** control access and track usage

### Supply Chain Security Checklist:

**For Models**:
- [ ] Verify source and author
- [ ] Check cryptographic hash
- [ ] Validate digital signature
- [ ] Test for backdoors
- [ ] Review training data sources

**For Data**:
- [ ] Audit data sources
- [ ] Scan for poisoning
- [ ] Validate labels
- [ ] Check licensing
- [ ] Document provenance

**For Dependencies**:
- [ ] Pin versions
- [ ] Scan for vulnerabilities
- [ ] Monitor for updates
- [ ] Use trusted sources
- [ ] Generate SBOM

### Best Practices:

1. **Verify everything**: Never trust without verification
2. **Document thoroughly**: Maintain complete provenance
3. **Scan regularly**: Continuous vulnerability monitoring
4. **Isolate testing**: Test untrusted models in sandbox
5. **Monitor access**: Log all model usage

---

## üöÄ Next Steps

1. **Implement** model verification in your deployment pipeline
2. **Generate** AI-SBOM for all production systems
3. **Establish** secure model registry
4. **Audit** training data sources
5. **Monitor** supply chain threats

**Continue to Notebook 15** to learn about incident response & forensics! üöÄ

---

## üìö Resources

**Standards**:
- NIST AI Risk Management Framework: https://www.nist.gov/itl/ai-risk-management-framework
- MITRE ATLAS: https://atlas.mitre.org/
- OWASP ML Top 10: https://owasp.org/www-project-machine-learning-security-top-10/

**Tools**:
- pip-audit: https://github.com/pypa/pip-audit
- Snyk: https://snyk.io/
- OWASP Dependency-Check: https://owasp.org/www-project-dependency-check/

**Research**:
- Backdoor Attacks on LLMs: https://arxiv.org/abs/2108.00352
- Data Poisoning Attacks: https://arxiv.org/abs/1712.05526
- Model Watermarking: https://arxiv.org/abs/1903.01743