In [None]:
# general imports
import tensorflow as tf
import numpy as np
import os
%matplotlib inline
import matplotlib.pyplot as plt

# --- Sionna imports ---
from sionna.phy.utils.plotting import PlotBER
from sionna.phy.fec.linear import LinearEncoder
from sionna.phy.fec.ldpc import LDPCBPDecoder, LDPC5GEncoder, LDPC5GDecoder

# --- Project-specific imports ---
%load_ext autoreload
%autoreload 2
from gnn import GNN_BP
from e2e_model import E2EModel

print(f"All modules loaded successfully!")

In [None]:
gpus = tf.config.list_physical_devices('GPU')
print('Number of GPUs available :', len(gpus))
if gpus:
    try:
        gpu_num = 0 # Use GPU 0
        tf.config.set_visible_devices(gpus[gpu_num], 'GPU')
        tf.config.experimental.set_memory_growth(gpus[gpu_num], True)
        print('Using GPU number', gpu_num)
    except RuntimeError as e:
        print(e)

In [None]:
#----- LDPC 5G -----
params={
    # --- Code Parameters ---
        "code": "5G-LDPC",
        "n": 140,
        "k": 60,
    # --- GNN Architecture ----
        "num_embed_dims": 16,
        "num_msg_dims": 16,
        "num_hidden_units": 48,
        "num_mlp_layers": 3,
        "num_iter": 10,
        "reduce_op": "sum",
        "activation": "relu",
        "clip_llr_to": 20,
        "use_attributes": False,
        "node_attribute_dims": 0,
        "msg_attribute_dims": 0,
        "return_infobits": False,
        "use_bias": True,
    # --- Training ---- #
        "batch_size": 128,
        "train_iter": 35000,
        "learning_rate": 5e-4,
        "ebno_db_train": [2, 8.],
        "ebno_db_eval": 2.,
        "batch_size_eval": 1000, # batch size only used for evaluation during training
        "eval_train_steps": 1000, # evaluate model every N iters
    # --- Log ----
        "save_weights_iter": 10000, # save weights every X iters
        "run_name": "LDPC_5G_01", # name of the stored weights/logs
        "save_dir": "results/", # folder to store results
}


In [None]:
k = params["k"]
n = params["n"]
encoder_5g = LDPC5GEncoder(k, n)
decoder_5g = LDPC5GDecoder(encoder_5g, prune_pcm=True)
pcm = decoder_5g.pcm.toarray()

In [None]:
tf.random.set_seed(2) # we fix the seed to ensure stable convergence

gnn_decoder = GNN_BP(pcm=pcm,
                     num_embed_dims=params["num_embed_dims"],
                     num_msg_dims=params["num_msg_dims"],
                     num_hidden_units=params["num_hidden_units"],
                     num_mlp_layers=params["num_mlp_layers"],
                     num_iter=params["num_iter"],
                     reduce_op=params["reduce_op"],
                     activation=params["activation"],
                     output_all_iter=True,
                     clip_llr_to=params["clip_llr_to"],
                     use_attributes=params["use_attributes"],
                     node_attribute_dims=params["node_attribute_dims"],
                     msg_attribute_dims=params["msg_attribute_dims"],
                     use_bias=params["use_bias"])

e2e_model = E2EModel(encoder=encoder_5g, decoder=gnn_decoder, k=k, n=n, fading=True)

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=params["learning_rate"])
checkpoint_dir = os.path.join(params["save_dir"], params["run_name"], "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

ckpt = tf.train.Checkpoint(model=gnn_decoder, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5)

print("Starting training...")

@tf.function(jit_compile=True)
def train_step(batch_size, ebno_db_range):
    ebno_db = tf.random.uniform(shape=[batch_size], minval=ebno_db_range[0], maxval=ebno_db_range[1])
    with tf.GradientTape() as tape:
        _, _, loss = e2e_model(batch_size, ebno_db, training=True)
    grads = tape.gradient(loss, e2e_model.decoder.trainable_weights)
    optimizer.apply_gradients(zip(grads, e2e_model.decoder.trainable_weights))
    return loss

for i in range(1, params["train_iter"] + 1):
    loss = train_step(params["batch_size"], params["ebno_db_train"])
    if i % 100 == 0:
        print(f"Iter: {i}, Loss: {loss:.4f}")