In [194]:
# Adversarial Neural Cryptography - https://arxiv.org/pdf/1610.06918.pdf

In [195]:
# Necessary imports
import torch
import random

import numpy as np
from torch import nn

import matplotlib.pyplot as plt

In [196]:
# Define the length of the key and the message
pln_txt_len = sec_key_len = cip_txt_len = 16

In [197]:
# Function to generate a random plain text and a secret key
def create_batch(plain_text_len, secret_key_len, batch_size):
    plain_text = torch.zeros(batch_size, plain_text_len)
    secret_key = torch.zeros(batch_size, secret_key_len)
    
    for i in range(batch_size):
        for j in range(plain_text_len):
            plain_text[i][j] = random.randint(0, 1)
        for j in range(secret_key_len):
            secret_key[i][j] = random.randint(0, 1)
            
    return plain_text.float(), secret_key.float()

plain_text, secret_key = create_batch(pln_txt_len, sec_key_len, 256)
print(plain_text.shape, secret_key.shape)

torch.Size([256, 16]) torch.Size([256, 16])


In [198]:
# Alice's network

class Alice_Net(nn.Module):
    def __init__(self, plain_text_len, secret_key_len, cipher_text_len):
        super(Alice_Net, self).__init__()
        self.input_size  = plain_text_len + secret_key_len
        self.output_size = cipher_text_len
        
        self.W = nn.Sequential(
            nn.Linear(self.input_size, self.input_size),
            nn.Tanh(),
        )
        
        self.C = nn.Sequential(
            nn.Conv1d(1, 2, 4, stride=1, padding=2),
            nn.Tanh(),  
            nn.Conv1d(2, 4, 2, stride=2, padding=0),
            nn.Tanh(), 
            nn.Conv1d(4, 4, 1, stride=1, padding=0),
            nn.Tanh(),   
            nn.Conv1d(4, 1, 1, stride=1, padding=0),
            nn.Sigmoid(),      
        )
        
    def forward(self, plain_text, secret_key):
        x = torch.cat((plain_text, secret_key), axis=-1)
        x = torch.unsqueeze(x, 1)
        
        x = self.W(x)
        x = self.C(x)
        
        x = x.view(-1, self.output_size)
        return x
    
    def encrypt(self, plain_text, secret_key):
        encrypted_text = self.forward(plain_text, secret_key)
        return torch.as_tensor((encrypted_text - 0.5) > 0, dtype=torch.float32)

In [199]:
# Bob's network

class Bob_Net(nn.Module):
    def __init__(self, plain_text_len, secret_key_len, cipher_text_len):
        super(Bob_Net, self).__init__()
        self.input_size  = cipher_text_len + secret_key_len
        self.output_size = plain_text_len
        
        self.W = nn.Sequential(
            nn.Linear(self.input_size, self.input_size),
            nn.Tanh(),
        )
        
        self.C = nn.Sequential(
            nn.Conv1d(1, 2, 4, stride=1, padding=2),
            nn.Tanh(),  
            nn.Conv1d(2, 4, 2, stride=2, padding=0),
            nn.Tanh(), 
            nn.Conv1d(4, 4, 1, stride=1, padding=0),
            nn.Tanh(),   
            nn.Conv1d(4, 1, 1, stride=1, padding=0),
            nn.Sigmoid(),      
        )
        
    def forward(self, cipher_text, secret_key):
        x = torch.cat((cipher_text, secret_key), axis=-1)
        x = torch.unsqueeze(x, 1)
        
        x = self.W(x)
        x = self.C(x)
        
        x = x.view(-1, self.output_size)
        return x
    
    def decrypt(self, cipher_text, secret_key):
        decrypted_text = self.forward(cipher_text, secret_key)
        return torch.as_tensor((decrypted_text - 0.5) > 0, dtype=torch.float32)

In [200]:
# Eve's network

class Eve_Net(nn.Module):
    def __init__(self, plain_text_len, cipher_text_len):
        super(Eve_Net, self).__init__()
        self.input_size  = cipher_text_len
        self.output_size = plain_text_len
        
        self.W = nn.Sequential(
            nn.Linear(self.input_size, 2 * self.input_size),
            nn.Tanh(),
        )
        
        self.C = nn.Sequential(
            nn.Conv1d(1, 2, 4, stride=1, padding=2),
            nn.Tanh(),  
            nn.Conv1d(2, 4, 2, stride=2, padding=0),
            nn.Tanh(), 
            nn.Conv1d(4, 4, 1, stride=1, padding=0),
            nn.Tanh(),   
            nn.Conv1d(4, 1, 1, stride=1, padding=0),
            nn.Sigmoid(),      
        )
        
    def forward(self, cipher_text):
        x = torch.unsqueeze(cipher_text, 1)
        
        x = self.W(x)
        x = self.C(x)
        
        x = x.view(-1, self.output_size)
        return x
    
    def eavesdrop(self, cipher_text):
        eavesdrop_text = self.forward(cipher_text)
        return torch.as_tensor((eavesdrop_text - 0.5) > 0, dtype=torch.float32)

In [201]:
# Define the device

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
print(torch.cuda.get_device_name(torch.cuda.current_device()))

NVIDIA GeForce RTX 2070 Super with Max-Q Design


In [202]:
# Instantiate the models
Alice = Alice_Net(pln_txt_len, sec_key_len, cip_txt_len).to(device)
Bob   = Bob_Net(pln_txt_len, sec_key_len, cip_txt_len).to(device)
Eve   = Eve_Net(pln_txt_len, cip_txt_len).to(device)

print("All models instantiated successfully !!!")

All models instantiated successfully !!!


In [203]:
# Import saved the models
Alice.load_state_dict(torch.load("../CNN_models/CNN_Alice.pth"))
Bob.load_state_dict(torch.load("../CNN_models/CNN_Bob.pth"))
Eve.load_state_dict(torch.load("../CNN_models/CNN_Eve.pth"))

print("Models loaded successfully !!!")

Models loaded successfully !!!


In [204]:
def hamming_distance(x, y):
    count = 0
    
    for i in range(len(x)):
        if x[i] != y[i]:
            count += 1
    
    return count   

In [205]:
# Set the models to evaluation mode
Alice.eval()
Bob.eval()
Eve.eval()

# Test the networks
plain_text, secret_key = create_batch(pln_txt_len, sec_key_len, 1)
plain_text, secret_key = plain_text.to(device), secret_key.to(device)

# Encrypt and decrypt the plain text
encrypted_text = Alice.encrypt(plain_text, secret_key)
decrypted_text = Bob.decrypt(encrypted_text, secret_key)
eavesdrop_text = Eve.eavesdrop(encrypted_text)

plain_text     = plain_text.detach().cpu().squeeze().numpy()
plain_text     = "".join(str(int(i)) for i in plain_text)

encrypted_text = encrypted_text.detach().cpu().squeeze().numpy()
encrypted_text = "".join(str(int(i)) for i in encrypted_text)

decrypted_text = decrypted_text.detach().cpu().squeeze().numpy()
decrypted_text = "".join(str(int(i)) for i in decrypted_text)

eavesdrop_text = eavesdrop_text.detach().cpu().squeeze().numpy()
eavesdrop_text = "".join(str(int(i)) for i in eavesdrop_text)

print("Plain text:  ", plain_text)
print("Cipher text: ", encrypted_text)

print("Bob's output:", decrypted_text, "Distance:", hamming_distance(plain_text, decrypted_text))
print("Eve's output:", eavesdrop_text, "Distance:", hamming_distance(plain_text, eavesdrop_text))

Plain text:   0010111011010000
Cipher text:  1001110000111101
Bob's output: 0010111100010000 Distance: 3
Eve's output: 0010111100011000 Distance: 4
