In [47]:
# Adversarial Neural Cryptography - https://arxiv.org/pdf/1610.06918.pdf
# Alternative implementation of the paper above using RNNs

In [48]:
# Necessary imports
import torch
import random

import numpy as np
from torch import nn

import matplotlib.pyplot as plt

In [49]:
# Define the length of the key and the message
pln_txt_len = cip_txt_len = 16
sec_key_len = 16

In [50]:
# Function to generate a random plain text and a secret key
def create_batch(plain_text_len, secret_key_len, batch_size):
    plain_text = torch.zeros(batch_size, plain_text_len)
    secret_key = torch.zeros(batch_size, secret_key_len)
    
    for i in range(batch_size):
        for j in range(plain_text_len):
            plain_text[i][j] = random.randint(0, 1)
        for j in range(secret_key_len):
            secret_key[i][j] = random.randint(0, 1)
            
    return plain_text.float(), secret_key.float()

plain_text, secret_key = create_batch(pln_txt_len, sec_key_len, 256)
print(plain_text.shape, secret_key.shape)

torch.Size([256, 16]) torch.Size([256, 16])


In [51]:
# Define the device

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
print(torch.cuda.get_device_name(torch.cuda.current_device()))

NVIDIA GeForce RTX 2070 Super with Max-Q Design


In [52]:
# Alice's network
# Add plain text and secret key and pass it through a linear layer
# Note: plain text and secret key are of the same length

class AliceRNN_Net(nn.Module):
    def __init__(self, plain_text_len, secret_key_len, cipher_text_len, hidden_size=128):
        super(AliceRNN_Net, self).__init__()   
        self.input_size  = plain_text_len + secret_key_len  # L + K
        self.output_size = cipher_text_len                  # L
        self.hidden_size = hidden_size                      # H
        
        self.C   = nn.Linear(self.input_size, self.output_size)
        self.W   = nn.Linear(1, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
        self.U   = nn.Linear(2*hidden_size, 1)
        
        self.V   = nn.Linear(2*hidden_size, 1)
        
    def init_hidden(self, batch_size):
        return torch.zeros(2, batch_size, self.hidden_size).to(device)
        
    def forward(self, plain_text, secret_key, hidden=None):
        x = torch.cat((plain_text, secret_key), dim=1)      # B x (L + K)
        x = self.C(x).unsqueeze(2)                          # B x L x 1
        x = torch.relu(x)                                   # B x L x 1
        x = self.W(x)                                       # B x L x H
        x = torch.relu(x)                                   # B x L x H
        
        if hidden is None:
            hidden = self.init_hidden(x.shape[0])           # N x B x H
        x, hidden = self.gru(x, hidden)                     # B x L x 2H, 2 x B x H
        
        x = self.U(x)                                       # B x L x 1
        x = torch.sigmoid(x)                                # B x L x 1
        x = x.view(-1, self.output_size)                    # B x L
        return x
    
    def encrypt(self, plain_text, secret_key):
        encrypted_text = self.forward(plain_text, secret_key)
        return torch.as_tensor((encrypted_text - 0.5) > 0, dtype=torch.float32)

In [53]:
class BobRNN_Net(nn.Module):
    def __init__(self, plain_text_len, secret_key_len, cipher_text_len, hidden_size=128):
        super(BobRNN_Net, self).__init__()   
        self.input_size  = cipher_text_len + secret_key_len     # L + K
        self.output_size = plain_text_len                       # L
        self.hidden_size = hidden_size                          # H 
        
        self.C   = nn.Linear(self.input_size, self.output_size)
        self.W   = nn.Linear(1, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
        self.U   = nn.Linear(2*hidden_size, 1)
        
    def init_hidden(self, batch_size):
        return torch.zeros(2, batch_size, self.hidden_size).to(device)
        
    def forward(self, cipher_text, secret_key, hidden=None):
        x = torch.cat((cipher_text, secret_key), dim=1)         # B x (L + K)
        x = self.C(x).unsqueeze(2)                              # B x L x 1
        x = torch.relu(x)                                       # B x L x 1
        x = self.W(x)                                           # B x L x H
        x = torch.relu(x)                                       # B x L x H
        
        if hidden is None:
            hidden = self.init_hidden(x.shape[0])               # N x B x H
        x, hidden = self.gru(x, hidden)                         # B x L x 2H, 2 x B x H
        
        x = self.U(x)                                           # B x L x 1
        x = torch.sigmoid(x)                                    # B x L x 1
        x = x.view(-1, self.output_size)                        # B x L
        return x
    
    def decrypt(self, cipher_text, secret_key):
        decrypted_text = self.forward(cipher_text, secret_key)
        return torch.as_tensor((decrypted_text - 0.5) > 0, dtype=torch.float32)

In [54]:
class EveRNN_Net(nn.Module):
    def __init__(self, plain_text_len, cipher_text_len, hidden_size=128):
        super(EveRNN_Net, self).__init__()  
        self.input_size  = cipher_text_len                      # L 
        self.output_size = plain_text_len                       # L
        self.hidden_size = hidden_size                          # H 
        
        self.W   = nn.Linear(1, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
        self.U   = nn.Linear(2 * hidden_size, 1)
        
    def init_hidden(self, batch_size):
        return torch.zeros(2, batch_size, self.hidden_size).to(device)
        
    def forward(self, cipher_text, hidden=None):
        x = cipher_text.unsqueeze(2)                            # B x L x 1
        x = self.W(x)                                           # B x L x H
        x = torch.relu(x)                                       # B x L x H
        
        if hidden is None:
            hidden = self.init_hidden(x.shape[0])               # N x B x H
        x, hidden = self.gru(x, hidden)                         # B x L x 2H, 2 x B x H
        
        x = self.U(x)                                           # B x L x 1
        x = torch.sigmoid(x)                                    # B x L x 1
        x = x.view(-1, self.output_size)                        # B x L
        return x
    
    def eavesdrop(self, cipher_text):
        eavesdrop_text = self.forward(cipher_text)
        return torch.as_tensor((eavesdrop_text - 0.5) > 0, dtype=torch.float32)

In [55]:
# Declare the hidden size
hidden_size = 128

# Instantiate the models
Alice = AliceRNN_Net(pln_txt_len, sec_key_len, cip_txt_len, hidden_size).to(device)
Bob   = BobRNN_Net(pln_txt_len, sec_key_len, cip_txt_len, hidden_size).to(device)
Eve   = EveRNN_Net(pln_txt_len, cip_txt_len, hidden_size).to(device)

print("All models instantiated successfully !!!")

All models instantiated successfully !!!


In [56]:
# Import saved the models
Alice.load_state_dict(torch.load("../GRU_models/GRU_Alice.pth"))
Bob.load_state_dict(torch.load("../GRU_models/GRU_Bob.pth"))
Eve.load_state_dict(torch.load("../GRU_models/GRU_Eve.pth"))

print("Models loaded successfully !!!")

Models loaded successfully !!!


In [57]:
def hamming_distance(x, y):
    count = 0
    
    for i in range(len(x)):
        if x[i] != y[i]:
            count += 1
    
    return count   

In [58]:
# Set the models to evaluation mode
Alice.eval()
Bob.eval()
Eve.eval()

# Test the networks
plain_text, secret_key = create_batch(pln_txt_len, sec_key_len, 1)
plain_text, secret_key = plain_text.to(device), secret_key.to(device)

# Encrypt and decrypt the plain text
encrypted_text = Alice.encrypt(plain_text, secret_key)
decrypted_text = Bob.decrypt(encrypted_text, secret_key)
eavesdrop_text = Eve.eavesdrop(encrypted_text)

plain_text     = plain_text.detach().cpu().squeeze().numpy()
plain_text     = "".join(str(int(i)) for i in plain_text)

encrypted_text = encrypted_text.detach().cpu().squeeze().numpy()
encrypted_text = "".join(str(int(i)) for i in encrypted_text)

decrypted_text = decrypted_text.detach().cpu().squeeze().numpy()
decrypted_text = "".join(str(int(i)) for i in decrypted_text)

eavesdrop_text = eavesdrop_text.detach().cpu().squeeze().numpy()
eavesdrop_text = "".join(str(int(i)) for i in eavesdrop_text)

print("Plain text:  ", plain_text)
print("Cipher text: ", encrypted_text)

print("Bob's output:", decrypted_text, "Distance:", hamming_distance(plain_text, decrypted_text))
print("Eve's output:", eavesdrop_text, "Distance:", hamming_distance(plain_text, eavesdrop_text))

Plain text:   1010101100110110
Cipher text:  1010000100110001
Bob's output: 1110100100110110 Distance: 2
Eve's output: 0010100000111100 Distance: 5
