In [None]:
# Adversarial Neural Cryptography - https://arxiv.org/pdf/1610.06918.pdf

In [1]:
# Necessary imports
import torch
import random

from torch import nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define the length of the key and the message
pln_txt_len = sec_key_len = cip_txt_len = 8

In [3]:
# Function to generate a random plain text and a secret key
def create_batch(plain_text_len, secret_key_len, batch_size):
    plain_text = torch.zeros(batch_size, plain_text_len)
    secret_key = torch.zeros(batch_size, secret_key_len)
    
    for i in range(batch_size):
        for j in range(plain_text_len):
            plain_text[i][j] = random.randint(0, 1) * 2 - 1
        for j in range(secret_key_len):
            secret_key[i][j] = random.randint(0, 1) * 2 - 1
            
    return plain_text, secret_key

plain_text, secret_key = create_batch(pln_txt_len, sec_key_len, 256)
print(plain_text.shape, secret_key.shape)

torch.Size([256, 8]) torch.Size([256, 8])


In [4]:
# Alice's network

class Alice_Net(nn.Module):
    def __init__(self, plain_text_len, secret_key_len, cipher_text_len):
        super(Alice_Net, self).__init__()
        self.input_size  = plain_text_len + secret_key_len
        self.output_size = cipher_text_len
        
        self.W = nn.Linear(self.input_size, self.input_size)
        self.C = nn.Sequential(
            nn.Conv1d(1, 2, 4, stride=1, padding=2),
            nn.Sigmoid(),  
            nn.Conv1d(2, 4, 2, stride=2, padding=0),
            nn.Sigmoid(), 
            nn.Conv1d(4, 4, 1, stride=1, padding=0),
            nn.Sigmoid(),   
            nn.Conv1d(4, 1, 1, stride=1, padding=0),
            nn.Tanh(),      
            )
        
    def forward(self, plain_text, secret_key):
        x = torch.cat((plain_text, secret_key), 1)
        x = torch.reshape(x, (x.shape[0], 1, x.shape[1]))
        x = self.W(x)
        x = self.C(x)
        x = torch.squeeze(x)
        return x
    
    def encrypt(self, plain_text, secret_key):
        return self.forward(plain_text, secret_key)

In [5]:
# Bob's network

class Bob_Net(nn.Module):
    def __init__(self, plain_text_len, secret_key_len, cipher_text_len):
        super(Bob_Net, self).__init__()
        self.input_size  = cipher_text_len + secret_key_len
        self.output_size = plain_text_len
        
        self.W = nn.Linear(self.input_size, self.input_size)
        self.C = nn.Sequential(
            nn.Conv1d(1, 2, 4, stride=1, padding=2),
            nn.Sigmoid(),  
            nn.Conv1d(2, 4, 2, stride=2, padding=0),
            nn.Sigmoid(), 
            nn.Conv1d(4, 4, 1, stride=1, padding=0),
            nn.Sigmoid(),   
            nn.Conv1d(4, 1, 1, stride=1, padding=0),
            nn.Tanh(),      
            )
        
    def forward(self, cipher_text, secret_key):
        x = torch.cat((cipher_text, secret_key), 1)
        x = torch.reshape(x, (x.shape[0], 1, x.shape[1]))
        x = self.W(x)
        x = self.C(x)
        x = torch.squeeze(x)
        return x
    
    def decrypt(self, cipher_text, secret_key):
        return self.forward(cipher_text, secret_key)

In [6]:
# Eve's network

class Eve_Net(nn.Module):
    def __init__(self, plain_text_len, cipher_text_len):
        super(Eve_Net, self).__init__()
        self.input_size  = cipher_text_len
        self.output_size = plain_text_len
        
        self.W = nn.Linear(self.input_size, 2 * self.input_size)
        self.C = nn.Sequential(
            nn.Conv1d(1, 2, 4, stride=1, padding=2),
            nn.Sigmoid(),  
            nn.Conv1d(2, 4, 2, stride=2, padding=0),
            nn.Sigmoid(), 
            nn.Conv1d(4, 4, 1, stride=1, padding=0),
            nn.Sigmoid(),   
            nn.Conv1d(4, 1, 1, stride=1, padding=0),
            nn.Tanh(),      
            )
        
    def forward(self, cipher_text):
        x = torch.reshape(cipher_text, (cipher_text.shape[0], 1, cipher_text.shape[1]))
        x = self.W(x)
        x = self.C(x)
        x = torch.squeeze(x)
        return x
    
    def eavesdrop(self, cipher_text):
        return self.forward(cipher_text)

In [7]:
# Hyperparameters

### Lets define the Loss Functions
Alice and Bob want to minimize Bob’s reconstruction error and to
maximize the reconstruction error of the “optimal Eve”.

Eve's Loss function is the L1 distance between Eve’s guess and the input plaintext.

In [8]:
def EveLoss(Plain, Decrypt, BatchSize):
    return torch.sum(torch.abs((Plain - Decrypt)/2)) / BatchSize

Alice and Bob Loss has two components:
1. The reconstruction error of Bob. which is just like eve L1 error between decrypted and plaintext
2. How good previous eve is in breaking the current moedl. For this we use: (N/2 − Eve_L1_error)^2 /(N/2)^2

In [9]:
def AliceBobLoss(Plain, Decrypt, BatchSize, prevEveLoss):
    loss1 = torch.sum(torch.abs((Plain - Decrypt)/2)) / BatchSize
    loss2 = ((pln_txt_len/2 - prevEveLoss)**2)/(pln_txt_len/2)**2
    return loss1 + loss2

In [10]:
# Train the networks

# We use TensorFlow’s Adam (Kingma & Ba, 2014) optimizer with a learning rate of 0.0008
# Alice/Bob training for one minibatch, and then
# Eve training for two minibatches