In [1]:
import numpy as np
from optic.models.devices import mzm, photodiode
from optic.models.channels import linearFiberChannel
from optic.comm.sources import bitSource
from optic.comm.modulation import modulateGray
from optic.comm.metrics import bert
from optic.dsp.core import firFilter, pulseShape, upsample, pnorm, anorm
from optic.utils import parameters, dBm2W
from scipy.special import erfc

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, Model



In [2]:
import tensorflow as tf
from tensorflow.keras import layers, models, initializers


def build_dpd_model():
    # should i change the first dim to None as per gemeni did? - cuz i feel like the next layer would not slide accross so id need to do the windowing manually as a preproc step
    # id say for now since it works/makes sense dont try to fix it, do the preprocessing manually and dont assume the below filter slides accross automatically.
    # update: apparently it does slide through, you just change the batch size (inference) to 1 not N/101, for now what you have just makes sense so play around with that later.
    # the "1" dimension is for features, it can be 2 for say an I/Q signal - but apparently here they made two seperate nets for I and Q so ud still use 1
    # I think the reason why batch size is mandatory to have in CNNs is cuz usually you'd pass an infintely long signal (or too long) unlike a typical dataset.
    # so almost always you'd wanna apply batching to reduce memory footprint.
    # but that's different from the "timestep" element which is the first dimension here (the 101 i chose, but can be anything .. maybe even 500 - play around w/ it.)
    # i mean since ill be applying windowing manually so i should get the same ooutput regardless.
    # lets now stick to what i understand - signal of length N -> reshape to N,1 -> apply a sliding window so it's (N-101, 101, 1) -> pass to the model
    # the thing to try for later is .... set the input shape to (None,1), and pass the input as (1, N, 1) and get your CNN to slide accross automatically for you
    # both should yield the same result - but my QS is why would you need to do batching in the first place and why is it not necessary to do for e.g. in regular NNs?
    # thats just purely an API design choice - nothing too crazy here.

    inputs = layers.Input(shape=(None,1))


    # QS here, why is your filer 3dimensional for a 1D operation?
    # 1D or 2D in CNNs refer to the sliding dimension, in 1D -> it's a single one way, in 2D, it slides in the X and Y directions
    # but that doesn't mean that your input array cant be multi-dimensional, in that case your filter would need to have a shape to basically fit on it.

    # so if your X input is (T,2), your filter would be F,2 as well, so there's weight parameters in the second dimension as well.
    # now what about the third dimension (e.g. here it's 101,1,1) - that's your filter count, sometimes you may need to capture the corellation to multiple features at once
    # so you'd use multiple filters for that.
    kernel_init_A = np.zeros((101, 1, 1)) 
    kernel_init_A[50, 0, 0] = 1.0
    sec_a = layers.Conv1D(filters=1, kernel_size=101, padding='same',
                            kernel_initializer=initializers.Constant(kernel_init_A))(inputs) # note the choice of padding matters here, 'same' adds padding so out dim is (101,1)



    kernel_init_B = np.zeros((11, 1, 21))
    kernel_init_B[5, 0, :] = 1.0  # Set the middle index (5) to 1.0 for all 21 filters
    b_conv = layers.Conv1D(filters=21, kernel_size=11, padding='same', kernel_initializer=initializers.Constant(kernel_init_B))(sec_a)
    x = layers.Dense(12, activation=layers.LeakyReLU(negative_slope=0.1))(b_conv)
    x = layers.Dense(8, activation=layers.LeakyReLU(negative_slope=0.1))(x)
    x = layers.Dense(8, activation=layers.LeakyReLU(negative_slope=0.1))(x)
    nonlinear_out = layers.Dense(1, activation='linear')(x) # Final sum to 1 neuron

    sec_b = layers.Add()([sec_a, nonlinear_out])

    # kernel_init_C = np.zeros((301, 1, 1)) 
    # kernel_init_C[150, 0, 0] = 1.0
    # section_c = layers.Conv1D(filters=1, kernel_size=301, padding='same', 
    #                           kernel_initializer=initializers.Constant(kernel_init_C))(sec_b)



    outputs = sec_b

    return models.Model(inputs, outputs)

In [3]:
import numpy as np

def create_sliding_windows(data, window_size):
    """
    Converts a 1D array into a 3D windowed dataset with the same output length.
    """
    data = np.asarray(data)
    
    # Pad the beginning of the data with zeros 
    # (window_size - 1) pads ensures the first window contains the first element
    padding_size = window_size - 1
    padded_data = np.pad(data, (padding_size, 0), mode='constant', constant_values=0)
    
    # Now the number of windows will equal len(data)
    num_windows = len(padded_data) - window_size + 1
    
    # Efficient window creation
    windows = [padded_data[i : i + window_size] for i in range(num_windows)]
    
    # Convert to (Samples, Window_Size, Features)
    X = np.array(windows)
    return X[..., np.newaxis]



In [4]:
import scipy.signal as sig

def align_signals(tx, rx):
    """
    Finds the time delay between tx and rx using cross-correlation,
    then truncates both arrays so they are perfectly aligned in time.
    """
    # Use FFT-based correlation for speed on large arrays
    corr = sig.correlate(rx, tx, mode='full', method='fft')
    
    # Calculate the delay (shift)
    delay = np.argmax(np.abs(corr)) - (len(tx) - 1)
    
    if delay > 0:
        # Rx is delayed relative to Tx
        rx_aligned = rx[delay:]
        tx_aligned = tx[:-delay]
    elif delay < 0:
        # Tx is delayed relative to Rx (rare in physical systems, but possible in DSP)
        rx_aligned = rx[:delay]
        tx_aligned = tx[-delay:]
    else:
        rx_aligned = rx
        tx_aligned = tx
        
    # Make sure they are the exact same length
    min_len = min(len(tx_aligned), len(rx_aligned))
    
    return tx_aligned[:min_len], rx_aligned[:min_len]

NO DPD

In [5]:
# simulation parameters
SpS = 16  # samples per symbol
M = 2  # order of the modulation format
Rs = 10e9  # Symbol rate
Fs = SpS * Rs  # Signal sampling frequency (samples/second)
Pi_dBm = 3  # laser optical power at the input of the MZM in dBm
Pi = dBm2W(Pi_dBm)  # convert from dBm to W

# Bit source parameters
paramBits = parameters()
paramBits.nBits = 100000  # number of bits to be generated
paramBits.mode = 'random' # mode of the bit source 
paramBits.seed = 123      # seed for the random number generator

# pulse shaping parameters
paramPulse = parameters()
paramPulse.pulseType = 'nrz'  # pulse shape type
paramPulse.SpS = SpS     # samples per symbol  

# MZM parameters
paramMZM = parameters()
paramMZM.Vpi = 2
paramMZM.Vb = -paramMZM.Vpi / 2

# linear fiber optical channel parameters
paramCh = parameters()
paramCh.L = 100        # total link distance [km]
paramCh.alpha = 0.2    # fiber loss parameter [dB/km]
paramCh.D = 16         # fiber dispersion parameter [ps/nm/km]
paramCh.Fc = 193.1e12  # central optical frequency [Hz]
paramCh.Fs = Fs

# photodiode parameters
paramPD = parameters()
paramPD.ideal = False
paramPD.B = Rs
paramPD.Fs = Fs
paramPD.seed = 456  # seed for the random number generator



## Simulation
print("\nStarting simulation...", end="")

# generate pseudo-random bit sequence
bitsTx = bitSource(paramBits)

# generate 2-PAM modulated symbol sequence
symbTx = modulateGray(bitsTx, M, "pam")

# upsampling
symbolsUp = upsample(symbTx, SpS)

# pulse shaping
pulse = pulseShape(paramPulse)
sigTx = firFilter(pulse, symbolsUp)
sigTx = anorm(sigTx) # normalize to 1 Vpp

# optical modulation
Ai = np.sqrt(Pi)  # ideal cw laser constant envelope
sigTxo = mzm(Ai, sigTx, paramMZM)

# linear fiber channel model
sigCh = linearFiberChannel(sigTxo, paramCh)

# noisy PD (thermal noise + shot noise + bandwidth limit)
I_Rx = photodiode(sigCh, paramPD)

# capture samples in the middle of signaling intervals
I_Rx = I_Rx[0::SpS]



# calculate the BER and Q-factor
BER, Q = bert(I_Rx, bitsTx)

print("\nTransmission performance metrics:")
print(f"Q-factor = {Q:.2f} ")
print(f"BER = {BER:.2e}")

# theoretical error probability from Q-factor
Pb = 0.5 * erfc(Q / np.sqrt(2))
print(f"Pb = {Pb:.2e}\n")


Starting simulation...
Transmission performance metrics:
Q-factor = 3.58 
BER = 1.00e-04
Pb = 1.73e-04



#DPD

In [6]:
# simulation parameters
SpS = 16  # samples per symbol
M = 2  # order of the modulation format
Rs = 10e9  # Symbol rate
Fs = SpS * Rs  # Signal sampling frequency (samples/second)
Pi_dBm = 3  # laser optical power at the input of the MZM in dBm
Pi = dBm2W(Pi_dBm)  # convert from dBm to W

# Bit source parameters
paramBits = parameters()
paramBits.nBits = 2**18  # number of bits to be generated
paramBits.mode = 'random' # mode of the bit source 
paramBits.seed = 123      # seed for the random number generator

# pulse shaping parameters
paramPulse = parameters()
paramPulse.pulseType = 'nrz'  # pulse shape type
paramPulse.SpS = SpS     # samples per symbol  

# MZM parameters
paramMZM = parameters()
paramMZM.Vpi = 2
paramMZM.Vb = -paramMZM.Vpi / 2

# linear fiber optical channel parameters
paramCh = parameters()
paramCh.L = 100        # total link distance [km]
paramCh.alpha = 0.2    # fiber loss parameter [dB/km]
paramCh.D = 16         # fiber dispersion parameter [ps/nm/km]
paramCh.Fc = 193.1e12  # central optical frequency [Hz]
paramCh.Fs = Fs

# photodiode parameters
paramPD = parameters()
paramPD.ideal = False
paramPD.B = Rs
paramPD.Fs = Fs
paramPD.seed = 456  # seed for the random number generator


# DPD Models:
#dpd_model_copy = build_dpd_model()
seq_length = 1024
dpd_model = build_dpd_model()
dpd_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), loss='mse')


BER_list = []
Q_list = []



for i in range(50):
    ## Starting Simulation

    # generate pseudo-random bit sequence
    bitsTx = bitSource(paramBits)

    # generate 2-PAM modulated symbol sequence
    symbTx = modulateGray(bitsTx, M, "pam")

    # symbTx_windows = create_sliding_windows(symbTx, window_size=seq_length)
    symbTx_dpd = dpd_model.predict(symbTx.reshape(1,-1,1), verbose=0).flatten() #TODO convert into tf dataset for faster inference.
        
        
    # upsampling
    symbolsUp = upsample(symbTx_dpd, SpS)

    # pulse shaping
    pulse = pulseShape(paramPulse)
    sigTx = firFilter(pulse, symbolsUp)
    sigTx = anorm(sigTx) # normalize to 1 Vpp

    # optical modulation
    Ai = np.sqrt(Pi)  # ideal cw laser constant envelope
    sigTxo = mzm(Ai, sigTx, paramMZM)

    # linear fiber channel model
    sigCh = linearFiberChannel(sigTxo, paramCh)

    # noisy PD (thermal noise + shot noise + bandwidth limit)
    I_Rx = photodiode(sigCh, paramPD)

    # capture samples in the middle of signaling intervals
    I_Rx = I_Rx[0::SpS]


    I_Rx_norm = (I_Rx - np.mean(I_Rx)) / np.std(I_Rx)

    dpd_model.fit(I_Rx_norm.reshape(1,-1,1), symbTx_dpd.reshape(1,-1,1), epochs=30, verbose=0) # currently batch size here isnt doing anything as were shaping that as 1,-1,1


    # PERFORMANCE METRICS
    BER, Q = bert(I_Rx, bitsTx) # BER and Q-factor
    print(f"Q-factor = {Q:.2f} ")
    print(f"BER = {BER:.2e}")

    BER_list.append(BER)

    Q_list.append(Q)

2026-02-22 12:17:27.575426: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3
2026-02-22 12:17:27.575454: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 24.00 GB
2026-02-22 12:17:27.575459: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 8.88 GB
2026-02-22 12:17:27.575472: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2026-02-22 12:17:27.575480: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2026-02-22 12:17:27.727333: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


Q-factor = 3.53 
BER = 1.56e-04
Q-factor = 3.45 
BER = 2.44e-04
Q-factor = 3.42 
BER = 2.98e-04
Q-factor = 3.41 
BER = 3.24e-04
Q-factor = 3.42 
BER = 3.43e-04
Q-factor = 3.40 
BER = 3.81e-04
Q-factor = 3.38 
BER = 4.27e-04
Q-factor = 3.34 
BER = 4.69e-04
Q-factor = 3.31 
BER = 4.96e-04
Q-factor = 3.27 
BER = 5.61e-04
Q-factor = 3.23 
BER = 6.37e-04
Q-factor = 3.20 
BER = 7.32e-04
Q-factor = 3.16 
BER = 8.32e-04
Q-factor = 3.12 
BER = 9.16e-04
Q-factor = 3.08 
BER = 1.05e-03
Q-factor = 3.05 
BER = 1.14e-03
Q-factor = 3.03 
BER = 1.21e-03
Q-factor = 3.01 
BER = 1.30e-03
Q-factor = 2.99 
BER = 1.36e-03
Q-factor = 2.97 
BER = 1.46e-03
Q-factor = 2.96 
BER = 1.51e-03
Q-factor = 2.94 
BER = 1.63e-03
Q-factor = 2.92 
BER = 1.73e-03
Q-factor = 2.91 
BER = 1.78e-03
Q-factor = 2.88 
BER = 1.92e-03
Q-factor = 2.86 
BER = 2.07e-03
Q-factor = 2.84 
BER = 2.26e-03
Q-factor = 2.83 
BER = 2.33e-03
Q-factor = 2.81 
BER = 2.44e-03
Q-factor = 2.81 
BER = 2.44e-03
Q-factor = 2.82 
BER = 2.37e-03
Q-factor

KeyboardInterrupt: 

you must be also looking into Q-factor values not just BER. BER on itself isnt enough.
UPDATE THE PAPER USES SNR INSTEAD

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, initializers
from optic.models.devices import mzm, photodiode
from optic.models.channels import linearFiberChannel
from optic.comm.sources import bitSource
from optic.comm.modulation import modulateGray
from optic.comm.metrics import bert
from optic.dsp.core import firFilter, pulseShape, upsample, anorm
from optic.utils import parameters, dBm2W

# --- 1. MODEL DEFINITIONS ---

def build_dpd_model():
    """Builds the DPD (G) as per Fig. 3 [cite: 203, 260]"""
    inputs = layers.Input(shape=(None, 1))
    
    # Section A: Linear CNN (101 taps) [cite: 260]
    kernel_init_A = np.zeros((101, 1, 1))
    kernel_init_A[50, 0, 0] = 1.0
    sec_a = layers.Conv1D(1, 101, padding='same', 
                          kernel_initializer=initializers.Constant(kernel_init_A))(inputs)

    # Section B: Nonlinear FFNN with shortcut [cite: 258, 261]
    b_conv = layers.Conv1D(21, 11, padding='same')(sec_a)
    x = layers.Dense(12, activation=layers.LeakyReLU(0.1))(b_conv)
    x = layers.Dense(8, activation=layers.LeakyReLU(0.1))(x)
    x = layers.Dense(8, activation=layers.LeakyReLU(0.1))(x)
    nonlinear_out = layers.Dense(1, activation='linear')(x)
    
    # ResNet connection 
    sec_b = layers.Add()([sec_a, nonlinear_out])
    return Model(inputs, sec_b, name="DPD_G")

def build_aux_model():
    """Builds the Auxiliary Channel (S) as a mirrored version [cite: 267]"""
    inputs = layers.Input(shape=(None, 1))
    # Mirrored Section C (301 taps) [cite: 267]
    x = layers.Conv1D(1, 301, padding='same')(inputs)
    # Mirrored Section B
    x = layers.Dense(21, activation=layers.LeakyReLU(0.1))(x)
    # Mirrored Section A (101 taps)
    outputs = layers.Conv1D(1, 101, padding='same')(x)
    return Model(inputs, outputs, name="Auxiliary_S")

# --- 2. SIMULATION PARAMETERS ---

SpS, M, Rs = 16, 2, 10e9
Fs = SpS * Rs
Pi = dBm2W(3)

paramBits = parameters(); paramBits.nBits = 2**18; paramBits.seed = 123
paramPulse = parameters(); paramPulse.pulseType = 'nrz'; paramPulse.SpS = SpS
paramMZM = parameters(); paramMZM.Vpi = 2; paramMZM.Vb = -1
paramCh = parameters(); paramCh.L = 80; paramCh.alpha = 0.2; paramCh.D = 16; paramCh.Fc = 193.1e12; paramCh.Fs = Fs
paramPD = parameters(); paramPD.ideal = False; paramPD.B = Rs; paramPD.Fs = Fs; paramPD.seed = 456

# --- 3. DLA INITIALIZATION ---

dpd_model = build_dpd_model()
aux_model = build_aux_model()

# Cascade: x -> DPD -> Aux -> y_est [cite: 133]
inputs_x = layers.Input(shape=(None, 1))
z_pred = dpd_model(inputs_x)
y_est = aux_model(z_pred)
dla_cascade = Model(inputs_x, y_est)

# Optimizers from Table II [cite: 294]
opt_dpd = tf.keras.optimizers.Adam(learning_rate=1e-1)
opt_aux = tf.keras.optimizers.Adam(learning_rate=5e-4)

aux_model.compile(optimizer=opt_aux, loss='mse')
dla_cascade.compile(optimizer=opt_dpd, loss='mse')

# --- 4. MAIN DLA LOOP ---

print(f"{'Iter':<5} | {'Q-Factor':<10} | {'BER':<10}")
print("-" * 30)

for iteration in range(50): # DLA typically converges in 9 iterations [cite: 371]
    # Step A: Data Generation
    bitsTx = bitSource(paramBits)
    symbTx = modulateGray(bitsTx, M, "pam")
    x_input = symbTx.reshape(1, -1, 1)

    # Step B: Apply current DPD and Run Simulation
    z_dpd = dpd_model.predict(x_input, verbose=0)
    z_signal = z_dpd.flatten()

    # Optic-Py Chain
    symbolsUp = upsample(z_signal, SpS)
    sigTx = firFilter(pulseShape(paramPulse), symbolsUp)
    sigTx = anorm(sigTx) # Normalize to 1 Vpp
    sigTxo = mzm(np.sqrt(Pi), sigTx, paramMZM)
    sigCh = linearFiberChannel(sigTxo, paramCh)
    I_Rx = photodiode(sigCh, paramPD)[0::SpS]

    # Normalize received signal for training [cite: 222]
    y_received = (I_Rx - np.mean(I_Rx)) / np.std(I_Rx)
    y_received = y_received.reshape(1, -1, 1)

    # --- DLA STEP 1: Train Auxiliary Channel (S) ---
    # Goal: S(z) ≈ y [cite: 134]
    aux_model.fit(z_dpd, y_received, epochs=10, verbose=0, batch_size=2048)

    # --- DLA STEP 2: Train DPD (G) ---
    # Goal: S(G(x)) ≈ x [cite: 137]
    aux_model.trainable = False
    dla_cascade.fit(x_input, x_input, epochs=10, verbose=0, batch_size=4096)
    aux_model.trainable = True

    # Performance Monitoring
    BER, Q = bert(I_Rx, bitsTx)
    print(f"{iteration+1:<5} | {Q:<10.2f} | {BER:<10.2e}")

print("\nTraining Complete.")

2026-02-22 12:19:14.490959: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3
2026-02-22 12:19:14.490979: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 24.00 GB
2026-02-22 12:19:14.490984: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 8.88 GB
2026-02-22 12:19:14.490998: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2026-02-22 12:19:14.491006: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2026-02-22 12:19:14.691381: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


Iter  | Q-Factor   | BER       
------------------------------
1     | 7.02       | 0.00e+00  
2     | 0.27       | 3.95e-01  
3     | 0.39       | 3.50e-01  
4     | 0.77       | 2.23e-01  
5     | 1.23       | 1.16e-01  
6     | 1.37       | 9.10e-02  
7     | 1.33       | 9.67e-02  
8     | 1.41       | 8.33e-02  
9     | 1.57       | 6.26e-02  
10    | 1.72       | 4.67e-02  
11    | 1.89       | 3.32e-02  
12    | 2.03       | 2.45e-02  
13    | 1.96       | 2.81e-02  
14    | 1.90       | 3.12e-02  
15    | 1.88       | 3.22e-02  
16    | 2.21       | 1.56e-02  
17    | 2.15       | 1.77e-02  
18    | 2.25       | 1.38e-02  
19    | 2.42       | 8.75e-03  
20    | 2.49       | 7.27e-03  
21    | 2.53       | 6.45e-03  
22    | 2.58       | 5.29e-03  
23    | 2.68       | 3.88e-03  
24    | 2.78       | 2.85e-03  
25    | 2.86       | 2.20e-03  
26    | 2.96       | 1.63e-03  
27    | 3.01       | 1.36e-03  
28    | 3.05       | 1.17e-03  
29    | 3.11       | 9.23e-04  
30    | 3