In [None]:
import numpy as np
from optic.models.devices import mzm, photodiode
from optic.models.channels import linearFiberChannel
from optic.comm.sources import bitSource
from optic.comm.modulation import modulateGray
from optic.comm.metrics import bert
from optic.dsp.core import firFilter, pulseShape, upsample, pnorm, anorm
from optic.utils import parameters, dBm2W
from scipy.special import erfc

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, Model



NO DPD

In [None]:
# simulation parameters
SpS = 16  # samples per symbol
M = 2  # order of the modulation format
Rs = 10e9  # Symbol rate
Fs = SpS * Rs  # Signal sampling frequency (samples/second)
Pi_dBm = 3  # laser optical power at the input of the MZM in dBm
Pi = dBm2W(Pi_dBm)  # convert from dBm to W

# Bit source parameters
paramBits = parameters()
paramBits.nBits = 100000  # number of bits to be generated
paramBits.mode = 'random' # mode of the bit source 
paramBits.seed = 123      # seed for the random number generator

# pulse shaping parameters
paramPulse = parameters()
paramPulse.pulseType = 'nrz'  # pulse shape type
paramPulse.SpS = SpS     # samples per symbol  

# MZM parameters
paramMZM = parameters()
paramMZM.Vpi = 2
paramMZM.Vb = -paramMZM.Vpi / 2

# linear fiber optical channel parameters
paramCh = parameters()
paramCh.L = 100        # total link distance [km]
paramCh.alpha = 0.2    # fiber loss parameter [dB/km]
paramCh.D = 16         # fiber dispersion parameter [ps/nm/km]
paramCh.Fc = 193.1e12  # central optical frequency [Hz]
paramCh.Fs = Fs

# photodiode parameters
paramPD = parameters()
paramPD.ideal = False
paramPD.B = Rs
paramPD.Fs = Fs
paramPD.seed = 456  # seed for the random number generator



## Simulation
print("\nStarting simulation...", end="")

# generate pseudo-random bit sequence
bitsTx = bitSource(paramBits)

# generate 2-PAM modulated symbol sequence
symbTx = modulateGray(bitsTx, M, "pam")

# upsampling
symbolsUp = upsample(symbTx, SpS)

# pulse shaping
pulse = pulseShape(paramPulse)
sigTx = firFilter(pulse, symbolsUp)
sigTx = anorm(sigTx) # normalize to 1 Vpp

# optical modulation
Ai = np.sqrt(Pi)  # ideal cw laser constant envelope
sigTxo = mzm(Ai, sigTx, paramMZM)

# linear fiber channel model
sigCh = linearFiberChannel(sigTxo, paramCh)

# noisy PD (thermal noise + shot noise + bandwidth limit)
I_Rx = photodiode(sigCh, paramPD)

# capture samples in the middle of signaling intervals
I_Rx = I_Rx[0::SpS]



# calculate the BER and Q-factor
BER, Q = bert(I_Rx, bitsTx)

print("\nTransmission performance metrics:")
print(f"Q-factor = {Q:.2f} ")
print(f"BER = {BER:.2e}")


you must be also looking into Q-factor values not just BER. BER on itself isnt enough.
UPDATE THE PAPER USES SNR INSTEAD

## Model Arch & Training

## Imtiaz et al 2025

### architecture logic/inspiration : WH paper .... read DPD_2 to see how to phrase this in the report .... no memeory added here for FFNN

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, initializers
from optic.models.devices import mzm, photodiode
from optic.models.channels import linearFiberChannel
from optic.comm.sources import bitSource
from optic.comm.modulation import modulateGray
from optic.comm.metrics import bert
from optic.dsp.core import firFilter, pulseShape, upsample, anorm
from optic.utils import parameters, dBm2W


def build_model():
    inputs = layers.Input(shape=(None, 1))
    
    sec_a = layers.Conv1D(1, 301, padding='same')(inputs)

    x = layers.Dense(20, activation=layers.LeakyReLU(0.1))(sec_a)
    x = layers.Dense(20, activation=layers.LeakyReLU(0.1))(x)
    x = layers.Dense(1, activation=layers.LeakyReLU(0.1))(x)
    nonlinear_out = layers.Dense(1, activation='linear')(x)
    
    sec_b = layers.Add()([sec_a, nonlinear_out])

    outputs = layers.Conv1D(1, 101, padding='same')(sec_b)


    return Model(inputs, outputs)

# --- 2. SIMULATION PARAMETERS ---

# simulation parameters
SpS = 16  # samples per symbol
M = 2  # order of the modulation format
Rs = 10e9  # Symbol rate
Fs = SpS * Rs  # Signal sampling frequency (samples/second)
Pi_dBm = 3  # laser optical power at the input of the MZM in dBm
Pi = dBm2W(Pi_dBm)  # convert from dBm to W

# Bit source parameters
paramBits = parameters()
paramBits.nBits = 2**18  # number of bits to be generated
paramBits.mode = 'random' # mode of the bit source 
paramBits.seed = 123      # seed for the random number generator

# pulse shaping parameters
paramPulse = parameters()
paramPulse.pulseType = 'nrz'  # pulse shape type
paramPulse.SpS = SpS     # samples per symbol  

# MZM parameters
paramMZM = parameters()
paramMZM.Vpi = 2
paramMZM.Vb = -paramMZM.Vpi / 2

# linear fiber optical channel parameters
paramCh = parameters()
paramCh.L = 100        # total link distance [km]
paramCh.alpha = 0.2    # fiber loss parameter [dB/km]
paramCh.D = 16         # fiber dispersion parameter [ps/nm/km]
paramCh.Fc = 193.1e12  # central optical frequency [Hz]
paramCh.Fs = Fs

# photodiode parameters
paramPD = parameters()
paramPD.ideal = False
paramPD.B = Rs
paramPD.Fs = Fs
paramPD.seed = 456  # seed for the random number generator

# --- 3. DLA INITIALIZATION ---

dpd_model = build_model()
aux_model = build_model()

# Cascade: x -> DPD -> Aux -> y_est
inputs_x = layers.Input(shape=(None, 1))
z_pred = dpd_model(inputs_x)
y_est = aux_model(z_pred)
dla_cascade = Model(inputs_x, y_est)

# Optimizers from Table II [cite: 294]
opt_dpd = tf.keras.optimizers.Adam(learning_rate=1e-3)
opt_aux = tf.keras.optimizers.Adam(learning_rate=1e-3)

aux_model.compile(optimizer=opt_aux, loss='mse')
dla_cascade.compile(optimizer=opt_dpd, loss='mse')

for iteration in range(50): # DLA typically converges in 9 iterations [cite: 371]
    
    # Step A: Data Generation
    bitsTx = bitSource(paramBits)
    symbTx = modulateGray(bitsTx, M, "pam")
    x_input = symbTx.reshape(-1, 8192, 1) # 8192 symbols/seq as per the paper.

    # Step B: Apply current DPD and Run Simulation
    z_dpd = dpd_model.predict(x_input, verbose=0)
    z_signal = z_dpd.flatten()


    # Optic-Py Chain
    symbolsUp = upsample(z_signal, SpS)
    sigTx = firFilter(pulseShape(paramPulse), symbolsUp)
    sigTx = anorm(sigTx) # Normalize to 1 Vpp
    sigTxo = mzm(np.sqrt(Pi), sigTx, paramMZM)
    sigCh = linearFiberChannel(sigTxo, paramCh)
    I_Rx = photodiode(sigCh, paramPD)[0::SpS]

    # Normalize received signal for training [cite: 222]
    y_received = (I_Rx - np.mean(I_Rx)) / np.std(I_Rx)
    y_received = y_received.reshape(-1, 8192, 1)

    # --- DLA STEP 1: Train Auxiliary Channel (S) ---
    # Goal: S(z) ≈ y [cite: 134]
    aux_model.fit(z_dpd, y_received, epochs=30, verbose=0, batch_size=32)

    # --- DLA STEP 2: Train DPD (G) ---
    # Goal: S(G(x)) ≈ x [cite: 137]
    aux_model.trainable = False
    dla_cascade.fit(x_input, x_input, epochs=30, verbose=0, batch_size=32)
    #dla_cascade.fit(dla_cascade.predict(x_input, verbose=0), z_dpd, epochs=10, verbose=0, batch_size=4096) # bad perf, exp failed, revise theory ... hypo was that this hould give same result as above
    aux_model.trainable = True

    # Performance Monitoring
    if iteration==0:
        print("Iteration | Q-factor  | BER")
    BER, Q = bert(I_Rx, bitsTx)
    print(f"{iteration+1:<5} | {Q:<10.2f} | {BER:<10.2e}")

print("\nTraining Complete.")


In [None]:
def perf_sim(DPD_FLAG=False, random_seed=123):
    # simulation parameters
    SpS = 16  # samples per symbol
    M = 2  # order of the modulation format
    Rs = 10e9  # Symbol rate
    Fs = SpS * Rs  # Signal sampling frequency (samples/second)
    Pi_dBm = 3  # laser optical power at the input of the MZM in dBm
    Pi = dBm2W(Pi_dBm)  # convert from dBm to W

    # Bit source parameters
    paramBits = parameters()
    paramBits.nBits = 2**18  # number of bits to be generated
    paramBits.mode = 'random' # mode of the bit source 
    paramBits.seed = random_seed      # seed for the random number generator

    # pulse shaping parameters
    paramPulse = parameters()
    paramPulse.pulseType = 'nrz'  # pulse shape type
    paramPulse.SpS = SpS     # samples per symbol  

    # MZM parameters
    paramMZM = parameters()
    paramMZM.Vpi = 2
    paramMZM.Vb = -paramMZM.Vpi / 2

    # linear fiber optical channel parameters
    paramCh = parameters()
    paramCh.L = 100        # total link distance [km]
    paramCh.alpha = 0.2    # fiber loss parameter [dB/km]
    paramCh.D = 16         # fiber dispersion parameter [ps/nm/km]
    paramCh.Fc = 193.1e12  # central optical frequency [Hz]
    paramCh.Fs = Fs

    # photodiode parameters
    paramPD = parameters()
    paramPD.ideal = False
    paramPD.B = Rs
    paramPD.Fs = Fs
    paramPD.seed = 456  # seed for the random number generator



    ## Simulation
    print("\nStarting simulation...", end="")

    # generate pseudo-random bit sequence
    bitsTx = bitSource(paramBits)

    # generate 2-PAM modulated symbol sequence
    symbTx = modulateGray(bitsTx, M, "pam")

    #DPD
    x_input = symbTx.reshape(-1, 8192, 1) # 8192 symbols/seq as per the paper.
    z_dpd = dpd_model.predict(x_input, verbose=0)
    z_signal = z_dpd.flatten()



    # upsampling
    if DPD_FLAG:
        symbolsUp = upsample(z_signal, SpS)
    else:
        symbolsUp = upsample(symbTx, SpS)


    # pulse shaping
    pulse = pulseShape(paramPulse)
    sigTx = firFilter(pulse, symbolsUp)
    sigTx = anorm(sigTx) # normalize to 1 Vpp

    # optical modulation
    Ai = np.sqrt(Pi)  # ideal cw laser constant envelope
    sigTxo = mzm(Ai, sigTx, paramMZM)

    # linear fiber channel model
    sigCh = linearFiberChannel(sigTxo, paramCh)

    # noisy PD (thermal noise + shot noise + bandwidth limit)
    I_Rx = photodiode(sigCh, paramPD)

    # capture samples in the middle of signaling intervals
    I_Rx = I_Rx[0::SpS]



    # calculate the BER and Q-factor
    BER, Q = bert(I_Rx, bitsTx)

    print("\nTransmission performance metrics:")
    print(f"Q-factor = {Q:.2f} ")
    print(f"BER = {BER:.2e}")


perf_sim(DPD_FLAG=True, random_seed=345)