In [20]:
import math
import numpy as np
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import qr
import itertools
from scipy.special import erfc
import galois
# from Dec import Dec

In [21]:
import pickle

def load_value_function(filename="MI_policy_1.pkl"):
    with open(filename, "rb") as f:
        Q = pickle.load(f)
    print(f"Q-function loaded from {filename}")
    return Q

def save_value_function(V, filename="value_function.pkl"):
    with open(filename, "wb") as f:
        pickle.dump(V, f)
    print(f"Value function saved to {filename}")




In [22]:
STATE_DIM = 2
BINS = 100

q_table = load_value_function("./pre_trained_policies/MI_policy_1.pkl")

Q-function loaded from ./pre_trained_policies/MI_policy_1.pkl


In [23]:
def quantize(s, bins):
    clamped = np.clip(s, -1, 1)
    indices = []
    bin_width = 2 / bins
    for value in clamped:
        index = int((value + 1) / bin_width)
        index = min(index, bins - 1)  # Ensure we stay within bounds
        indices.append(index)
    return tuple(indices)


def dequantize(disc_s, bins):
    continuous = []
    bin_width = 2 / bins
    for index in disc_s:
        midpoint = -1 + (index + 0.5) * bin_width
        continuous.append(midpoint)
    return np.array(continuous)


# sanity check for loaded q-table
for k in range(10):
    s = np.random.uniform(0,0.2,STATE_DIM)
    s_q = quantize(s,BINS)
    v1, v2 = q_table.get((s_q,0),np.nan), q_table.get((s_q,1),np.nan)
    print(f"state = {s} action = {0,1} q_val = {v1,v2}")


state = [0.06137895 0.18014374] action = (0, 1) q_val = (0.16795327686491965, 0.24359347788766644)
state = [0.12398386 0.06001203] action = (0, 1) q_val = (0.5865421822759235, nan)
state = [0.13487469 0.16082073] action = (0, 1) q_val = (0.3001331840129082, 0.25230382373549515)
state = [0.18468536 0.15259966] action = (0, 1) q_val = (nan, 0.30693978158199486)
state = [0.00265523 0.16455167] action = (0, 1) q_val = (0.276456239222565, 0.2407741563672055)
state = [0.15210064 0.04057916] action = (0, 1) q_val = (nan, nan)
state = [0.08732466 0.16387124] action = (0, 1) q_val = (0.35410504606895027, 0.26835267593716083)
state = [0.18108845 0.00606642] action = (0, 1) q_val = (nan, nan)
state = [0.06206183 0.02108036] action = (0, 1) q_val = (0.1596591452537836, 0.12905474684336127)
state = [0.09797329 0.16449954] action = (0, 1) q_val = (0.35410504606895027, 0.26835267593716083)


In [25]:
import math
import numpy as np

class Decoder:
    def __init__ (self,H, channel_model, channel_parameters, num_iter,cluster_size):
        # code parameters
        self.n = H.shape[1]
        self.k = self.n - H.shape[0]
        self.H = H
        self.model = channel_model
        self.params = channel_parameters
        self.num_iter = num_iter
        
        # misc parameters (epsilon : for numerical stability)
        self.ep = 1e-5
        
        # graph parameters
        self.num_VN = self.n
        self.num_CN = self.n-self.k
        self.cluster_size = cluster_size
        
        # adjacency list for VN and CN
        self.CN = []
        self.VN = []
        self.construct_graph(H)
        self.initialize_clusters(self.cluster_size)
        self.iteration_number = 0
        self.policy = None # policy for choosing clusters (using RL)
        self.num_clusters = math.ceil(self.num_CN/cluster_size)
        # self.q_net = self.load_model("./agents/policy_net.pth", state_dim=self.num_clusters, action_dim=self.num_clusters)
        self.MI_counter = 0
        
    # build the adjacency list for Tanner Graph
    def construct_graph(self,H):
        for i in range(self.num_CN):
            temp = []
            for j in range(self.num_VN):
                if H[i,j]==1:
                    temp.append(j)
            self.CN.append(temp)
            
        for i in range(self.num_VN):
            temp = []
            for j in range(self.num_CN):
                if H[j,i]==1:
                    temp.append(j)
            self.VN.append(temp)
    
    # print the tanner graph
    def print_graph(self,mode):
        if mode=="matrix":
            print(self.H)
        elif mode=="list":
            print('CN : ',self.CN)
            print('VN : ',self.VN)
        else:
            print('Invalid mode')
    
    
    # initialize clusters of CNs
    def initialize_clusters(self,cluster_size):
        self.num_clusters = math.ceil(self.num_CN/cluster_size)
        self.MI = np.zeros(self.num_clusters) # for storing current MI of clusters

        self.clusters = []
        for i in np.arange(0,self.num_CN,cluster_size):
            temp = []
            for j in range(i,min(i+cluster_size,self.num_CN)):
                temp.append(j)
            self.clusters.append(temp)  
            

                    
    # returns the absolute min, 2nd min and parity of the input array
    def get_min(self,arr):
        arr = np.array(arr)
        arr = arr[~np.isnan(arr)]
        parity = np.prod(np.sign(arr))
        if len(arr) < 2:
            print("arr = ",arr)
            raise ValueError("Not enough valid elements in the array.")
        arr = np.sort(np.abs(arr))
        return arr[0],arr[1],parity
    
    # perform row update of cluster a 
    def row_update(self,a):
        # subtract step (removing known beliefs)
        for j in range(self.num_VN):
            tot = 0
            for i in self.clusters[a]:
                if not np.isnan(self.L[i,j]):
                    tot = tot + self.L[i,j]
            self.sum[j] = self.sum[j] - tot

        # flow down the sum into the cluster a
        for j in range(self.num_VN):
            for i in self.clusters[a]:
                if not np.isnan(self.L[i,j]):
                    self.L[i,j] = self.sum[j]

        # perform minsum for each row
        for i in self.clusters[a]:
            m1,m2,p = self.get_min(self.L[i])
            for j in range(self.num_VN):
                if not np.isnan(self.L[i,j]):
                    self.L[i,j] = p*np.sign(self.L[i,j])*m2 if np.abs(self.L[i,j])==m1 else p*np.sign(self.L[i,j])*m1

    def col_update(self,a):
        for j in range(self.num_VN):
            tot = 0
            for i in self.clusters[a]:
                if not np.isnan(self.L[i,j]):
                    tot = tot + self.L[i,j]
            self.sum[j] = self.sum[j] + tot
            
    def pseudo_row_update(self,a,mode=None):
        # subtract step (removing known beliefs)
        p_L = self.L.copy()
        p_sum = self.sum.copy()

        if mode=="debug":
            print(type(self.L))
            print(type(self.sum))
            print("Init L = ",self.L)
            print("Init sum = ",self.sum)
        
        for j in range(self.num_VN):
            tot = 0
            for i in self.clusters[a]:
                if not np.isnan(p_L[i,j]):
                    tot = tot + p_L[i,j]
            p_sum[j] = p_sum[j] - tot

        # flow down the sum into the cluster a
        for j in range(self.num_VN):
            for i in self.clusters[a]:
                if not np.isnan(p_L[i,j]):
                    p_L[i,j] = p_sum[j]

        # perform minsum for each row
        for i in self.clusters[a]:
            m1,m2,p = self.get_min(p_L[i])
            for j in range(self.num_VN):
                if not np.isnan(p_L[i,j]):
                    p_L[i,j] = p*np.sign(p_L[i,j])*m2 if np.abs(p_L[i,j])==m1 else p*np.sign(p_L[i,j])*m1
        # print("new Init sum = ",self.sum)
        return p_L, p_sum

    # def pseudo_col_update(self, a, p_L, p_sum):
    #     for j in range(self.num_VN):
    #         tot = 0
    #         for i in self.clusters[a]:
    #             if not np.isnan(p_L[i,j]):
    #                 tot = tot + p_L[i,j]
    #         p_sum[j] = p_sum[j] + tot
        
    #     return p_L


    def decode(self,y,verbose="off",scheduling="round-robin", prompt=False):
        # print("Started decoding...\n")
        self.prompt = prompt
        # sum is initialized by incoming beliefs (initial LLRs)
        self.sum = y
        
        # initializing the storage matrix
        self.L = np.full(self.H.shape, np.nan, dtype=float)  # Initialize with NaNs
        for i in range(self.num_CN):
            for j in range(self.num_VN):
                self.L[i,j] = 0 if self.H[i,j]==1 else np.nan
        
#         # initial MI approximation
#         self.MI = self.get_MI(self.L)
        self.prev_MI = np.zeros(self.num_clusters)
        for k in range(self.num_iter):
            self.MI = self.get_MI(self.L)
            # print("Mutual Info : ",self.MI)

            # print(f"Iter : {k} MI : {self.MI}",end="")
            self.iteration_number = k
            # state = self.MI - self.prev_MI
            self.prev_MI = self.MI
            # print("---------------------------------------------------------------------------------------")
            a = self.get_next_cluster(k,scheduling)
            # print("---------------------------------------------------------------------------------------")

            # if verbose=="on":
            # print("Iteration : ",k,"\tCluster scheduled : ",a)
            self.row_update(a)
            self.col_update(a)


            # pppp_L_1, _ = self.pseudo_row_update(a,mode='debug')
            # mmmi_1 = self.get_MI(pppp_L_1)
            # print("MI_1 = ",mmmi_1)
            # pppp_L_2, _ = self.pseudo_row_update(a,mode='debug')
            # mmmi_2 = self.get_MI(pppp_L_2)
            # print("MI_2 = ",mmmi_2)

            c_hat = np.array([1 if val<0 else 0 for val in self.sum])
            if self.stopping_criteria(k,c_hat):
                break
        return c_hat
    
    def get_MI(self,L):
        var = np.zeros(self.num_clusters)
        for a in range(self.num_clusters):
            total = 0
            count = 0
            for i in self.clusters[a]:
                for j in self.CN[i]:
                    total = total + np.abs(L[i,j])
                    count = count + 1
            mean = total/count
            var[a] = mean*2
        
        MI = np.zeros(self.num_clusters)
        for a in range(self.num_clusters):
            MI[a] = self.J(var[a])
        
        return MI

        
    def J(self,sigma):
        # constants
        a_J1 = -0.0421061
        b_J1 = 0.209252
        c_J1 = -0.00640081

        a_J2 = 0.00181491
        b_J2 = -0.142675
        c_J2 = -0.0822054
        d_J2 = 0.0549608

        # threshold sigma*
        sigma_star = 1.6363
        
        if 0 <= sigma <= sigma_star:
            return a_J1 * sigma**3 + b_J1 * sigma**2 + c_J1 * sigma
        elif sigma_star < sigma < 10:
            exponent = a_J2 * sigma**3 + b_J2 * sigma**2 + c_J2 * sigma + d_J2
            return 1 - math.exp(exponent)
        else:  # sigma >= 10
            return 1

    # Load the trained model
    def load_model(self, filepath, state_dim, action_dim):
        model = DQN(state_dim, action_dim)
        model.load_state_dict(torch.load(filepath))
        model.eval()  # Set the model to evaluation mode
        print(f"Model loaded from {filepath}")
        return model

        
    def stopping_criteria(self,k,c_hat):
        c1 = k>self.num_iter
        c2 = (np.sum(np.dot(self.H,c_hat.T)%2)==0)
        return c1 or c2
    
    # calculates the difference between information metrics for cluster a
    def calculate_residual(self, info, info_new, a):
        return info_new[a] - info[a]
    
    def get_next_cluster(self,iter_number,scheduling,state=None):
        # round robin
        if scheduling=="round-robin":
            return (iter_number%self.num_clusters)
        
        elif scheduling=="residual-llr":
            llr_gains = np.zeros(self.num_clusters)
            for a in range(self.num_clusters):
                p_L, p_sum = self.pseudo_row_update(a)
                max_diff = -100000
                for j in self.CN[a]:
                    max_diff = max(np.abs(self.L[a,j]-p_L[a,j]),max_diff)
                llr_gains[a] = max_diff
            return np.argmax(llr_gains)
                
        elif scheduling=="rl":
            # state = self.get_MI(self.L)
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():  # Disable gradient computation
                q_values = self.q_net(state_tensor)
                optimal_action = torch.argmax(q_values).item()  # Get the index of the max Q-value
            return optimal_action
        

        elif scheduling=="rl-q-learning":
            
            MI_gains = np.zeros(self.num_clusters)
            for a in range(self.num_clusters):
                p_L, p_sum = self.pseudo_row_update(a)
                p_L = self.pseudo_col_update(a,p_L,p_sum)
                MI_new = self.get_MI(p_L)
                gain = self.calculate_residual(self.MI, MI_new, a)
                MI_gains[a] = gain


            S = MI_gains
            S_q = quantize(S,BINS)
            return np.argmax(loaded_q_table[S_q])

        elif scheduling=="policy":
            # print("Getting cluster using POLICY")
            MI_gains = np.zeros(self.num_clusters)
            pseudo_MI = np.zeros(self.num_clusters)
            for a in range(self.num_clusters):
                pp_L, pp_sum = self.pseudo_row_update(a)
                # print("2nd row update \n",pp_L)
                # p_L = self.pseudo_col_update(a,p_L,p_sum)
                MI_new = self.get_MI(pp_L)
                gain = MI_new[a]-self.MI[a]
                MI_gains[a] = gain
                pseudo_MI[a] = MI_new[a]
            state = MI_gains
            print("state caculation started ...")
            print(f"Actual MI : {self.MI}")
            print(f"Pseudo MI : {pseudo_MI}")
            print(f"Gain : {MI_gains} = {state}")
            print("State calculation ended ...")


            # state = self.get_state("pseudo-difference")
            s_d = quantize(state,BINS)
            # print("s : ",state," d_s : ",s_d," MI : ",self.MI,end="")
            val0,val1 = q_table.get((s_d,0),np.nan),q_table.get((s_d,1),np.nan)
            # print(f" v = {val0},{val1}",end=" ")
            a = None
            
            if state[0]!=0 and state[1]!=0:
                if val0==val0 and val1==val1:
                    print("Choosing from pi*\n")
                    MI_gains = np.zeros(self.num_clusters)
                    pseudo_MI = np.zeros(self.num_clusters)
                    for a in range(self.num_clusters):
                        pp_L, pp_sum = self.pseudo_row_update(a)
                        # print("2nd row update \n",pp_L)
                        # p_L = self.pseudo_col_update(a,p_L,p_sum)
                        MI_new = self.get_MI(pp_L)
                        gain = MI_new[a]-self.MI[a]
                        MI_gains[a] = gain
                        pseudo_MI[a] = MI_new[a]
                    print("State = ",state)
                    print("Q(s,a) = ",[val0, val1])
                    print("Gains from greedy : ",MI_gains)
                    a2 = np.argmax(MI_gains)
                    print("Action chosen from greedy : ",a2)
                    print("Actual MI = ",self.MI)
                    print("pseudo MI = ",pseudo_MI)
                    a1 = np.argmax(np.array([val0, val1]))
                    print("Action policy = ",a1)

                    print(f"policy = pi1*\tQ={[val0,val1]} state={state} a1={a1} MI_gains={MI_gains}\ta2={a2}")
                    if self.prompt:
                        temp = input("Enter to proceed...")
                        if temp=="q":
                            self.prompt = False

                    return a1
                elif val0!=val0 and val1==val1:
                    print(f"policy = pi2*\tQ={[val0,val1]}\ta1={a1}")
                    return 1
                elif val0==val0 and val1!=val1:
                    print(f"policy = pi2*\tQ={[val0,val1]}\ta1={a1}")
                    return 0

        
        self.MI_counter = self.MI_counter+1
        # Information Gain
        MI_gains = np.zeros(self.num_clusters)
        for a in range(self.num_clusters):
            p_L, p_sum = self.pseudo_row_update(a)
            # p_L = self.pseudo_col_update(a,p_L,p_sum)
            MI_new = self.get_MI(p_L)
            gain = self.calculate_residual(self.MI, MI_new, a)
            MI_gains[a] = gain
            
        print(f"policy = MI*\tMI_gains={MI_gains}\ta2={a}")
        a = np.argmax(MI_gains)
        return a
    

    # def get_state(self, state_definition):
    #     print("Starting state calculation...")
    #     # self.MI = self.get_MI(self.L)
    #     print("current MI = ",self.MI)
    #     pi = np.zeros(self.num_clusters)
    #     if state_definition=="pseudo-difference":
    #         MI_gains = np.zeros(self.num_clusters)
    #         for a in range(self.num_clusters):
    #             ppp_L, p_sum = self.pseudo_row_update(a)
    #             MI_new = self.get_MI(ppp_L)
    #             gain = MI_new[a] - self.MI[a]
    #             pi[a] = MI_new[a]
    #             MI_gains[a] = gain
    #         print("Pseudo-calc for state = ",pi)
    #         print("MI gain = ",MI_gains)
    #         print("Ended state calculation...")
    #         return MI_gains
    #     elif state_definition=="MI":
    #         state = np.zeros(self.num_clusters)
    #         state = self.MI
    #         return state  

        
    
        

In [26]:
def circular_shift_identity(z, k):
    if k == -1:
        return np.zeros((z, z), dtype=int)
    else:
        return np.roll(np.eye(z, dtype=int), -k, axis=1)

def expand_base_matrix(B, z):
    m, n = B.shape
    H = np.zeros((m * z, n * z), dtype=int)  # Full parity check matrix initialized to zeros

    for i in range(m):
        for j in range(n):
            # Expand each base matrix element into a z x z block
            block = circular_shift_identity(z, B[i, j])
            H[i * z:(i + 1) * z, j * z:(j + 1) * z] = block

    return H
# generates list of codewords corresponding to the parity check matrix H
def get_codewords(H):
    GF = galois.GF(2)
    A = GF(H)
    N = A.null_space()
    vec = [np.array(v) for v in itertools.product([0, 1], repeat=N.shape[0])]
    C = np.array(GF(vec) @ N)
    return C

In [27]:
B = np.array([[0,2,-1,3,-1,-1,2,4,1,0],
            [3,1,-1,-1,0,0,1,2,3,0],
            [1,0,0,-1,1,4,2,1,0,-1],
            [-1,0,0,2,3,-1,-1,-1,0,-1],
            [1,0,2,0,1,0,-1,-1,-1,-1],
            [2,1,0,0,2,-1,-1,-1,1,-1]]
            ,dtype=int)

# B = np.array([[-1,0,1],
#              [1,0,2]],dtype=int)
print("B = \n",B)
z = 5

H = expand_base_matrix(B,z)
np.savetxt('output.txt',H,fmt="%d")
print("H : \n",H.shape)

C = get_codewords(H)
print("Codewords : \n",C)
print(len(C)," codewords found")
dmin = np.min(np.sum(C[1:],axis=1))
print("dmin = ",dmin)

B = 
 [[ 0  2 -1  3 -1 -1  2  4  1  0]
 [ 3  1 -1 -1  0  0  1  2  3  0]
 [ 1  0  0 -1  1  4  2  1  0 -1]
 [-1  0  0  2  3 -1 -1 -1  0 -1]
 [ 1  0  2  0  1  0 -1 -1 -1 -1]
 [ 2  1  0  0  2 -1 -1 -1  1 -1]]
H : 
 (30, 50)
Codewords : 
 [[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [1 1 1 ... 1 1 1]
 [1 1 1 ... 1 1 1]
 [1 1 1 ... 1 1 1]]
1048576  codewords found
dmin =  9


In [31]:
snrdb = 1.5
snr = math.pow(10,snrdb/20)
num_iter = 150

suc = 0
tot = 1
for kk in range(tot):
    tg = Decoder(H,"bsc",0,num_iter,15)

    c = C[np.random.choice(len(C))]
    print(f"codeword sent = {c}")
    sigma = 1/math.sqrt(snr)
    # print(np.power(0.1,c))
    llrs = np.array([(-2*e+1) for e in c])
    # print(llrs)
    y = np.power(-1,c) + np.random.normal(loc=0, scale=sigma, size=c.shape)
    c_hat = tg.decode(y,scheduling="policy",verbose="off", prompt=True)
    print(f"codeword decoded = {c_hat}")
    c_hat = [1 if val else 0 for val in c_hat]
    c_hat = np.array(c_hat)%2
    e = np.sum((c_hat+c)%2)/len(c)
    print(f"Error = {e}")    
    print("---------------------------------------------------------------")
    if e==0:
        suc = suc + 1

print(f"rate = {suc}/{tot}")


codeword sent = [0 0 1 1 0 1 1 1 1 0 0 1 1 1 1 1 0 1 1 1 0 0 0 0 0 1 0 0 0 1 1 1 0 1 1 0 0
 1 1 0 0 1 1 0 0 1 1 0 1 1]
state caculation started ...
Actual MI : [0. 0.]
Pseudo MI : [0.01957523 0.06979621]
Gain : [0.01957523 0.06979621] = [0.01957523 0.06979621]
State calculation ended ...
Choosing from pi*

State =  [0.01957523 0.06979621]
Q(s,a) =  [0.14383227460851555, 0.18397230234646403]
Gains from greedy :  [0.01957523 0.06979621]
Action chosen from greedy :  1
Actual MI =  [0. 0.]
pseudo MI =  [0.01957523 0.06979621]
Action policy =  1
policy = pi1*	Q=[0.14383227460851555, 0.18397230234646403] state=[0.01957523 0.06979621] a1=1 MI_gains=[0.01957523 0.06979621]	a2=1
state caculation started ...
Actual MI : [0.         0.06979621]
Pseudo MI : [0.09256519 0.06979621]
Gain : [0.09256519 0.        ] = [0.09256519 0.        ]
State calculation ended ...
policy = MI*	MI_gains=[0.09256519 0.        ]	a2=1
state caculation started ...
Actual MI : [0.09256519 0.06979621]
Pseudo MI : [0.0925