# GNN Decoder for BCH Codes

This notebook reproduces the (63,45) BCH results in the paper [Graph Neural Networks for Channel Decoding](https://arxiv.org/pdf/2207.14742.pdf).

**Remark**: the training can take several hours. However, pre-trained weights are stored in this repository and can be directly loaded to reproduce the results.

This notebook requires [Sionna](https://nvlabs.github.io/sionna/).

In [2]:
# general imports
import tensorflow as tf
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

# load required Sionna components
from sionna.fec.utils import load_parity_check_examples, LinearEncoder, gm2pcm
from sionna.utils.plotting import PlotBER
from sionna.fec.ldpc import LDPCBPDecoder

%load_ext autoreload
%autoreload 2
from gnn import * # load GNN functions
from wbp import * # load weighted BP functions

In [3]:
gpus = tf.config.list_physical_devices('GPU')
print('Number of GPUs available :', len(gpus))
if gpus:
    gpu_num = 0 # Number of the GPU to be used
    try:
        #tf.config.set_visible_devices([], 'GPU')
        tf.config.set_visible_devices(gpus[gpu_num], 'GPU')
        print('Only GPU number', gpu_num, 'used.')
        tf.config.experimental.set_memory_growth(gpus[gpu_num], True)
    except RuntimeError as e:
        print(e)

Number of GPUs available : 0


## Define Hyperparameters and Load Code

We define all parameters as dictionary to support different architectures for different codes.


In [4]:
#----- BCH -----
params={
    # --- Code Parameters ---
        "code": "BCH", # (63,45)
    # --- GNN Architecture ----
        "num_embed_dims": 20,
        "num_msg_dims": 20,
        "num_hidden_units": 40,
        "num_mlp_layers": 2,
        "num_iter": 8,
        "reduce_op": "mean",
        "activation": "tanh",
        "clip_llr_to": None,
        "use_attributes": False,
        "node_attribute_dims": 0,
        "msg_attribute_dims": 0,
        "use_bias": False,        
    # --- Training ---- # 
        "batch_size": [256, 256, 256], # bs, iter, lr must have same dim
        "train_iter": [35000, 300000, 300000],
        "learning_rate": [1e-3, 1e-4, 1e-5],
        "ebno_db_train": [3, 8.],
        "ebno_db_eval": 4.,          
        "batch_size_eval": 10000, # batch size only used for evaluation during training
        "eval_train_steps": 1000, # evaluate model every N iters
    # --- Log ----
        "save_weights_iter": 10000, # save weights every X iters
        "run_name": "BCH_01", # name of the stored weights/logs
        "save_dir": "results/", # folder to store results
    # --- MC Simulation parameters ----
        "mc_iters": 1000,
        "mc_batch_size": 2000,
        "num_target_block_errors": 500,
        "ebno_db_min": 2.,
        "ebno_db_max": 9.,
        "ebno_db_stepsize": 1.,
        "eval_iters": [2, 3, 4, 6, 8, 10],        
    # --- Weighted BP parameters ----
        "simulate_wbp": True, # simulate weighted BP as baseline
        "wbp_batch_size" : [2000, 2000, 2000],
        "wbp_train_iter" : [300, 10000, 2000],
        "wbp_learning_rate" : [1e-2, 1e-3, 1e-3],
        "wbp_ebno_train" : [5., 5., 6.],
        "wbp_ebno_val" : 7., # validation SNR during training
        "wbp_batch_size_val" : 2000,
        "wbp_clip_value_grad" : 10, 
}

## Load / Generate the Graph

In [5]:
# all codes must provide an encoder-layer and a pcm
if params["code"]=="BCH":    
    print("Loading BCH code")
    pcm, k, n, coderate = load_parity_check_examples(pcm_id=1, verbose=True)

    encoder = LinearEncoder(pcm, is_pcm=True)
    params["k"] = k
    params["n"] = n
else:
    raise ValueError("Unknown code type")

Loading BCH code

n: 63, k: 45, coderate: 0.714


## Simulate Baseline BER Performance

In [6]:
ber_plot = PlotBER(f"GNN-based Decoding - {params['code']}, (k,n)=({k},{n})")
ebno_dbs = np.arange(params["ebno_db_min"],
                     params["ebno_db_max"]+1,
                     params["ebno_db_stepsize"])

In [7]:
# uncoded QPSK
e2e_uncoded = E2EModel(None, None, k=100, n=100) # k and n are not relevant here
ber_plot.simulate(e2e_uncoded,
                  ebno_dbs=ebno_dbs,
                  batch_size=params["mc_batch_size"],
                  num_target_block_errors=params["num_target_block_errors"],
                  legend="Uncoded",
                  soft_estimates=True,
                  max_mc_iter=params["mc_iters"],
                  forward_keyboard_interrupt=False,
                  show_fig=False);

2023-11-12 12:11:18.085503: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x2a1eb7dc0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2023-11-12 12:11:18.085654: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2023-11-12 12:11:18.133894: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:255] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-11-12 12:11:18.175145: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator mapper/assert_greater_equal/Assert/Assert
2023-11-12 12:11:18.175193: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator awgn/assert_greater_equal/Assert/Assert
2023-11-12 12:11:18.175203: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator awgn/assert_less_equal/Assert/Assert
2023-11-12 12:11:18.175212: W tensorflow/compiler/tf2xla

EbNo [dB] |        BER |       BLER |  bit errors |    num bits | block errors |  num blocks | runtime [s] |    status
---------------------------------------------------------------------------------------------------------------------------------------
      2.0 | 3.7740e-02 | 9.7950e-01 |        7548 |      200000 |         1959 |        2000 |         0.8 |reached target block errors
      3.0 | 2.2205e-02 | 8.9900e-01 |        4441 |      200000 |         1798 |        2000 |         0.0 |reached target block errors
      4.0 | 1.2420e-02 | 7.1950e-01 |        2484 |      200000 |         1439 |        2000 |         0.0 |reached target block errors
      5.0 | 6.1400e-03 | 4.5400e-01 |        1228 |      200000 |          908 |        2000 |         0.0 |reached target block errors
      6.0 | 2.3850e-03 | 2.1300e-01 |         954 |      400000 |          852 |        4000 |         0.0 |reached target block errors
      7.0 | 7.3750e-04 | 7.1125e-02 |         590 |      800000 |

In [8]:
# simulate "conventional" BP performance for given pcm
bp_decoder = LDPCBPDecoder(pcm, num_iter=20, hard_out=False)
e2e_bp = E2EModel(encoder, bp_decoder, k, n)
ber_plot.simulate(e2e_bp,
                 ebno_dbs=ebno_dbs,
                 batch_size=params["mc_batch_size"],
                 num_target_block_errors=params["num_target_block_errors"],
                 legend=f"BP {bp_decoder._num_iter.numpy()} iter.",
                 soft_estimates=True,
                 max_mc_iter=params["mc_iters"],
                 forward_keyboard_interrupt=False,
                 show_fig=False);

2023-11-12 12:11:20.273279: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator mapper_1/assert_greater_equal/Assert/Assert
2023-11-12 12:11:20.273301: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator awgn_1/assert_greater_equal/Assert/Assert
2023-11-12 12:11:20.273309: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator awgn_1/assert_less_equal/Assert/Assert
2023-11-12 12:11:20.273317: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator awgn_1/assert_greater_equal_1/Assert/Assert
2023-11-12 12:11:20.282405: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator ldpcbp_decoder/assert_equal_1/Assert/Assert
2023-11-12 12:11:20.318521: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator ldpcbp_decoder/while/RaggedFromRowSplits/RowPartitionFromRowSplits/assert_equal_1/Assert/Assert
2023-11-12 12:11:20.318549: W tensorflow/compiler

EbNo [dB] |        BER |       BLER |  bit errors |    num bits | block errors |  num blocks | runtime [s] |    status
---------------------------------------------------------------------------------------------------------------------------------------
      2.0 | 5.9460e-02 | 7.9650e-01 |        7492 |      126000 |         1593 |        2000 |         1.1 |reached target block errors
      3.0 | 3.1968e-02 | 4.9300e-01 |        4028 |      126000 |          986 |        2000 |         0.1 |reached target block errors
      4.0 | 1.3528e-02 | 2.0825e-01 |        3409 |      252000 |          833 |        4000 |         0.2 |reached target block errors
      5.0 | 4.3373e-03 | 6.3750e-02 |        2186 |      504000 |          510 |        8000 |         0.4 |reached target block errors
      6.0 | 9.7354e-04 | 1.1952e-02 |        2576 |     2646000 |          502 |       42000 |         1.9 |reached target block errors
      7.0 | 1.4286e-04 | 1.8872e-03 |        2394 |    16758000 |

In [9]:
# train and simulate Weighted BP as additional baseline
# please note that the training parameters could be critical 
if params["simulate_wbp"]:
    evaluate_wbp(params, pcm, encoder, ebno_dbs, ber_plot)



Note that WBP requires Sionna > v0.11.
Iter: 0 loss: 0.003048 ber: 0.000095 bmi: 0.999
Iter: 50 loss: 0.001102 ber: 0.000040 bmi: 1.000
Iter: 100 loss: 0.001546 ber: 0.000135 bmi: 0.999
Iter: 150 loss: 0.002193 ber: 0.000079 bmi: 0.999


KeyboardInterrupt: 

### GNN-based Decoding


In [10]:
tf.random.set_seed(2) # we fix the seed to ensure stable convergence 

# init the GNN decoder
gnn_decoder = GNN_BP(pcm=pcm,
                     num_embed_dims=params["num_embed_dims"],
                     num_msg_dims=params["num_msg_dims"],
                     num_hidden_units=params["num_hidden_units"],
                     num_mlp_layers=params["num_mlp_layers"],
                     num_iter=params["num_iter"],
                     reduce_op=params["reduce_op"],
                     activation=params["activation"],
                     output_all_iter=True,
                     clip_llr_to=params["clip_llr_to"],
                     use_attributes=params["use_attributes"],
                     node_attribute_dims=params["node_attribute_dims"],
                     msg_attribute_dims=params["msg_attribute_dims"],
                     use_bias=params["use_bias"])
                     
e2e_gnn = E2EModel(encoder, gnn_decoder, k, n)

In [11]:
# init model and print summary
e2e_gnn(1, 1.)
e2e_gnn.summary()

Model: "e2e_model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 binary_source_2 (BinarySou  multiple                  0         
 rce)                                                            
                                                                 
 mapper_2 (Mapper)           multiple                  0         
                                                                 
 demapper_2 (Demapper)       multiple                  0         
                                                                 
 awgn_2 (AWGN)               multiple                  0         
                                                                 
 gnn_bp (GNN_BP)             multiple                  9640      
                                                                 
 linear_encoder (LinearEnco  multiple                  0         
 der)                                                  

and let's train the model

In [12]:
train = False # remark: training takes several hours
if train:
    train_gnn(e2e_gnn, params)
else:
    # you can also load the precomputed weights
    load_weights(e2e_gnn, "weights/BCH_precomputed.npy")

## prune model

top k pruning // random pruning

In [13]:
# pruning function
import pandas as pd
from tensorflow.python.ops.script_ops import numpy_function
import random

# pruning percentage
K = [5]

print(e2e_gnn._decoder.update_h_vn._embed_mlp)
print(e2e_gnn._decoder.update_h_vn._embed_mlp._layers[0].get_weights)

msg_mlp = e2e_gnn._decoder.update_h_vn._msg_mlp._layers
embed_mlp = e2e_gnn._decoder.update_h_vn._embed_mlp._layers

num_msg_mlp_layers = len(msg_mlp)
num_embed_mlp_layers = len(embed_mlp)

# pruning function 구현
def pruner_mlp(mlp): # K 에 맞는 비율로 pruning을 진행한다.
    all_weights = {}

    for layer_no in range(len(mlp) -1):
        layer_weights = (pd.DataFrame(mlp[layer_no].get_weights()[0]).stack()).to_dict()
        layer_weights = {(layer_no, k[0], k[1]): v for k, v in layer_weights.items()}
        all_weights.update(layer_weights)
        
    all_weights_sorted = {k: v for k, v in sorted(all_weights.items(), key=lambda item: abs(item[1]))}
    
    print(all_weights_sorted)
    total_no_weights = len(all_weights_sorted)
    
    for pruning_percent in K:
        for layer_no in range(len(mlp) -1):
            new_weights = mlp[layer_no].get_weights()
            
            prune_fraction = pruning_percent/100 # K에서 지정한 비율만큼 pruning
            number_of_weights_to_be_pruned = int(prune_fraction*total_no_weights)
            weights_to_be_pruned = {k: all_weights_sorted[k] for k in list(all_weights_sorted)[: number_of_weights_to_be_pruned]}

            for k, v in weights_to_be_pruned.items():
                new_weights[k[0]][k[1], k[2]]=0 # weights_to_be_pruned에 저장된 인덱스의 가중치를 0으로 만들어 준다.
                
            mlp[layer_no].set_weights(new_weights)

def random_pruner_mlp(mlp): # random pruning을 진행한다.
    all_weights = {}

    for layer_no in range(len(mlp) -1):
        layer_weights = (pd.DataFrame(mlp[layer_no].get_weights()[0]).stack()).to_dict()
        layer_weights = {(layer_no, k[0], k[1]): v for k, v in layer_weights.items()}
        all_weights.update(layer_weights)
        
    all_weights_sorted = {k: v for k, v in sorted(all_weights.items(), key=lambda item: abs(item[1]))}
    
    print(all_weights_sorted)
    total_no_weights = len(all_weights_sorted)
    
    for pruning_percent in K:
        for layer_no in range(len(mlp) -1):
            new_weights = mlp[layer_no].get_weights()
            
            prune_fraction = pruning_percent / 100 # K에서 지정한 비율만큼 pruning
            number_of_weights_to_be_pruned = int(prune_fraction*total_no_weights)
            # random pruning에서 중요한 부분
            random_indices = np.random.choice(total_no_weights, size=number_of_weights_to_be_pruned, replace=False)
            random_weights_to_be_pruned = [list(layer_weights.keys())[i] for i in random_indices]

            for k in random_weights_to_be_pruned:
                new_weights[k[0]][k[1], k[2]]=0 # weights_to_be_pruned에 저장된 인덱스의 가중치를 0으로 만들어 준다.

            mlp[layer_no].set_weights(new_weights)    
    
random_pruner_mlp(msg_mlp)
random_pruner_mlp(embed_mlp)

<gnn.MLP object at 0x2d34dc510>
<bound method Layer.get_weights of <keras.src.layers.core.dense.Dense object at 0x2d5328c10>>
{(0, 24, 22): 0.0001909721759147942, (0, 7, 15): -0.0006235960754565895, (0, 39, 1): -0.0013334148097783327, (0, 14, 31): -0.0015354674542322755, (0, 6, 16): -0.0018282118253409863, (0, 25, 27): -0.0020399109926074743, (0, 18, 27): -0.002050400245934725, (0, 7, 31): 0.002213281812146306, (0, 28, 35): 0.002231164136901498, (0, 7, 20): -0.002579705324023962, (0, 30, 5): -0.003068871097639203, (0, 13, 30): 0.0035210279747843742, (0, 21, 15): 0.0037671728059649467, (0, 22, 5): 0.004191659390926361, (0, 37, 36): -0.004276237916201353, (0, 19, 6): 0.00449257530272007, (0, 22, 17): 0.00519695645198226, (0, 36, 3): -0.005583583377301693, (0, 34, 32): -0.0056730047799646854, (0, 3, 21): -0.005762663669884205, (0, 16, 14): -0.00608694925904274, (0, 20, 15): -0.00618352647870779, (0, 24, 16): -0.006499253213405609, (0, 36, 9): -0.006638810038566589, (0, 17, 9): 0.006748681

## Evaluate Final Performance

and store results...

In [None]:
for iters in params["eval_iters"]:
    # instantiate new decoder for each number of iter (otherwise no retracing)
    gnn_dec_temp = GNN_BP(pcm=pcm,
                          num_embed_dims=params["num_embed_dims"],
                          num_msg_dims=params["num_msg_dims"],
                          num_hidden_units=params["num_hidden_units"],
                          num_mlp_layers=params["num_mlp_layers"],
                          num_iter=iters,
                          reduce_op=params["reduce_op"],
                          activation=params["activation"],
                          output_all_iter=False,
                          clip_llr_to=params["clip_llr_to"],
                          use_attributes=params["use_attributes"],
                          node_attribute_dims=params["node_attribute_dims"],
                          msg_attribute_dims=params["msg_attribute_dims"],
                          use_bias=params["use_bias"])    
    # generate new model   
    model = E2EModel(encoder, gnn_dec_temp, k, n)
    model(1,1.) # init model
    # copy weights from trained decoder
    model._decoder.set_weights(gnn_decoder.get_weights())

    # and run the BER simulations
    ber_plot.simulate(model,
                     ebno_dbs=ebno_dbs,
                     batch_size=params["mc_batch_size"],
                     num_target_block_errors=500,
                     legend=f"GNN {iters} iter.",
                     soft_estimates=True,
                     max_mc_iter=params["mc_iters"],
                     forward_keyboard_interrupt=False,
                     show_fig=False);

ber_plot(xlim=[2,9],ylim=[1e-5,0.1]) # show final figure

EbNo [dB] |        BER |       BLER |  bit errors |    num bits | block errors |  num blocks | runtime [s] |    status
---------------------------------------------------------------------------------------------------------------------------------------
      2.0 | 2.0212e-01 | 9.9700e-01 |       25467 |      126000 |         1994 |        2000 |         1.8 |reached target block errors
      3.0 | 1.3008e-01 | 9.8200e-01 |       16390 |      126000 |         1964 |        2000 |         0.2 |reached target block errors
      4.0 | 6.6381e-02 | 8.6850e-01 |        8364 |      126000 |         1737 |        2000 |         0.3 |reached target block errors
      5.0 | 2.5794e-02 | 5.6850e-01 |        3250 |      126000 |         1137 |        2000 |         0.3 |reached target block errors
      6.0 | 6.8730e-03 | 2.2900e-01 |        1732 |      252000 |          916 |        4000 |         0.5 |reached target block errors
      7.0 | 1.1794e-03 | 5.4800e-02 |         743 |      630000 |

In [None]:
ber_plot(xlim=[2,9],ylim=[1e-5,0.1]) # show final figure

In [None]:
# export results for pgf plots
col_names = ["uncoded",  "bp-20"]
if params["simulate_wbp"]:
    col_names.append("wbp-20")
for it in params["eval_iters"]:
    col_names.append("gnn-" + str(it))
export_pgf(ber_plot, col_names)