In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
batch_size = 4096
TEXT_SIZE = 16
KEY_SIZE = 16

In [None]:
class AllyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # Inputs to hidden layer linear transformation
        self.connected = nn.Linear(TEXT_SIZE + KEY_SIZE, TEXT_SIZE + KEY_SIZE)
        self.conv1 = nn.Conv1d(in_channels = 1, out_channels = 2, kernel_size = 4, stride=1)
        self.conv2 = nn.Conv1d(in_channels = 2, out_channels = 4, kernel_size = 2, stride=2)
        self.conv3 = nn.Conv1d(in_channels = 4, out_channels = 4, kernel_size = 1, stride=1)
        self.conv4 = nn.Conv1d(in_channels = 4, out_channels = 1, kernel_size = 1, stride=1)
        
        # Define sigmoid activation and softmax output 
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        # Pass the input tensor through each of our operations
        # print(0)
        x = x.unsqueeze(0)
        x = self.connected(x)
        # print(1)
        x = F.pad(x, (1,2))
        # print(2)
        x = self.sigmoid(x)
        # print(3)

        x = x.unsqueeze(0)
        x = self.conv1(x)
        # print(4)
        x = F.pad(x, (0,1))
        # print(5)
        x = self.sigmoid(x)

        x = self.conv2(x)
        x = self.sigmoid(x)

        x = self.conv3(x)
        x = self.sigmoid(x)

        x = self.conv4(x)
        x = self.tanh(x)
        x = x.squeeze(0)
        x = x.squeeze(0)


        return x

In [None]:
class AdversaryNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # Inputs to hidden layer linear transformation
        self.connected = nn.Linear(TEXT_SIZE, TEXT_SIZE)
        self.conv1 = nn.Conv1d(in_channels = 1, out_channels = 2, kernel_size = 4, stride=1)
        self.conv2 = nn.Conv1d(in_channels = 2, out_channels = 4, kernel_size = 2, stride=1)
        self.conv3 = nn.Conv1d(in_channels = 4, out_channels = 4, kernel_size = 1, stride=1)
        self.conv4 = nn.Conv1d(in_channels = 4, out_channels = 1, kernel_size = 1, stride=1)
        
        # Define sigmoid activation and softmax output 
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        # Pass the input tensor through each of our operations
        x = x.unsqueeze(0)
        x = self.connected(x)
        # print(1)
        x = F.pad(x, (1,2))
        # print(2)
        x = self.sigmoid(x)
        # print(3)

        x = x.unsqueeze(0)
        x = self.conv1(x)
        # print(4)
        x = F.pad(x, (0,1))
        # print(5)
        x = self.sigmoid(x)

        x = self.conv2(x)
        x = self.sigmoid(x)

        x = self.conv3(x)
        x = self.sigmoid(x)

        x = self.conv4(x)
        x = self.tanh(x)
        x = x.squeeze(0)
        x = x.squeeze(0)
        return x

In [None]:
# Eves_loss = (1/batch_size)*tf.reduce_sum( tf.abs( Eve_out_message - Alice_input_message ))
def eveLoss(aliceInput,eveOutput):
  aliceInput = (aliceInput + 1)/2
  eveOutput = (eveOutput + 1)/2
  loss = nn.L1Loss()
  loss = loss(aliceInput,eveOutput)
  return loss


def aliceBobLoss(aliceInput, bobOutput, eveLoss):
  aliceInput = (aliceInput + 1)/2
  bobOutput = (bobOutput + 1)/2
  bobLoss = nn.L1Loss()
  bobLoss = bobLoss(aliceInput,bobOutput)

  eveEvadroppingLoss = (((TEXT_SIZE/2)-eveLoss)*((TEXT_SIZE/2)-eveLoss))/((TEXT_SIZE/2)*(TEXT_SIZE/2))

  finalLoss =  bobLoss + eveEvadroppingLoss
  return finalLoss



In [None]:
# input = torch.tensor([1,-1,1,-1,-1,-1,1], dtype=torch.float32)
# target = torch.tensor([1,-1,1,-1,-1,-1,-1], dtype=torch.float32)
# i = eveLoss(input,target)
# print(i,type(i))
# i = aliceBobLoss(input, target, i)
# print(i,type(i))

In [None]:
Alice = AllyNetwork().cuda()
Bob = AllyNetwork().cuda()
Eve = AdversaryNetwork().cuda()

optimizerAlice = optim.Adam(Alice.parameters())
optimizerBob = optim.Adam(Bob.parameters())
optimizerEve = optim.Adam(Eve.parameters())

In [None]:
import numpy as np

In [None]:

def geterateDataset(msgLen,keyLen,size):
  dsMsg = []
  for i in range(size):
    haha = torch.randint(0,2,(msgLen,))
    haha = (haha*2)-1
    haha = haha.float()
    dsMsg.append(haha)

  dsKey = []
  for i in range(size):
    haha = torch.randint(0,2,(keyLen,))
    haha = (haha*2)-1
    haha = haha.float()
    dsKey.append(haha)
  return  dsMsg,  dsKey



In [None]:
epochs = 100
batch_size = 4096
sample_size = 4096*5
steps_per_epoch = int(sample_size/batch_size)

for epoch in range(epochs):
  MSG,KEY = geterateDataset(TEXT_SIZE,KEY_SIZE,sample_size)
  running_loss = 0.0
  for i in range(steps_per_epoch):
    msg = MSG[i]
    key = KEY[i]

    # ============ Forward ============
    
    t1 = torch.cat((msg,key)).t().cuda()
    alice_output = Alice(t1.cuda())    

    t2 = torch.cat((alice_output.t().cuda(),key.cuda()))
    bob_output = Bob(t2.cuda())

    t3 = alice_output.t().cuda()
    eve_output = Eve(t3.cuda())
    eve_pred_loss = eveLoss(msg.cuda(),eve_output.cuda())

    
    bob_pred_loss = aliceBobLoss(msg.cuda(),bob_output,eve_pred_loss)
    alice_loss = bob_pred_loss
    # print(bob_pred_loss)
    # ============ Backward ============
    optimizerAlice.zero_grad()
    bob_pred_loss.backward(retain_graph=True)
    # optimizerAlice.step()

    optimizerBob.zero_grad()
    alice_loss.backward()
    optimizerBob.step()

    # ============ Logging ===========
    running_loss += bob_pred_loss.data

    if epoch%20 ==0 and i==4:
      print(msg, alice_output,bob_output, eve_output)

    # if i == 4:
    #   print('[%d, %5d] loss: %.3f' %
		# 			(epoch + 1, i + 1, running_loss))
    #   running_loss = 0.0


    #//////////////////////////////////////////////

  running_loss = 0.0

  for i in range(steps_per_epoch):
    msg = MSG[i]
    key = KEY[i]

    # ============ Forward ============
    
    t1 = torch.cat((msg,key)).t().cuda()
    alice_output = Alice(t1.cuda())    

    t2 = alice_output.t().cuda()
    eve_output = Eve(t2.cuda())
    
    eve_pred_loss = eveLoss(msg.cuda(),eve_output)

    # ============ Backward ============
    optimizerEve.zero_grad()
    eve_pred_loss.backward(retain_graph=True)
    optimizerEve.step()


    # ============ Logging ===========
    running_loss += eve_pred_loss.data

    if epoch%20 ==0 and i==4:
      print(msg, alice_output, eve_output)
    # if i == 4:
    #   print('[%d, %5d] loss: %.3f' %
    #       (epoch + 1, i + 1, running_loss / 2000))
    #   running_loss = 0.0    


tensor([ 1., -1.,  1.,  1.,  1.,  1.,  1., -1.,  1., -1.,  1.,  1.,  1.,  1.,
        -1.,  1.]) tensor([0.0262, 0.0260, 0.0259, 0.0261, 0.0259, 0.0259, 0.0258, 0.0261, 0.0261,
        0.0261, 0.0259, 0.0260, 0.0261, 0.0261, 0.0260, 0.0259],
       device='cuda:0', grad_fn=<SqueezeBackward1>) tensor([0.0546, 0.0547, 0.0545, 0.0546, 0.0548, 0.0547, 0.0544, 0.0547, 0.0547,
        0.0545, 0.0547, 0.0546, 0.0547, 0.0546, 0.0546, 0.0546],
       device='cuda:0', grad_fn=<SqueezeBackward1>) tensor([-0.1451, -0.1452, -0.1452, -0.1451, -0.1452, -0.1452, -0.1452, -0.1452,
        -0.1451, -0.1452, -0.1451, -0.1452, -0.1451, -0.1452, -0.1451, -0.1453],
       device='cuda:0', grad_fn=<SqueezeBackward1>)
tensor([ 1., -1.,  1.,  1.,  1.,  1.,  1., -1.,  1., -1.,  1.,  1.,  1.,  1.,
        -1.,  1.]) tensor([0.0262, 0.0260, 0.0259, 0.0261, 0.0259, 0.0259, 0.0258, 0.0261, 0.0261,
        0.0261, 0.0259, 0.0260, 0.0261, 0.0261, 0.0260, 0.0259],
       device='cuda:0', grad_fn=<SqueezeBackward1>) te

In [None]:
loss = nn.L1Loss()
input = torch.randn(1,2, requires_grad=True)
target = torch.randn(1, 2)
output = loss(input, target)
print((output))


tensor(1.3193, grad_fn=<L1LossBackward0>)


In [None]:
haha = torch.randint(0,2,(3,))
haha = torch.cat((haha,haha))
haha = haha.float()
# input = F.pad(input, (1,2))
# input = input + torch.randn(3)

# input = torch.nn.ZeroPad2d((3,0))

In [None]:
print(haha,type(haha))

tensor([1., 0., 1., 1., 0., 1.]) <class 'torch.Tensor'>
