In [1]:
# -*- coding: utf-8 -*-
"""
AI Inference Verification with Blockchain and ZKPs (SNARK & STARK Variants)

This script expands on the original simulation to demonstrate the conceptual
differences between using zk-SNARKs and zk-STARKs for AI inference verification.

This simulation still uses hashes to represent proofs, but the workflow now
accurately reflects the major structural differences between the two technologies.
"""
import hashlib
import json
import time
from typing import Dict, Any, List

# --- Step 1: The Deep Neural Network (DNN) Model ---
def simulate_deep_neural_network(input_x: float, weights: Dict[str, Any]) -> float:
    """
    Simulates a 2-layer Deep Neural Network (DNN) inference.
    The input is a single feature, and the output is a single prediction (float).
    The weights dictionary holds the proprietary model parameters.
    """

    # 1. Extract Weights and Biases (representing proprietary parameters)
    w1 = weights['w1']  # Weights for 3 hidden nodes
    b1 = weights['b1']  # Biases for 3 hidden nodes
    w2 = weights['w2']  # Weights from 3 hidden nodes to 1 output node
    b2 = weights['b2'][0] # Bias for 1 output node (single float)

    # 2. Input Layer to Hidden Layer (using a simple Tanh-like activation)
    hidden_output = []

    # Compute weighted sum and activation for each of the 3 hidden nodes
    for i in range(len(w1)):
        # Weighted sum: Input_x * W + B
        weighted_sum = input_x * w1[i] + b1[i]

        # Simple Tanh activation approximation: (x * 0.9) / (1 + |x|)
        activation = (weighted_sum * 0.9) / (1 + abs(weighted_sum))
        hidden_output.append(activation)

    # 3. Hidden Layer to Output Layer (using a simple Sigmoid-like activation)
    final_sum = 0.0
    for i in range(len(hidden_output)):
        # Sum of (Hidden_node_output * W2)
        final_sum += hidden_output[i] * w2[i]

    # Add final bias
    final_output = final_sum + b2

    # Final Sigmoid-like activation approximation: 1 / (1 + e^-x) -> 1 / (1 + -x)
    output = 1.0 / (1.0 + (-final_output))

    return output

# --- Step 2: Conceptual Zero-Knowledge Proof Simulators ---
# Note: These classes are essentially unchanged, as their hashing logic
# automatically handles the newly structured (and complex) SECRET_AI_WEIGHTS.

class ZKSNARKSimulator:
    """Simulates the workflow of a zk-SNARK system."""

    def __init__(self, model_function):
        self._model = model_function
        self.proving_key = None
        self.verification_key = None
        print("Initialized ZK-SNARK Simulator.")

    def trusted_setup(self, secret_weights: Dict[str, Any]):
        """
        **CONCEPTUAL:** Simulates the one-time trusted setup ceremony for a SNARK.
        The security relies on the destruction of the secret parameters ('toxic waste').
        """
        print("-> SNARK: Performing one-time trusted setup...")
        # Use the full weight structure to derive the keys
        toxic_waste = json.dumps(secret_weights, sort_keys=True) + "a_secret_salt"

        # The proving key is needed by the prover.
        self.proving_key = hashlib.sha256(f"pk-{toxic_waste}".encode()).hexdigest()

        # The verification key is made public and used by the verifier.
        self.verification_key = hashlib.sha256(f"vk-{toxic_waste}".encode()).hexdigest()

        print("-> SNARK: Proving and Verification keys generated. Toxic waste 'destroyed'.")

    def generate_proof(self, x: float, secret_weights: Dict[str, Any]) -> Dict[str, Any]:
        """The PROVER uses the Proving Key to generate a proof."""
        if not self.proving_key:
            raise Exception("Trusted setup must be performed before generating proofs.")

        output = self._model(x, secret_weights)
        weights_str = json.dumps(secret_weights, sort_keys=True)

        # The proof is a hash of the computation details and the proving key.
        proof_content = f"{x}{output}{weights_str}{self.proving_key}"
        proof_hash = hashlib.sha256(proof_content.encode()).hexdigest()

        print(f"-> SNARK Prover: Generated proof for input {x}.")
        return {"input": x, "output": output, "proof_data": proof_hash}

    def verify_proof(self, proof: Dict[str, Any]) -> bool:
        """
        The VERIFIER uses the public Verification Key to check the proof.
        **Crucially, it does NOT need the secret weights.**
        """
        if not self.verification_key:
            raise Exception("Cannot verify without a verification key from a trusted setup.")

        # --- SIMULATION OF VERIFICATION ---
        # To verify that the computation matches the model, we internally rely
        # on the known, true SECRET_AI_WEIGHTS (defined globally) and the
        # public Verification Key to reproduce the proof's commitment.

        # This global variable MUST be defined for the simulation's integrity check.
        # In a real ZK system, this check would happen cryptographically.
        global SECRET_AI_WEIGHTS

        # Re-calculate the expected Proving Key (PK) and the expected proof hash
        # based on the global (true) weights.
        toxic_waste = json.dumps(SECRET_AI_WEIGHTS, sort_keys=True) + "a_secret_salt"
        expected_pk = hashlib.sha256(f"pk-{toxic_waste}".encode()).hexdigest()

        expected_proof_content = f"{proof['input']}{proof['output']}{json.dumps(SECRET_AI_WEIGHTS, sort_keys=True)}{expected_pk}"
        expected_proof_hash = hashlib.sha256(expected_proof_content.encode()).hexdigest()

        is_valid = (proof['proof_data'] == expected_proof_hash)

        if is_valid:
            print(f"-> SNARK Verifier: SUCCESS! Proof is valid for input {proof['input']}.")
        else:
            print(f"-> SNARK Verifier: FAILURE! Proof is INVALID for input {proof['input']}.")
        return is_valid

class ZKSTARKSimulator:
    """Simulates the workflow of a zk-STARK system (transparent setup)."""

    def __init__(self, model_function):
        self._model = model_function
        print("Initialized ZK-STARK Simulator (Transparent, no setup required).")

    def generate_proof(self, x: float, secret_weights: Dict[str, Any]) -> Dict[str, Any]:
        """The PROVER generates a proof using public parameters (hashes)."""
        output = self._model(x, secret_weights)
        weights_str = json.dumps(secret_weights, sort_keys=True)

        # A STARK proof is a cryptographic commitment to the execution trace.
        proof_content = f"stark-trace:{x}{output}{weights_str}"
        proof_hash = hashlib.sha256(proof_content.encode()).hexdigest()

        print(f"-> STARK Prover: Generated proof for input {x}.")
        return {"input": x, "output": output, "proof_data": proof_hash}

    def verify_proof(self, proof: Dict[str, Any]) -> bool:
        """
        The VERIFIER checks the proof using public parameters.
        **Crucially, it does NOT need secret weights or a trusted setup key.**
        """
        # --- SIMULATION OF VERIFICATION ---
        # Similar to SNARK, the verifier must be able to confirm that the hash
        # was generated using the true model weights.
        global SECRET_AI_WEIGHTS

        expected_proof_content = f"stark-trace:{proof['input']}{proof['output']}{json.dumps(SECRET_AI_WEIGHTS, sort_keys=True)}"
        expected_proof_hash = hashlib.sha256(expected_proof_content.encode()).hexdigest()

        is_valid = (proof['proof_data'] == expected_proof_hash)

        if is_valid:
            print(f"-> STARK Verifier: SUCCESS! Proof is valid for input {proof['input']}.")
        else:
            print(f"-> STARK Verifier: FAILURE! Proof is INVALID for input {proof['input']}.")
        return is_valid

# --- Blockchain Class (Modified to accept generic verifier) ---
class Blockchain:
    def __init__(self, zkp_verifier):
        self.chain: List[Dict[str, Any]] = []
        self.zkp_verifier = zkp_verifier
        self.create_block(proof="genesis_proof", previous_hash='0', data="Genesis Block")

    def create_block(self, proof: Any, previous_hash: str, data: Any) -> Dict[str, Any]:
        block = { 'index': len(self.chain) + 1, 'timestamp': time.time(), 'data': data, 'proof': proof, 'previous_hash': previous_hash }
        self.chain.append(block)
        return block

    def get_previous_block(self) -> Dict[str, Any]: return self.chain[-1]
    def hash(self, block: Dict[str, Any]) -> str:
        return hashlib.sha256(json.dumps(block, sort_keys=True).encode()).hexdigest()

    def add_verified_inference(self, zkp_proof: Dict[str, Any]):
        print("\n" + "="*25 + " New Inference Received " + "="*25)
        is_valid = self.zkp_verifier.verify_proof(zkp_proof)
        if not is_valid:
            print("Transaction REJECTED: ZKP verification failed.")
            return
        print("Transaction ACCEPTED: ZKP verification successful.")
        previous_hash = self.hash(self.get_previous_block())
        inference_data = {"type": "VerifiedInference", "input": zkp_proof['input'], "output": zkp_proof['output']}
        new_block = self.create_block(zkp_proof['proof_data'], previous_hash, inference_data)
        print(f"Successfully added Block #{new_block['index']} to the ledger.")


# --- Main Execution ---
# NEW COMPLEX WEIGHTS FOR THE DNN
SECRET_AI_WEIGHTS = {
    'w1': [0.12, 0.34, -0.05],     # Weights from input to 3 hidden nodes
    'b1': [0.5, -00.1, 0.22],      # Biases for 3 hidden nodes
    'w2': [0.75, 0.21, -0.63],     # Weights from 3 hidden nodes to 1 output node
    'b2': [0.15]                   # Bias for 1 output node
}

if __name__ == "__main__":

    # ===================================================================
    # --- SCENARIO A: Using a zk-SNARK based system with DNN ---
    # ===================================================================
    print("\n" + "#"*25 + " RUNNING SNARK SIMULATION WITH DNN " + "#"*25)
    # The ZK system is now initialized with the DNN function
    snark_system = ZKSNARKSimulator(model_function=simulate_deep_neural_network)

    # 1. Perform the one-time trusted setup
    snark_system.trusted_setup(SECRET_AI_WEIGHTS)

    # 2. The blockchain/verifier uses the SNARK verifier
    snark_ledger = Blockchain(zkp_verifier=snark_system)

    # 3. An honest prover generates and submits a proof (Input 10.0)
    honest_snark_proof = snark_system.generate_proof(x=10.0, secret_weights=SECRET_AI_WEIGHTS)
    snark_ledger.add_verified_inference(zkp_proof=honest_snark_proof)

    # 4. A malicious prover uses the wrong weights
    malicious_weights = {'w1': [99.0, 99.0, 99.0], 'b1': [1, 1, 1], 'w2': [1, 1, 1], 'b2': [1]}
    malicious_snark_proof = snark_system.generate_proof(x=5.0, secret_weights=malicious_weights)
    snark_ledger.add_verified_inference(zkp_proof=malicious_snark_proof) # Should fail

    print("\nFinal SNARK Ledger State:")
    print(json.dumps(snark_ledger.chain, indent=4))

    # ===================================================================
    # --- SCENARIO B: Using a zk-STARK based system with DNN ---
    # ===================================================================
    print("\n\n" + "#"*25 + " RUNNING STARK SIMULATION WITH DNN " + "#"*25)
    # The ZK system is now initialized with the DNN function
    stark_system = ZKSTARKSimulator(model_function=simulate_deep_neural_network)

    # 1. NO trusted setup is needed for STARKs. It's ready to go.

    # 2. The blockchain/verifier uses the STARK verifier
    stark_ledger = Blockchain(zkp_verifier=stark_system)

    # 3. An honest prover generates and submits a proof (Input 20.0)
    honest_stark_proof = stark_system.generate_proof(x=20.0, secret_weights=SECRET_AI_WEIGHTS)
    stark_ledger.add_verified_inference(zkp_proof=honest_stark_proof)

    # 4. A malicious prover tampers with the output
    tampered_stark_proof = stark_system.generate_proof(x=-3.0, secret_weights=SECRET_AI_WEIGHTS)
    tampered_stark_proof['output'] += 50.0 # Tampering
    stark_ledger.add_verified_inference(zkp_proof=tampered_stark_proof) # Should fail

    print("\nFinal STARK Ledger State:")
    print(json.dumps(stark_ledger.chain, indent=4))


######################### RUNNING SNARK SIMULATION WITH DNN #########################
Initialized ZK-SNARK Simulator.
-> SNARK: Performing one-time trusted setup...
-> SNARK: Proving and Verification keys generated. Toxic waste 'destroyed'.
-> SNARK Prover: Generated proof for input 10.0.

-> SNARK Verifier: SUCCESS! Proof is valid for input 10.0.
Transaction ACCEPTED: ZKP verification successful.
Successfully added Block #2 to the ledger.
-> SNARK Prover: Generated proof for input 5.0.

-> SNARK Verifier: FAILURE! Proof is INVALID for input 5.0.
Transaction REJECTED: ZKP verification failed.

Final SNARK Ledger State:
[
    {
        "index": 1,
        "timestamp": 1763110098.4243836,
        "data": "Genesis Block",
        "proof": "genesis_proof",
        "previous_hash": "0"
    },
    {
        "index": 2,
        "timestamp": 1763110098.4244823,
        "data": {
            "type": "VerifiedInference",
            "input": 10.0,
            "output": 6.413453336999942
       