In [4]:
import random
import hashlib
from collections import Counter
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
import binascii

# Oblivious Transfer implementation (1-out-of-2)
class ObliviousTransfer:
    def __init__(self):
        # Public parameters (in real crypto, these would be large primes)
        self.p = 101  # Prime modulus
        self.g = 3    # Generator
    
    def receiver_step1(self, choice):
        """Receiver creates public keys based on choice bit"""
        self.sk = random.randint(1, self.p-1)
        self.choice = choice
        
        # Create public keys based on choice
        if choice == 0:
            pk0 = pow(self.g, self.sk, self.p)
            pk1 = pow(self.g, random.randint(1, self.p-1), self.p)
        else:
            pk0 = pow(self.g, random.randint(1, self.p-1), self.p)
            pk1 = pow(self.g, self.sk, self.p)
            
        return (pk0, pk1)
    
    def sender_step2(self, m0, m1, pk0, pk1):
        """Sender encrypts both messages"""
        # Encrypt m0 with pk0
        r0 = random.randint(1, self.p-1)
        k0 = pow(pk0, r0, self.p)
        h0 = pow(self.g, r0, self.p)
        c0 = self._encrypt(m0, k0)
        
        # Encrypt m1 with pk1
        r1 = random.randint(1, self.p-1)
        k1 = pow(pk1, r1, self.p)
        h1 = pow(self.g, r1, self.p)
        c1 = self._encrypt(m1, k1)
        
        return ((h0, c0), (h1, c1))
    
    def receiver_step3(self, c0, c1):
        """Receiver decrypts chosen message"""
        # Compute decryption key
        c = c0 if self.choice == 0 else c1
        key = pow(c[0], self.sk, self.p)
        return self._decrypt(c[1], key)

    ## cipher = Enc(message, key) = H(key) \xor message
    def _encrypt(self, message, key):
        """AES encryption using key hash"""
        key_hash = hashlib.sha256(str(key).encode()).digest()[:16]
        cipher = AES.new(key_hash, AES.MODE_ECB)
        return cipher.encrypt(pad(str(message).encode(), AES.block_size))

    ## message = Dec(cipher, key) = H(key) \xor cipher
    def _decrypt(self, ciphertext, key):
        """AES decryption using key hash"""
        try:
            key_hash = hashlib.sha256(str(key).encode()).digest()[:16]
            cipher = AES.new(key_hash, AES.MODE_ECB)
            return unpad(cipher.decrypt(ciphertext), AES.block_size).decode()
        except:
            return "DECRYPTION_FAILED"

class GarbledCircuit:
    def __init__(self, circuit_id):
        self.generator = random.Random(circuit_id)
        # Generate random wire labels (in real crypto, these would be cryptographic keys)
        self.input_labels = {
            'a0': self.generator.getrandbits(128),
            'a1': self.generator.getrandbits(128),
            'b0': self.generator.getrandbits(128),
            'b1': self.generator.getrandbits(128)
        }
        self.output_labels = {
            'out0': self.generator.getrandbits(128),
            'out1': self.generator.getrandbits(128)
        }
        
        # Create garbled truth table for AND gate
        self.garbled_table = self.create_garbled_table()
        
        # Commit to circuit construction
        self.commitment = self.create_commitment()

    def create_garbled_table(self):
        """Create encrypted truth table for AND gate"""
        table = [
            (self.input_labels['a0'], self.input_labels['b0'], self.output_labels['out0']),  # 0 AND 0 = 0
            (self.input_labels['a0'], self.input_labels['b1'], self.output_labels['out0']),  # 0 AND 1 = 0
            (self.input_labels['a1'], self.input_labels['b0'], self.output_labels['out0']),  # 1 AND 0 = 0
            (self.input_labels['a1'], self.input_labels['b1'], self.output_labels['out1'])   # 1 AND 1 = 1
        ]
        random.shuffle(table)  # Hide input/output relationships
        return table

    def create_commitment(self):
        """Create commitment to circuit construction"""
        data = str(sorted(self.input_labels.items())) + str(sorted(self.output_labels.items()))
        return hashlib.sha256(data.encode()).hexdigest()

    def evaluate(self, a_label, b_label):
        """Evaluate circuit with given wire labels"""
        # Find matching entry in garbled table
        for entry in self.garbled_table:
            if entry[0] == a_label and entry[1] == b_label:
                output_label = entry[2]
                break
        else:
            raise ValueError("No matching entry found")
        
        # Map output label to value
        return 1 if output_label == self.output_labels['out1'] else 0

def cut_and_choose(a_input, b_input, num_circuits=5, check_fraction=0.4):
    # Create OT instance
    ot = ObliviousTransfer()
    
    # 1. Garbler creates circuits
    circuits = [GarbledCircuit(i) for i in range(num_circuits)]
    
    # 2. Evaluator selects circuits to check
    all_indices = list(range(num_circuits))
    check_indices = random.sample(all_indices, int(num_circuits * check_fraction))
    eval_indices = [i for i in all_indices if i not in check_indices]
    
    # 3. Verify check circuits, ensure these selected circuits are correctly constructed
    ## In this case, evaluator knows every labels of a gate
    for idx in check_indices:
        c = circuits[idx]
        
        # Recreate circuit with same ID to verify commitment
        test_circuit = GarbledCircuit(idx)
        if test_circuit.commitment != c.commitment:
            raise ValueError(f"Commitment mismatch in circuit {idx}")
        
        # Verify AND gate behavior
        test_results = [
            (0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 1)
        ]
        for a_val, b_val, expected in test_results:
            a_label = c.input_labels['a1'] if a_val else c.input_labels['a0']
            b_label = c.input_labels['b1'] if b_val else c.input_labels['b0']
            result = c.evaluate(a_label, b_label)
            if result != expected:
                raise ValueError(f"Circuit {idx} failed verification")
    
    print(f"✅ Verified {len(check_indices)} circuits")
    
    # 4. Evaluate remaining circuits with OT for input privacy
    results = []
    for idx in eval_indices:
        circuit = circuits[idx]
        
        # Get Alice's input label directly
        a_label = circuit.input_labels['a1'] if a_input else circuit.input_labels['a0']
        
        # Use OT to get Bob's input label privately
        pk0, pk1 = ot.receiver_step1(b_input)
        c0, c1 = ot.sender_step2(
            circuit.input_labels['b0'], 
            circuit.input_labels['b1'], 
            pk0, pk1
        )
        b_label = int(ot.receiver_step3(c0, c1))
        
        # Evaluate circuit
        try:
            result = circuit.evaluate(a_label, b_label)
            results.append(result)
        except ValueError:
            results.append(-1)  # Evaluation failed
    
    # 5. Output majority result
    valid_results = [r for r in results if r != -1]
    if not valid_results:
        raise RuntimeError("All circuit evaluations failed!")
        
    majority = Counter(valid_results).most_common(1)[0][0]
    print(f"Evaluation results: {results}")
    print(f"Majority output: {a_input} AND {b_input} = {majority}")
    return majority

# Example usage
if __name__ == "__main__":
    # Alice's input (Garbler)
    a_input = 1
    # Bob's input (Evaluator)
    b_input = 0
    
    print("Running cut-and-choose with oblivious transfer...")
    print(f"Alice's input: {a_input}, Bob's input: {b_input} (secret)\n")
    
    result = cut_and_choose(a_input, b_input, num_circuits = 10)
    
    print("\nSecurity analysis:")
    print("- Bob's input remains private throughout the protocol")
    print("- Alice can't cheat due to circuit verification")
    print(f"- Cheating probability: < 2^{-int(5*0.4)} = 1/{2**2} = 1/4")

Running cut-and-choose with oblivious transfer...
Alice's input: 1, Bob's input: 0 (secret)

✅ Verified 4 circuits
Evaluation results: [0, 0, 0, 0, 0, 0]
Majority output: 1 AND 0 = 0

Security analysis:
- Bob's input remains private throughout the protocol
- Alice can't cheat due to circuit verification
- Cheating probability: < 2^-2 = 1/4 = 1/4
