In [89]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time as tm
import math
import sys
import pickle as pkl
from detnet import DetNet
from sample_generator import sample_generator

#parameters
NT = 2
NR = 4

snrdb_classical_list = {2:np.arange(5.0, 15.0), 32:np.arange(10.0, 18.0)}

L=3*NT
v_size = 2*2*NT
hl_size = 8*2*NT

startingLearningRate = 0.0001
decay_factor = 0.97
decay_step_size = 1000

train_iter = 50000
train_batch_size = 500

res_alpha=0.9
num_snr = 6

corr_flag = True
batch_corr = True
QR = True
rho_low = 0.55
rho_high = 0.75

test_set_flag = True
batch_size = 500
MMNet_batch_size = 500
time_seq = 5

rho = 0.6


mod_n = 4


def sym_detection(x_hat, j_indices, real_QAM_const, imag_QAM_const):
    x_real, x_imag = torch.chunk(x_hat, 2, dim=-1)
    x_real = x_real.unsqueeze(dim=-1).expand(-1,-1, real_QAM_const.numel())
    x_imag = x_imag.unsqueeze(dim=-1).expand(-1, -1, imag_QAM_const.numel())

    x_real = torch.pow(x_real - real_QAM_const, 2)
    x_imag = torch.pow(x_imag - imag_QAM_const, 2)
    x_dist = x_real + x_imag
    x_indices = torch.argmin(x_dist, dim=-1)

    accuracy = (x_indices == j_indices).sum().to(dtype=torch.float32)
    return accuracy.item()/j_indices.numel()


def loss_fn(batch_X, batch_HY, batch_HH, list_batch_x_predicted, j_indices, real_QAM_const, imag_QAM_const, ber_only=False, last_only=False):
    if (ber_only):
        BER_final = sym_detection(list_batch_x_predicted[-1], j_indices, real_QAM_const, imag_QAM_const)
        return BER_final
    else:
        HtHinv = torch.inverse(batch_HH)
        X_LS = torch.einsum(('ijk,ik->ij'), (HtHinv, batch_HY))

        LSE_error = torch.mean(torch.pow((batch_X - X_LS), 2), dim=1)
        loss_list = []
        BER_final = []
        for index, batch_x_predicted in enumerate(list_batch_x_predicted):
            loss_index = math.log(index+1)*torch.mean(torch.mean(torch.pow((batch_X - batch_x_predicted),2), dim=1)/LSE_error)
            loss_list.append(loss_index)
        BER_final = sym_detection(list_batch_x_predicted[-1], j_indices, real_QAM_const, imag_QAM_const)
        if (last_only):
            return loss_list[-1], BER_final, LSE_error, X_LS
        else:
            return sum(loss_list), BER_final

def loss_ber_ls(X_LS, j_indices, real_QAM_const, imag_QAM_const):
    ber_LS = sym_detection(X_LS, j_indices, real_QAM_const, imag_QAM_const)
    return ber_LS

def pre_process_data(H, y):
    H = H.float()
    y = y.float()
    H_t = H.permute(0,2,1)
    HTY = torch.einsum(('ijk,ik->ij'), (H_t, y))
    HTH = torch.matmul(H_t, H)
    return (HTY, HTH)


def validate_model_given_data(model, validtn_H, validtn_y, validtn_j_indices, real_QAM_const, imag_QAM_const, device):

    HTY, HTH = pre_process_data(validtn_H, validtn_y)
    with torch.no_grad():
        HTY = HTY.to(device=device)
        HTH = HTH.to(device=device)
        list_batch_x_predicted = model.forward(train_batch_size, HTY, HTH)
        validtn_out = list_batch_x_predicted[-1].to(device=device)
        accr = sym_detection(validtn_out, validtn_j_indices, real_QAM_const, imag_QAM_const)

        del HTY, HTH

    return accr




def train(model, optimizer, lr_scheduler, generator, device='cpu'):
    with open('/home/nicolas/MIMO_detection_project/HyperMIMO/rho_model_kron/H_test', 'rb') as fp:
        H = pkl.load(fp)
    H = H[0:1,:,:].repeat_interleave(train_batch_size, dim=0).to(device=device)
    real_QAM_const = generator.real_QAM_const.to(device=device)
    imag_QAM_const = generator.imag_QAM_const.to(device=device)
    for i in range(train_iter):
        y, x, j_indices, noise_sigma = generator.give_batch_data_Hinput(H, NT, snr_db_min=snrdb_classical_list[NT][0], snr_db_max=snrdb_classical_list[NT][-1], batch_size=train_batch_size)

#         H, y, x, j_indices, ṇoise_sigma = generator.give_batch_data(NT, snr_db_min=snrdb_classical_list[NT][0], snr_db_max=snrdb_classical_list[NT][-1], batch_size=batch_size, correlated_flag=corr_flag, rho=rho)
#         H, y, x, j_indices, _ = generator.give_batch_data(NT, snr_db_min=snrdb_classical_list[NT][0], snr_db_max=snrdb_classical_list[NT][-1] , batch_size=train_batch_size)
        H = H.float().to(device=device)
        y = y.float().to(device=device)
        HTY, HTH = pre_process_data(H.float(), y.float())
        HTY = HTY.to(device=device)
        HTH = HTH.to(device=device)
        x = x.to(device=device)
        j_indices = j_indices.to(device=device)
        list_batch_x_predicted = model.forward(train_batch_size, HTY, HTH)
        loss, BER_final = loss_fn(x, HTY, HTH, list_batch_x_predicted, j_indices, real_QAM_const, imag_QAM_const)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        if (i%1000==0):
#             H, y, x, j_indices, noise_sigma = generator.give_batch_data(NT, snr_db_min=snrdb_classical_list[NT][-1], snr_db_max=snrdb_classical_list[NT][-1], batch_size=batch_size, correlated_flag=corr_flag, rho=rho)
            y, x, j_indices, noise_sigma = generator.give_batch_data_Hinput(H, NT, snr_db_min=snrdb_classical_list[NT][-1], snr_db_max=snrdb_classical_list[NT][-1], batch_size=train_batch_size)
            HTY, HTH = pre_process_data(H, y)
            HTY = HTY.to(device=device)
            HTH = HTH.to(device=device)
            x = x.to(device=device)
            j_indices = j_indices.to(device=device)
            with torch.no_grad():
                list_batch_x_predicted = model.forward(train_batch_size, HTY, HTH)
                loss_last, BER_final, loss_LS, X_LS = loss_fn(x, HTY, HTH, list_batch_x_predicted, j_indices, real_QAM_const, imag_QAM_const, last_only=True)
                ber_LS = loss_ber_ls(X_LS, j_indices, real_QAM_const, imag_QAM_const)
                results = [loss_LS.mean().detach().item(), loss_last.detach().item(),ber_LS, BER_final]
                print_string = [i]+results
                print(' '.join('%s' % np.round(x,6) for x in print_string))
    torch.save(model.state_dict(), model_filename)
    print('************Model Saved************ at directory : ', model_filename)

device = 'cuda'
model = DetNet(L, NT, v_size, hl_size, device=device)
generator = sample_generator(train_batch_size, mod_n, NR)
model = model.to(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=startingLearningRate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, decay_step_size, decay_factor)
train(model, optimizer, lr_scheduler, generator, device)
print('******************************** Now Testing **********************************************')


0 0.027734 72.857063 1.0 0.246
1000 0.031394 29.138254 0.999 0.997
2000 0.030842 23.294485 0.999 1.0
3000 0.029639 26.104383 1.0 1.0
4000 0.029453 28.100235 1.0 1.0
5000 0.027585 27.703535 1.0 1.0
6000 0.028056 33.826088 1.0 1.0
7000 0.031124 24.166666 1.0 1.0
8000 0.029508 25.776665 1.0 1.0
9000 0.028975 28.93532 1.0 1.0
10000 0.02878 29.139544 0.999 1.0
11000 0.029015 26.097843 1.0 1.0
12000 0.02845 27.809931 1.0 1.0
13000 0.028157 28.179646 1.0 1.0
14000 0.028133 26.378672 1.0 1.0
15000 0.029413 26.313507 1.0 1.0


KeyboardInterrupt: 

In [90]:
##Testing

batch_size = 100
time_seq = 5
H0 = torch.empty((batch_size, 2 * NR, 2 * NT))
H1 = torch.empty((batch_size, 2 * NR, 2 * NT))
H2 = torch.empty((batch_size, 2 * NR, 2 * NT))
H3 = torch.empty((batch_size, 2 * NR, 2 * NT))
H4 = torch.empty((batch_size, 2 * NR, 2 * NT))

with open('/home/nicolas/MIMO_detection_project/HyperMIMO/rho_model_kron/H_test', 'rb') as fp:
    H = pkl.load(fp)
for ii in range(0, batch_size):
    H0[ii] = H[0 + ii * time_seq:1 + ii*time_seq,:,:]
    H1[ii] = H[1 + ii * time_seq:2 + ii*time_seq,:,:]
    H2[ii] = H[2 + ii * time_seq:3 + ii*time_seq,:,:]
    H3[ii] = H[3 + ii * time_seq:4 + ii*time_seq,:,:]
    H4[ii] = H[4 + ii * time_seq:5 + ii*time_seq,:,:]


generator = sample_generator(train_batch_size, mod_n, NR)
device = 'cuda'
# with open('/home/nicolas/MIMO_detection_project/HyperMIMO/rho_model_kron/H_test', 'rb') as fp:
#     H = pkl.load(fp)
H = H.to(device=device)
# H = H4.repeat_interleave(5, dim=0).to(device=device)
real_QAM_const = generator.real_QAM_const.to(device=device)
imag_QAM_const = generator.imag_QAM_const.to(device=device)

results_total = []
results = []
for i in range(np.linspace(5, 14, 10).shape[0]):
    acum = 0.
    for jj in range(500):

        y, x, j_indices, noise_sigma = generator.give_batch_data_Hinput(H, NT, snr_db_min=snrdb_classical_list[NT][i], snr_db_max=snrdb_classical_list[NT][i], batch_size=train_batch_size)

        SER_final = validate_model_given_data(model, H, y.to(device=device), j_indices.to(device=device), real_QAM_const, imag_QAM_const, device)
        acum   += SER_final
    results.append(1 - acum / 500)
    print(results)

results_total.append(results)

[0.13158399999999892]
[0.13158399999999892, 0.10831400000000113]
[0.13158399999999892, 0.10831400000000113, 0.08966200000000024]
[0.13158399999999892, 0.10831400000000113, 0.08966200000000024, 0.0776340000000002]
[0.13158399999999892, 0.10831400000000113, 0.08966200000000024, 0.0776340000000002, 0.06794399999999978]
[0.13158399999999892, 0.10831400000000113, 0.08966200000000024, 0.0776340000000002, 0.06794399999999978, 0.06079399999999957]
[0.13158399999999892, 0.10831400000000113, 0.08966200000000024, 0.0776340000000002, 0.06794399999999978, 0.06079399999999957, 0.05556799999999995]
[0.13158399999999892, 0.10831400000000113, 0.08966200000000024, 0.0776340000000002, 0.06794399999999978, 0.06079399999999957, 0.05556799999999995, 0.05280800000000074]
[0.13158399999999892, 0.10831400000000113, 0.08966200000000024, 0.0776340000000002, 0.06794399999999978, 0.06079399999999957, 0.05556799999999995, 0.05280800000000074, 0.05018199999999995]
[0.13158399999999892, 0.10831400000000113, 0.0896620