In [None]:
import sys
sys.path.append('..')

import matplotlib.pyplot as plt
import autograd.numpy as np

from lib.graph import Graph
from simulator.fiber.assets.propagator import Propagator

from simulator.fiber.evaluator_subclasses.evaluator_pulserep import  PulseRepetition_dual
from simulator.fiber.evolver import HessianProbabilityEvolver
from simulator.fiber.node_types_subclasses.inputs import ContinuousWaveLaser, PulsedLaser
from simulator.fiber.node_types_subclasses.outputs import Photodiode,MeasurementDevice
from simulator.fiber.node_types_subclasses.single_path import PhaseModulator, WaveShaper, OpticalAmplifier

from simulator.fiber.node_types_subclasses.multi_path import VariablePowerSplitter,DualOutputMZM
from simulator.fiber.node_types_subclasses.terminals import TerminalSource, TerminalSink

from algorithms.parameter_optimization import parameters_optimize


In [None]:
propagator = Propagator(window_t=4e-9, n_samples=2**15, central_wl=1.55e-6)
pulse_width, rep_t, peak_power = (3e-12, 1/10.0e9, 1.0)
p, q = (2, 1)
input_laser = PulsedLaser(parameters_from_name={'pulse_width': pulse_width, 'peak_power': peak_power,
                                                't_rep': rep_t, 'pulse_shape': 'gaussian',
                                                'central_wl': 1.55e-6, 'train': True})
input_laser.node_lock = True
input_laser.protected = True

input = input_laser.get_pulse_train(propagator.t, pulse_width=pulse_width, rep_t=rep_t, peak_power=peak_power)
target1 = input_laser.get_pulse_train(propagator.t, pulse_width=pulse_width * (p / q), rep_t=rep_t * (p / q),
                                      peak_power=peak_power * 1.0, phase_shift=0.0)
target2 = input_laser.get_pulse_train(propagator.t, pulse_width=pulse_width * (p / q), rep_t=rep_t * (p / q),
                                      peak_power=peak_power * 1.0, phase_shift=0.5)  # shift 0.5T

evaluator = PulseRepetition_dual(propagator,
                                     targets={'sink1': np.array(target1),
                                            'sink2': np.array(target2)
                                            },
                                     pulse_width=pulse_width, rep_t=rep_t, peak_power=peak_power)
evolver = HessianProbabilityEvolver(verbose=False)
md = MeasurementDevice()
md.protected = True
nodes = {'source': TerminalSource(),
             0: DualOutputMZM(),
             'sink1': TerminalSink(node_name='sink1'),
             'sink2': TerminalSink(node_name='sink2')
             }
edges = {('source', 0): input_laser,
         (0, 'sink1'): md,
         (0, 'sink2'): md,
         }

graph = Graph.init_graph(nodes, edges)
graph.update_graph()
graph.initialize_func_grad_hess(propagator, evaluator, exclude_locked=True)

graph.draw()

In [None]:
method = 'L-BFGS+GA'
graph.clear_propagation()
x0, models, parameter_index, *_ = graph.extract_parameters_to_list()
# graph, x, score, log = parameters_optimize(graph, x0=x0, method=method, verbose=False)
my_params = [np.pi,5e+09,np.pi/2]
for i in range(len(models)):
    models[i].parameters = my_params
x = my_params
graph.distribute_parameters_from_list(x, models, parameter_index)

In [None]:

graph.propagate(propagator, save_transforms=False)
sink_nodes = sorted([k for k in evaluator.targets.keys() if k.startswith('sink')],
                        key=lambda x: int(x[4:]))
num_sinks = len(sink_nodes)
ncols = 1 + num_sinks

fig, axs = plt.subplots(nrows=1, ncols=ncols,
                        figsize=(4 * ncols, 4 ),
                        squeeze=False)  # 确保axs始终是二维数组

# 扩展时域显示范围（显示完整时间轴）
t_ns = propagator.t * 1e9  # 转换为纳秒单位
xlim_full = [t_ns.min(), t_ns.max()]  # 完整时域范围

graph.propagate(propagator, save_transforms=False)

# 绘制拓扑图
graph.draw(ax=axs[0, 0], debug=False)
axs[0, 0].set_title(f"Topology")
for col, sink in enumerate(sink_nodes, start=1):
    # 获取目标信号和实际信号
    target = evaluator.targets[sink]
    measured = np.abs(graph.measure_propagator(sink))

    # 绘制时域对比
    axs[0, col].plot(t_ns, target, '--', label=f'Target {sink}')
    axs[0, col].plot(t_ns, measured, '-', label=f'Measured {sink}')
    axs[0, col].set(xlim=xlim_full,  # 显示完整时域
                    xlabel='Time (ns)',
                    ylabel='Amplitude',
                    title=f'{sink} Comparison')
    axs[0, col].legend(loc='upper right', fontsize=8)

    # 添加性能指标注释
    mse = np.mean((target - measured) ** 2)
    axs[0, col].annotate(f"MSE: {mse:.2e}",
                         xy=(0.95, 0.95),
                         xycoords='axes fraction',
                         ha='right', va='top',
                         fontsize=8,
                         bbox=dict(facecolor='white', alpha=0.8))
plt.show()