# First Steps for a Demapper on GPU

In [1]:
import numba
from numba import cuda
from sigcom.tx.util import qam_alphabet
import numpy as np

In [2]:
M = 4
qam = qam_alphabet(M)

In [3]:
sqam = ','.join([str(q) for q in qam])
sqam = '[' + sqam + ']'

In [4]:
s='''
cqam = np.array({sqam})

@cuda.jit
def func(out):
    qam = cuda.const.array_like(cqam)
    for i in range(4):
        out[i] = qam[i]
'''.format(sqam=sqam)
exec(s, globals())

In [5]:
out = cuda.device_array(4, np.complex128)
func[1,1](out)
out.copy_to_host()

array([ 0.70710678+0.70710678j,  0.70710678-0.70710678j,
       -0.70710678+0.70710678j, -0.70710678-0.70710678j])

In [6]:
out = np.zeros(4, np.complex128)
func[1,1](out)
np.sum(np.abs(out-qam)**2)

0.0

# On to more serious stuff

In [16]:
import math
import cmath
import numba
import numpy as np
from numba import cuda
from sigcom.tx.util import make_cells
from sigcom.ch.util import make_noise
from sigcom.tx.util import qam_alphabet

N_cells = 3240
M = 16

qam = qam_alphabet(M)
tx, bits = make_cells(qam, N_cells)
noise = make_noise(N_cells)
SNRs_dB = np.linspace(-10,10,100)
tx = np.asarray(tx, np.complex64)
noise = np.asarray(noise, np.complex64)

In [17]:
sqam = ','.join([str(q) for q in qam])
sqam = '[' + sqam + ']'

In [18]:
@cuda.jit(device=True)
def max_star(a,b):
    return max((a,b))+math.log(1.+math.exp(-abs(a-b)))

s='''
cqam = np.array({sqam})

@cuda.jit
def func(MIs, tx, noise, SNRs_dB, bits):
    tid = cuda.grid(1)
    qam = cuda.const.array_like(cqam)
    D = cuda.local.array({M}, numba.float32)
    SNR = 10**(SNRs_dB[tid]/10)
    N_cells = len(tx)
    MI = 0.
    for k in range(N_cells):
        rx = tx[k] + noise[k]/math.sqrt(SNR)
        for m in range({M}):
            D[m] = -SNR*abs(rx-qam[m])**2
        for m in range({ldM}):
            num = -np.inf
            den = -np.inf
            for i in range({M}):
                if (i >> ({ldM}-1-m)) & 1:
                    den = max_star(den, D[i])
                else:
                    num = max_star(num, D[i])
            L = num-den
            t = k*{ldM}+m
            x = -(1.-2.*np.float32(bits[t]))*L
            MI -= math.log(1.+math.exp(x))
    MIs[tid] = {ldM}+MI/math.log(2.0)/np.float32(N_cells)
'''.format(sqam=sqam, M=len(qam), ldM=int(np.log2(M)))
exec(s, globals())

In [19]:
MIs = np.zeros(len(SNRs_dB), dtype=np.float32)
func[len(SNRs_dB),1](MIs, tx, noise, SNRs_dB, bits)
MIs

array([0.10373498, 0.10886358, 0.1142291 , 0.11984152, 0.12571116,
       0.13184866, 0.13826503, 0.14497162, 0.15198015, 0.15930262,
       0.16695149, 0.1749395 , 0.18327972, 0.1919856 , 0.20107095,
       0.21054983, 0.22043669, 0.23074624, 0.2414935 , 0.25269374,
       0.26436257, 0.27651575, 0.2891693 , 0.30233943, 0.31604245,
       0.33029488, 0.3451133 , 0.36051434, 0.3765146 , 0.39313075,
       0.41037932, 0.42827663, 0.44683897, 0.46608227, 0.48602217,
       0.506674  , 0.5280525 , 0.5501721 , 0.5730465 , 0.5966888 ,
       0.6211114 , 0.64632595, 0.67234313, 0.69917285, 0.72682387,
       0.7553041 , 0.78462017, 0.81477773, 0.8457811 , 0.87763333,
       0.91033643, 0.9438907 , 0.9782952 , 1.0135474 , 1.0496433 ,
       1.0865769 , 1.1243407 , 1.1629248 , 1.2023174 , 1.2425044 ,
       1.2834688 , 1.3251916 , 1.3676507 , 1.4108212 , 1.4546759 ,
       1.4991848 , 1.5443165 , 1.5900375 , 1.6363134 , 1.6831101 ,
       1.7303936 , 1.7781316 , 1.826293  , 1.8748493 , 1.92377

In [24]:
%%timeit
func[1, len(SNRs_dB)](MIs, tx, noise, SNRs_dB, bits)

290 ms ± 718 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
from sigcom.rx.util import demap
from sigcom.it.util import mutual_information_magic
h = np.ones(N_cells)

In [22]:
ldM = int(np.log2(M))
MIs_ref = []
for SNR_dB in SNRs_dB:
    SNR = 10**(SNR_dB/10)
    rx = tx + noise/np.sqrt(SNR)
    L = demap(rx, qam, SNR, h)
    MIs_ref.append(mutual_information_magic(L, bits, 1)*ldM)

In [23]:
MIs-MIs_ref

array([ 9.51215373e-09,  2.93175617e-09,  3.73125353e-09,  8.87571350e-09,
       -5.43964074e-09,  1.07390825e-08,  5.63157831e-09, -9.61374980e-09,
       -1.18726806e-10,  1.64571614e-08,  1.59625011e-08,  9.26411037e-09,
        6.27255137e-09, -1.05767084e-09,  4.01569933e-09, -4.27432356e-09,
       -1.65667480e-09,  1.13976149e-08,  2.27673391e-08,  2.86179729e-08,
        3.65752317e-08,  1.49281876e-08, -4.38769976e-08, -3.30288072e-08,
       -1.32656348e-08, -2.67508384e-08,  2.22647274e-08,  3.37641453e-08,
        3.68247566e-09, -8.58906724e-10,  1.27987700e-08, -5.05149020e-08,
       -2.44812726e-08,  1.87151126e-08, -1.50432702e-08, -9.27386745e-09,
        3.06337768e-08, -5.48898971e-09, -4.72665054e-08, -7.51063745e-09,
       -2.05869268e-08,  6.53187415e-09,  3.17184501e-10, -7.53238809e-08,
       -5.73102859e-08, -1.76345325e-08,  9.76443504e-09,  1.02105964e-07,
        4.94601275e-08, -9.40791756e-08,  7.50234612e-08, -8.08070602e-08,
        5.75490771e-08, -