In [1]:
from TPM import TPM
import numpy as np
import argparse

# parser = argparse.ArgumentParser(description='Synchronize two TPMs')
# parser.add_argument('k', type=int, help='the number of hidden units')
# parser.add_argument('n', type=int, help='the number of units connected to the hidden units')
# parser.add_argument('l', type=int, help='maximum weight value: {+l, ..., -3, -2, -1, 0, 1, 2, 3, ..., -l}')
# parser.add_argument('update_rule', type=str, help='the update rule to use for the TPMs')
# parser.add_argument('--verbose', action='store_true', help='print data about the running state of the program')
# args = parser.parse_args()

class args:
    k = 8
    n = 32
    l = 9
    update_rule = 'hebbian'
    verbose = True

k = args.k
n = args.n
l = args.l
update_rule = args.update_rule

if args.verbose:
    verbose = True
else:
    verbose = False

def score(tpm1, tpm2):
    matching_elements = np.sum(tpm1.W == tpm2.W)
    total_elements = tpm1.W.size  
    proportion_of_matches = matching_elements / total_elements
    scaled_score = 1 + 99 * proportion_of_matches
    return scaled_score

def random_number_generator(l, k, n):
    return np.random.randint(-l, l + 1, (k, n))

if verbose:
    print('Creating TPM1 with k = {}, n = {}, l = {}'.format(args.k, args.n, args.l))

tpm1 = TPM(args.k, args.n, args.l)

if verbose:
    print('Creating TPM2 with k = {}, n = {}, l = {}'.format(args.k, args.n, args.l))

tpm2 = TPM(args.k, args.n, args.l)

if verbose:
    print('Creating man in the middle machine with k = {}, n = {}, l = {}'.format(args.k, args.n, args.l))

evil_tpm = TPM(args.k, args.n, args.l)

if verbose:
    print('Synchronizing TPM1 and TPM2 using the {} update rule'.format(args.update_rule))

sync = False
epoch = 0

while not sync:
    X = random_number_generator(l, k, n)

    output1 = tpm1.output(X)
    output2 = tpm2.output(X)
    evil_output = evil_tpm.output(X)

    tpm1.update(output2, update_rule)
    tpm2.update(output1, update_rule)

    if output1 == output2 == evil_output:
        evil_tpm.update(output1, update_rule)

    score_tpm = score(tpm1, tpm2)
    epoch += 1

    if verbose:
        print('Score: {}'.format(score_tpm))
        print('Epoch: {}'.format(epoch))

    if score_tpm == 100:
        sync = True

if verbose:
    print('TPM1 and TPM2 have synchronized after {} epochs'.format(epoch))

    print('Final score: {}'.format(score_tpm))

evil_score = score(tpm1, evil_tpm)

if evil_score == 100:
    print("COMPROMISED")
else:
    print("SECURE")

print('\n\n--------------------------------------\n\n')
print("Public key: {}".format(X))
print('\n\n--------------------------------------\n\n')
print("Private key: {}".format(tpm1.W))

# Write the public key to a file
with open('public_key.txt', 'w') as f:
    f.write(str(X))

# Write the private key to a file
with open('private_key.txt', 'w') as f:
    f.write(str(tpm1.W))

print(tpm1.output(X))


Creating TPM1 with k = 8, n = 32, l = 9
Creating TPM2 with k = 8, n = 32, l = 9
Creating man in the middle machine with k = 8, n = 32, l = 9
Synchronizing TPM1 and TPM2 using the hebbian update rule
Score: 4.09375
Epoch: 1
Score: 4.09375
Epoch: 2
Score: 4.09375
Epoch: 3
Score: 4.09375
Epoch: 4
Score: 4.8671875
Epoch: 5
Score: 4.8671875
Epoch: 6
Score: 9.89453125
Epoch: 7
Score: 9.89453125
Epoch: 8
Score: 16.85546875
Epoch: 9
Score: 16.85546875
Epoch: 10
Score: 16.85546875
Epoch: 11
Score: 16.85546875
Epoch: 12
Score: 18.7890625
Epoch: 13
Score: 18.7890625
Epoch: 14
Score: 18.7890625
Epoch: 15
Score: 19.94921875
Epoch: 16
Score: 19.94921875
Epoch: 17
Score: 23.04296875
Epoch: 18
Score: 23.04296875
Epoch: 19
Score: 23.04296875
Epoch: 20
Score: 23.04296875
Epoch: 21
Score: 23.04296875
Epoch: 22
Score: 28.45703125
Epoch: 23
Score: 32.7109375
Epoch: 24
Score: 32.7109375
Epoch: 25
Score: 32.7109375
Epoch: 26
Score: 32.7109375
Epoch: 27
Score: 32.7109375
Epoch: 28
Score: 32.7109375
Epoch: 29


In [25]:
# Flatten the public and private keys
public_key = X
private_key = tpm1.W

import hashlib
import time

def create_signature(message, weights):
    weights = weights.flatten()
    message_hash = hashlib.sha256(message.message.encode()).hexdigest()
    weights_hash = hashlib.sha256(str(weights).encode()).hexdigest()
    # XOR the two hashes
    signature = ''.join(chr(ord(a) ^ ord(b)) for a, b in zip(message_hash, weights_hash))
    return signature

def verify_signature(message, signature, private_info):
    if time.time() - message.timestamp > message.ttl:
        return False
    private_info = private_info.flatten()
    message_hash = hashlib.sha256(message.message.encode()).hexdigest()
    # Assume public_info allows recreation of the weights_hash
    weights_hash = hashlib.sha256(str(private_info).encode()).hexdigest()
    computed_signature = ''.join(chr(ord(a) ^ ord(b)) for a, b in zip(message_hash, weights_hash))
    return computed_signature == signature

# Test the signature creation and verification
class Message:
    def __init__(self, message, ttl):
        self.message = message
        self.timestamp = time.time()
        self.ttl = ttl

    def get(self, key):
        return getattr(self, key)
message = Message('Hello, world!', 10)
signature = create_signature(message, private_key)
print(verify_signature(message, signature, private_key))



True
