## LDPC + RL
This notebook will contain experiments conducted as part of the work done for ITW 2024 submissions 

In [None]:
%pylab inline 
import matlab
import numpy as np
import scipy as sp
import networkx as nx 
import matlab.engine
from utils import *
MATLAB = matlab.engine.start_matlab()
initial = 0

The parity check matrix used here is the 5G NR base matrix;

The `SubMatrix` is represented by:

In [None]:
if initial == 0:
    MATLAB.eval("addpath('LDPC_Matlab'); cd 'LDPC_Matlab/'; setup;",nargout=0)
    initial = 1

Z = int(MATLAB.workspace['z'])
h_submatrix = array(MATLAB.workspace['H'], dtype=int)
m,n = shape(h_submatrix)

In [None]:
H = SubMatrix2PCM(h_submatrix, Z)
M,N = shape(H)
h = sp.sparse.csr_array(H)
G = nx.bipartite.from_biadjacency_matrix(h)

In [None]:
check_nds, var_nds = nx.bipartite.sets(G)

K =  6
a = list(nx.simple_cycles(G,K))
a = [cyc for cyc in a if len(cyc) == K] #filtering for K cycles
b = []
for cyc in a:
    if cyc[0] >= M:
        b.append(cyc[1:] + [cyc[0]])
    else:
        b.append(cyc)

In [None]:
# SANITY CHECK
# len(b)
# checjs = set()
# for cyc in b:
#     for node in cyc:
#         checjs.add(node)

# len(checjs)

In [None]:
# Clustering check 
clusters = cluster_form(2, b, 6, M)
clusters

In [None]:
num_clusters = len(clusters)

# clusters = dict(clusters) already a dictionary 
vns_in_cluster = { cluster_idx: NeighborVN(cluster, G, M) for cluster_idx, cluster in clusters.items() }

MATLAB.workspace['num_clusters'] = matlab.int64(num_clusters)

# create cell arrays for irregular sets of check nodes/ bit nodes 
MATLAB.eval("clusters = cell(num_clusters,1);", nargout=0)
MATLAB.eval("vns_in_cluster = cell(num_clusters,1);", nargout=0)

for i in range(num_clusters):
    MATLAB.eval("clusters" + "{" + str(i+1) + "} = " + str(array(clusters[i]) + 1), nargout=0) 
    MATLAB.eval("vns_in_cluster" + "{" + str(i+1) + "} = " + str(array(vns_in_cluster[i]) + 1), nargout=0)

MATLAB.eval("save('Imp.mat')",nargout=0)

#### RELDEC Algorithm 

prerequisites:
-   Cycle maximizing algorithm that allocates clusters of size `z` from `m` bitnodes ::tick!::
-   State space representation of CNs (and hence neighbouring VNs)

The plan is to keep the original indices intact so that the order of nodes through which states get updated can be kept constant. 

All the RL steps can be accomplished through python structures 


All the BP local flooding computations can be done on MATLAB scripts that support this notebook

In [None]:
k_max = int(MATLAB.eval("max(sum(a~=-1, 2)')"))  # max number of varnodes per parity check eq  
z = 2
tau = int(ceil(m/z))
_l_max = 50
epsilon = 0.7

# clusters = cluster_form(81, b, 6, M)
clusters = cluster_form(z, b, 6, M)

A = [i for i in range(tau)] 

S = {} 

Q = zeros((2**(z*k_max), tau))

# import L from matlab
MATLAB.eval("load('./LDPC_M/L_dataset.mat');", nargout =0)
L = array(matlab.double(MATLAB.workspace['L']))
C = array(matlab.double(MATLAB.workspace['C']))

ndata = shape(L)[0]

In [None]:
# RELDEC 
for idx in range(ndata):
    l = L[idx,:]
    c = C[idx,:]

    _l = 0
    L_hat_l = l

    # initialize S
    
    for cluster_idx, cluster in clusters.items():
        S.update({cluster_idx: int_m(l[NeighborVN(cluster)] < 0)})

    while _l < _l_max:

        # happens with probability epsilon 
        if random.random() <= epsilon:
            a = np.random.randint(0,tau)
        else:
            a = argmax(Q[list(S.keys()), :], axis=0)
        
        cluster_a = clusters[a]
        vns_at_cluster = NeighborVN(cluster_a)
        l_a = len(vns_at_cluster)

        MATLAB.workspace['curr_clust'] = matlab.double(cluster_a)
        l_hat, CNarray, VNarray = MATLAB.local_flood(nargout=0)

        x_hat_a = (l_hat[vns_at_cluster] < 0)*1
        s_a = int_m(x_hat_a)
    
        Reward = (1/l_a)*sum(x_hat_a == c[vns_at_cluster])
        
        

        _l += 1