In [1]:
# Adversarial Neural Cryptography - https://arxiv.org/pdf/1610.06918.pdf

In [2]:
# Necessary imports
import torch
import random

from torch import nn

In [3]:
# Define the length of the key and the message
pln_txt_len = sec_key_len = cip_txt_len = 8

In [4]:
# Function to generate a random plain text and a secret key
def create_batch(plain_text_len, secret_key_len, batch_size):
    plain_text = torch.zeros(batch_size, plain_text_len)
    secret_key = torch.zeros(batch_size, secret_key_len)
    
    for i in range(batch_size):
        for j in range(plain_text_len):
            plain_text[i][j] = random.randint(0, 1)
        for j in range(secret_key_len):
            secret_key[i][j] = random.randint(0, 1)
            
    return plain_text, secret_key

plain_text, secret_key = create_batch(pln_txt_len, sec_key_len, 256)
print(plain_text.shape, secret_key.shape)

torch.Size([256, 8]) torch.Size([256, 8])


In [5]:
# Alice's network

class Alice_Net(nn.Module):
    def __init__(self, plain_text_len, secret_key_len, cipher_text_len):
        super(Alice_Net, self).__init__()
        self.input_size  = plain_text_len + secret_key_len
        self.output_size = cipher_text_len
        
        self.W = nn.Linear(self.input_size, self.input_size)
        self.C = nn.Sequential(
            nn.Conv1d(4, 1, 2, stride=1, padding=1),
            nn.Sigmoid(),  
            nn.Conv1d(2, 2, 4, stride=2, padding=1),
            nn.Sigmoid(), 
            nn.Conv1d(1, 4, 4, stride=1, padding=1),
            nn.Sigmoid(),   
            nn.Conv1d(1, 4, 1, stride=1, padding=1),
            nn.Tanh(),      
            )
        
    def forward(self, plain_text, secret_key):
        x = torch.cat((plain_text, secret_key), 1)
        x = self.W(x)
        x = self.C(x)
        return x
    
    def encrypt(self, plain_text, secret_key):
        return self.forward(plain_text, secret_key)

In [6]:
# Bob's network

class Bob_Net(nn.Module):
    def __init__(self, plain_text_len, secret_key_len, cipher_text_len):
        super(Bob_Net, self).__init__()
        self.input_size  = cipher_text_len + secret_key_len
        self.output_size = plain_text_len
        
        self.W = nn.Linear(self.input_size, self.input_size)
        self.C = nn.Sequential(
            nn.Conv1d(4, 1, 2, stride=1, padding=1),
            nn.Sigmoid(),  
            nn.Conv1d(2, 2, 4, stride=2, padding=1),
            nn.Sigmoid(), 
            nn.Conv1d(1, 4, 4, stride=1, padding=1),
            nn.Sigmoid(),   
            nn.Conv1d(1, 4, 1, stride=1, padding=1),
            nn.Tanh(),      
            )
        
    def forward(self, cipher_text, secret_key):
        x = torch.cat((cipher_text, secret_key), 1)
        x = self.W(x)
        x = self.C(x)
        return x
    
    def decrypt(self, cipher_text, secret_key):
        return self.forward(cipher_text, secret_key)

In [7]:
# Eve's network

class Eve_Net(nn.Module):
    def __init__(self, plain_text_len, cipher_text_len):
        super(Eve_Net, self).__init__()
        self.input_size  = cipher_text_len
        self.output_size = plain_text_len
        
        self.W = nn.Linear(self.input_size, 2 * self.input_size)
        self.C = nn.Sequential(
            nn.Conv1d(4, 1, 2, stride=1, padding=1),
            nn.Sigmoid(),  
            nn.Conv1d(2, 2, 4, stride=2, padding=1),
            nn.Sigmoid(), 
            nn.Conv1d(1, 4, 4, stride=1, padding=1),
            nn.Sigmoid(),   
            nn.Conv1d(1, 4, 1, stride=1, padding=1),
            nn.Tanh(),      
            )
        
    def forward(self, cipher_text):
        x = self.W(cipher_text)
        x = self.C(x)
        return x
    
    def eavesdrop(self, cipher_text):
        return self.forward(cipher_text)

In [8]:
# Hyperparameters

In [9]:
# Train the networks