In [1]:
%matplotlib inline

import ctypes
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sys

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Activation, Dense, Flatten, Reshape, Conv1D, concatenate
from tensorflow.keras.optimizers import Adam, RMSprop


In [None]:
m_bits = 16
k_bits = 16
signature = 16

pad = 'same'

In [None]:
a_message = Input(shape=(m_bits, ))
a_signature = Input(shape=(k_bits, ))

a_message_reshaped = Reshape((m_bits, 1))(a_message)
a_signature_reshaped = Reshape((k_bits, 1))(a_signature)

ainput = concatenate([a_message,a_signature],axis=1)

adense = Dense(units = (m_bits+k_bits))(ainput)
adense_act = Activation('tanh')(adense)

areshape = Reshape((m_bits+k_bits,1, ))(adense_act)

aconv1 = Conv1D(filters = 2, kernel_size = 4, strides = 1, padding=pad)(areshape)
aconv1a = Activation('tanh')(aconv1)

aconv2 = Conv1D(filters = 4, kernel_size = 2, strides = 2, padding=pad)(aconv1a)
aconv2a = Activation('tanh')(aconv2)

a_flat = Flatten()(aconv2a)

adense1 = Dense(units = (signature))(a_flat)
adense1_act = Activation('tanh')(adense1)

aoutput = concatenate([a_message, adense1_act], axis = 1)

alice = Model(inputs = [a_message, a_signature], outputs = [aoutput] , name='alice')
alice.summary()

Model: "alice"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 16)]                 0         []                            
                                                                                                  
 input_2 (InputLayer)        [(None, 16)]                 0         []                            
                                                                                                  
 concatenate (Concatenate)   (None, 32)                   0         ['input_1[0][0]',             
                                                                     'input_2[0][0]']             
                                                                                                  
 dense (Dense)               (None, 32)                   1056      ['concatenate[0][0]']     

In [None]:
b_inp = Input(shape=(m_bits+signature, ))

# binput = concatenate([b_inp, b_key], axis=1)

bdense1 = Dense(units=(m_bits+signature))(b_inp)
bdense1a = Activation('tanh')(bdense1)

breshape = Reshape((m_bits+signature, 1,))(bdense1a)

bconv1 = Conv1D(filters=2, kernel_size=4, strides=1, padding=pad)(breshape)
bconv1a = Activation('tanh')(bconv1)

bconv2 = Conv1D(filters=4, kernel_size=2, strides=2, padding=pad)(bconv1a)
bconv2a = Activation('tanh')(bconv2)

bconv3 = Conv1D(filters=4, kernel_size=1, strides=1, padding=pad)(bconv2a)
bconv3a = Activation('tanh')(bconv3)

b_flat = Flatten()(bconv3a)

bdense2 = Dense(units=(m_bits+signature))(b_flat)
bdense2_a = Activation('relu')(bdense2)

bdense3 = Dense(units=(m_bits+signature)/2 )(bdense2_a)
bdense3_a = Activation('relu')(bdense3)

bdense4 = Dense(units = 1)(bdense3_a)
bdense4_a = Activation('sigmoid')(bdense4)

bob = Model(inputs = [b_inp], outputs = bdense4_a, name='bob')
bob.summary()

Model: "bob"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 32)]              0         
                                                                 
 dense_2 (Dense)             (None, 32)                1056      
                                                                 
 activation_4 (Activation)   (None, 32)                0         
                                                                 
 reshape_3 (Reshape)         (None, 32, 1)             0         
                                                                 
 conv1d_2 (Conv1D)           (None, 32, 2)             10        
                                                                 
 activation_5 (Activation)   (None, 32, 2)             0         
                                                                 
 conv1d_3 (Conv1D)           (None, 16, 4)             20      

In [None]:
e_cipher = Input(shape=(m_bits+signature, ))

edense1 = Dense(units = (m_bits+signature))(e_cipher)
edense1a = Activation('tanh')(edense1)

edense2 = Dense(units = (m_bits+signature))(edense1a)
edense2a = Activation('tanh')(edense2)

ereshape = Reshape((m_bits+signature,1, ))(edense2a)

econv1 = Conv1D(filters=2, kernel_size=4, strides=1, padding=pad)(ereshape)
econv1a = Activation('tanh')(econv1)

econv2 = Conv1D(filters=4, kernel_size=2, strides=2, padding=pad)(econv1a)
econv2a = Activation('tanh')(econv2)

e_flat = Flatten()(econv2a)

e_dense3 = Dense(units = (m_bits+signature))(e_flat)
e_dense3a = Activation('tanh')(e_dense3)

e_dense4 = Dense(units = (m_bits+signature))(e_dense3a)
e_dense4a = Activation('tanh')(e_dense4)

eve = Model(inputs = [e_cipher], outputs = e_dense4a, name = 'Eve')
eve.summary()

Model: "Eve"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 32)]              0         
                                                                 
 dense_6 (Dense)             (None, 32)                1056      
                                                                 
 activation_11 (Activation)  (None, 32)                0         
                                                                 
 dense_7 (Dense)             (None, 32)                1056      
                                                                 
 activation_12 (Activation)  (None, 32)                0         
                                                                 
 reshape_4 (Reshape)         (None, 32, 1)             0         
                                                                 
 conv1d_5 (Conv1D)           (None, 32, 2)             10      

In [None]:
# Connect Neural Networks
aliceout =  alice( [a_message, a_signature] )
bobout   =  bob( [aliceout] )

eveout   =  eve( [aliceout] )
bobout2  =  bob( [eveout] )

"""Loss for Eve is just L1 distance between ainput0 and eoutput. The sum is taken over all the bits in the message. The quantity inside the K.mean() is per-example loss.
We take the average across the entire mini-batch"""
eveloss = K.log(1+bobout2) - 0.01*K.mean(K.sum(K.abs(aliceout - eveout), axis=-1))

"""Loss for Alice-Bob communication depends on Bob's reconstruction, but also on Eve's ability to decrypt the message. Eve should do no better than random guessing,
so on average she will guess half the bits right."""
# Loss function for alice-bob model
abeloss =   K.log(1+bobout) + K.log(2-bobout2)

# Backpropagation
abeoptim = RMSprop(lr=0.001)
eveoptim = RMSprop(lr=0.001)

#Alice Bob Model
abemodel = Model(inputs = [a_message, a_signature], outputs = [bobout, bobout2] , name = 'Alice_bob')
abemodel.add_loss(abeloss)
abemodel.compile(optimizer = abeoptim)

# Alice-Eve Model with no change in weights on Alice
evemodel = Model(inputs = [a_message, a_signature], outputs = [bobout2] , name='evemodel')
evemodel.add_loss(eveloss)
evemodel.compile(optimizer = eveoptim)




In [None]:
batch_size = 2
eve.trainable = False
m_batch = np.random.randint(0, 2, m_bits * batch_size).reshape(batch_size, m_bits)
k_batch = np.random.randint(0, 2, k_bits * batch_size).reshape(batch_size, k_bits)

In [None]:
batch_size = 32
alice.trainable = True
bob.trainable = True
eve.trainable = False
m_batch = np.random.randint(0, 2, m_bits * batch_size).reshape(batch_size, m_bits)
k_batch = np.random.randint(0, 2, k_bits * batch_size).reshape(batch_size, k_bits)
loss = abemodel.train_on_batch([m_batch, k_batch], None)
loss

0.8199781179428101

In [None]:
alice.trainable = False
bob.trainable = False
eve.trainable = True
for cycle in range(2):
    m_batch = np.random.randint(0, 2, m_bits * batch_size).reshape(batch_size, m_bits)
    k_batch = np.random.randint(0, 2, k_bits * batch_size).reshape(batch_size, k_bits)
    loss = evemodel.train_on_batch([m_batch, k_batch], None)
loss

0.2557808756828308

In [None]:
abelosses = []
evelosses = []

# 1 epoch trains on the entire dataset
# 1 batch trains on a subset of a dataset calculates the loss and backpropagates and changes the weights

n_epochs = 20
batch_size = 256
m_train = 2**(m_bits)
n_batches = 128

# eve and bob are trained at a ratio of 2:1
abecycle = 1
evecycle = 2
epoch = 0

print("Training for", n_epochs, "epochs with", n_batches, "batches of size", batch_size)
while epoch<n_epochs:
    avg_abelosses = []
    avg_evelosses = []

    for batch in range(n_batches):

      #Train alice-bob model
      alice.trainable = True
      bob.trainable = True
      eve.trainable = False
      for cycle in range(abecycle):
          m_batch = np.random.randint(0, 2, m_bits * batch_size).reshape(batch_size, m_bits)
          k_batch = np.random.randint(0, 2, k_bits * batch_size).reshape(batch_size, k_bits)
          loss = abemodel.train_on_batch([m_batch, k_batch], None)

          #Alice_Bob'e Error
          abelosses.append(loss)
          avg_abelosses.append(loss)
          abe_avg = np.mean(avg_abelosses)

      # #Bob's Error
      # m_enc = alice.predict([m_batch, k_batch])
      # m_dec = bob.predict([m_enc, k_batch])
      # loss = np.mean(  np.sum( np.abs(m_batch - m_dec), axis=-1)  )
      # avg_boblosses.append(loss)
      # boblosses.append(loss)
      # bob_avg = np.mean(avg_boblosses)

      #Train eve model
      alice.trainable = False
      bob.trainable = False
      eve.trainable = True
      for cycle in range(evecycle):
          m_batch = np.random.randint(0, 2, m_bits * batch_size).reshape(batch_size, m_bits)
          k_batch = np.random.randint(0, 2, k_bits * batch_size).reshape(batch_size, k_bits)
          loss = evemodel.train_on_batch([m_batch, k_batch], None)

      #Eve's Error
      evelosses.append(loss)
      avg_evelosses.append(loss)
      eve_avg = np.mean(avg_evelosses)
      print("\rEpoch {:3}: {:3}% | abe: {:2.3f} | eve: {:2.3f}".format(epoch, 100 * batch // n_batches, abe_avg, eve_avg), end="")

    print()
    epoch+=1
print('Training finished.')

Training for 20 epochs with 128 batches of size 256
Epoch   0:  99% | abe: 0.711 | eve: -0.465
Epoch   1:   4% | abe: 0.693 | eve: -0.559