In [31]:
# packages and constants
import numpy as np
import torch
import torch.nn as nn
from torchdiffeq import odeint

control_axis = 'x'
h_bar = 1 # plank const
gamma = 1 # gyromagnetic ratio
B_z = 2 # magnetic field
omega_0 = gamma * B_z
omega_c = omega_0 # control laser frequency
Rabi = omega_0*2 # Rabi frequency
T = np.pi/omega_c
# Pauli matrices
sigma_x = np.array([[0, 1], [1, 0]])
sigma_y = np.array([[0, -1j], [1j, 0]])
sigma_z = np.array([[1, 0], [0, -1]])
sigma_i = np.eye(2)
up = np.array([1,0]).T # excited state
down = np.array([0,1]).T # ground state
# eigenstates in x/y basis
x1 = (np.array([1,1])/np.sqrt(2)).T
x0 = (np.array([1,-1])/np.sqrt(2)).T
y1 = (np.array([1,1j])/np.sqrt(2)).T
y0 = (np.array([1,-1j])/np.sqrt(2)).T
if control_axis=='z':
    sigma_minus = np.outer(down,up.conj().T) #lowering operator
    sigma_plus = np.outer(up,down.conj().T) #raising operator
elif control_axis=='x':
    sigma_minus = np.outer(x0,x1.conj().T) #lowering operator
    sigma_plus = np.outer(x1,x0.conj().T) #raising operator
elif control_axis=='y':
    sigma_minus = np.outer(y0,y1.conj().T) #lowering operator
    sigma_plus = np.outer(y1,y0.conj().T) #raising operator
sigma_minus = torch.tensor(sigma_minus,dtype=torch.cfloat)
sigma_plus = torch.tensor(sigma_plus,dtype=torch.cfloat)

H = torch.tensor(sigma_z * h_bar * omega_0 / 2, dtype=torch.cfloat)

In [32]:
# Neural network to generate parameters
class ParameterNN(nn.Module):
    def __init__(self):
        super(ParameterNN, self).__init__()
        self.fc = nn.Linear(24, 3)

    def forward(self, H, rho_0, rho_T):
        concatenated = torch.cat((torch.real(H), torch.imag(H),torch.real(rho_0), torch.imag(rho_0),torch.real(rho_T), torch.imag(rho_T)), dim=0)
        p = concatenated.view(-1,24)
        return self.fc(p)

def unpack_params(params):
    # prevent negative frequency
    omega_c, Rabi, phi = params[0,0], params[0,1], params[0,2]
    if Rabi<0:
        return omega_c, -1*Rabi, phi + np.pi
    else:
        return omega_c, Rabi, phi

def solve_ode(initial_state, params, t):
    def ode_func(t, y):
        omega_c, Rabi, phi = unpack_params(params)
        H_c = (torch.exp(-(omega_c*t+phi)*1j)*sigma_plus+torch.exp((omega_c*t+phi)*1j)*sigma_minus)*h_bar*Rabi/2
        H_total = H + H_c
        y_real, y_imag = torch.chunk(y, 2, dim=-1)
        y_complex = y_real + y_imag*1j
        commutator = torch.matmul(H_total, y_complex) - torch.matmul(y_complex, H_total)
        dydt = -1j * commutator / h_bar
        dydt_real = torch.real(dydt)
        dydt_imag = torch.imag(dydt)
        dydt = torch.cat((dydt_real, dydt_imag), dim=-1)
        return dydt
    solution = odeint(ode_func, initial_state, t)
    return solution

# Customize the loss function
def custom_loss_fn(predicted_trajectory, target_trajectory):
    return torch.mean((predicted_trajectory - target_trajectory) ** 2)

In [33]:
# Initial state
rho_0 = torch.tensor([[0.0 + 0.0j, 0.0 + 0.0j], [0.0 + 0.0j, 1.0 + 0.0j]], dtype=torch.cfloat)
#phi0 = up*np.sqrt(3/4) + down*np.sqrt(1/4) # initial state
#rho_0 = torch.tensor(np.outer(phi0,phi0)+np.array([[0j, 0j], [0j, 0j]]), dtype=torch.cfloat)

t = torch.linspace(0.0, 1.0, 30)
# Target state
rho_T = torch.tensor([[1.0 + 0.0j, 0.0 + 0.0j], [0.0 + 0.0j, 0.0 + 0.0j]], dtype=torch.cfloat)

initial_state_real = torch.real(rho_0)
initial_state_imag = torch.imag(rho_0)
initial_state_combined = torch.cat((initial_state_real, initial_state_imag), dim=-1)
# Instantiate the model
model = ParameterNN()
# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Softmin temperature
temperature = 0.1

# Training loop
for epoch in range(500):
    optimizer.zero_grad()
    params = model(H, rho_0, rho_T)
    pred_trajectory_combined = solve_ode(initial_state_combined, params, t)
    pred_trajectory_real, pred_trajectory_imag = torch.chunk(pred_trajectory_combined, 2, dim=-1)
    pred_trajectory = pred_trajectory_real + 1j * pred_trajectory_imag
    # Compute the distance between the target state and each state in the trajectory
    distances = torch.norm(pred_trajectory - rho_T, dim=(1, 2))
    _,ind = torch.min(distances,dim=0)
    # Compute the softmin weights
    softmin_weights = torch.softmax(-distances / temperature, dim=0)
    
    # Compute the weighted average of the distances
    softmin_distance = torch.sum(softmin_weights * distances)
    
    # Compute the normalized time index
    time_indices = torch.arange(len(t), dtype=torch.float32) / (len(t) - 1)
    softmin_time = torch.sum(softmin_weights * time_indices)
    
    # Define the combined loss
    alpha = 0.3  # Weight for the time-based loss
    loss = softmin_distance + alpha * softmin_time
    
    loss.backward()
    optimizer.step()

    if epoch % 40 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}, Softmin Distance: {softmin_distance.item()}, Softmin Time: {softmin_time.item()}, state: {pred_trajectory[ind]}')


Epoch 0, Loss: 1.5642131567001343, Softmin Distance: 1.4141796827316284, Softmin Time: 0.5001116991043091, state: tensor([[ 1.9965e-04+0.0000j, -9.9094e-03+0.0101j],
        [-9.9094e-03-0.0101j,  9.9980e-01+0.0000j]], grad_fn=<SelectBackward0>)
Epoch 40, Loss: 1.0678170919418335, Softmin Distance: 0.7964939475059509, Softmin Time: 0.9044104814529419, state: tensor([[ 0.7457+0.0000j, -0.3113+0.3045j],
        [-0.3113-0.3045j,  0.2543+0.0000j]], grad_fn=<SelectBackward0>)
Epoch 80, Loss: 0.3428279459476471, Softmin Distance: 0.10918982326984406, Softmin Time: 0.7787936925888062, state: tensor([[9.9910e-01+0.0000j, 2.9279e-02+0.0067j],
        [2.9279e-02-0.0067j, 9.0331e-04+0.0000j]], grad_fn=<SelectBackward0>)
Epoch 120, Loss: 0.3135245144367218, Softmin Distance: 0.11266018450260162, Softmin Time: 0.6695477366447449, state: tensor([[ 0.9987+0.0000j, -0.0357+0.0073j],
        [-0.0357-0.0073j,  0.0013+0.0000j]], grad_fn=<SelectBackward0>)
Epoch 160, Loss: 0.2986171245574951, Softmin D

In [34]:
omega_c, Rabi, phi = unpack_params(params)
print(f'omega_c = {omega_c.item()}\nRabi = {Rabi.item()}\nphi = {phi.item()}\nT = {softmin_time.item()}')

omega_c = 0.3437114357948303
Rabi = 5.222267150878906
phi = -2.0529043674468994
T = 0.6465154886245728
