In [1]:
import math
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

from ipynb.fs.full.module import * # !pip install ipynb

In [2]:
# %%html
# <style type='text/css'>
# .CodeMirror{
# font-size: 17px;
# </style>

In [2]:
class params():
    def __init__(self):

        # Encoder
        self.encoder_act_func = 'tanh'
        self.encoder_N_layers: int = 2    # number of RNN layers at encoder
        self.encoder_N_neurons: int = 50  # number of neurons at each RNN

        # Decoder
        self.decoder_N_layers: int = 2    # number of RNN layers at decoder
        self.decoder_N_neurons: int = 50  # number of neurons at each RNN
        self.decoder_bidirection = True   # True: bi-directional decoding, False: uni-directional decoding
        self.attention_type: int = 5      # choose the attention type among five options
        # 1. Only the last timestep (N-th)
        # 2. Merge the last outputs of forward/backward RNN
        # 3. Sum over all timesteps
        # 4. Attention mechanism with N weights (same weight for forward/backward)
        # 5. Attention mechanism with 2N weights (separate weights for forward/backward)
        # Uni --> Choose False, 4
        
        # Setup
        self.K1: int = 6                # number of bits
        self.K2: int = 6                # number of bits
        self.N_channel_use = 18             # number of channel uses
        self.input_type = 'bit_vector'      # choose 'bit_vector' or 'one_hot_vector'
        self.output_type = 'one_hot_vector' # choose 'bit_vector' or 'one_hot_vector'
        self.decoder_info = 'None'          # 'bit_estimate', 'state_vector', 'None' for encoder input
        self.encoder_info = 'tran_symbol'   # 'tran_symbol', 'state_vector', 'None' for decoder input

        # Learning parameters
        self.batch_size = int(2.5e4) 
#         self.batch_size = int(2e4) 
#         self.batch_size = int(5e4)  # -- train with 1e8, sig, bi // test with 2e5
#         self.batch_size = int(1e5)
#         self.batch_size = int(2e5)
        self.learning_rate = 0.01 
        self.use_cuda = True

In [3]:
# Depending on decoder_info, the architecture is restricted
# 1. If self.decoder_info = 'None', no restriction
# 2. If self.decoder_info = 'bit_estimate', consider only uni-directional and immediate decoding
# 3. If self.decoder_info = 'state_vector', consider only uni-directional (No need for immediate decoding)

In [3]:
# model setup
parameter = params()
use_cuda = parameter.use_cuda and torch.cuda.is_available()
device = torch.device("cuda:1" if use_cuda else "cpu")

In [4]:
# Generate training data
SNR1 = -5               # SNR at User1 in dB
np1 = 10**(-SNR1/10)   # noise power1 -- Assuming signal power is set to 1
sigma1 = np.sqrt(np1)
SNR2 = 30               # SNR at User2 in dB
np2 = 10**(-SNR2/10)
sigma2 = np.sqrt(np2)

# Training set: tuples of (stream, noise1, noise 2)
N_train = int(1e7)  # number of training set
# N_train = int(1e9)  # number of training set
# N_train = int(1e5)
bit1_train     = torch.randint(0, 2, (N_train, parameter.K1, 1))
bit2_train     = torch.randint(0, 2, (N_train, parameter.K2, 1))
noise1_train   = sigma1*torch.randn((N_train, parameter.N_channel_use, 1)) 
noise2_train   = sigma2*torch.randn((N_train, parameter.N_channel_use, 1)) 

# Validation
N_validation = int(1e5)

print('np1: ', np1)
print('np2: ', np2)

np1:  3.1622776601683795
np2:  0.001


In [5]:
class Twoway_coding(torch.nn.Module):
    def __init__(self, param):
        super(Twoway_coding, self).__init__()
        
        # import parameter
        self.param        = param
        if self.param.decoder_bidirection == True:
            self.decoder_bi = 2 # bi-direction
        else:
            self.decoder_bi = 1 # uni-direction

        ### Encoder input type
        # 1. input_type (bit vector, one-hot vector) for Encoder
        if self.param.input_type == 'bit_vector':
            self.num_input1 = self.param.K1
            self.num_input2 = self.param.K2
        elif self.param.input_type == 'one_hot_vector':
            self.num_input1 = 2**self.param.K1
            self.num_input2 = 2**self.param.K2
        
        # 2. Decoder Info (bit estimate, state vector) for Encoder
        if self.param.decoder_info == 'bit_estimate':
            self.num_D = self.param.N_bits
        elif self.param.decoder_info == 'state_vector':
            self.num_D = self.decoder_bi * self.param.decoder_N_neurons # 2*50 = 100
        elif self.param.decoder_info == 'None':
            self.num_D = 0
        
        ### Decoder input type
        # 1. output_type (bits, one-hot vector) for Decoder
        if self.param.output_type == 'bit_vector':
            self.num_output1 = self.param.K2
            self.num_output2 = self.param.K1
        elif self.param.output_type == 'one_hot_vector':
            self.num_output1 = 2**self.param.K2
            self.num_output2 = 2**self.param.K1
            
        # 2. Encoder Info ('tran_symbol', 'state_vector', 'None') for Decoder
        if self.param.encoder_info == 'tran_symbol':
            self.num_E = 1
        elif self.param.encoder_info == 'state_vector':
            self.num_E = self.param.encoder_N_neurons # 50
        elif self.param.encoder_info == 'None':
            self.num_E = 0
            

        # encoder 1. RNN
        self.encoder1_RNN   = torch.nn.GRU(self.num_input1 + 1 + self.num_D, self.param.encoder_N_neurons, num_layers = self.param.encoder_N_layers, 
                                          bias=True, batch_first=True, dropout=0, bidirectional=False)
        self.encoder1_linear = torch.nn.Linear(self.param.encoder_N_neurons, 1)
        
        # encoder 2. RNN
        self.encoder2_RNN   = torch.nn.GRU(self.num_input2 + 1 + self.num_D, self.param.encoder_N_neurons, num_layers = self.param.encoder_N_layers, 
                                          bias=True, batch_first=True, dropout=0, bidirectional=False)
        self.encoder2_linear = torch.nn.Linear(self.param.encoder_N_neurons, 1)

        # power weight 1
        self.weight_power1 = torch.nn.Parameter(torch.Tensor(self.param.N_channel_use), requires_grad = True )
        self.weight_power1.data.uniform_(1.0, 1.0) # all 1
        self.weight_power1_normalized = torch.sqrt(self.weight_power1**2 *(self.param.N_channel_use)/torch.sum(self.weight_power1**2))
        
        # power weight 2
        self.weight_power2 = torch.nn.Parameter(torch.Tensor(self.param.N_channel_use), requires_grad = True )
        self.weight_power2.data.uniform_(1.0, 1.0) # all 1
        self.weight_power2_normalized = torch.sqrt(self.weight_power2**2 *(self.param.N_channel_use)/torch.sum(self.weight_power2**2))
        
        # decoder 1
        self.decoder1_RNN = torch.nn.GRU(self.num_input1 + 1 + self.num_E, self.param.decoder_N_neurons, num_layers = self.param.decoder_N_layers, 
                                        bias=True, batch_first=True, dropout=0, bidirectional= self.param.decoder_bidirection) 
        self.decoder1_linear = torch.nn.Linear(self.decoder_bi*self.param.decoder_N_neurons, self.num_output1) # 100,10
        
        # decoder 2
        self.decoder2_RNN = torch.nn.GRU(self.num_input2 + 1 + self.num_E, self.param.decoder_N_neurons, num_layers = self.param.decoder_N_layers, 
                                        bias=True, batch_first=True, dropout=0, bidirectional= self.param.decoder_bidirection) 
        self.decoder2_linear = torch.nn.Linear(self.decoder_bi*self.param.decoder_N_neurons, self.num_output2) # 100,10

        
        # attention type
        if self.param.attention_type==5:  # bi-directional --> 2N weights
            self.weight1_merge = torch.nn.Parameter(torch.Tensor(self.param.N_channel_use,2), requires_grad = True ) 
            self.weight1_merge.data.uniform_(1.0, 1.0) # all 1
            # Normalization
            self.weight1_merge_normalized_fwd = torch.sqrt(self.weight1_merge[:,0]**2 *(self.param.N_channel_use)/torch.sum(self.weight1_merge[:,0]**2)) 
            self.weight1_merge_normalized_bwd  = torch.sqrt(self.weight1_merge[:,1]**2 *(self.param.N_channel_use)/torch.sum(self.weight1_merge[:,1]**2))
        
            self.weight2_merge = torch.nn.Parameter(torch.Tensor(self.param.N_channel_use,2), requires_grad = True ) 
            self.weight2_merge.data.uniform_(1.0, 1.0) # all 1
            # Normalization
            self.weight2_merge_normalized_fwd = torch.sqrt(self.weight2_merge[:,0]**2 *(self.param.N_channel_use)/torch.sum(self.weight2_merge[:,0]**2)) 
            self.weight2_merge_normalized_bwd  = torch.sqrt(self.weight2_merge[:,1]**2 *(self.param.N_channel_use)/torch.sum(self.weight2_merge[:,1]**2))
        
        if self.param.attention_type== 4: # uni-directional --> N weights
            self.weight1_merge = torch.nn.Parameter(torch.Tensor(self.param.N_channel_use),requires_grad = True )
            self.weight1_merge.data.uniform_(1.0, 1.0) # all 1
            # Normalization
            self.weight1_merge_normalized  = torch.sqrt(self.weight1_merge**2 *(self.param.N_channel_use)/torch.sum(self.weight1_merge**2))
        
            self.weight2_merge = torch.nn.Parameter(torch.Tensor(self.param.N_channel_use),requires_grad = True )
            self.weight2_merge.data.uniform_(1.0, 1.0) # all 1
            # Normalization
            self.weight2_merge_normalized  = torch.sqrt(self.weight2_merge**2 *(self.param.N_channel_use)/torch.sum(self.weight2_merge**2))
        
        
        # Parameters for normalization (mean and variance)
        # User 1
        self.mean1_batch = torch.zeros(self.param.N_channel_use) 
        self.std1_batch = torch.ones(self.param.N_channel_use)
        self.mean1_saved = torch.zeros(self.param.N_channel_use)
        self.std1_saved = torch.ones(self.param.N_channel_use)
        # User 2
        self.mean2_batch = torch.zeros(self.param.N_channel_use) 
        self.std2_batch = torch.ones(self.param.N_channel_use)
        self.mean2_saved = torch.zeros(self.param.N_channel_use)
        self.std2_saved = torch.ones(self.param.N_channel_use)
        self.normalization_with_saved_data = False   # True: inference with saved mean/var, False: calculate mean/var

    def decoder_activation(self, inputs):
        if self.param.output_type == 'bit_vector':
            return torch.sigmoid(inputs) # training with binary cross entropy
        elif self.param.output_type == 'one_hot_vector':
            return inputs # Note. softmax function is applied in "F.cross_entropy" function
    
    # Convert `bit vector' to 'one-hot vector'
    def one_hot(self, bit_vec):
#         bit_vec = bit_vec.view(parameter.batch_size, parameter.N_bits)
        bit_vec = bit_vec.view(parameter.batch_size, -1)
        N_batch = bit_vec.size(0) # batch_size
        N_bits = bit_vec.size(1)  # N_bits=K

        ind = torch.arange(0,N_bits).repeat(N_batch,1) 
        ind = ind.to(device)
        ind_vec = torch.sum( torch.mul(bit_vec, 2**ind), axis=1 ).long()
        bit_onehot = torch.zeros((N_batch, 2**N_bits), dtype=int)
        for ii in range(N_batch):
            bit_onehot[ii, ind_vec[ii]]=1 # one-hot vector
        return bit_onehot 
        
    def normalization(self, inputs, t_idx, user_idx):
        if self.training: # During training
            mean_batch = torch.mean(inputs)
            std_batch  = torch.std(inputs)
            outputs   = (inputs - mean_batch)/std_batch
        else: 
            if self.normalization_with_saved_data: # During inference
                if user_idx==1:
                    outputs   = (inputs - self.mean1_saved[t_idx])/self.std1_saved[t_idx]
                elif user_idx==2:
                    outputs   = (inputs - self.mean2_saved[t_idx])/self.std2_saved[t_idx]
            else: 
                # During validation
                mean_batch = torch.mean(inputs)
                std_batch  = torch.std(inputs)
                outputs   = (inputs - mean_batch)/std_batch
                # calculate mean/var after training
                if user_idx==1:
                    self.mean1_batch[t_idx] = mean_batch
                    self.std1_batch[t_idx] = std_batch
                elif user_idx==2:
                    self.mean2_batch[t_idx] = mean_batch
                    self.std2_batch[t_idx] = std_batch
        return outputs


    def forward(self, b1, b2, noise1, noise2):

        # Normalize power weights
        self.weight_power1_normalized  = torch.sqrt(self.weight_power1**2 *(self.param.N_channel_use)/torch.sum(self.weight_power1**2))
        self.weight_power2_normalized  = torch.sqrt(self.weight_power2**2 *(self.param.N_channel_use)/torch.sum(self.weight_power2**2))
        
        # Encoder input
        if self.param.input_type == 'bit_vector':
            I1 = b1 
            I2 = b2 
        elif self.param.input_type == 'one_hot_vector':
            I1 = self.one_hot(b1).to(device)
            I2 = self.one_hot(b2).to(device)
        
        for t in range(self.param.N_channel_use): # timesteps
            # Encoder
            if t == 0: # 1st timestep
                input1_total        = torch.cat([I1.view(self.param.batch_size, 1, self.num_input1), 
                                               torch.zeros((self.param.batch_size, 1, self.num_D+1)).to(device)], dim=2) 
                ### input1_total   -- (batch, 1, num_input + num_D + 1) 
                x1_t_after_RNN, s1_t_hidden  = self.encoder1_RNN(input1_total)
                ### x1_t_after_RNN -- (batch, 1, hidden)
                ### s1_t_hidden    -- (layers, batch, hidden)
                x1_t_tilde =   torch.tanh(self.encoder1_linear(x1_t_after_RNN))   
                
                input2_total        = torch.cat([I2.view(self.param.batch_size, 1, self.num_input2), 
                                               torch.zeros((self.param.batch_size, 1, self.num_D+1)).to(device)], dim=2) 
                x2_t_after_RNN, s2_t_hidden  = self.encoder2_RNN(input2_total)
                x2_t_tilde =   torch.tanh(self.encoder2_linear(x2_t_after_RNN))   
                
                
            else: # 2nd-Nth timestep
                if self.param.decoder_info == 'None':
                    input1_total        = torch.cat([I1.view(self.param.batch_size, 1, self.num_input1), y1_t], dim=2) 
                    input2_total        = torch.cat([I2.view(self.param.batch_size, 1, self.num_input2), y2_t], dim=2) 
                else:
                    input1_total        = torch.cat([I1.view(self.param.batch_size, 1, self.num_input1), y1_t, D1_tmp], dim=2) 
                    input2_total        = torch.cat([I2.view(self.param.batch_size, 1, self.num_input2), y2_t, D2_tmp], dim=2)
                
                x1_t_after_RNN, s1_t_hidden  = self.encoder1_RNN(input1_total, s1_t_hidden)
                x1_t_tilde =   torch.tanh(self.encoder1_linear(x1_t_after_RNN))
                
                
                x2_t_after_RNN, s2_t_hidden  = self.encoder2_RNN(input2_total, s2_t_hidden)
                x2_t_tilde =   torch.tanh(self.encoder2_linear(x2_t_after_RNN))
            
            # Power control layer: 1. Normalization, 2. Power allocation
            x1_t_norm = self.normalization(x1_t_tilde, t, 1).view(self.param.batch_size, 1, 1)
            x1_t  = x1_t_norm * self.weight_power1_normalized[t] 
            x2_t_norm = self.normalization(x2_t_tilde, t, 2).view(self.param.batch_size, 1, 1)
            x2_t  = x2_t_norm * self.weight_power2_normalized[t] 
            
            # Forward transmission (from User 1 to 2)
            y2_t = x1_t + noise1[:,t,:].view(self.param.batch_size, 1, 1)
            
            # Backward transmission (from User 2 to 1)
            y1_t = x2_t + noise2[:,t,:].view(self.param.batch_size, 1, 1)
            
            # Concatenate values along time t
            if t == 0:
#                 x1_norm_total = x1_t_norm
                x1_total = x1_t
                x2_total = x2_t
                y1_total = y1_t
                y2_total = y2_t
            else:
#                 x_norm_total = torch.cat([x_norm_total, x_t_norm], dim=1) 
                x1_total = torch.cat([x1_total, x1_t ], dim = 1) # In the end, (batch, N, 1)
                x2_total = torch.cat([x2_total, x2_t ], dim = 1)
                y1_total = torch.cat([y1_total, y1_t ], dim = 1) 
                y2_total = torch.cat([y2_total, y2_t ], dim = 1) 
            
            # Encoder info updates
            if self.param.encoder_info == 'tran_symbol':
                E1_tmp = x1_t # (batch,1,1)
                E2_tmp = x2_t # (batch,1,1)
            elif self.param.encoder_info == 'state_vector':
                E1_tmp = x1_t_after_RNN # (batch, 1, hidden) 
                E2_tmp = x2_t_after_RNN # (batch, 1, hidden) 
#                 E1_tmp = s1_t_hidden[-1].view(self.param.batch_size, 1, -1) # (batch, 1, hidden) # only last layer
#                 E2_tmp = s2_t_hidden[-1].view(self.param.batch_size, 1, -1) # (batch, 1, hidden)
            elif self.param.encoder_info == 'None':
                E1_tmp = None
                E2_tmp = None
            
            ########################################
            # Immediate Decoding
            if self.param.decoder_info != 'None': 
                # Encoder uses decoder info
                # --> There is connection from decoder to encoder
                # --> Immediate decoding is needed!
                if self.param.encoder_info == 'None': # Decoder does not use encoder info
                    decoder_input1 = torch.cat([I1.view(self.param.batch_size, 1, self.num_input1), y1_t], dim=2) # (batch, 1, num_input + 1)
                    decoder_input2 = torch.cat([I2.view(self.param.batch_size, 1, self.num_input2), y2_t], dim=2) # (batch, 1, num_input + 1)
                else:
                    decoder_input1 = torch.cat([I1.view(self.param.batch_size, 1, self.num_input1), y1_t, E1_tmp], dim=2) # (batch, 1, num_input + 1 + num_E)
                    decoder_input2 = torch.cat([I2.view(self.param.batch_size, 1, self.num_input2), y2_t, E2_tmp], dim=2)
    
                ### decoder_input1: (batch, 1, num_input + 1 + num_E) -- batch, input seq, input size
                r1_t_last, r1_t_hidden  = self.decoder1_RNN(decoder_input1)
                ### r1_t_last   -- (batch, 1(=input seq), hidden)
                ### r1_t_hidden -- (layer, batch, hidden)
                r2_t_last, r2_t_hidden  = self.decoder2_RNN(decoder_input2)
                
                if self.param.decoder_info == 'state_vector': # No output calculation required
                    D1_tmp = r1_t_last
                    D2_tmp = r2_t_last
                    if t==0:
                        r1_hidden = r1_t_last
                        r2_hidden = r2_t_last
                    else:
                        r1_hidden = torch.cat([r1_hidden, r1_t_last], dim=1) # Finally, (batch, N, hidden)
                        r2_hidden = torch.cat([r2_hidden, r2_t_last], dim=1)
                        
#                 if self.decoder_info == 'bit_estimate':
                
#                     output1     = self.decoder_activation(self.decoder1_linear(z1_t_after_RNN)) # (batch,1,num_output)
#                     output1_last = output1.view(self.param.batch_size,-1,1) # (batch, num_output, 1)
                
        # Decoder do inference after N transmission are conducted!   
        if self.param.decoder_info == 'state_vector':
            # Only Uni-directinoal is possible
            
            # Normalize attention weights (Uni-directional attention weights)
            self.weight1_merge_normalized  = torch.sqrt(self.weight1_merge**2 *(self.param.N_channel_use)/torch.sum(self.weight1_merge**2)) 
            self.weight2_merge_normalized  = torch.sqrt(self.weight2_merge**2 *(self.param.N_channel_use)/torch.sum(self.weight2_merge**2)) 
            
            # Multiply attention weights
            r1_merge = torch.tensordot(r1_hidden, self.weight1_merge_normalized, dims=([1], [0])) # (batch, hidden)
            output1 = self.decoder_activation(self.decoder1_linear(r1_merge)) 
            output1_last = output1.view(self.param.batch_size,-1,1) # (batch, num_output, 1)

            r2_merge = torch.tensordot(r2_hidden, self.weight2_merge_normalized, dims=([1], [0])) # (batch, hidden)
            output2 = self.decoder_activation(self.decoder2_linear(r2_merge)) 
            output2_last = output2.view(self.param.batch_size,-1,1) # (batch, num_output, 1)
    
    
        # Non-immediate Decoding
        if self.param.decoder_info == 'None': # No connection from decoder to encoder
            # Normalize attention weights
            if parameter.attention_type== 4:
                self.weight1_merge_normalized  = torch.sqrt(self.weight1_merge**2 *(self.param.N_channel_use)/torch.sum(self.weight1_merge**2)) 
                self.weight2_merge_normalized  = torch.sqrt(self.weight2_merge**2 *(self.param.N_channel_use)/torch.sum(self.weight2_merge**2)) 
            if parameter.attention_type== 5:
                self.weight1_merge_normalized_fwd  = torch.sqrt(self.weight1_merge[:,0]**2 *(self.param.N_channel_use)/torch.sum(self.weight1_merge[:,0]**2)) # 30
                self.weight1_merge_normalized_bwd  = torch.sqrt(self.weight1_merge[:,1]**2 *(self.param.N_channel_use)/torch.sum(self.weight1_merge[:,1]**2))
                self.weight2_merge_normalized_fwd  = torch.sqrt(self.weight2_merge[:,0]**2 *(self.param.N_channel_use)/torch.sum(self.weight2_merge[:,0]**2)) # 30
                self.weight2_merge_normalized_bwd  = torch.sqrt(self.weight2_merge[:,1]**2 *(self.param.N_channel_use)/torch.sum(self.weight2_merge[:,1]**2))

                
            I1_tmp = I1.view(self.param.batch_size, 1, self.num_input1)
            I1_copy = I1_tmp.repeat(1, self.param.N_channel_use, 1) # (batch, N, K1)
            I2_tmp = I2.view(self.param.batch_size, 1, self.num_input2)
            I2_copy = I2_tmp.repeat(1, self.param.N_channel_use, 1) # (batch, N, K2)
            if self.param.encoder_info == 'None':
                decoder1_input = torch.cat([I1_copy, y1_total], dim=2) # (batch, N, K1+1)
                decoder2_input = torch.cat([I2_copy, y2_total], dim=2) # (batch, N, K2+1)
            elif self.param.encoder_info == 'tran_symbol':
                decoder1_input = torch.cat([I1_copy, x1_total, y1_total], dim=2) # (batch, N, K1+2)
                decoder2_input = torch.cat([I2_copy, x2_total, y2_total], dim=2) # (batch, N, K2+2)
            
            r1_hidden, _  = self.decoder1_RNN(decoder1_input) # (batch, N, bi*hidden_size)
            r2_hidden, _  = self.decoder2_RNN(decoder2_input) # (batch, N, bi*hidden_size)


    #         # Option 1. Only the N-th timestep
    #         if parameter.attention_type== 1:
    #             output     = self.decoder_activation(self.decoder_linear(r_hidden)) #(batch,N,bi*hidden)-->(batch,N,num_output)
    #             output_last = output[:,-1,:].view(self.param.batch_size,-1,1) # (batch,num_output,1)

    #         # Option 2. Merge the "last" outputs of forward/backward RNN
    #         if parameter.attention_type== 2:
    #             r_backward = r_hidden[:,0,self.param.decoder_N_neurons:] # Output at the 1st timestep of reverse RNN 
    #             r_forward = r_hidden[:,-1,:self.param.decoder_N_neurons] # Output at the N-th timestep of forward RNN
    #             r_concat = torch.cat([r_backward, r_forward ], dim = 1) 
    #             output = self.decoder_activation(self.decoder_linear(r_concat)) # (batch,num_output)
    #             output_last = output.view(self.param.batch_size,-1,1) # (batch,num_output,1)

    #         # Option 3. Sum over all timesteps
    #         if parameter.attention_type== 3:
    #             output     = self.decoder_activation(self.decoder_linear(r_hidden)) 
    #             output_last = torch.sum(output, dim=1).view(self.param.batch_size,-1,1) # (batch,num_output,1)

            # Option 4. Attention mechanism (N weights)
            if parameter.attention_type== 4:
                r1_concat = torch.tensordot(r1_hidden, self.weight1_merge_normalized, dims=([1], [0])) # (batch, hidden_size)
                output1 = self.decoder_activation(self.decoder1_linear(r1_concat)) 
                output1_last = output1.view(self.param.batch_size,-1,1) # (batch,num_output,1)
                
                r2_concat = torch.tensordot(r2_hidden, self.weight2_merge_normalized, dims=([1], [0])) # (batch, hidden_size)
                output2 = self.decoder_activation(self.decoder2_linear(r2_concat)) 
                output2_last = output2.view(self.param.batch_size,-1,1) # (batch,num_output,1)

            # Option 5. Attention mechanism (2N weights) for forward/backward
            if parameter.attention_type== 5:
                r1_hidden_forward = r1_hidden[:,:,:self.param.decoder_N_neurons]  # (batch,num_output,hidden_size)
                r1_hidden_backward = r1_hidden[:,:,self.param.decoder_N_neurons:] # (batch,num_output,hidden_size)
                r1_forward_weighted_sum = torch.tensordot(r1_hidden_forward, self.weight1_merge_normalized_fwd, dims=([1], [0]))  # (batch,hidden_size)
                r1_backward_weighted_sum = torch.tensordot(r1_hidden_backward, self.weight1_merge_normalized_bwd, dims=([1], [0]))         # (batch,hidden_size)
                r1_concat = torch.cat([r1_forward_weighted_sum, r1_backward_weighted_sum], dim = 1) 
                output1 = self.decoder_activation(self.decoder1_linear(r1_concat)) 
                output1_last = output1.view(self.param.batch_size,-1,1) # (batch,num_output,1)

                r2_hidden_forward = r2_hidden[:,:,:self.param.decoder_N_neurons]  # (batch,num_output,hidden_size)
                r2_hidden_backward = r2_hidden[:,:,self.param.decoder_N_neurons:] # (batch,num_output,hidden_size)
                r2_forward_weighted_sum = torch.tensordot(r2_hidden_forward, self.weight2_merge_normalized_fwd, dims=([1], [0]))  # (batch,hidden_size)
                r2_backward_weighted_sum = torch.tensordot(r2_hidden_backward, self.weight2_merge_normalized_bwd, dims=([1], [0]))         # (batch,hidden_size)
                r2_concat = torch.cat([r2_forward_weighted_sum, r2_backward_weighted_sum], dim = 1) 
                output2 = self.decoder_activation(self.decoder2_linear(r2_concat)) 
                output2_last = output2.view(self.param.batch_size,-1,1) # (batch,num_output,1)

            
        self.x1 = x1_total                    # (batch,N,1)
        self.x2 = x2_total                    # (batch,N,1)
        
        return output1_last, output2_last

In [6]:
# Convert the `bit vector' with (batch,K,1) to 'one hot vector' with (batch,2^K)
def one_hot(bit_vec):
    bit_vec = bit_vec.squeeze(-1)  # (batch, K)
    N_batch = bit_vec.size(0) 
    N_bits = bit_vec.size(1)

    ind = torch.arange(0,N_bits).repeat(N_batch,1) # (batch, K)
    ind = ind.to(device)
    ind_vec = torch.sum( torch.mul(bit_vec,2**ind), axis=1).long() # batch
    b_onehot = torch.zeros((N_batch, 2**N_bits), dtype=int)
    for ii in range(N_batch):
        b_onehot[ii, ind_vec[ii]]=1 # one-hot vector
    return b_onehot

In [7]:
# Test
def test_RNN(N_test): 

    # Generate test data
    bit1_test     = torch.randint(0, 2, (N_test, parameter.K1, 1)) 
    bit2_test     = torch.randint(0, 2, (N_test, parameter.K2, 1)) 
    noise1_test  = sigma1*torch.randn((N_test, parameter.N_channel_use,1))
    noise2_test   = sigma2*torch.randn((N_test, parameter.N_channel_use,1))
    
    model.eval() # model.training() becomes False
    N_iter = (N_test//parameter.batch_size) # N_test should be multiple of batch_size
    ber1=0
    bler1=0
    ber2=0
    bler2=0
    power1_acc = np.zeros((parameter.batch_size, parameter.N_channel_use ,1))
    power2_acc = np.zeros((parameter.batch_size, parameter.N_channel_use ,1))
    with torch.no_grad():
        for i in range(N_iter):
            bit1 = bit1_test[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.K1,1) # batch, K,1
            bit2 = bit2_test[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.K2,1)
            noise1 = noise1_test[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1) # batch, N,1
            noise2 = noise2_test[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1) # batch, N,1

            bit1 = bit1.to(device)
            bit2 = bit2.to(device)
            noise1 = noise1.to(device)
            noise2 = noise2.to(device)

            # Forward pass
            X2_hat, X1_hat = model(bit1, bit2, noise1, noise2)

            if parameter.output_type == 'bit_vector':
                ber1_tmp, bler1_tmp = error_rate_bitvector(X1_hat.cpu(), bit1.cpu())
                ber2_tmp, bler2_tmp = error_rate_bitvector(X2_hat.cpu(), bit2.cpu())
            elif parameter.output_type == 'one_hot_vector':
                ber1_tmp, bler1_tmp = error_rate_onehot(X1_hat.cpu(), bit1.cpu())
                ber2_tmp, bler2_tmp = error_rate_onehot(X2_hat.cpu(), bit2.cpu())
                
            ber1 = ber1 + ber1_tmp
            ber2 = ber2 + ber2_tmp
            bler1 = bler1 + bler1_tmp
            bler2 = bler2 + bler2_tmp
            
            # Power
            signal1 = model.x1.cpu().detach().numpy()
            power1_acc += signal1**2 
            signal2 = model.x2.cpu().detach().numpy()
            power2_acc += signal2**2 
            
        ber1  = ber1/N_iter
        ber2  = ber2/N_iter
        bler1 = bler1/N_iter
        bler2 = bler2/N_iter
        power1_avg = power1_acc/N_iter
        power2_avg = power2_acc/N_iter

    return ber1, ber2, bler1, bler2, power1_avg, power2_avg

In [8]:
if use_cuda:
    model = Twoway_coding(parameter).to(device)
else:
    model = Twoway_coding(parameter)

print(model)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=parameter.learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

Twoway_coding(
  (encoder1_RNN): GRU(7, 50, num_layers=2, batch_first=True)
  (encoder1_linear): Linear(in_features=50, out_features=1, bias=True)
  (encoder2_RNN): GRU(7, 50, num_layers=2, batch_first=True)
  (encoder2_linear): Linear(in_features=50, out_features=1, bias=True)
  (decoder1_RNN): GRU(8, 50, num_layers=2, batch_first=True, bidirectional=True)
  (decoder1_linear): Linear(in_features=100, out_features=64, bias=True)
  (decoder2_RNN): GRU(8, 50, num_layers=2, batch_first=True, bidirectional=True)
  (decoder2_linear): Linear(in_features=100, out_features=64, bias=True)
)


In [None]:
# Training
num_epoch = 100
clipping_value = 1

print('Before training ')
print('weight_power1: ', model.weight_power1_normalized.cpu().detach().numpy().round(3))
print('weight_power2: ', model.weight_power2_normalized.cpu().detach().numpy().round(3))
if parameter.attention_type==4:
    print('weight1_merge: ', model.weight1_merge_normalized.cpu().detach().numpy().round(3))
    print('weight2_merge: ', model.weight2_merge_normalized.cpu().detach().numpy().round(3))
if parameter.attention_type==5:
    print('weight1_merge_fwd: ', model.weight1_merge_normalized_fwd.cpu().detach().numpy().round(3))
    print('weight1_merge_bwd: ', model.weight1_merge_normalized_bwd.cpu().detach().numpy().round(3))
    print('weight2_merge_fwd: ', model.weight2_merge_normalized_fwd.cpu().detach().numpy().round(3))
    print('weight2_merge_bwd: ', model.weight2_merge_normalized_bwd.cpu().detach().numpy().round(3))
print()

for epoch in range(num_epoch):

    model.train() # model.training() becomes True
    loss_training = 0
    
    N_iter = (N_train//parameter.batch_size)
    for i in range(N_iter):
        bit1 = bit1_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.K1,1) 
        bit2 = bit2_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.K2,1) 
        noise1 = noise1_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1)
        noise2 = noise2_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1)

        bit1   = bit1.to(device)
        bit2   = bit2.to(device)
        noise1 = noise1.to(device)
        noise2 = noise2.to(device)

        # forward pass
        optimizer.zero_grad() 
        X2_hat, X1_hat = model(bit1, bit2, noise1, noise2)

        # Define loss according to output type
        if parameter.output_type == 'bit_vector':
            bit1 = bit1.type(torch.float32)
            bit2 = bit2.type(torch.float32)
            loss = F.binary_cross_entropy(X1_hat, bit1) + F.binary_cross_entropy(X2_hat, bit2)
        elif parameter.output_type == 'one_hot_vector':
            bit1_hot =  one_hot(bit1).view(parameter.batch_size, 2**parameter.K1, 1) # (batch,2^K,1)
            bit2_hot =  one_hot(bit2).view(parameter.batch_size, 2**parameter.K2, 1)
            loss = (F.cross_entropy(X1_hat.squeeze(-1), torch.argmax(bit1_hot,dim=1).squeeze(-1).to(device)) # (batch,2^K), (batch)
                    + F.cross_entropy(X2_hat.squeeze(-1), torch.argmax(bit2_hot,dim=1).squeeze(-1).to(device)))
        
        # training
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_value)
        loss_training += loss.item()
        optimizer.step()
        
        if i % 100 == 0:
            print('Epoch: {}, Iter: {} out of {}, Loss: {:.4f}'.format(epoch, i, N_iter, loss.item()))

    # Summary of each epoch
    print('Summary: Epoch: {}, lr: {}, Average loss: {:.4f}'.format(epoch, optimizer.param_groups[0]['lr'], loss_training/N_iter) )

    scheduler.step() # reduce learning rate
    
    print('weight_power1: ', model.weight_power1_normalized.cpu().detach().numpy().round(3))
    print('weight_power2: ', model.weight_power2_normalized.cpu().detach().numpy().round(3))
    if parameter.attention_type==4:
        print('weight1_merge: ', model.weight1_merge_normalized.cpu().detach().numpy().round(3))
        print('weight2_merge: ', model.weight2_merge_normalized.cpu().detach().numpy().round(3))
    if parameter.attention_type==5:
        print('weight1_merge_fwd: ', model.weight1_merge_normalized_fwd.cpu().detach().numpy().round(3))
        print('weight1_merge_bwd: ', model.weight1_merge_normalized_bwd.cpu().detach().numpy().round(3))
        print('weight2_merge_fwd: ', model.weight2_merge_normalized_fwd.cpu().detach().numpy().round(3))
        print('weight2_merge_bwd: ', model.weight2_merge_normalized_bwd.cpu().detach().numpy().round(3))
    print()
    
    # Validation
    ber1_val, ber2_val, bler1_val, bler2_val, _, _ = test_RNN(N_validation)
    print('Ber1:  ', float(ber1_val))
    print('Ber2:  ', float(ber2_val))
    print('Bler1: ', float(bler1_val))
    print('Bler2: ', float(bler2_val))
    print()

Before training 
weight_power1:  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
weight_power2:  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
weight1_merge_fwd:  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
weight1_merge_bwd:  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
weight2_merge_fwd:  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
weight2_merge_bwd:  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]

Epoch: 0, Iter: 0 out of 400, Loss: 10.9219
Epoch: 0, Iter: 100 out of 400, Loss: 3.1653
Epoch: 0, Iter: 200 out of 400, Loss: 2.8243
Epoch: 0, Iter: 300 out of 400, Loss: 2.5596
Summary: Epoch: 0, lr: 0.01, Average loss: 3.1404
weight_power1:  [1.606 1.272 1.098 1.056 1.025 0.938 0.983 0.929 0.865 0.894 0.885 0.846
 0.864 0.868 0.778 0.731 0.9   1.108]
weight_power2:  [1.65  1.17  0.972 1.003 1.038 0.983 0.943 0.931 0.895 0.819 0.771 0.767
 0.783 0.787 0.779 0.847 1.048 1.36 ]
weight1_merge_fwd:  [0.896 1.137 1.351 1.476 1.493 

Ber1:   0.23132501542568207
Ber2:   3.166666647302918e-05
Bler1:  0.5386999845504761
Bler2:  4.999999873689376e-05

Epoch: 7, Iter: 0 out of 400, Loss: 1.7729
Epoch: 7, Iter: 100 out of 400, Loss: 1.7354
Epoch: 7, Iter: 200 out of 400, Loss: 1.7283
Epoch: 7, Iter: 300 out of 400, Loss: 1.7455
Summary: Epoch: 7, lr: 0.006983372960937498, Average loss: 1.7479
weight_power1:  [0.99  1.009 1.008 1.003 1.013 1.013 1.004 1.009 0.999 0.999 1.007 0.996
 0.993 0.996 0.996 0.989 0.989 0.987]
weight_power2:  [1.356 0.67  1.035 0.992 0.914 1.034 1.056 0.997 1.014 1.023 0.998 1.017
 1.014 1.024 1.002 0.99  1.025 0.655]
weight1_merge_fwd:  [1.023 1.833 2.06  1.889 1.665 1.275 0.901 0.652 0.367 0.075 0.047 0.012
 0.007 0.005 0.007 0.032 0.043 0.026]
weight1_merge_bwd:  [0.049 0.05  0.174 0.489 0.952 1.398 1.87  2.199 1.691 1.325 1.008 0.606
 0.164 0.    0.    0.    0.003 0.71 ]
weight2_merge_fwd:  [8.230e-01 1.349e+00 2.456e+00 2.903e+00 1.014e+00 4.500e-02 5.000e-03
 5.000e-03 7.100e-02 5.300e-02 4.

In [None]:
# Calculate mean/var with training data
model.eval()   # model.training() becomes False
N_iter = N_train//parameter.batch_size
mean1_train = torch.zeros(parameter.N_channel_use)
std1_train  = torch.zeros(parameter.N_channel_use)
mean2_train = torch.zeros(parameter.N_channel_use)
std2_train  = torch.zeros(parameter.N_channel_use)
mean1_total = 0
std1_total = 0
mean2_total = 0
std2_total = 0

with torch.no_grad():
    for i in range(N_iter):
        bit1   = bit1_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.K1,1) 
        bit2   = bit2_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.K2,1) 
        noise1 = noise1_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1)
        noise2 = noise2_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1)

        bit1   = bit1.to(device)
        bit2   = bit2.to(device)
        noise1 = noise1.to(device)
        noise2 = noise2.to(device)
        
        X2_hat, X1_hat = model(bit1, bit2, noise1, noise2)
        mean1_total += model.mean1_batch
        std1_total  += model.std1_batch
        mean2_total += model.mean2_batch
        std2_total  += model.std2_batch
        if i%100==0: print(i)
        
mean1_train = mean1_total/N_iter
std1_train = std1_total/N_iter
mean2_train = mean2_total/N_iter
std2_train = std2_total/N_iter
print('Mean1: ',mean1_train)
print('std1 : ',std1_train)
print('Mean2: ',mean2_train)
print('std2 : ',std2_train)

In [None]:
######## Save model
# save_results_to = ('saved_model/'+ 'diff_rates/' +
#                      'K1_K2_N(3_3_9)/bi_sigmoid/SNR1(1dB)SNR2(30dB)/train(1e9)/')
save_results_to = 'saved_model/'+ 'diff_rates/K1_K2_N(6_6_18)/bi_softmax/SNR1(-5dB)SNR2(30dB)/'
# save_results_to = 'saved_model/'+ 'diff_rates/K1_K2_N(3_3_9)/bi_sigmoid/SNR1(-10dB)SNR2(-10dB)/'

torch.save(model.state_dict(), save_results_to+'model.pth')
# np2 = sigma2_train**2
# save_results_to = 'saved_model/'+ 'np2_'+ str(np2)+'/'

####### Save normalization weights
torch.save(mean1_train, save_results_to+'mean1_train.pt')
torch.save(std1_train, save_results_to+'std1_train.pt')
torch.save(mean2_train, save_results_to+'mean2_train.pt')
torch.save(std2_train, save_results_to+'std2_train.pt')

save_results_to

In [None]:
save_results_to

In [None]:
# Inference stage
# N_inference = int(4e8) 
# N_small = int(1e5) # In case that N_inference is very large, we divide into small chunks
N_inference = int(1e10) 
N_small = int(1e6) 
N_iter  = N_inference//N_small

model.normalization_with_saved_data = True
model.mean1_saved = mean1_train
model.std1_saved  = std1_train
model.mean2_saved = mean2_train
model.std2_saved  = std2_train

ber1_sum  = 0
bler1_sum = 0
ber2_sum  = 0
bler2_sum = 0
power1_sum = np.zeros((parameter.batch_size, parameter.N_channel_use ,1))
power2_sum = np.zeros((parameter.batch_size, parameter.N_channel_use ,1))

for ii in range(N_iter):
    ber1_tmp, ber2_tmp, bler1_tmp, bler2_tmp, power1_tmp, power2_tmp = test_RNN(N_small)
    ber1_sum += ber1_tmp
    ber2_sum += ber2_tmp
    bler1_sum += bler1_tmp
    bler2_sum += bler2_tmp
    power1_sum += power1_tmp # (batch, N, 1)
    power2_sum += power2_tmp
    if ii%100==0: 
        print('Iter: {} out of {}'.format(ii, N_iter))
        print('Ber1:  ', float(ber1_sum/(ii+1)))
        print('Ber2:  ', float(ber2_sum/(ii+1)))
        print('Bler1: ', float(bler1_sum/(ii+1)))
        print('Bler2: ', float(bler2_sum/(ii+1)))
        print('Power1: ', round(np.sum(power1_sum)/(parameter.batch_size*(ii+1)),3))
        print('Power2: ', round(np.sum(power2_sum)/(parameter.batch_size*(ii+1)),3))

ber1_inference  = ber1_sum/N_iter
ber2_inference  = ber2_sum/N_iter
bler1_inference = bler1_sum/N_iter
bler2_inference = bler2_sum/N_iter

print()
print('Ber1:  ', float(ber1_inference))
print('Ber2:  ', float(ber2_inference))
print('Bler1: ', float(bler1_inference))
print('Bler2: ', float(bler2_inference))
print('Power1: ', round(np.sum(power1_sum)/(parameter.batch_size*N_iter),3))
print('Power2: ', round(np.sum(power2_sum)/(parameter.batch_size*N_iter),3))


In [None]:
print('Iter: {} out of {}'.format(ii, N_iter))
print('Ber1:  ', float(ber1_sum/(ii+1)))
print('Ber2:  ', float(ber2_sum/(ii+1)))
print('Bler1: ', float(bler1_sum/(ii+1)))
print('Bler2: ', float(bler2_sum/(ii+1)))
print('Power1: ', round(np.sum(power1_sum)/(parameter.batch_size*(ii+1)),3))
print('Power2: ', round(np.sum(power2_sum)/(parameter.batch_size*(ii+1)),3))

In [17]:
# END

'saved_model/diff_rates/K1_K2_N(6_6_18)/bi_sigmoid/SNR1(-1dB)SNR2(5dB)/'

In [9]:
######## Recall model
# save_results_to = 'saved_model/'+ 'diff_rates/K1_K2_N(3_3_9)/bi_softmax/SNR1(5dB)SNR2(20dB)/'
# save_results_to = 'saved_model/'+ 'diff_rates/K1_K2_N(3_3_9)/bi_sigmoid/SNR1(1dB)SNR2(30dB)/train(1e8)/'
save_results_to = ('saved_model/'+ 'diff_rates/' +
                     'K1_K2_N(3_3_9)/bi_sigmoid/SNR1(1dB)SNR2(20dB)/train(1e8)/')

model.load_state_dict(torch.load(save_results_to+'model.pth'))
model.eval()


####### Load normalization weights
mean1_train = torch.load(save_results_to+'mean1_train.pt')
std1_train = torch.load(save_results_to+'std1_train.pt')
mean2_train = torch.load(save_results_to+'mean2_train.pt')
std2_train = torch.load(save_results_to+'std2_train.pt')