In [1]:
import math
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

from ipynb.fs.full.module import * # !pip install ipynb

In [3]:
# %%html
# <style type='text/css'>
# .CodeMirror{
# font-size: 17px;
# </style>

In [2]:
class params():
    def __init__(self):

        # Encoder
        self.encoder_act_func = 'tanh'
        self.encoder_N_layers: int = 2    # number of RNN layers at encoder
        self.encoder_N_neurons: int = 50  # number of neurons at each RNN

        # Decoder
        self.decoder_N_layers: int = 2    # number of RNN layers at decoder
        self.decoder_N_neurons: int = 50  # number of neurons at each RNN
        self.decoder_bidirection = True   # True: bi-directional decoding, False: uni-directional decoding
        self.attention_type: int = 5      # choose the attention type among five options
        # 1. Only the last timestep (N-th)
        # 2. Merge the last outputs of forward/backward RNN
        # 3. Sum over all timesteps
        # 4. Attention mechanism with N weights (same weight for forward/backward)
        # 5. Attention mechanism with 2N weights (separate weights for forward/backward)
        
        # Setup
        self.N_bits: int = 3                # number of bits
        self.N_channel_use = 9             # number of channel uses
        self.input_type = 'bit_vector'      # choose 'bit_vector' or 'one_hot_vector'
        self.output_type = 'bit_vector' # choose 'bit_vector' or 'one_hot_vector'

        # Learning parameters
#         self.batch_size = int(2.5e4) 
#         self.batch_size = int(2e4) 
#         self.batch_size = int(1e5)
        self.batch_size = int(2.5e5) 
        self.learning_rate = 0.01 
        self.use_cuda = True

In [3]:
# model setup
parameter = params()
use_cuda = parameter.use_cuda and torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [4]:
# Generate training data
SNR1 = 15              # forward SNR in dB
np1 = 10**(-SNR1/10)   # noise power1 -- Assuming signal power is set to 1
sigma1 = np.sqrt(np1)

# np2_dB = -10          # noise power2 in dB  
# np2 = 10**(np2_dB/10)
# sigma2 = np.sqrt(np2)
sigma2 = 0

# Training set: tuples of (stream, noise1, noise 2)
# N_train = int(1e7)  # number of training set
N_train = int(1e8)  # number of training set
bits_train     = torch.randint(0, 2, (N_train, parameter.N_bits, 1))
noise1_train  = sigma1*torch.randn((N_train, parameter.N_channel_use, 1)) 
noise2_train   = sigma2*torch.randn((N_train, parameter.N_channel_use, 1)) 

# Validation
N_validation = int(1e5)

print('np1: ', np1)
# print('np2: ', np2)

np1:  0.03162277660168379


In [5]:
class Feedback_Code(torch.nn.Module):
    def __init__(self, param):
        super(Feedback_Code, self).__init__()
        
        # import parameter
        self.param        = param
        if self.param.decoder_bidirection == True:
            self.decoder_bi = 2 # bi-direction
        else:
            self.decoder_bi = 1 # uni-direction

        # input_type (bit vector, one-hot vector) -- Encoder
        if self.param.input_type == 'bit_vector':
            self.num_input = self.param.N_bits
        elif self.param.input_type == 'one_hot_vector':
            self.num_input = 2**self.param.N_bits
        
        # output_type (bits, one-hot vector) -- Decoder
        if self.param.output_type == 'bit_vector':
            self.num_output = self.param.N_bits
        elif self.param.output_type == 'one_hot_vector':
            self.num_output = 2**self.param.N_bits

        # encoder RNN
        self.encoder_RNN   = torch.nn.GRU(self.num_input + 1, self.param.encoder_N_neurons, num_layers = self.param.encoder_N_layers, 
                                          bias=True, batch_first=True, dropout=0, bidirectional=False)
        self.encoder_linear = torch.nn.Linear(self.param.encoder_N_neurons, 1)

        # power weights
        self.weight_power = torch.nn.Parameter(torch.Tensor(self.param.N_channel_use), requires_grad = True )
        self.weight_power.data.uniform_(1.0, 1.0) # all 1
        self.weight_power_normalized = torch.sqrt(self.weight_power**2 *(self.param.N_channel_use)/torch.sum(self.weight_power**2))

        # decoder 
        self.decoder_RNN = torch.nn.GRU(1, self.param.decoder_N_neurons, num_layers = self.param.decoder_N_layers, 
                                        bias=True, batch_first=True, dropout=0, bidirectional= self.param.decoder_bidirection) 
        self.decoder_linear = torch.nn.Linear(self.decoder_bi*self.param.decoder_N_neurons, self.num_output) # 100,10

        # attention type
        if self.param.attention_type==5:  # bi-directional --> 2N weights
            self.weight_merge = torch.nn.Parameter(torch.Tensor(self.param.N_channel_use,2), requires_grad = True ) 
            self.weight_merge.data.uniform_(1.0, 1.0) # all 1
            # Normalization
            self.weight_merge_normalized_fwd = torch.sqrt(self.weight_merge[:,0]**2 *(self.param.N_channel_use)/torch.sum(self.weight_merge[:,0]**2)) 
            self.weight_merge_normalized_bwd  = torch.sqrt(self.weight_merge[:,1]**2 *(self.param.N_channel_use)/torch.sum(self.weight_merge[:,1]**2))
        
        if self.param.attention_type== 4: # uni-directional --> N weights
            self.weight_merge = torch.nn.Parameter(torch.Tensor(self.param.N_channel_use),requires_grad = True )
            self.weight_merge.data.uniform_(1.0, 1.0) # all 1
            # Normalization
            self.weight_merge_normalized  = torch.sqrt(self.weight_merge**2 *(self.param.N_channel_use)/torch.sum(self.weight_merge**2))
        
        # Parameters for normalization (mean and variance)
        self.mean_batch = torch.zeros(self.param.N_channel_use) 
        self.std_batch = torch.ones(self.param.N_channel_use)
        self.mean_saved = torch.zeros(self.param.N_channel_use)
        self.std_saved = torch.ones(self.param.N_channel_use)
        self.normalization_with_saved_data = False   # True: inference with saved mean/var, False: calculate mean/var

    def decoder_activation(self, inputs):
        if self.param.output_type == 'bit_vector':
            return torch.sigmoid(inputs) # training with binary cross entropy
        elif self.param.output_type == 'one_hot_vector':
            return inputs # Note. softmax function is applied in "F.cross_entropy" function
    
    # Convert `bit vector' to 'one-hot vector'
    def one_hot(self, bit_vec):
        bit_vec = bit_vec.view(parameter.batch_size, parameter.N_bits)
        N_batch = bit_vec.size(0) # batch_size
        N_bits = bit_vec.size(1)  # N_bits=K

        ind = torch.arange(0,N_bits).repeat(N_batch,1) 
        ind = ind.to(device)
        ind_vec = torch.sum( torch.mul(bit_vec, 2**ind), axis=1 ).long()
        bit_onehot = torch.zeros((N_batch, 2**N_bits), dtype=int)
        for ii in range(N_batch):
            bit_onehot[ii, ind_vec[ii]]=1 # one-hot vector
        return bit_onehot 
        
    def normalization(self, inputs, t_idx):
        if self.training: # During training
            mean_batch = torch.mean(inputs)
            std_batch  = torch.std(inputs)
            outputs   = (inputs - mean_batch)/std_batch
        else: 
            if self.normalization_with_saved_data: # During inference
                outputs   = (inputs - self.mean_saved[t_idx])/self.std_saved[t_idx]
            else: # During validation
                mean_batch = torch.mean(inputs)
                std_batch  = torch.std(inputs)
                self.mean_batch[t_idx] = mean_batch
                self.std_batch[t_idx] = std_batch
                outputs   = (inputs - mean_batch)/std_batch
        return outputs


    def forward(self, b, noise1, noise2):

        # Normalize power weights
        self.weight_power_normalized  = torch.sqrt(self.weight_power**2 *(self.param.N_channel_use)/torch.sum(self.weight_power**2))

        # Encoder input
        if self.param.input_type == 'bit_vector':
            I = b 
        elif self.param.input_type == 'one_hot_vector':
            b_hot = self.one_hot(b).to(device)
            I = b_hot 
        
        for t in range(self.param.N_channel_use): # timesteps
            # Encoder
            if t == 0: # 1st timestep
                input_total        = torch.cat([I.view(self.param.batch_size, 1, self.num_input), 
                                               torch.zeros((self.param.batch_size, 1, 1)).to(device)], dim=2) 
                ### input_total -- (batch,1, num_input + 1) 
                x_t_after_RNN, s_t_hidden  = self.encoder_RNN(input_total)
                ### x_t_after_RNN -- (batch, 1, hidden)
                ### s_t -- (layers, batch, hidden)
                x_t_tilde =   torch.tanh(self.encoder_linear(x_t_after_RNN))   
                
            else: # 2-30nd timestep
                ######### Open-loop code feeds x_t instead of z_t
                input_total        = torch.cat([I.view(self.param.batch_size, 1, self.num_input), x_t], dim=2) 
                x_t_after_RNN, s_t_hidden  = self.encoder_RNN(input_total, s_t_hidden)
                x_t_tilde =   torch.tanh(self.encoder_linear(x_t_after_RNN))
            
            # Power control layer: 1. Normalization, 2. Power allocation
            x_t_norm = self.normalization(x_t_tilde,t).view(self.param.batch_size, 1, 1)
            x_t  = x_t_norm * self.weight_power_normalized[t] 
            
            # Forward transmission
            y_t = x_t + noise1[:,t,:].view(self.param.batch_size, 1, 1)
            
            # Feedback transmission
            z_t = y_t + noise2[:,t,:].view(self.param.batch_size, 1, 1)
            
            # Concatenate values along time t
            if t == 0:
                x_norm_total = x_t_norm
                x_total = x_t
                y_total = y_t
                z_total = z_t
            else:
                x_norm_total = torch.cat([x_norm_total, x_t_norm], dim=1) 
                x_total = torch.cat([x_total, x_t ], dim = 1) # In the end, (batch, N, 1)
                y_total = torch.cat([y_total, y_t ], dim = 1) 
                z_total = torch.cat([z_total, z_t ], dim = 1) 
     
        # Decoder
        # Normalize attention weights
        if parameter.attention_type== 4:
            self.weight_merge_normalized  = torch.sqrt(self.weight_merge**2 *(self.param.N_channel_use)/torch.sum(self.weight_merge**2)) 
        if parameter.attention_type== 5:
            self.weight_merge_normalized_fwd  = torch.sqrt(self.weight_merge[:,0]**2 *(self.param.N_channel_use)/torch.sum(self.weight_merge[:,0]**2)) # 30
            self.weight_merge_normalized_bwd  = torch.sqrt(self.weight_merge[:,1]**2 *(self.param.N_channel_use)/torch.sum(self.weight_merge[:,1]**2))
            
        decoder_input = y_total
        r_hidden, _  = self.decoder_RNN(decoder_input) # (batch, N, bi*hidden_size)
        
        # Option 1. Only the N-th timestep
        if parameter.attention_type== 1:
            output     = self.decoder_activation(self.decoder_linear(r_hidden)) #(batch,N,bi*hidden)-->(batch,N,num_output)
            output_last = output[:,-1,:].view(self.param.batch_size,-1,1) # (batch,num_output,1)
        
        # Option 2. Merge the "last" outputs of forward/backward RNN
        if parameter.attention_type== 2:
            r_backward = r_hidden[:,0,self.param.decoder_N_neurons:] # Output at the 1st timestep of reverse RNN 
            r_forward = r_hidden[:,-1,:self.param.decoder_N_neurons] # Output at the N-th timestep of forward RNN
            r_concat = torch.cat([r_backward, r_forward ], dim = 1) 
            output = self.decoder_activation(self.decoder_linear(r_concat)) # (batch,num_output)
            output_last = output.view(self.param.batch_size,-1,1) # (batch,num_output,1)
            
        # Option 3. Sum over all timesteps
        if parameter.attention_type== 3:
            output     = self.decoder_activation(self.decoder_linear(r_hidden)) 
            output_last = torch.sum(output, dim=1).view(self.param.batch_size,-1,1) # (batch,num_output,1)

        # Option 4. Attention mechanism (N weights)
        if parameter.attention_type== 4:
            r_concat = torch.tensordot(r_hidden, self.weight_merge_normalized, dims=([1], [0])) # (batch,bi*hidden_size)
            output = self.decoder_activation(self.decoder_linear(r_concat)) 
            output_last = output.view(self.param.batch_size,-1,1) # (batch,num_output,1)
            
        # Option 5. Attention mechanism (2N weights) for forward/backward
        if parameter.attention_type== 5:
            r_hidden_forward = r_hidden[:,:,:self.param.decoder_N_neurons]  # (batch,num_output,hidden_size)
            r_hidden_backward = r_hidden[:,:,self.param.decoder_N_neurons:] # (batch,num_output,hidden_size)
            r_forward_weighted_sum = torch.tensordot(r_hidden_forward, self.weight_merge_normalized_fwd, dims=([1], [0]))  # (batch,hidden_size)
            r_backward_weighted_sum = torch.tensordot(r_hidden_backward, self.weight_merge_normalized_bwd, dims=([1], [0]))         # (batch,hidden_size)
            r_concat = torch.cat([r_forward_weighted_sum, r_backward_weighted_sum], dim = 1) 
            output = self.decoder_activation(self.decoder_linear(r_concat)) 
            output_last = output.view(self.param.batch_size,-1,1) # (batch,num_output,1)

        self.x = x_total                    # (batch,N,1)

        return output_last

In [6]:
# Convert the `bit vector' with (batch,K,1) to 'one hot vector' with (batch,2^K)
def one_hot(bit_vec):
    bit_vec = bit_vec.squeeze(-1)  # (batch, K)
    N_batch = bit_vec.size(0) 
    N_bits = bit_vec.size(1)

    ind = torch.arange(0,N_bits).repeat(N_batch,1) # (batch, K)
    ind = ind.to(device)
    ind_vec = torch.sum( torch.mul(bit_vec,2**ind), axis=1).long() # batch
    b_onehot = torch.zeros((N_batch, 2**N_bits), dtype=int)
    for ii in range(N_batch):
        b_onehot[ii, ind_vec[ii]]=1 # one-hot vector
    return b_onehot

In [7]:
# Test
def test_RNN(N_test): 

    # Generate test data
    bits_test     = torch.randint(0, 2, (N_test, parameter.N_bits, 1)) 
    noise1_test  = sigma1*torch.randn((N_test, parameter.N_channel_use,1))
    noise2_test   = sigma2*torch.randn((N_test, parameter.N_channel_use,1))
    
    model.eval() # model.training() becomes False
    N_iter = (N_test//parameter.batch_size) # N_test should be multiple of batch_size
    ber=0
    bler=0
    power_avg = np.zeros((parameter.batch_size, parameter.N_channel_use ,1))
    with torch.no_grad():
        for i in range(N_iter):
            bits = bits_test[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_bits,1) # batch, K,1
            noise1 = noise1_test[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1) # batch, N,1
            noise2 = noise2_test[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1) # batch, N,1

            bits = bits.to(device)
            noise1 = noise1.to(device)
            noise2 = noise2.to(device)

            # Forward pass
            output = model(bits, noise1, noise2)

            if parameter.output_type == 'bit_vector':
                ber_tmp, bler_tmp = error_rate_bitvector(output.cpu(), bits.cpu())
            elif parameter.output_type == 'one_hot_vector':
                ber_tmp, bler_tmp = error_rate_onehot(output.cpu(), bits.cpu())

            ber = ber + ber_tmp
            bler = bler + bler_tmp
            # Power
            signal = model.x.cpu().detach().numpy()
            power_avg += signal**2 
            
        ber  = ber/N_iter
        bler = bler/N_iter
        power_avg = power_avg/N_iter

    return ber, bler, power_avg

In [8]:
if use_cuda:
    model = Feedback_Code(parameter).to(device)
else:
    model = Feedback_Code(parameter)

print(model)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=parameter.learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

Feedback_Code(
  (encoder_RNN): GRU(4, 50, num_layers=2, batch_first=True)
  (encoder_linear): Linear(in_features=50, out_features=1, bias=True)
  (decoder_RNN): GRU(1, 50, num_layers=2, batch_first=True, bidirectional=True)
  (decoder_linear): Linear(in_features=100, out_features=3, bias=True)
)


In [10]:
# Training
num_epoch = 100
clipping_value = 1
 
print('Before training ')
print('weight_power: ', model.weight_power_normalized.cpu().detach().numpy().round(3))
if parameter.attention_type==4:
    print('weight_merge: ', model.weight_merge_normalized.cpu().detach().numpy().round(3))
if parameter.attention_type==5:
    print('weight_merge_fwd: ', model.weight_merge_normalized_fwd.cpu().detach().numpy().round(3))
    print('weight_merge_bwd: ', model.weight_merge_normalized_bwd.cpu().detach().numpy().round(3))
print()

for epoch in range(num_epoch):

    model.train() # model.training() becomes True
    loss_training = 0
    
    N_iter = (N_train//parameter.batch_size)
    for i in range(N_iter):
        bits = bits_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_bits,1) 
        noise1 = noise1_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1)
        noise2 = noise2_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1)

        bits   = bits.to(device)
        noise1 = noise1.to(device)
        noise2 = noise2.to(device)

        # forward pass
        optimizer.zero_grad() 
        output = model(bits, noise1, noise2)

        # Define loss according to output type
        if parameter.output_type == 'bit_vector':
            bits = bits.type(torch.float32)
            loss = F.binary_cross_entropy(output, bits) 
        elif parameter.output_type == 'one_hot_vector':
            b_hot =  one_hot(bits).view(parameter.batch_size, 2**parameter.N_bits, 1) # (batch,2^K,1)
            loss = F.cross_entropy(output.squeeze(-1), torch.argmax(b_hot,dim=1).squeeze(-1).to(device))

        # training
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_value)
        loss_training += loss.item()
        optimizer.step()
        
        if i % 100 == 0:
            print('Epoch: {}, Iter: {} out of {}, Loss: {:.4f}'.format(epoch, i, N_iter, loss.item()))

    # Summary of each epoch
    print('Summary: Epoch: {}, lr: {}, Average loss: {:.4f}'.format(epoch, optimizer.param_groups[0]['lr'], loss_training/N_iter) )

    scheduler.step() # reduce learning rate
    
    print('weight_power', model.weight_power_normalized.cpu().detach().numpy())
    if parameter.attention_type==4:
        print('weight_merge: ', model.weight_merge_normalized.cpu().detach().numpy().round(3))
    if parameter.attention_type==5:
        print('weight_merge_fwd: ', model.weight_merge_normalized_fwd.cpu().detach().numpy().round(3))
        print('weight_merge_bwd: ', model.weight_merge_normalized_bwd.cpu().detach().numpy().round(3))
    print()
    
    # Validation
    ber_val, bler_val, _ = test_RNN(N_validation)
    print('Ber:  ', float(ber_val))
    print('Bler: ', float(bler_val))
    print()

Before training 
weight_power:  [1. 1. 1. 1. 1. 1. 1. 1. 1.]
weight_merge_fwd:  [1. 1. 1. 1. 1. 1. 1. 1. 1.]
weight_merge_bwd:  [1. 1. 1. 1. 1. 1. 1. 1. 1.]

Epoch: 0, Iter: 0 out of 1000, Loss: 0.6924
Epoch: 0, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 0, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 0, Iter: 300 out of 1000, Loss: 0.0000
Epoch: 0, Iter: 400 out of 1000, Loss: 0.0000
Epoch: 0, Iter: 500 out of 1000, Loss: 0.0000
Epoch: 0, Iter: 600 out of 1000, Loss: 0.0000
Epoch: 0, Iter: 700 out of 1000, Loss: 0.0000
Epoch: 0, Iter: 800 out of 1000, Loss: 0.0000
Epoch: 0, Iter: 900 out of 1000, Loss: 0.0000
Summary: Epoch: 0, lr: 0.01, Average loss: 0.0167
weight_power [1.1180408  1.1417706  0.9152422  0.99419135 0.8202066  1.0213768
 1.1003183  0.95180297 0.8875142 ]
weight_merge_fwd:  [0.866 1.014 0.82  0.998 1.186 1.09  1.001 1.056 0.918]
weight_merge_bwd:  [0.757 1.163 0.908 0.913 1.171 0.987 0.886 1.112 1.024]

Ber:   0.0
Bler:  0.0

Epoch: 1, Iter: 0 out of 1000, Loss: 0.0000
Ep

Epoch: 10, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 10, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 10, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 10, Iter: 300 out of 1000, Loss: 0.0000
Epoch: 10, Iter: 400 out of 1000, Loss: 0.0000
Epoch: 10, Iter: 500 out of 1000, Loss: 0.0000
Epoch: 10, Iter: 600 out of 1000, Loss: 0.0000
Epoch: 10, Iter: 700 out of 1000, Loss: 0.0000
Epoch: 10, Iter: 800 out of 1000, Loss: 0.0000
Epoch: 10, Iter: 900 out of 1000, Loss: 0.0000
Summary: Epoch: 10, lr: 0.005987369392383786, Average loss: 0.0000
weight_power [2.843454   0.5969379  0.24680261 0.05103696 0.06370119 0.11636201
 0.14150913 0.150721   0.6592264 ]
weight_merge_fwd:  [1.823 1.485 0.594 0.698 1.231 0.511 0.495 0.613 0.485]
weight_merge_bwd:  [0.106 1.327 0.055 0.041 1.55  0.09  0.073 1.566 1.534]

Ber:   0.0
Bler:  0.0

Epoch: 11, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 11, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 11, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 11, Iter: 300 out of 1000, Los

Epoch: 20, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 20, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 20, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 20, Iter: 300 out of 1000, Loss: 0.0000
Epoch: 20, Iter: 400 out of 1000, Loss: 0.0000
Epoch: 20, Iter: 500 out of 1000, Loss: 0.0000
Epoch: 20, Iter: 600 out of 1000, Loss: 0.0000
Epoch: 20, Iter: 700 out of 1000, Loss: 0.0000
Epoch: 20, Iter: 800 out of 1000, Loss: 0.0000
Epoch: 20, Iter: 900 out of 1000, Loss: 0.0000
Summary: Epoch: 20, lr: 0.0035848592240854188, Average loss: 0.0000
weight_power [2.8170657  0.45004568 0.15106589 0.06173339 0.01536005 0.04200209
 0.09851836 0.18158749 0.8889816 ]
weight_merge_fwd:  [1.852 1.633 0.61  0.686 1.15  0.471 0.394 0.483 0.359]
weight_merge_bwd:  [0.151 1.715 0.009 0.011 1.377 0.007 0.007 1.406 1.471]

Ber:   0.0
Bler:  0.0

Epoch: 21, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 21, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 21, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 21, Iter: 300 out of 1000, Lo

Epoch: 30, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 30, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 30, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 30, Iter: 300 out of 1000, Loss: 0.0000
Epoch: 30, Iter: 400 out of 1000, Loss: 0.0000
Epoch: 30, Iter: 500 out of 1000, Loss: 0.0000
Epoch: 30, Iter: 600 out of 1000, Loss: 0.0000
Epoch: 30, Iter: 700 out of 1000, Loss: 0.0000
Epoch: 30, Iter: 800 out of 1000, Loss: 0.0000
Epoch: 30, Iter: 900 out of 1000, Loss: 0.0000
Summary: Epoch: 30, lr: 0.0021463876394293723, Average loss: 0.0000
weight_power [2.7858636  0.53990537 0.12192475 0.025239   0.01363005 0.01865478
 0.04385028 0.09789461 0.9591257 ]
weight_merge_fwd:  [1.826 1.633 0.632 0.696 1.138 0.492 0.42  0.505 0.385]
weight_merge_bwd:  [0.172 1.881 0.035 0.023 1.291 0.019 0.019 1.341 1.402]

Ber:   0.0
Bler:  0.0

Epoch: 31, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 31, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 31, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 31, Iter: 300 out of 1000, Lo

Epoch: 40, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 40, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 40, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 40, Iter: 300 out of 1000, Loss: 0.0000
Epoch: 40, Iter: 400 out of 1000, Loss: 0.0000
Epoch: 40, Iter: 500 out of 1000, Loss: 0.0000
Epoch: 40, Iter: 600 out of 1000, Loss: 0.0000
Epoch: 40, Iter: 700 out of 1000, Loss: 0.0000
Epoch: 40, Iter: 800 out of 1000, Loss: 0.0000
Epoch: 40, Iter: 900 out of 1000, Loss: 0.0000
Summary: Epoch: 40, lr: 0.0012851215656510308, Average loss: 0.0000
weight_power [2.78092    0.55010694 0.11906029 0.0260184  0.01800107 0.01851505
 0.03885602 0.07784257 0.9699367 ]
weight_merge_fwd:  [1.79  1.613 0.663 0.716 1.129 0.528 0.461 0.541 0.429]
weight_merge_bwd:  [0.198 1.933 0.06  0.041 1.266 0.037 0.037 1.319 1.369]

Ber:   0.0
Bler:  0.0

Epoch: 41, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 41, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 41, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 41, Iter: 300 out of 1000, Lo

Epoch: 50, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 50, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 50, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 50, Iter: 300 out of 1000, Loss: 0.0000
Epoch: 50, Iter: 400 out of 1000, Loss: 0.0000
Epoch: 50, Iter: 500 out of 1000, Loss: 0.0000
Epoch: 50, Iter: 600 out of 1000, Loss: 0.0000
Epoch: 50, Iter: 700 out of 1000, Loss: 0.0000
Epoch: 50, Iter: 800 out of 1000, Loss: 0.0000
Epoch: 50, Iter: 900 out of 1000, Loss: 0.0000
Summary: Epoch: 50, lr: 0.0007694497527671312, Average loss: 0.0000
weight_power [2.774252   0.5656492  0.11909429 0.02647053 0.0200293  0.01896744
 0.03754265 0.06873735 0.98070806]
weight_merge_fwd:  [1.766 1.602 0.68  0.729 1.125 0.548 0.484 0.561 0.453]
weight_merge_bwd:  [0.211 1.937 0.068 0.046 1.264 0.043 0.042 1.316 1.365]

Ber:   0.0
Bler:  0.0

Epoch: 51, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 51, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 51, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 51, Iter: 300 out of 1000, Lo

Epoch: 60, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 60, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 60, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 60, Iter: 300 out of 1000, Loss: 0.0000
Epoch: 60, Iter: 400 out of 1000, Loss: 0.0000
Epoch: 60, Iter: 500 out of 1000, Loss: 0.0000
Epoch: 60, Iter: 600 out of 1000, Loss: 0.0000
Epoch: 60, Iter: 700 out of 1000, Loss: 0.0000
Epoch: 60, Iter: 800 out of 1000, Loss: 0.0000
Epoch: 60, Iter: 900 out of 1000, Loss: 0.0000
Summary: Epoch: 60, lr: 0.00046069798986951934, Average loss: 0.0000
weight_power [2.771902   0.5721502  0.11925492 0.02653947 0.02055372 0.01905853
 0.03713359 0.06597105 0.9837478 ]
weight_merge_fwd:  [1.759 1.598 0.686 0.733 1.124 0.555 0.491 0.567 0.46 ]
weight_merge_bwd:  [0.215 1.938 0.069 0.047 1.263 0.044 0.043 1.316 1.364]

Ber:   0.0
Bler:  0.0

Epoch: 61, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 61, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 61, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 61, Iter: 300 out of 1000, L

Epoch: 70, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 70, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 70, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 70, Iter: 300 out of 1000, Loss: 0.0000
Epoch: 70, Iter: 400 out of 1000, Loss: 0.0000
Epoch: 70, Iter: 500 out of 1000, Loss: 0.0000
Epoch: 70, Iter: 600 out of 1000, Loss: 0.0000
Epoch: 70, Iter: 700 out of 1000, Loss: 0.0000
Epoch: 70, Iter: 800 out of 1000, Loss: 0.0000
Epoch: 70, Iter: 900 out of 1000, Loss: 0.0000
Summary: Epoch: 70, lr: 0.00027583690436774953, Average loss: 0.0000
weight_power [2.7707975  0.5752243  0.11934965 0.02656489 0.02078233 0.01909493
 0.036945   0.06471796 0.98513865]
weight_merge_fwd:  [1.755 1.597 0.688 0.735 1.123 0.557 0.495 0.57  0.463]
weight_merge_bwd:  [0.216 1.938 0.07  0.048 1.263 0.045 0.044 1.316 1.364]

Ber:   0.0
Bler:  0.0

Epoch: 71, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 71, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 71, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 71, Iter: 300 out of 1000, L

Epoch: 80, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 80, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 80, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 80, Iter: 300 out of 1000, Loss: 0.0000
Epoch: 80, Iter: 400 out of 1000, Loss: 0.0000
Epoch: 80, Iter: 500 out of 1000, Loss: 0.0000
Epoch: 80, Iter: 600 out of 1000, Loss: 0.0000
Epoch: 80, Iter: 700 out of 1000, Loss: 0.0000
Epoch: 80, Iter: 800 out of 1000, Loss: 0.0000
Epoch: 80, Iter: 900 out of 1000, Loss: 0.0000
Summary: Epoch: 80, lr: 0.00016515374385013573, Average loss: 0.0000
weight_power [2.7702348  0.5768545  0.11940361 0.02657721 0.0209007  0.01911313
 0.03684504 0.06405646 0.98580515]
weight_merge_fwd:  [1.754 1.596 0.69  0.736 1.123 0.559 0.496 0.571 0.465]
weight_merge_bwd:  [0.217 1.938 0.07  0.048 1.263 0.045 0.044 1.316 1.364]

Ber:   0.0
Bler:  0.0

Epoch: 81, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 81, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 81, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 81, Iter: 300 out of 1000, L

Epoch: 90, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 90, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 90, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 90, Iter: 300 out of 1000, Loss: 0.0000
Epoch: 90, Iter: 400 out of 1000, Loss: 0.0000
Epoch: 90, Iter: 500 out of 1000, Loss: 0.0000
Epoch: 90, Iter: 600 out of 1000, Loss: 0.0000
Epoch: 90, Iter: 700 out of 1000, Loss: 0.0000
Epoch: 90, Iter: 800 out of 1000, Loss: 0.0000
Epoch: 90, Iter: 900 out of 1000, Loss: 0.0000
Summary: Epoch: 90, lr: 9.888364709658946e-05, Average loss: 0.0000
weight_power [2.7699366  0.5777566  0.11943255 0.02658343 0.02096631 0.01912292
 0.03678828 0.06368464 0.98613566]
weight_merge_fwd:  [1.752 1.595 0.69  0.736 1.123 0.56  0.497 0.572 0.466]
weight_merge_bwd:  [0.218 1.938 0.071 0.048 1.263 0.045 0.044 1.316 1.363]

Ber:   0.0
Bler:  0.0

Epoch: 91, Iter: 0 out of 1000, Loss: 0.0000
Epoch: 91, Iter: 100 out of 1000, Loss: 0.0000
Epoch: 91, Iter: 200 out of 1000, Loss: 0.0000
Epoch: 91, Iter: 300 out of 1000, Lo

In [11]:
# Calculate mean/var with training data
model.eval()   # model.training() becomes False
N_iter = N_train//parameter.batch_size
mean_train = torch.zeros(parameter.N_channel_use)
std_train  = torch.zeros(parameter.N_channel_use)
mean_total = 0
std_total = 0

with torch.no_grad():
    for i in range(N_iter):
        bits   = bits_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_bits,1) 
        noise1 = noise1_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1)
        noise2 = noise2_train[parameter.batch_size*i:parameter.batch_size*(i+1),:,:].view(parameter.batch_size, parameter.N_channel_use,1)

        bits   = bits.to(device)
        noise1 = noise1.to(device)
        noise2 = noise2.to(device)
        
        output = model(bits, noise1, noise2)
        mean_total += model.mean_batch
        std_total  += model.std_batch
        if i%100==0: print(i)
        
mean_train = mean_total/N_iter
std_train = std_total/N_iter
print('Mean: ',mean_train)
print('std : ',std_train)

0
100
200
300
400
500
600
700
800
900
Mean:  tensor([0.9276, 0.9791, 0.9813, 0.9832, 0.9851, 0.9870, 0.9890, 0.9908, 0.9919])
std :  tensor([0.0217, 0.0103, 0.0097, 0.0089, 0.0083, 0.0079, 0.0079, 0.0080, 0.0077])


In [10]:
# Inference stage
# N_inference = int(4e8) 
# N_small = int(1e5) # In case that N_inference is very large, we divide into small chunks
N_inference = int(1e11) 
N_small = int(1e6) 
N_iter  = N_inference//N_small

model.normalization_with_saved_data = True
model.mean_saved = mean_train
model.std_saved  = std_train

ber_sum  = 0
bler_sum = 0

for ii in range(N_iter):
    ber_tmp, bler_tmp, _ = test_RNN(N_small)
    ber_sum += ber_tmp
    bler_sum += bler_tmp
    if ii%100==0: 
        print('Iter: {} out of {}'.format(ii, N_iter))
        print('Ber:  ', float(ber_sum/(ii+1)))
        print('Bler: ', float(bler_sum/(ii+1)))

ber_inference  = ber_sum/N_iter
bler_inference = bler_sum/N_iter
    
print()
print('Ber:  ', float(ber_inference))
print('Bler: ', float(bler_inference))

Iter: 0 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 100 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 200 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 300 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 400 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 500 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 600 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 700 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 800 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 900 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 1000 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 1100 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 1200 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 1300 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 1400 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 1500 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 1600 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 1700 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 1800 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 1900 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 2000 out of 100000
Ber:   0.0
Bler:  0.0
Iter: 2100 out of 100000


Iter: 11100 out of 100000
Ber:   1.801639532006405e-10
Bler:  5.404918179685581e-10
Iter: 11200 out of 100000
Ber:   1.7855548983813918e-10
Bler:  5.356664556366297e-10
Iter: 11300 out of 100000
Ber:   1.769754898184317e-10
Bler:  5.309264694552951e-10
Iter: 11400 out of 100000
Ber:   1.7542320374097642e-10
Bler:  5.262695834673536e-10
Iter: 11500 out of 100000
Ber:   1.7389792383859515e-10
Bler:  5.216937437602098e-10
Iter: 11600 out of 100000
Ber:   1.7239892846632188e-10
Bler:  5.171967853989656e-10
Iter: 11700 out of 100000
Ber:   1.7092556536812964e-10
Bler:  5.127766544710255e-10
Iter: 11800 out of 100000
Ber:   1.6947716841020366e-10
Bler:  5.084314635972476e-10
Iter: 11900 out of 100000
Ber:   1.6805309921430478e-10
Bler:  5.041592698873387e-10
Iter: 12000 out of 100000
Ber:   1.6665277491334507e-10
Bler:  4.999582969844596e-10
Iter: 12100 out of 100000
Ber:   1.652755987624488e-10
Bler:  4.958267685317708e-10
Iter: 12200 out of 100000
Ber:   1.6392098789452803e-10
Bler:  4.917

Iter: 20800 out of 100000
Ber:   2.0832331693032557e-10
Bler:  5.288207649556398e-10
Iter: 20900 out of 100000
Ber:   2.0732661420996834e-10
Bler:  5.262906777048215e-10
Iter: 21000 out of 100000
Ber:   2.0633939001868384e-10
Bler:  5.237846267824864e-10
Iter: 21100 out of 100000
Ber:   2.0536151945638181e-10
Bler:  5.213023346328782e-10
Iter: 21200 out of 100000
Ber:   2.0439287762297198e-10
Bler:  5.188435237002409e-10
Iter: 21300 out of 100000
Ber:   2.0343333961836407e-10
Bler:  5.164077498953645e-10
Iter: 21400 out of 100000
Ber:   2.024827527868922e-10
Bler:  5.13994735662493e-10
Iter: 21500 out of 100000
Ber:   2.0154101998404172e-10
Bler:  5.116041479347189e-10
Iter: 21600 out of 100000
Ber:   2.1603939470704603e-10
Bler:  5.555298998594083e-10
Iter: 21700 out of 100000
Ber:   2.1504385772086465e-10
Bler:  5.529699476092276e-10
Iter: 21800 out of 100000
Ber:   2.1405746619684862e-10
Bler:  5.504335320871689e-10
Iter: 21900 out of 100000
Ber:   2.1308008135711987e-10
Bler:  5.47

KeyboardInterrupt: 

In [11]:
print('Iter: {} out of {}'.format(ii, N_iter))
print('Ber:  ', float(ber_sum/(ii+1)))
print('Bler: ', float(bler_sum/(ii+1)))

Iter: 22419 out of 100000
Ber:   2.0814749923658837e-10
Bler:  5.352364662591924e-10
The history saving thread hit an unexpected error (OperationalError('database is locked',)).History will not be written to the database.


In [None]:
# print('weight_merge_fwd: ', model.weight_merge_normalized_fwd.cpu().detach().numpy().round(3))
# print('weight_merge_bwd: ', model.weight_merge_normalized_bwd.cpu().detach().numpy().round(3))

In [12]:
######## Save model
# save_results_to = 'saved_model/NoFeedback/K3_N9/bi_sigmoid/SNR(15dB)/'
save_results_to = 'saved_model/NoFeedback/K3_N9/bi_sigmoid/SNR(15dB)/train(1e8)/'

torch.save(model.state_dict(), save_results_to+'model.pth')
####### Save normalization weights
torch.save(mean_train, save_results_to+'mean_train.pt')
torch.save(std_train, save_results_to+'std_train.pt')

In [13]:
save_results_to

'saved_model/NoFeedback/K3_N9/bi_sigmoid/SNR(15dB)/train(1e8)/'

In [9]:
######## Recall model
# save_results_to = 'saved_model/NoFeedback/'+ 'SNR1(1)K6_N18/'
save_results_to = 'saved_model/NoFeedback/K3_N9/bi_sigmoid/SNR(15dB)/train(1e8)/'
model.load_state_dict(torch.load(save_results_to+'model.pth'))
model.eval()

####### Load normalization weights
mean_train = torch.load(save_results_to+'mean_train.pt')
std_train = torch.load(save_results_to+'std_train.pt')