<a href="https://colab.research.google.com/github/JWooram/gnn_decoder_pruning_2023_11_12/blob/main/GNN_decoder_standalone.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Weighted BP

The BP decoder can be further optimized by applying the concept of *weighted* BP [3].
See [Sionna example notebook](https://nvlabs.github.io/sionna/examples/Weighted_BP_Algorithm.html) for further details.

Note that the model below is only used to train the WBP decoder, the evaluation is done by using the previously defined e2e-model.


In [None]:
class WeightedBP(tf.keras.Model):
    """System model for BER simulations of weighted BP decoding.

    This model uses `GaussianPriorSource` to mimic the LLRs after demapping of
    QPSK symbols transmitted over an AWGN channel.

    Parameters
    ----------
        pcm: ndarray
            The parity-check matrix of the code under investigation.

        num_iter: int
            Number of BP decoding iterations.

    Input
    -----
        batch_size: int or tf.int
            The batch_size used for the simulation.

        ebno_db: float or tf.float
            A float defining the simulation SNR. # simulation 환경에 맞는 snr을 입력한다.

    Output
    ------
        (u, u_hat, loss):
            Tuple:

        u: tf.float32
            A tensor of shape `[batch_size, k] of 0s and 1s containing the transmitted information bits.

        u_hat: tf.float32
            A tensor of shape `[batch_size, k] of 0s and 1s containing the estimated information bits.

        loss: tf.float32
            Binary cross-entropy loss between `u` and `u_hat`.
    """
    def __init__(self, pcm, num_iter=5):
        super().__init__()
        print("Note that WBP requires Sionna > v0.11.")
        # init components
        self.decoder = LDPCBPDecoder(pcm,
                                     num_iter=1, # iterations are done via outer loop (to access intermediate results for multi-loss)
                                     stateful=True,
                                     hard_out=False, # we need to access soft-information
                                     cn_type="boxplus",
                                     trainable=True) # the decoder must be trainable, otherwise no weights are generated

        # used to generate llrs during training (see example notebook on all-zero codeword trick)
        self._llr_source = GaussianPriorSource()
        self._num_iter = num_iter

        self._n = pcm.shape[1]
        self._coderate = 1 - pcm.shape[0]/pcm.shape[1]
        self._bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    def call(self, batch_size, ebno_db):
        #batch_size = tf.constant(batch_size, dtype=tf.int32)
        #ebno_db = tf.constant(ebno_db, dtype=tf.float32)
        noise_var = ebnodb2no(ebno_db,
                              num_bits_per_symbol=2, # QPSK
                              coderate=self._coderate)

        # all-zero CW to calculate loss / BER
        c = tf.zeros([batch_size, self._n])

        # Gaussian LLR source
        llr = self._llr_source([[batch_size, self._n], noise_var])

        # implement multi-loss as proposed by Nachmani et al.
        loss = 0
        msg_vn = None
        for _ in range(self._num_iter):
            c_hat, msg_vn = self.decoder((llr, msg_vn)) # perform one decoding iteration; decoder
            # returns soft-values
            loss += self._bce(c, c_hat)  # add loss after each iteration

        loss /= self._num_iter # scale loss by number of iterations

        return c, c_hat, loss

    def train_wbp(self, train_param):

        assert len(train_param["batch_size"])==len(train_param["train_iter"]),\
                        "batch_size must have same lengths as train_iter"

        assert len(train_param["batch_size"])==\
               len(train_param["learning_rate"]),\
                "batch_size must have same lengths as learning_rate"

        assert len(train_param["batch_size"])==len(train_param["ebno_train"]),\
                        "batch_size must have same lengths as ebno_train"

        # bmi is used as metric to evaluate the intermediate results
        bmi = BitwiseMutualInformation()

        for idx, batch_size in enumerate(train_param["batch_size"]):

            optimizer = tf.keras.optimizers.Adam(train_param["learning_rate"][idx])

            for it in range(train_param["train_iter"][idx]):
                with tf.GradientTape() as tape:
                    b, llr, loss = self(batch_size,
                                        train_param["ebno_train"][idx]) #call 함수 호출하는 구문 / output으로 c, c_hat, loss return

                grads = tape.gradient(loss, self.trainable_variables) # trainable_variables?
                grads = tf.clip_by_value(grads,
                                         -train_param["clip_value_grad"],
                                         train_param["clip_value_grad"],
                                         name=None)
                optimizer.apply_gradients(zip(grads, self.trainable_weights)) # ??

                # calculate and print intermediate metrics
                # only for information, this has no impact on the training
                if it%50==0: # evaluate every 10 iterations / 10 iter 마다 hard decision 실행
                    b, llr, loss = self(train_param["batch_size_val"],
                                        train_param["ebno_val"])
                    b_hat = hard_decisions(llr) # hard decided LLRs first
                    ber = compute_ber(b, b_hat) # BER 계산
                    # and print results
                    mi = bmi(b, llr).numpy() # calculate bit-wise mutual information
                    l = loss.numpy() # copy loss to numpy for printing ?
                    print(f"Iter: {it} loss: {l:3f} ber: {ber:.4f} bmi: {mi:.3f}".format())
                    bmi.reset_states() # reset the BMI metric

*Remark*: running the WBP requires at least Sionna v0.11.

In [None]:
model_wbp = WeightedBP(pcm=pcm, num_iter=10) # model is only used for training

In [None]:
train_param = {
    "batch_size" : [2000, 2000, 2000],
    "train_iter" : [300, 1000, 1000],
    "learning_rate" : [1e-2, 1e-3, 1e-3],
    "ebno_train" : [5., 5., 6.],
    "ebno_val" : 7., # validation SNR during training
    "batch_size_val" : 2000,
    "clip_value_grad" : 10,
}

In [None]:
model_wbp.train_wbp(train_param)

In [None]:
# generate new decoder object (with 20 iterations)
bp_decoder_wbp = LDPCBPDecoder(pcm, num_iter=20, hard_out=False, trainable=True) # LDPC Decoder 설정에 의해서

# copy weights from trained decoder
bp_decoder_wbp.set_weights(model_wbp.decoder.get_weights())
e2e_wbp = E2EModel(pcm, bp_decoder_wbp) # end-to-end model

# and evaluate the performance
ber_plot.simulate(e2e_wbp,
                  ebno_dbs=ebno_dbs,
                  batch_size=mc_batch_size*10, #increased bs for higher gpu load
                  num_target_block_errors=num_target_block_errors,
                  legend=f"Weighted BP {e2e_wbp._decoder.num_iter.numpy()} iter.",
                  soft_estimates=True,
                  max_mc_iter=mc_iters,
                  forward_keyboard_interrupt=False,
                  show_fig=False);



### GNN-based Decoding

We now define the GNN for iterative decoding.

In [None]:
pip install tensorflow_model_optimization

In [None]:
bp_decoder = LDPCBPDecoder(pcm, num_iter=20, hard_out=False)
e2e_bp = E2EModel(pcm, bp_decoder)

ber_plot.simulate(e2e_bp,
                  ebno_dbs=ebno_dbs,
                  batch_size=mc_batch_size,
                  num_target_block_errors=num_target_block_errors,
                  legend=f"BP {e2e_bp._decoder.num_iter.numpy()} iter.",
                  soft_estimates=True,
                  max_mc_iter=mc_iters,
                  forward_keyboard_interrupt=False,
                  show_fig=False);

# Graph Neural Network based FEC Decoding

In this notebook, you will learn about graph neural network (GNN)-based decoding of BCH codes.

This code is provided as supplementary material to the paper [Graph Neural Networks for Channel Decoding](https://arxiv.org/pdf/2207.14742.pdf).
If you in any way use this code for research that results in publications, please cite it appropriately.

The idea is to let a neural network (NN) learn a generalized message passing algorithm over a given graph that represents the forward error correction (FEC) code
structure by replacing node and edge message updates with trainable functions.
A similar decoding approach based on GNNs can be found in [2].

Note that some simulations in this notebook require severe simulation time, in particular, for the GNN training.
Please keep in mind that each cell in this notebook already contains the pre-computed outputs and no new execution is required to understand the examples.

## Table of Contents
* [GPU Configuration and Imports](#GPU-Configuration-and-Imports)
* [Defining the End-to-End Model](#Define-the-End-to-end-Model)
* [Weighted BP](#Weighted-BP)
* [GNN-based Decoding](#GNN-based-Decoding)
* [Training](#Training)
* [Final Performance](#Final-Performance)
* [References](#References)

## GPU Configuration and Imports

In [2]:
# GNN based decoder algorithm
#
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

# Import Sionna
try:
    import sionna as sn
except ImportError as e:
    # sionna가 설치 되지 않았을 경우에 실행되는 구문
    import os
    os.system("pip install sionna")
    import sionna as sn

# For plotting
%matplotlib inline
# also try %matplotlib widget

In [3]:
gpus = tf.config.list_physical_devices('GPU') # 사용 가능한 gpu 확인 구문
print('Number of GPUs available :', len(gpus)) # 사용 가능한 코어 확인
if gpus:
    gpu_num = 0 # Number of the GPU to be used
    try:
        #tf.config.set_visible_devices([], 'GPU')
        tf.config.set_visible_devices(gpus[gpu_num], 'GPU')
        print('Only GPU number', gpu_num, 'used.')
        tf.config.experimental.set_memory_growth(gpus[gpu_num], True)
    except RuntimeError as e:
        print(e)

Number of GPUs available : 0


Please ensure that the [Sionna link-level simulator](https://nvlabs.github.io/sionna/) is installed.

In [4]:
from sionna.fec.utils import load_parity_check_examples, LinearEncoder, GaussianPriorSource
from sionna.utils import BinarySource, ebnodb2no, BitwiseMutualInformation, hard_decisions
from sionna.utils.metrics import compute_ber
from sionna.utils.plotting import PlotBER
from sionna.mapping import Mapper, Demapper
from sionna.channel import AWGN
from sionna.fec.ldpc import LDPCBPDecoder

from tensorflow.keras.layers import Dense, Layer

In [5]:
bi = BinarySource()
print(bi)

<sionna.utils.misc.BinarySource object at 0x7e5f4a5f83a0>


## Define the End-to-end Model

We define an end-to-end transmission model to evaluate the trained decoders and the baseline in the same setup.

In [6]:
class E2EModel(tf.keras.Model):
    def __init__(self, pcm, decoder): # pcm (parity check matrix) / 패리티 체크 행렬과 디코더가 입력으로 필요하다.
        super().__init__() # tf.keras.Model init 상속
        self._pcm = pcm
        self._n = pcm.shape[1] # pcm의 열의 개수 = n
        print("self._n 출력 : ", self._n)
        self._k = self._n - pcm.shape[0] # 패리티 체크 행렬 n-(n-k)=k importmation bit length
        print("self._k : ", self._k)
        self._encoder = LinearEncoder(pcm, is_pcm=True) # pcm을 가지고 generater matrix 생성
        self._binary_source = BinarySource() # binary message를 생성해주는 함수

        self._num_bits_per_symbol = 2 # at the moment only QPSK is supported
        # symbol로 변환? QPSK는 한 비트당 두개의 심볼을 보낸다.
        # 따라서 2비트씩 하나의 심볼로 변환하기 때문에 길이는 절반으로 줄어들 것이다.

        self._mapper = Mapper("qam", self._num_bits_per_symbol)
        print("self._mapper : ", self._mapper)
        self._demapper = Demapper("app", "qam", self._num_bits_per_symbol) # app 의미 : LLR을 구하는 방식 정해주는 구문 / 노션에 설명 첨부
        self._channel = AWGN()
        self._decoder = decoder

    @tf.function()
    def call(self, batch_size, ebno_db):

        # calculate noise variance
        if self._decoder is not None:
            no = ebnodb2no(ebno_db, self._num_bits_per_symbol, self._k/self._n) # code rate
        else: #for uncoded BPSK the rate is 1
            no = ebnodb2no(ebno_db, self._num_bits_per_symbol, 1)

        # draw random info bits to transmit
        b = self._binary_source([batch_size, self._k]) #ramdom하게 binary bit 생성 (bk) [batch_size, information length]
        c = self._encoder(b)

        # zero padding to support odd codeword lengths / 홀수 코드워드인 경우에만 패딩을 더해준다.
        if self._n%2==1:
            c_pad = tf.concat([c, tf.zeros([batch_size, 1])], axis=1)
        else: # no padding
            c_pad = c

        # map to symbols
        x = self._mapper(c_pad) # 패딩 후 짝수 코드

        # transmit over AWGN channel
        y = self._channel([x, no]) # 채널을 통과시키는 구문

        # demap to LLRs
        llr = self._demapper([y, no]) # LLR 구하는 구문

        # remove filler bits
        if self._n%2==1:
            llr = llr[:,:-1]

        # and decode
        if self._decoder is not None:
            llr = self._decoder(llr)

        return c, llr

We now load the parity-check matrix of the code. This line can be replaced by other codes as well, however, please keep in mind that the hyperparameters of the decoder (and its training) must be adjusted accordingly.

In [7]:
# load the parity-check matrix of the code
pcm, k, n, coderate = load_parity_check_examples(pcm_id=1, verbose=True)

print(pcm)
print(pcm.shape) # n-k = 18, n=63


n: 63, k: 45, coderate: 0.714
[[1 1 0 ... 0 0 0]
 [0 1 1 ... 0 0 0]
 [0 0 1 ... 0 0 0]
 ...
 [0 0 0 ... 1 0 0]
 [0 0 0 ... 1 1 0]
 [0 0 0 ... 0 1 1]]
(18, 63)


Let us define the simulation parameters.

In [8]:
ber_plot = PlotBER("GNN-based Decoding Results") # ber graph 그리기 위한 구문
ebno_db_min = 2.0 # 하한값
ebno_db_max = 9.0 # 최댓값
ebno_dbs = np.arange(ebno_db_min,ebno_db_max+1) # 2~9
print(ebno_dbs)

mc_iters = 100
mc_batch_size = 10000
num_target_block_errors = 2000

[2. 3. 4. 5. 6. 7. 8. 9.]


Simulate the uncoded baseline first.

In [9]:
e2e_uncoded = E2EModel(pcm, None) # 기준 성능 확인 / 디코더가 설정이 되어있지 않아 BPSK로 진행

ber_plot.simulate(e2e_uncoded,
                  ebno_dbs=ebno_dbs, # 2~9
                  batch_size=mc_batch_size, #10000
                  num_target_block_errors=num_target_block_errors, # 2000번의 블럭 오류가 발생할 때까지 simulate 진행
                  legend="Uncoded",
                  soft_estimates=True,
                  max_mc_iter=mc_iters,
                  forward_keyboard_interrupt=False,
                  show_fig=False);



self._n 출력 :  63
self._k :  45
self._mapper :  <sionna.mapping.Mapper object at 0x7e5ebaf989a0>
EbNo [dB] |        BER |       BLER |  bit errors |    num bits | block errors |  num blocks | runtime [s] |    status
---------------------------------------------------------------------------------------------------------------------------------------
      2.0 | 3.7376e-02 | 9.1090e-01 |       23547 |      630000 |         9109 |       10000 |         2.5 |reached target block errors
      3.0 | 2.2919e-02 | 7.6540e-01 |       14439 |      630000 |         7654 |       10000 |         0.1 |reached target block errors
      4.0 | 1.2614e-02 | 5.5470e-01 |        7947 |      630000 |         5547 |       10000 |         0.1 |reached target block errors
      5.0 | 6.0841e-03 | 3.1950e-01 |        3833 |      630000 |         3195 |       10000 |         0.2 |reached target block errors
      6.0 | 2.4214e-03 | 1.4260e-01 |        3051 |     1260000 |         2852 |       20000 |         0.

And the BP decoding baseline. Please note that, BP decoding is not necessarily a good choice to decode BCH codes due to their high-density graph structure.

In [10]:
class MLP(Layer):
    """Simple MLP layer.

    Parameters
    ----------
    units : List of int
        Each element of the list describes the number of units of the
        corresponding layer.

    activations : List of activations
        Each element of the list contains the activation to be used
        by the corresponding layer.

    use_bias : List of booleans
        Each element of the list indicates if the corresponding layer
        should use a bias or not.
    """
    def __init__(self, units, activations, use_bias): # 생성자 (노드 개수, 활성화 함수, 사용하는 bias)
        super().__init__()
        self._num_units = units
        self._activations = activations
        self._use_bias = use_bias

    def build(self, input_shape):
        print('MLP build')
        self._layers = []
        for i, units in enumerate(self._num_units):
            print(i, '번째 layer 쌓기 : ', units)
            self._layers.append(Dense(units,
                                      self._activations[i],
                                      use_bias=self._use_bias[i]))



    def call(self, inputs): # list 형태로 쌓아놓은 self._layers를
        print("MLP call")
        print('MLP inputs :', inputs)
        outputs = inputs
        for layer in self._layers:
            outputs = layer(outputs)
            print("계산중")

        print("MLP outpus :", outputs)

        return outputs

class GNN_BP(Layer):
    """GNN-based BP Decoder

    Parameters
    ---------
    H : [num_ch, num_vn], numpy.array
        The parity check matrix.

    num_embed_dims: int
        Number of dimensions of the vertex embeddings.

    num_msg_dims: int
        Number of dimensions of a message.

    num_hidden_units: int
        Number of hidden units of the MLPs used to compute
        messages and to update the vertex embeddings.

    num_mlp_layers: int
        Number of layers of the MLPs used to compute
        messages and to update the vertex embeddings.

    num_iter: int
        Number of iterations.

    reduce_op: str
        A string defining the vertex aggregation function.
        Currently, "mean" and "sum" is supported.

    activation: str
        A string defining the activation function of the hidden MLP layers to
        be used. Defaults to "relu".

    output_all_iter: Bool
        Indicates if the LLRs of all iterations should be returned as list
        or if only the LLRs of the last iteration should be returned.

    clip_llr_to: float or None
        If set, the absolute value of the input LLRs will be clipped to this value.

    Input
    -----
    llr : [batch_size, num_vn], tf.float32
        Tensor containing the LLRs of all bits.

    Output
    ------
    llr_hat: : [batch_size, num_vn], tf.float32
        Tensor containing the LLRs at the decoder output.
        If `output_all_iter`==True, a list of such tensors will be returned.
    """
    def __init__(self,
                 pcm, # H
                 num_embed_dims, # embedding 할 때 사용할 차원의 수
                 num_msg_dims,
                 num_hidden_units,
                 num_mlp_layers,
                 num_iter,
                 reduce_op="sum",
                 activation="relu",
                 output_all_iter=False, # 각 반복마다 발생한 llr 값을 모두 출력할 것인지 설정하는 인자
                 clip_llr_to=None): # 최종 llr CLIIPING 할 것인지 결정하는 것인데 이것은 아직 모르겠다.
        super().__init__() # 생성자 상속

        self._pcm = pcm # Parity check matrix
        self._num_cn = pcm.shape[0] # Number of check nodes 행의 길이가 체크 노드
        self._num_vn = pcm.shape[1] # Number of variables nodes 열의 길이가 변수 노드
        self._num_edges = int(np.sum(pcm)) # Number of edges / 패리티 행렬의 1을 모두 더하여 엣지의 수를 구한다.
        print("num_edges :", self._num_edges)

        # Array of shape [num_edges, 2]
        # 1st col = CN id, 2nd col = VN id
        # The ith row of this array defines the ith edge.
        self._edges = np.stack(np.where(pcm), axis=1) # parity check matrix에서 1인 인덱스를 받아서 edges에 전부 쌓는다.

        # Create 2D ragged tensor of shape [num_cn,...]
        # cn_edges[i] contains the edge ids for CN i
        cn_edges = []


        """
self._edges = np.array([
    [0, 2],
    [1, 2],
    [1, 3],
    [2, 3],
    [2, 4],
    [3, 4]
])
cn_edges = [
                array([0]),       # 체크 노드 0과 연결된 엣지의 인덱스
                array([1, 2]),    # 체크 노드 1과 연결된 엣지의 인덱스
                array([3, 4]),    # 체크 노드 2와 연결된 엣지의 인덱스
                array([5]),       # 체크 노드 3과 연결된 엣지의 인덱스
                array([])         # 체크 노드 4와 연결된 엣지의 인덱스
]
        """

        for i in range(self._num_cn):
            cn_edges.append(np.where(self._edges[:,0]==i)[0]) # 1번째 체크노드와 연결된 엣지는 모두 cn_edges[1]에 저장되어 있다.
        self._cn_edges = tf.ragged.constant(cn_edges)

        # Create 2D ragged tensor of shape [num_vn,...]
        # vn_edges[i] contains the edge ids for VN i
        vn_edges = []
        for i in range(self._num_vn):
            vn_edges.append(np.where(self._edges[:,1]==i)[0])
        self._vn_edges = tf.ragged.constant(vn_edges)

        self._num_embed_dims = num_embed_dims # Number of dimensions for vertex embeddings
        self._num_msg_dims = num_msg_dims # Number of dimensions for messages
        self._num_hidden_units = num_hidden_units # Number of hidden units for MLPs computing messages and embeddings
        self._num_mlp_layers = num_mlp_layers # Number of layers for MLPs computing messages and embeddings
        self._num_iter = num_iter # Number of BP iterations, can be modified

        self._reduce_op = reduce_op # reduce operation for message aggregation
        self._activation = activation # activation function of the hidden MLP layers

        self._output_all_iter = output_all_iter
        self._clip_llr_to = clip_llr_to

    @property
    def num_iter(self):
        return self._num_iter

    @num_iter.setter
    def num_iter(self, value):
        self._num_iter = value

    def build(self, input_shape):
        print('GNN-BP build')
        # NN to transform input LLR to VN embedding
        self._llr_embed = Dense(self._num_embed_dims) # 출력 뉴런의 수가 _num_embed_dims, 활성화 함수는 Linear

        # NN to transform VN embedding to output LLR
        self._llr_inv_embed = Dense(1) # 출력 뉴런의 수가 1, 인풋으로 받은 임베딩 값을 LLR로 변환하는 뉴런 층, 활성화 함수는 Linear

        # CN embedding update function
        self.update_h_cn = UpdateEmbeddings(self._num_msg_dims,
                                            self._num_hidden_units,
                                            self._num_mlp_layers,
                                            np.flip(self._edges, 1), # Flip columns: "from VN to CN"
                                            self._cn_edges,
                                            self._reduce_op,
                                            self._activation)

        # VN embedding update function
        self.update_h_vn = UpdateEmbeddings(self._num_msg_dims,
                                            self._num_hidden_units,
                                            self._num_mlp_layers,
                                            self._edges, # "from CN to VN"
                                            self._vn_edges,
                                            self._reduce_op,
                                            self._activation)

    def llr_to_embed(self, llr):
        """Transform LLRs to VN embeddings"""
        return self._llr_embed(tf.expand_dims(llr, -1))

    def embed_to_llr(self, h_vn):
        """Transform VN embeddings to LLRs"""
        return tf.squeeze(self._llr_inv_embed(h_vn), axis=-1)

    def call(self, llr):
        print('GNN-BP call')
        batch_size = tf.shape(llr)[0]

        # Initialize vertex embeddings
        if self._clip_llr_to is not None:
            llr = tf.clip_by_value(llr, -self._clip_llr_to, self._clip_llr_to) #

        h_vn = self.llr_to_embed(llr)
        h_cn = tf.zeros([batch_size , self._num_cn, self._num_embed_dims]) #n
        # 처음 변수노드는 입력받은 데이터값을 통해서 llr을 초기화하고, 체크 노드의 경우 llr을 0으로 초기화한다.

        # BP iterations
        if self._output_all_iter:
            llr_hat = []

        for i in range(self._num_iter):
            print(i, '번째 업데이트 과정 중')
            # Update CN embeddings
            h_cn = self.update_h_cn(h_vn, h_cn) # h_new_cn
            print(i,'번째 체크노드 업데이트')
            # Update VNs
            h_vn = self.update_h_vn(h_cn, h_vn) # h_new_vn
            print(i,'번째 변수노드 업데이트')
            if self._output_all_iter:
                llr_hat.append(self.embed_to_llr(h_vn))

        if not self._output_all_iter:
            llr_hat = self.embed_to_llr(h_vn)
        print('llr_hat ', llr_hat)
        # 체크노드 먼저 업데이트 후에 변수노드 업데이트
        return llr_hat

class UpdateEmbeddings(Layer): # layer 상속 받아서 해당 업데이트 임베딩 함수를 정의
    """Update vertex embeddings of the GNN BP decoder.

    This layer computes first the messages that are sent across the edges
    of the graph, then sums the incoming messages at each vertex, finally and
    updates their embeddings.

    Parameters
    ----------
    num_msg_dims: int
        Number of dimensions of a message. 메시지 차원이 언제 사용되는지

    num_hidden_units: int
        Number of hidden units of MLPs used to compute
        messages and to update the vertex embeddings.

    num_mlp_layers: int
        Number of layers of the MLPs used to compute
        messages and to update the vertex embeddings.

    from_to_ind: [num_egdes, 2], np.array
        Two dimensional array containing in each row the indices of the
        originating and receiving vertex for an edge.

    gather_ind: [`num_vn` or `num_cn`, None], tf.ragged.constant
        Ragged tensor that contains for each receiving vertex the list of
        edge indices from which to aggregate the incoming messages. As each
        vertex can have a different degree, a ragged tensor is used.

    reduce_op: str
        A string defining the vertex aggregation function.
        Currently, "mean" and "sum" is supported.

    activation: str
        A string defining the activation function of the hidden MLP layers to
        be used. Defaults to "relu".

    Input
    -----
    h_from : [batch_size, num_cn or num_vn, num_embed_dims], tf.float32
        Tensor containing the embeddings of the "transmitting" vertices.

    h_to : [batch_size, num_vn or num_cn, num_embed_dims], tf.float32
        Tensor containing the embeddings of the "receiving" vertices.

    Output
    ------
    h_to_new : Same shape and type as `h_to`
        Tensor containing the updated embeddings of the "receiving" vertices.
    """
    def __init__(self,
                 num_msg_dims,
                 num_hidden_units,
                 num_mlp_layers,
                 from_to_ind,
                 gather_ind,
                 reduce_op="sum",
                 activation="relu",
                 ):
        super().__init__()
        self._num_msg_dims = num_msg_dims
        self._num_hidden_units = num_hidden_units
        self._num_mlp_layers = num_mlp_layers
        self._from_ind = from_to_ind[:,0]
        self._to_ind = from_to_ind[:,1]
        self._gather_ind = gather_ind
        self._reduce_op = reduce_op
        self._activation = activation

        # gnn_decoder = GNN_BP(pcm=pcm,
        #             num_embed_dims=20,
        #             num_msg_dims=20,
        #             num_hidden_units=40,
        #             num_mlp_layers=2,
        #             num_iter=8,
        #             reduce_op="mean",
        #             activation="tanh",
        #             output_all_iter=True,
        #             clip_llr_to=None)

    def build(self, input_shape):
        print('updated bulid')
        num_embed_dims = input_shape[-1] # 마지막 층 개수

        # MLP to compute messages
        units = [self._num_hidden_units]*(self._num_mlp_layers-1) + [self._num_msg_dims] # ex. [40]*1 + [20] = [40, 20]
        activations = [self._activation]*(self._num_mlp_layers-1) + [None]  # ex. [tanh]*1 + [None] = ['tanh', 'None']
        use_bias = [True]*self._num_mlp_layers # 사용되는 편향 값 layer마다 한 개 씩 / ex. ['True', 'True']
        self._msg_mlp = MLP(units, activations, use_bias)
        # MLP([40, 20], ['tanh', 'None'], ['True', 'True']) 인자로 들어감
        # 그러면 층이 두개가 쌓이게 된다.
        # Dense(40, 'tanh', 'None')
        # Dense(20, 'tanh', 'None')

        print('msg_mlp')

        # GNN-BP build
        #  GNN-BP call
        #  updated bulid
        #  msg_mlp
        #  embed_mlp
        #  update call
        #  MLP build
        #  0 번째 layer 쌓기 :  40
        # 1 번째 layer 쌓기 :  20

        # MLP to update embeddings from accumulated messages
        units[-1] = num_embed_dims # units 젤 마지막
        self._embed_mlp = MLP(units, activations, use_bias)

        # 따로 따로 msg_mlp, embed_mlp 생성
        print('embed_mlp')

    def call(self, h_from, h_to):
        print('update call')
        # Concatenate embeddings of the transmitting (from) and receiving (to) vertex for each edge
        features = tf.concat([tf.gather(h_from, self._from_ind, axis=1),
                              tf.gather(h_to, self._to_ind, axis=1)],
                              axis=-1)

        # tf.concat([tf.gather(h_vn, self._from_ind, axis=1),
        #            tf.gather(h_cn, self._to_ind, axis=1)],
        #            axis=-1)

        # Compute messsages for all edges
        print('msg_mlp 실행')
        messages = self._msg_mlp(features) # (feature 432, 40)

        # Reduce messages at each receiving (to) vertex
        # note: bring batch dim to last dim for improved performance
        # with ragged tensors
        messages = tf.transpose(messages, (1,2,0)) #
        print('messages : ', messages)
        m_ragged = tf.gather(messages, self._gather_ind, axis=0) # 여기서 income information 합쳐준다. v1= [1, 2, 3, 4, 5] -> tf.gather(v1, [2, 4], axis=0 ) = [3, 5]
        print('m_ragged: ', m_ragged)
        if self._reduce_op=="sum":
            m = tf.reduce_sum(m_ragged, axis=1)
        elif self._reduce_op=="mean":
            m = tf.reduce_mean(m_ragged, axis=1)
        else:
            raise ValueError("unknown reduce operation")

        m = tf.transpose(m, (2,0,1)) # batch-dim back to first dim
        print('m: ', m)
        # Compute new embeddings
        print('embed_mlp 실행')
        h_to_new = self._embed_mlp(tf.concat([m, h_to], axis=-1)) # concat을 해서 합쳐준다.


        return h_to_new

## Training

We now define our custom training loop. Please note that the training may take several hours and should be only executed if enough computational resources are available.

In [11]:
def train_model(model, params):
    loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    for p in params:
        train_batch_size, lr, train_iter = p
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

            # Define a pruning schedule (you can customize the parameters)

        @tf.function()
        def train_step():
            ebno_db = tf.random.uniform([train_batch_size, 1], minval=ebno_db_min, maxval=ebno_db_max) # batch 별로 신호대잡음비 값을 랜덤으로 생성한다.
            with tf.GradientTape() as tape:
                c, llr_hat = model(train_batch_size, ebno_db)
                print('c, llr_hat 계산')
                loss_value = 0
                for m, l in enumerate(llr_hat):
                    loss_value += loss(c, l)

            weights = model.trainable_weights
            print("weights : ", weights)
            grads = tape.gradient(loss_value, weights)
            optimizer.apply_gradients(zip(grads, weights))
            return c, llr_hat

        for i in range(train_iter):
            c, llr_hat = train_step()
            if i%1000==0:
                ebno_db = tf.random.uniform([10000, 1],
                                            minval=ebno_db_min,
                                            maxval=ebno_db_max)
                c, llr_hat = model(10000, ebno_db)
                loss_value = 0
                for l in llr_hat:
                    loss_value += loss(c, l)
                c_hat = tf.cast(tf.greater(llr_hat[-1], 0), tf.float32)
                ber = compute_ber(c, c_hat).numpy()
                print(f"Iteration {i}, loss = {loss_value.numpy():.3f}, " \
                      f"ber = {ber:.5f}")

In [12]:
tf.random.set_seed(2)

gnn_decoder = GNN_BP(pcm=pcm,
                     num_embed_dims=20,
                     num_msg_dims=20,
                     num_hidden_units=40,
                     num_mlp_layers=2,
                     num_iter=8,
                     reduce_op="mean",
                     activation="tanh",
                     output_all_iter=True,
                     clip_llr_to=None)
e2e_gnn = E2EModel(pcm, gnn_decoder) # decoder 설정 후에 E2EModel 생성

num_edges : 432
self._n 출력 :  63
self._k :  45
self._mapper :  <sionna.mapping.Mapper object at 0x7e5ebaf07730>


In [13]:
e2e_gnn(1, 1.) # call model once to init
e2e_gnn.summary()

save_at = "/kaggle/working/model.hdf5"

GNN-BP build
GNN-BP call
0 번째 업데이트 과정 중
updated bulid
msg_mlp
embed_mlp
update call
msg_mlp 실행
MLP build
0 번째 layer 쌓기 :  40
1 번째 layer 쌓기 :  20
MLP call
MLP inputs : Tensor("gnn_bp/update_embeddings/concat:0", shape=(None, 432, 40), dtype=float32)
계산중
계산중
MLP outpus : Tensor("gnn_bp/update_embeddings/mlp/dense_3/BiasAdd:0", shape=(None, 432, 20), dtype=float32)
messages :  Tensor("gnn_bp/update_embeddings/transpose:0", shape=(432, 20, None), dtype=float32)
m_ragged:  tf.RaggedTensor(values=Tensor("gnn_bp/update_embeddings/RaggedGather/GatherV2:0", shape=(432, 20, None), dtype=float32), row_splits=tf.Tensor(
[  0  24  48  72  96 120 144 168 192 216 240 264 288 312 336 360 384 408
 432], shape=(19,), dtype=int64))
m:  Tensor("gnn_bp/update_embeddings/transpose_1:0", shape=(None, 18, 20), dtype=float32)
embed_mlp 실행
MLP build
0 번째 layer 쌓기 :  40
1 번째 layer 쌓기 :  20
MLP call
MLP inputs : Tensor("gnn_bp/update_embeddings/concat_1:0", shape=(None, 18, 40), dtype=float32)
계산중
계산중
MLP outpus 

And let's train the GNN.

In [None]:
train_params = [
    #batch_size, learning_rate, num_iter
    [256, 1e-3, 1000],
    [256, 1e-4, 1000],
    [256, 1e-5, 1000],
]
e2e_gnn._decoder._output_all_iter = True # use multi-loss during training
train_model(e2e_gnn, train_params)

GNN-BP call
0 번째 업데이트 과정 중
update call
msg_mlp 실행
MLP call
MLP inputs : Tensor("gnn_bp/update_embeddings/concat:0", shape=(256, 432, 40), dtype=float32)
계산중
계산중
MLP outpus : Tensor("gnn_bp/update_embeddings/mlp/dense_3/BiasAdd:0", shape=(256, 432, 20), dtype=float32)
messages :  Tensor("gnn_bp/update_embeddings/transpose:0", shape=(432, 20, 256), dtype=float32)
m_ragged:  tf.RaggedTensor(values=Tensor("gnn_bp/update_embeddings/RaggedGather/GatherV2:0", shape=(432, 20, 256), dtype=float32), row_splits=tf.Tensor(
[  0  24  48  72  96 120 144 168 192 216 240 264 288 312 336 360 384 408
 432], shape=(19,), dtype=int64))
m:  Tensor("gnn_bp/update_embeddings/transpose_1:0", shape=(256, 18, 20), dtype=float32)
embed_mlp 실행
MLP call
MLP inputs : Tensor("gnn_bp/update_embeddings/concat_1:0", shape=(256, 18, 40), dtype=float32)
계산중
계산중
MLP outpus : Tensor("gnn_bp/update_embeddings/mlp_1/dense_5/BiasAdd:0", shape=(256, 18, 20), dtype=float32)
0 번째 체크노드 업데이트
update call
msg_mlp 실행
MLP call
MLP inp



c, llr_hat 계산
weights :  [<tf.Variable 'gnn_bp/dense/kernel:0' shape=(1, 20) dtype=float32>, <tf.Variable 'gnn_bp/dense/bias:0' shape=(20,) dtype=float32>, <tf.Variable 'gnn_bp/dense_1/kernel:0' shape=(20, 1) dtype=float32>, <tf.Variable 'gnn_bp/dense_1/bias:0' shape=(1,) dtype=float32>, <tf.Variable 'gnn_bp/update_embeddings/mlp/dense_2/kernel:0' shape=(40, 40) dtype=float32>, <tf.Variable 'gnn_bp/update_embeddings/mlp/dense_2/bias:0' shape=(40,) dtype=float32>, <tf.Variable 'gnn_bp/update_embeddings/mlp/dense_3/kernel:0' shape=(40, 20) dtype=float32>, <tf.Variable 'gnn_bp/update_embeddings/mlp/dense_3/bias:0' shape=(20,) dtype=float32>, <tf.Variable 'gnn_bp/update_embeddings/mlp_1/dense_4/kernel:0' shape=(40, 40) dtype=float32>, <tf.Variable 'gnn_bp/update_embeddings/mlp_1/dense_4/bias:0' shape=(40,) dtype=float32>, <tf.Variable 'gnn_bp/update_embeddings/mlp_1/dense_5/kernel:0' shape=(40, 20) dtype=float32>, <tf.Variable 'gnn_bp/update_embeddings/mlp_1/dense_5/bias:0' shape=(20,) dty

In [None]:
# pruning functio
import pandas as pd
from tensorflow.python.ops.script_ops import numpy_function
import random

# pruning percentages
K = [25]

print(e2e_gnn._decoder.update_h_vn._embed_mlp)

print(e2e_gnn._decoder.update_h_vn._embed_mlp._layers[0].get_weights)

msg_mlp = e2e_gnn._decoder.update_h_vn._msg_mlp._layers
embed_mlp = e2e_gnn._decoder.update_h_vn._embed_mlp._layers

num_msg_mlp_layers = len(msg_mlp)
num_embed_mlp_layers = len(embed_mlp)

# msg_mlp purning

msg_all_weights={}

for layer_no in range(num_msg_mlp_layers-1):
    msg_layer_weights = (pd.DataFrame(msg_mlp[layer_no].get_weights()[0]).stack()).to_dict()
    msg_layer_weights = {(layer_no, k[0], k[1]): v for k, v in msg_layer_weights.items()}
    msg_all_weights.update(msg_layer_weights)

msg_all_weights_sorted = {k: v for k, v in sorted(msg_all_weights.items(), key=lambda item: abs(item[1]))}

print(msg_all_weights_sorted)
total_no_weights = len(msg_all_weights_sorted)
print(total_no_weights)

for pruning_percent in K:
    for layer_no in range(num_msg_mlp_layers - 1):

        msg_new_weights = msg_mlp[layer_no].get_weights() # msg_mlp_weights 가져오기 / pruning하기 위한 준비

# 가장 영향을 적게 주는 가중치들을 제거하는 방식
"""
        prune_fraction = pruning_percent/100 # K에서 지정한 비율만큼 pruning
        number_of_weights_to_be_pruned = int(prune_fraction*total_no_weights)
        msg_weights_to_be_pruned = {k: msg_all_weights_sorted[k] for k in list(msg_all_weights_sorted)[: number_of_weights_to_be_pruned]}


        for k, v in msg_weights_to_be_pruned.items():
          msg_new_weights[k[0]][k[1], k[2]]=0 # msg_weights_to_be_pruned에 저장된 인덱스의 가중치를 0으로 만들어 준다.
"""

        prune_fraction = pruning_percent/100 # K에서 지정한 비율만큼 pruning
        number_of_weights_to_be_pruned = int(prune_fraction*total_no_weights)

        random_indices = np.random.choice(total_no_weights, size=number_of_weights_to_be_pruned, replace=False)
        msg_weights_to_be_pruned = [list(msg_all_weights_sorted.keys())[i] for i in random_indices]

        for k, v in msg_weights_to_be_pruned.items():
          msg_new_weights[k[0]][k[1], k[2]]=0 # msg_weights_to_be_pruned에 저장된 인덱스의 가중치를 0으로 만들어 준다.

""" 나영
        prune_fraction = pruning_percent / 100
        number_of_weights_to_be_pruned = int(prune_fraction * total_no_weights)
        weights_to_prune_indices = random.sample(range(total_no_weights), number_of_weights_to_be_pruned)

        for index in weights_to_prune_indices:
            k = list(msg_all_weights_sorted.keys())[index]
            msg_new_weights[k[0]][k[1], k[2]] = 0
        print(msg_new_weights)
"""
        msg_mlp[layer_no].set_weights(msg_new_weights)


In [None]:
train_model(e2e_gnn, train_params)

## Final Perfomance

We now evaluate the final BER performance of the trained GNN decoder and compare it with our baseline.

In [None]:
e2e_gnn._decoder._output_all_iter = False # deactivate multi-loss for inference

ber_plot.simulate(e2e_gnn, # ebno_dbs_1,
                  ebno_dbs=ebno_dbs,
                  batch_size=mc_batch_size,
                  num_target_block_errors=num_target_block_errors,
                  legend="GNN {} iter.".format(gnn_decoder._num_iter),
                  soft_estimates=True,
                  max_mc_iter=mc_iters,
                  forward_keyboard_interrupt=False,
                  show_fig=True);

In [None]:
# and show the final performance / 해당 ber 그래프를 그리기 위한 구문
ber_plot(xlim=[2., 9.], ylim=[5e-6, 0.1])

As can be seen, the GNN-based decoder outperforms BP and WBP with significantly less iterations.

Please note that due to the limit the training time, the GNN is not perfectly trained for the BCH code.
You can find a more sophisticated training in the same provided repository.

A few ideas to explore:
- Can you train the decoder for LDPC or Polar codes?
- What is the optimal NN architecture?
- Conceptually, non-binary decoders could be also learned within the same framework
- Can we get insights from the learned solution for improved *classical* decoding algorithms?