In [1]:
# Adversarial Neural Cryptography - https://arxiv.org/pdf/1610.06918.pdf

In [38]:
# Necessary imports
import torch
import random

In [39]:
# Define the length of the key and the message
pln_txt_len = 8
sec_key_len = 8
cip_txt_len = 8

# Define the length of the first FC layer after input layer (Alice and Bob and Eve will have length of 2N) 
fc_len = pln_txt_len + sec_key_len

In [40]:
# Function to generate a random plain text and a secret key
def create_batch(plain_text_len, secret_key_len, batch_size):
    plain_text = torch.zeros(batch_size, plain_text_len)
    secret_key = torch.zeros(batch_size, secret_key_len)
    
    for i in range(batch_size):
        for j in range(plain_text_len):
            plain_text[i][j] = random.randint(0, 1)
        for j in range(secret_key_len):
            secret_key[i][j] = random.randint(0, 1)
            
    return plain_text, secret_key

plain_text, secret_key = create_batch(pln_txt_len, sec_key_len, 256)
print(plain_text.shape, secret_key.shape)

torch.Size([256, 8]) torch.Size([256, 8])


` Network Structure: Our networks follow the “mix & transform” pattern described in Section 2.4.
The Alice network concatenates two N-bit inputs (the plaintext and the key) into a 2N-entry vector,
using −1 and 1 to represent the bit values. [Start of Page 6 in paper]`

In [41]:
# Alternate way to generate a random plain text and a secret key since paper specifies bit values to be 1 or -1

def create_batch_alt(plain_text_len, secret_key_len, batch_size):
    plain_text = torch.randint(0, 2, (batch_size, plain_text_len), dtype=torch.float) * 2 - 1
    secret_key = torch.randint(0, 2, (batch_size, secret_key_len), dtype=torch.float) * 2 - 1
    
    return plain_text, secret_key

plain_text, secret_key = create_batch_alt(pln_txt_len, sec_key_len, 256)
print(plain_text.shape, secret_key.shape)

torch.Size([256, 8]) torch.Size([256, 8])


In [42]:
# Alice's network
class Alice(torch.nn.Module):
    def __init__(self, plain_text_len, secret_key_len, cipher_text_len, fc_len):
        super(Alice, self).__init__()
        self.plain_text_len = plain_text_len
        self.secret_key_len = secret_key_len
        self.cipher_text_len = cipher_text_len
        self.fc_len = fc_len
        
        self.fc1 = torch.nn.Linear(plain_text_len + secret_key_len, fc_len)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv1 = torch.nn.Conv1d(1, 2, 4, stride=1)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv2 = torch.nn.Conv1d(2, 4, 2, stride=2)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv3 = torch.nn.Conv1d(4, 4, 1, stride=1)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv4 = torch.nn.Conv1d(4, 1, 1, stride=1)
        self.tanh = torch.nn.Tanh()
        
    def forward(self, plain_text, secret_key):
        x = torch.cat((plain_text, secret_key), 1) # Concatenate the plain text and the secret key
        x = self.fc1(x)
        x = self.sigmoid(x)
        x = self.conv1(x)
        x = self.sigmoid(x)
        x = self.conv2(x)
        x = self.sigmoid(x)
        x = self.conv3(x)
        x = self.sigmoid(x)
        x = self.conv4(x)
        x = self.tanh(x)
        x = torch.squeeze(x, 1)
        return x

In [43]:
# Bob's network
class Bob(torch.nn.Module):
    def __init__(self, plain_text_len, secret_key_len, cipher_text_len, fc_len):
        super(Bob, self).__init__()
        self.plain_text_len = plain_text_len
        self.secret_key_len = secret_key_len
        self.cipher_text_len = cipher_text_len
        self.fc_len = fc_len
        
        self.fc1 = torch.nn.Linear(cipher_text_len + secret_key_len, fc_len)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv1 = torch.nn.Conv1d(1, 2, 4, stride=1)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv2 = torch.nn.Conv1d(2, 4, 2, stride=2)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv3 = torch.nn.Conv1d(4, 4, 1, stride=1)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv4 = torch.nn.Conv1d(4, 1, 1, stride=1)
        self.tanh = torch.nn.Tanh()
        
    def forward(self, cipher_text, secret_key):
        x = torch.cat((cipher_text, secret_key), 1) # Concatenate the cipher text and the secret key
        x = self.fc1(x)
        x = self.sigmoid(x)
        x = self.conv1(x)
        x = self.sigmoid(x)
        x = self.conv2(x)
        x = self.sigmoid(x)
        x = self.conv3(x)
        x = self.sigmoid(x)
        x = self.conv4(x)
        x = self.tanh(x)
        x = torch.squeeze(x, 1)
        return x

In [44]:
# Eve's network
class Eve(torch.nn.Module):
    def __init__(self, plain_text_len,  cipher_text_len, fc_len):
        super(Eve, self).__init__()
        self.plain_text_len = plain_text_len
        self.cipher_text_len = cipher_text_len
        self.fc_len = fc_len

        self.fc1 = torch.nn.Linear(cipher_text_len, fc_len)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv1 = torch.nn.Conv1d(1, 2, 4, stride=1)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv2 = torch.nn.Conv1d(2, 4, 2, stride=2)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv3 = torch.nn.Conv1d(4, 4, 1, stride=1)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv4 = torch.nn.Conv1d(4, 1, 1, stride=1)
        self.tanh = torch.nn.Tanh()

    def forward(self, cipher_text):
        x = cipher_text
        x = self.fc1(x)
        x = self.sigmoid(x)
        x = self.conv1(x)
        x = self.sigmoid(x)
        x = self.conv2(x)
        x = self.sigmoid(x)
        x = self.conv3(x)
        x = self.sigmoid(x)
        x = self.conv4(x)
        x = self.tanh(x)
        x = torch.squeeze(x, 1)
        return x

In [None]:
# Hyperparameters

In [None]:
# Train the networks