In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from matplotlib.animation import FuncAnimation
import torch
import torch.nn as nn
import torch.optim as optim
from itertools import chain

# Constants and laser parameters
Gamma = 0.3
tau_p = 1.5e-12
tau_n = 2.0e-9
eta = 1e-5
K_inj = 0.1
phi = 0
Delta_omega = 2 * np.pi * 5e9
V = 2e-12
q = 1.6e-19
g0 = 1e-16
N0 = 1.5e24
Isat = 1e-3
alpha = 2
beta = 1e-4
alpha_c = 4
n_refractive = 3.5
c = 299792458  # Speed of light in vacuum, m/s
L = 1.0        # Assume path length of 1 meter for simplification
K = 1e-6  # Proportional constant for frequency difference influence
nu_0 = 3e14  # Central or nominal frequency of the lasers

# Gain function with thermal effects
def gain(N, P, delta_n):
    return g0 * (N - N0 + delta_n) / (1 + P / Isat)

# Neural Network for feedback control
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(1, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load the trained model
model1 = NeuralNetwork()
model2 = NeuralNetwork()
model3 = NeuralNetwork()

model1.load_state_dict(torch.load('model1.pth'))
model2.load_state_dict(torch.load('model2.pth'))
model3.load_state_dict(torch.load('model3.pth'))

# Set to evaluation mode
model1.eval()
model2.eval()
model3.eval() 

# System dynamics function to integrate
class LaserSystem:
    def __init__(self, nn_model, use_nn=True):
        self.models = models
        self.use_nn = use_nn
        self.desired_sync_freq = get_float_input("Enter the desired synchronized laser frequency (Hz): ")
        
        # Combine parameters from all models
        combined_parameters = chain(*(model.parameters() for model in self.models))
        self.optimizer = optim.Adam(combined_parameters, lr=0.001)

    def system_dynamics(self, t, y):
        # Unpack state variables
        Pm, Nm, Ps, Ns, I_master, I_slave, bias_voltage, temperature = y

        # Calculate the change in carrier density due to thermal effects
        delta_Nm = Nm - N0
        delta_Ns = Ns - N0
        delta_n = beta * (temperature - 300)  # Thermal effect on refractive index
        actual_freq_diff = (nu_0 * alpha_c / (2 * n_refractive)) * (delta_Nm - delta_Ns)
        current_sync_freq = nu_0 + actual_freq_diff
        sync_error = self.desired_sync_freq - current_sync_freq
        new_input_freq_diff = actual_freq_diff + sync_error

        if self.use_nn:
            control_inputs = [torch.tensor([item], dtype=torch.float32) for item in [Pm, Nm, Ps, Ns]]
            control_signals = [model(input_tensor).item() for model, input_tensor in zip(self.models, control_inputs)]
            I_master += control_signals[0]
            I_slave += control_signals[1]
            bias_voltage += control_signals[2]
            temperature += control_signals[3]
            
        dPm_dt = Gamma * gain(Nm, Pm, delta_n) * Pm + eta - Pm / tau_p
        dNm_dt = I_master / (q * V) - Nm / tau_n - gain(Nm, Pm, delta_n) * Pm
        dPs_dt = Gamma * gain(Ns, Ps, delta_n) * Ps + eta - Ps / tau_p + K_inj * np.cos(Delta_omega * t + phi)
        dNs_dt = I_slave / (q * V) - Ns / tau_n - gain(Ns, Ps, delta_n) * Ps
        return [dPm_dt, dNm_dt, dPs_dt, dNs_dt, I_master, I_slave, bias_voltage, temperature]

def get_float_input(prompt):
    while True:
        try:
            return float(input(prompt))
        except ValueError:
            print("Invalid input. Please enter a valid floating-point number.")

# Create instances with and without NN
models = [model1, model2, model3, model4]
laser_system_with_nn = LaserSystem(models, use_nn=True)
laser_system_without_nn = LaserSystem(models, use_nn=False)
laser_system = laser_system_with_nn
# Initial conditions and simulation setup
initial_conditions = [1e8, 2e24, 1e8, 2e24, 40e-3, 40e-3, 1.2, 300]
t_span = (0, 5e-9)
t_eval = np.linspace(0, 5e-9, 500)
solution = solve_ivp(laser_system.system_dynamics, t_span, initial_conditions, t_eval=t_eval, method='BDF')

# Solve system dynamics
solution_with_nn = solve_ivp(laser_system_with_nn.system_dynamics, t_span, initial_conditions, t_eval=t_eval, method='BDF')
solution_without_nn = solve_ivp(laser_system_without_nn.system_dynamics, t_span, initial_conditions, t_eval=t_eval, method='BDF')

# Setup the figure for animation
fig, axs = plt.subplots(3, 2, figsize=(14, 18))  # 3 rows, 2 columns

def init():
    for ax in axs.flat:
        ax.clear()
        ax.set_xlim(0, max(t_eval))
        ax.set_ylim(0, 1)  # Adjust based on expected ranges
    return [line for ax in axs for line in ax.lines]

def update(frame):
    for i, solution, title in zip(range(6), [solution_with_nn, solution_without_nn]*3,
                                  ["Synchronized Frequency (With NN)", "Synchronized Frequency (Without NN)",
                                   "System Outputs (With NN)", "System Outputs (Without NN)",
                                   "NN Outputs (Control Signals)", "NN Outputs (Control Signals)"]):
        ax = axs.flat[i]
        ax.set_title(title)
        if i < 2:  # Frequency plots
            freqs = nu_0 + nu_0 * (solution.y[1][:frame] - solution.y[3][:frame]) / N0
            ax.plot(t_eval[:frame], freqs, label='Actual Frequency')
            ax.plot(t_eval[:frame], [desired_sync_freq]*frame, 'r--', label='Desired Frequency')
        elif i < 4:  # System outputs plots
            ax.plot(t_eval[:frame], solution.y[0][:frame], label='Master Photon Density')
            ax.plot(t_eval[:frame], solution.y[1][:frame], label='Master Carrier Density')
            ax.plot(t_eval[:frame], solution.y[2][:frame], label='Slave Photon Density')
            ax.plot(t_eval[:frame], solution.y[3][:frame], label='Slave Carrier Density')
        else:  # NN outputs plots
            ax.plot(t_eval[:frame], solution.y[4][:frame], label='I_master')
            ax.plot(t_eval[:frame], solution.y[5][:frame], label='I_slave')
            ax.plot(t_eval[:frame], solution.y[6][:frame], label='Bias Voltage')
            ax.plot(t_eval[:frame], solution.y[7][:frame], label='Temperature')
        ax.legend()
    return [line for ax in axs for line in ax.lines]

ani = FuncAnimation(fig, update, frames=len(t_eval), init_func=init, blit=False)
plt.tight_layout()
plt.show()
 