In [1]:
import sys
sys.path.insert(0, '../train')
from helper import *
%matplotlib inline

#Wireless Parameters
N_t = 64
N_r = 16
latent_dim = 65
channel_model = 'A'

length = int(N_t/4)
breadth = int(N_r/4)

G_test = torch.nn.Sequential(
    torch.nn.Linear(latent_dim, 128*length*breadth),
    torch.nn.ReLU(),
    View([1,128,length,breadth]),
    torch.nn.Upsample(scale_factor=2),
    Conv2d(128,128,4,bias=False),
    torch.nn.BatchNorm2d(128,momentum=0.8),
    torch.nn.ReLU(),
    torch.nn.Upsample(scale_factor=2),
    Conv2d(128,128,4,bias=False),
    torch.nn.BatchNorm2d(128,momentum=0.8),
    torch.nn.ReLU(),
    Conv2d(128,2,4,bias=False),
)
G_test = G_test.type(dtype)

In [2]:
import copy
H_org = sio.loadmat("../../data/H_16x64_MIMO_CDL_%s_ULA_clean.mat"%channel_model)
H_ex = H_org['hest']
H_extracted = np.transpose(copy.deepcopy(H_ex),(2,1,0))
dft_basis = sio.loadmat("../../data/dft_basis.mat")
A_T = dft_basis['A1']/np.sqrt(N_t)
A_R = dft_basis['A2']/np.sqrt(N_r)
for i in range(H_ex.shape[2]):
    H_extracted[i] = np.transpose(np.matmul(np.matmul(A_R.conj().T,H_extracted[i].T,dtype='complex64'),A_T))

img_np_real = np.real(H_extracted)
img_np_imag = np.imag(H_extracted)

mu_real = np.mean(img_np_real,axis=0)
mu_imag = np.mean(img_np_imag,axis=0)
std_real = np.std(img_np_real,axis=0)
std_imag = np.std(img_np_imag,axis=0)

A_T_R = np.kron(A_T.conj(),A_R)
A_T_R_real = dtype(np.real(A_T_R))
A_T_R_imag = dtype(np.imag(A_T_R))

H_org = sio.loadmat("../../data/H_16x64_MIMO_CDL_%s_ULA_test.mat"%channel_model)
H_ex = H_org['hest']
H_extracted = np.transpose(copy.deepcopy(H_ex),(2,1,0))
for i in range(H_ex.shape[2]):
    H_extracted[i] = np.transpose(np.matmul(np.matmul(A_R.conj().T,H_extracted[i].T,dtype='complex64'),A_T))
img_np_real = np.real(H_extracted)
img_np_imag = np.imag(H_extracted)
img_np_real = (img_np_real - mu_real)/std_real
img_np_imag = (img_np_imag - mu_imag)/std_imag

In [None]:
N_s = N_r
N_rx_rf = N_r
Nbit_t = 6
Nbit_r = 2
angles_t = np.linspace(0,2*np.pi,2**Nbit_t,endpoint=False)
angles_r = np.linspace(0,2*np.pi,2**Nbit_r,endpoint=False)
freq = 10
model_vec = range(0,610,freq)

def training_precoder(N_t,N_s):
    angle_index = np.random.choice(len(angles_t),(N_t,N_s))
    return (1/np.sqrt(N_t))*np.exp(1j*angles_t[angle_index])

def training_combiner(N_r,N_rx_rf):
    angle_index = np.random.choice(len(angles_r),(N_r,N_rx_rf))
    W = (1/np.sqrt(N_r))*np.exp(1j*angles_r[angle_index])
    return np.matrix(W).getH()

ntest = 20              
nrepeat = 5 #Different noise realizations
SNR_vec = range(15,20,5)
alpha = 0.4
nmse_fedambgan = np.zeros((len(SNR_vec),len(model_vec)))
ct = 0
N_p = int(alpha*N_t)
qpsk_constellation = (1/np.sqrt(2))*np.array([1+1j,1-1j,-1+1j,-1-1j])

pilot_sequence_ind = np.random.randint(0,4,size=(N_s,N_p))
symbols = qpsk_constellation[pilot_sequence_ind]
precoder_training = training_precoder(N_t,N_s)
W = training_combiner(N_r,N_rx_rf)
A = np.kron(np.matmul(symbols.T,precoder_training.T),W)

A_real = dtype(np.real(A))
A_imag = dtype(np.imag(A))
identity = np.identity(N_r)
lambda_reg = 1e-3
ct += 1
for model in model_vec:
    if model < 600:
        G_test.load_state_dict(torch.load('../../results/pilot_gan/U_4/fedpilotgan/cache/checkpoints/n_d_5/global_G_state%d.pkl'%model)) 
    else:
        G_test.load_state_dict(torch.load('../../results/pilot_gan/U_4/fedpilotgan/cache/checkpoints/n_d_5/global_G_state%d.pkl')) 
    G_test.eval()
    for SNR in SNR_vec:
        for i in range(nrepeat):
            for ind in range(ntest):
                vec_H_single = np.reshape(H_ex[:,:,ind].flatten('F'),[N_r*N_t,1])
                signal = np.matmul(H_ex[:,:,ind],np.matmul(precoder_training,symbols))
                E_s = np.multiply(signal,np.conj(signal))
                noise_matrix = (1/np.sqrt(2))*(np.random.randn(N_r,N_p)+1j*np.random.randn(N_r,N_p))
                vec_y = np.zeros((N_rx_rf*N_p,1,1),dtype='complex64')
                std_dev = (1/(10**(SNR/20)))*np.sqrt(E_s)
                rx_signal = signal + np.multiply(std_dev,noise_matrix)
                rx_signal = np.matmul(W,rx_signal)
                vec_y[:,0,0] = rx_signal.flatten('F') 
                vec_y_real = dtype(np.real(vec_y[:,:,0]))
                vec_y_imag = dtype(np.imag(vec_y[:,:,0]))
                def gen_output(x):
                    pred = G_test(x)
                    pred[0,0,:,:] = dtype(std_real)*pred[0,0,:,:] + dtype(mu_real)
                    pred[0,1,:,:] = dtype(std_imag)*pred[0,1,:,:] + dtype(mu_imag)
                    pred_real = torch.mm(A_T_R_real,pred[0,0,:,:].view(N_t*N_r,-1)) - torch.mm(A_T_R_imag,pred[0,1,:,:].view(N_t*N_r,-1))
                    pred_imag = torch.mm(A_T_R_real,pred[0,1,:,:].view(N_t*N_r,-1)) + torch.mm(A_T_R_imag,pred[0,0,:,:].view(N_t*N_r,-1))
                    diff_real = vec_y_real - torch.mm(A_real,pred_real) + torch.mm(A_imag,pred_imag)
                    diff_imag = vec_y_imag - torch.mm(A_real,pred_imag) - torch.mm(A_imag,pred_real)
                    diff = torch.norm(diff_real)**2 + torch.norm(diff_imag)**2
                    return diff + lambda_reg*torch.norm(x)**2
                x = Variable(torch.randn(1, latent_dim)).type(dtype)
                x.requires_grad = True
                learning_rate = 1e-1
                optimizer = torch.optim.Adam([x], lr=learning_rate)
                for a in range(100): 
                    optimizer.zero_grad()
                    loss = gen_output(x)
                    loss.backward()
                    optimizer.step()
                gen_imgs = G_test(x).data.cpu().numpy()
                gen_imgs[0,0,:,:] = std_real*gen_imgs[0,0,:,:] + mu_real
                gen_imgs[0,1,:,:] = std_imag*gen_imgs[0,1,:,:] + mu_imag
                gen_imgs_complex = gen_imgs[0,0,:,:] + 1j*gen_imgs[0,1,:,:]
                gen_imgs_complex = np.matmul(A_T_R,np.reshape(gen_imgs_complex,[N_t*N_r,1]))
                nmse_fedambgan[ct-1,int((model-model_vec[0])/freq)] = nmse_fedambgan[ct-1,int((model-model_vec[0])/freq)] + (np.linalg.norm(gen_imgs_complex - vec_H_single)/np.linalg.norm(vec_H_single))**2
                print((np.linalg.norm(gen_imgs_complex - vec_H_single)/np.linalg.norm(vec_H_single))**2)
nmse_fedambgan = nmse_fedambgan/(ntest*nrepeat)

In [None]:
def smooth(x,window_len=4,window='hanning'):
    if window_len<3:
        return x
    s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]]
    if window == 'flat': #moving average
        w=np.ones(window_len,'d')
    else:
        w=eval('np.'+window+'(window_len)')
    y=np.convolve(w/w.sum(),s,mode='valid')
    return y[(int(window_len/2)-1):-int(window_len/2)]

In [None]:
from matplotlib.legend_handler import HandlerLine2D, HandlerTuple
model_vec_2 = range(0,610,20)
p1, = plt.plot(model_vec_2,smooth(10*np.log10(nmse_CDL_A_ld_65).T[np.arange(0,61,2),0],6),'o-')
p2, = plt.plot(model_vec_2,smooth(10*np.log10(nmse_CDL_A_ld_65_n_d_20).T[np.arange(0,61,2),0],6),'v-')
p3, = plt.plot(model_vec_2,smooth(10*np.log10(nmse_fedambgan).T[np.arange(0,61,2),0],6),'o-')
p4, = plt.plot(model_vec_2,smooth(10*np.log10(nmse_fedambgan_n_d_20).T[np.arange(0,61,2),0],6),'v-')
plt.legend([(p2, p4), (p1, p3)], [r'$n_d = 20$',r'$n_d = 5$'], numpoints=1,
               handler_map={tuple: HandlerTuple(ndivide=None)},loc = 'upper right')
plt.grid(ls=':')
plt.xlabel('Rounds')
plt.ylabel('NMSE(in dB)')
plt.xlim([-10,610])
plt.savefig('../../results/FedGAN.pdf',dpi=10000)