In [None]:
import numpy as np 
import scipy 
import sys
import os 
import matplotlib.pyplot as plt
from scipy.linalg import sqrtm
sys.path.append('..')
import evos.src.lattice.spin_one_half_lattice as spin_lat
import vectorized_lindbladian_and_davies_map as vec_lind
import metropolis_qubits_and_fermions as mqf 
import entropy_and_coherences as ent_and_coh
from reny_entropy import renyi_divergence

plt.rcParams['text.latex.preamble'] = r'\usepackage{{amsmath}}'  # Add any additional LaTeX packages if needed
plt.rcParams['text.usetex'] = False

#parameters #FIXME: use argparse
n_sites = 4 #4
J = -1
h = 1 
temperature_bath = 0.1 #10
temperature_initial = 1
alpha = 0.98 #renyi divergence parameter
t_max =12 # for 4 qubits: 0.3 #TODO: plot to shorter times in order to see the crossing
noise_state = 0.05 #crossing classical: 0.05,  crossing quantum: 0.25
dt = 0.01
random_iter = 200
micro_iter = 20 
macro_iter = 20
qubits_or_fermions = 'qubits'

#lattice
lat = spin_lat.SpinOneHalfLattice(n_sites)
#hamiltonian
ham = 0.
for site in range(n_sites-1): #NOTE: OBC
    ham += J*(lat.sso('sz',site)@ lat.sso('sz',site+1) )  
for site in range(n_sites):
    ham += h*lat.sso('sx',site)
#check if hamiltonian is hermitian
if not np.allclose(ham, np.conjugate(np.transpose(ham))):
    raise ValueError('hamiltonian is not hermitian')

#diagonalize hamiltonian
eigw_hamiltonian = np.linalg.eigh(ham)[0]
eigenbasis_hamiltonian = np.linalg.eigh(ham)[1]
#Lindbladian: Davies map
jump_op_list = vec_lind.jump_operators_for_davies_map_spin_onehalf(n_sites, ham, 1./temperature_bath)
lindbladian_vectorized = vec_lind.vectorized_lindbladian(n_sites, jump_op_list, np.diag(eigw_hamiltonian) )
np.save("lindbladian_vectorized", lindbladian_vectorized)
#eigenvalues and eigenvectors
eigw_vectorized_lindbladian, left_eigv_vectorized_lindbladian, right_eigv_vectorized_lindbladian = scipy.linalg.eig(lindbladian_vectorized, left=True, right=True)
idx = np.argsort(np.abs(np.real(eigw_vectorized_lindbladian)))
eigw_vectorized_lindbladian = eigw_vectorized_lindbladian[idx]
left_eigv_vectorized_lindbladian = left_eigv_vectorized_lindbladian[:,idx]
right_eigv_vectorized_lindbladian = right_eigv_vectorized_lindbladian[:,idx]
#slowest decaying modes
l2_matrix = np.reshape(left_eigv_vectorized_lindbladian[:,1], (2**n_sites,2**n_sites) )
l3_matrix = np.reshape(left_eigv_vectorized_lindbladian[:,2], (2**n_sites,2**n_sites) )
l4_matrix = np.reshape(left_eigv_vectorized_lindbladian[:,3], (2**n_sites,2**n_sites) )
l5_matrix = np.reshape(left_eigv_vectorized_lindbladian[:,4], (2**n_sites,2**n_sites) )

print(eigw_vectorized_lindbladian)
# quit()
#thermal state
# initial_state = mqf.rho_thermal(temperature_initial, np.diag(eigw_hamiltonian)  ) 
     
def thermal_state_with_positive_noise(temperature, eigw_hamiltonian, noise_strength, seed):
    rho_therm = mqf.rho_thermal(temperature_initial, np.diag(eigw_hamiltonian)  )
    np.random.seed(seed)
    rho_positive_random = noise_strength * np.random.rand(2**n_sites,2**n_sites) + 1j*noise_strength * np.random.rand(2**n_sites,2**n_sites)
    rho_positive_random = np.conjugate(np.transpose(rho_positive_random)) @ rho_positive_random
    state = rho_therm + rho_positive_random
    state /= np.trace(state)
    return state

def gen_random_mixed_state(seed):
    np.random.seed(seed)
    state = 0.
    random_state = np.random.rand(2**n_sites,2**n_sites) + 1j*np.random.rand(2**n_sites,2**n_sites)
    for i in range(1000):
        random_ket = np.random.rand(2**n_sites) + 1j*np.random.rand(2**n_sites)
        state += np.outer(random_ket, np.conjugate(random_ket))
        state /= np.trace(state)
    return state 

#NOTE: I should rotate this state to the Hamiltonian eigenbasis.
#      However, since it is just a random state, I consider this as if it was the Hamiltonian eigenbasis already

# initial_state = gen_random_mixed_state(1)

initial_state = thermal_state_with_positive_noise(temperature_initial, eigw_hamiltonian, noise_state, 0)  

modes_to_remove_list = [l2_matrix] #[l2_matrix, l3_matrix, l4_matrix] 
#NEW  METROPOLIS FUNCTION
overlap_vector_tot, parameter_matrix = mqf.minimize_scalar_product_metropolis_multiple_particles(modes_to_remove_list, initial_state, n_sites, qubits_or_fermions, max_random_sweeps=random_iter, max_micro_iterations=micro_iter, max_macro_iterations=macro_iter, cooling_rate = 0.9999)
print("Finished metropolis optimization")
rotation_matrix = mqf.sigle_qubit_unitary_on_all_qubits(parameter_matrix, n_sites)
rotated_optimal = rotation_matrix @ initial_state @ np.conjugate( np.transpose(rotation_matrix) )
print("energy of rotated state = ", np.trace(np.diag(eigw_hamiltonian) @ rotated_optimal) )

#save overlap with slowest modes as function of titeration
np.save("overlap_vector_tot", overlap_vector_tot)

#Time-evolution 
eigenbasis_hamiltonian = np.linalg.eigh(ham)[1]
#initialize observables
state_unrotated_vectorized_every_t = np.zeros((4**n_sites,int(t_max/dt)+1), dtype=np.complex128)
state_rotated_vectorized_every_t = np.zeros((4**n_sites,int(t_max/dt)+1), dtype=np.complex128)

distance_unrotated = []
distance_rotated = []
quantum_relative_entropy_original = []
quantum_relative_entropy_rotated = []
classical_relative_entropy_original = []
classical_relative_entropy_diagonalized = []

renyi_divergence_original_single_alpha = []
renyi_divergence_rotated_single_alpha = []

initial_state_vectorized = initial_state.flatten()
rotated_optimal_vectorized = rotated_optimal.flatten()
id_v = np.eye(2**n_sites).flatten()
steady_state = np.reshape(right_eigv_vectorized_lindbladian[:,0], (2**n_sites,2**n_sites))
steady_state /= np.trace(steady_state)

tevo_op = scipy.linalg.expm(dt*lindbladian_vectorized)
state_unrotated_vectorized = initial_state_vectorized.copy()
state_rotated_vectorized = rotated_optimal_vectorized.copy()

def compute_wigner_yanase_skew_information(state, operator):
    return np.trace( state @ operator @ operator ) - np.trace( sqrtm(state) @ operator @ sqrtm(state) @ operator )

wy_unrotated = []
wy_rotate = []

count = 0
for t in np.arange(0, t_max, dt):
    #tevo
    state_unrotated_vectorized = tevo_op @ state_unrotated_vectorized
    state_rotated_vectorized = tevo_op @ state_rotated_vectorized
    state_unrotated = np.reshape(state_unrotated_vectorized, (2**n_sites,2**n_sites))
    state_rotated = np.reshape(state_rotated_vectorized, (2**n_sites,2**n_sites))
    state_unrotated /= np.trace(state_unrotated)
    state_rotated /= np.trace(state_rotated)
    #state at every timestep
    state_unrotated_vectorized_every_t[:,count] = state_unrotated_vectorized
    state_rotated_vectorized_every_t[:,count] = state_rotated_vectorized
    #L1 distance
    distance_unrotated.append( np.sum(np.abs(state_unrotated - steady_state)) )
    distance_rotated.append( np.sum(np.abs(state_rotated - steady_state)) )
    #Wigner-Yanase skew information
    wy_unrotated.append( compute_wigner_yanase_skew_information(state_unrotated, np.diag(eigw_hamiltonian)) )
    wy_rotate.append( compute_wigner_yanase_skew_information(state_rotated, np.diag(eigw_hamiltonian)) )
    #quantum rel. entr.
    quantum_relative_entropy_original.append( ent_and_coh.relative_entropy_coherence( state_unrotated, np.eye((2**n_sites), dtype='complex' ) ) )  
    quantum_relative_entropy_rotated.append( ent_and_coh.relative_entropy_coherence( state_rotated, np.eye((2**n_sites), dtype='complex' ) ) ) 
    #classical rel. entr.
    #classical relative entropy
    classical_relative_entropy_original.append(ent_and_coh.relative_entropy_Kullback( np.diag(np.diag(state_unrotated)), steady_state))
    classical_relative_entropy_diagonalized.append(ent_and_coh.relative_entropy_Kullback( np.diag(np.diag(state_rotated)), steady_state))
    #renyi divergence
    renyi_divergence_original_single_alpha.append( renyi_divergence(state_unrotated, np.diag(np.diag(state_unrotated)), alpha) )
    renyi_divergence_rotated_single_alpha.append( renyi_divergence(state_rotated, np.diag(np.diag(state_rotated)), alpha) )

    count += 1
    
#plot L1 DISTANCES
color_unrotated = '#beaed4'
color_rotated = '#fdc086'
time_v = np.arange(0, t_max, dt)

# plt.figure(2)
# plt.semilogy(time_v, distance_unrotated, label=r'$\rho_{T=5} $', color=color_unrotated) #+ 0.05\rho_{\text{rand}}
# plt.semilogy(time_v, distance_rotated, label='rotated state', color=color_rotated)

# #plot also the real part of the first 3 non-zero eigenvalues * time 
# lambda_2 = distance_unrotated[-1]*np.exp(np.real(eigw_vectorized_lindbladian[1])*(time_v-time_v[-1]))
# lambda_3 = distance_rotated[-1]*np.exp(np.real(eigw_vectorized_lindbladian[2])*(time_v-time_v[-1]))

# plt.semilogy(time_v, lambda_2, label='$\exp(\Re{(\lambda_2)}t) $', linestyle='dashed', color = color_unrotated)
# plt.semilogy(time_v, lambda_3, label='$\exp(\Re{(\lambda_3)}t)$', linestyle='dotted', color = color_rotated)

# plt.ylabel('L1 distance')
# plt.xlabel('time')
# plt.legend()
# plt.show()    

# plt.figure(3)
# plt.rcParams.update({'font.size': 20})
# plt.semilogy(time_v, wy_unrotated, label=r'$\rho_{T=}$'+str(temperature_initial)+' +noise') #+ 0.05\rho_{\text{rand}}
# plt.semilogy(time_v, wy_rotate, label='rotated state')
# plt.ylabel(r'$F_{\text{WY}}(\hat{H}_s)$')
# plt.xlabel('time $(1/J)$')
# plt.legend()
# plt.show()

#fit the logarithm of the quantum relative entropy to a linear function
quantum_relative_entropy_original = np.array(quantum_relative_entropy_original)
quantum_relative_entropy_rotated = np.array(quantum_relative_entropy_rotated)
# fit quantum relative entropy to a linear function 

slope_original, intercept_original = np.polyfit(time_v[int(7/dt):], np.log(quantum_relative_entropy_original[int(7/dt):]), 1)
slope_rotated, intercept_rotated = np.polyfit(time_v[int(7/dt):int(12/dt)], np.log(quantum_relative_entropy_rotated[int(7/dt):int(12/dt)]), 1)
print("Slope original: ", slope_original)
print("Slope rotated: ", slope_rotated)
#print the real part of the second lindbladian eigenvalue, i.e. the slowest decaying mode
print("2*Real part of slowest coherent mode: ", 2*np.real(eigw_vectorized_lindbladian[1]))
print("2*Real part of second-slowest coherent mode: ", 2*np.real(eigw_vectorized_lindbladian[4]))

#same but for imaginary part

#plot quantum relative entropy
plt.figure()
plt.rcParams.update({'font.size': 20})
plt.semilogy(time_v, quantum_relative_entropy_original, label=r'$S(\hat{\rho}_s ||\mathcal{G}(\hat{\rho}_s))$, ' + r'$\hat\rho_1$', color='orange') #+ 0.05\rho_{\text{rand}}
plt.semilogy(time_v, quantum_relative_entropy_rotated, label=r'$S(\hat{\rho}_s ||\mathcal{G}(\hat{\rho}_s))$, ' + r'$\hat\rho_2$', color='orange', linestyle='dashed')
# plt.ylabel(r'$S(\hat{\rho}_s ||\mathcal{G}(\hat{\rho}_s))$')

##plot also the exponential fits 
## plt.semilogy(time_v, np.exp(slope_original*time_v + intercept_original), label='fit original', color='blue')
## plt.semilogy(time_v, np.exp(slope_rotated*time_v + intercept_rotated), label='fit rotated', color='red', linestyle='dashed')

#now plot the exponentials with exponents 2*np.real(eigw_vectorized_lindbladian[1] and 2*np.real(eigw_vectorized_lindbladian[4])
# plt.semilogy(time_v, np.exp(2*np.real(eigw_vectorized_lindbladian[1])*time_v+ intercept_original), label='2*Re($\lambda_2$)', color='black')
# plt.semilogy(time_v, np.exp(2*np.real(eigw_vectorized_lindbladian[4])*time_v+ intercept_rotated), label='2*Re($\lambda_4$)', color='black', linestyle='dashed')

plt.text(0.5, 0.6, r'$e^{\Re(\lambda_2)}$', fontsize=20, color='black', transform=plt.gca().transAxes)
plt.text(0.6, 0.35, r'$e^{\Re(\lambda_4)}$', fontsize=20, color='black', transform=plt.gca().transAxes)
plt.xlabel('time (1/J)')
plt.legend()


#plot the real and imaginary part of the spectrum
plt.axes([0.2, 0.2, .25, .25])
eigvals = eigw_vectorized_lindbladian[:10]  # Only the first 10 eigenvalues
plt.xlabel(r'Re($\lambda$)')
plt.ylabel(r'Im($\lambda$)')

used_indices = set()
label_counter = 1

for i in range(len(eigvals)):
    if i in used_indices:
        continue

    val = eigvals[i]
    real = np.real(val)
    imag = np.imag(val)

    if np.isclose(imag, 0):
        # Purely real eigenvalue
        plt.scatter(real, imag, color='green', s=30)
        plt.text(real, imag, r'$\lambda_{' + str(label_counter) + r'}$', fontsize=10, color='black')
        label_counter += 1
        used_indices.add(i)
    else:
        # Try to find and label the complex conjugate pair
        for j in range(i + 1, len(eigvals)):
            if j in used_indices:
                continue
            if np.isclose(val, np.conj(eigvals[j])):
                # Plot both
                plt.scatter([real, np.real(eigvals[j])], [imag, np.imag(eigvals[j])], color='orange', s=30)

                # Label as λ_n and λ_n^*
                plt.text(real, imag, r'$\lambda_{' + str(label_counter) + r'}$', fontsize=10, color='black')
                plt.text(np.real(eigvals[j]), np.imag(eigvals[j]), r'$\lambda_{' + str(label_counter) + r'}^*$', fontsize=10, color='black')

                used_indices.update([i, j])
                label_counter += 1
                break


from matplotlib.lines import Line2D
# Define custom legend handles as circles
green_dot = Line2D([], [], color='green', marker='o', linestyle='None', markersize=8, label='$\omega=0$')
orange_dot = Line2D([], [], color='orange', marker='o', linestyle='None', markersize=8, label=r'$\omega \neq 0$')
# Add legend
plt.legend(handles=[green_dot, orange_dot], loc='best', fontsize=6, frameon=False)

plt.show()


#reshape right_eigv_vectorized_lindbladian[:,1] to a matrix and compute the trace of its square
# r2_matrix = np.reshape(right_eigv_vectorized_lindbladian[:,1], (2**n_sites,2**n_sites) )
# print(r2_matrix@np.transpose(np.conjugate(r2_matrix)))
# print("Trace of square of right eigenvector of second slowest mode: ", np.trace(r2_matrix@np.transpose(np.conjugate(r2_matrix))))



#CLASSICAL RELATIVE ENTROPY
# # plt.figure(1)
# plt.rcParams.update({'font.size': 20})
# plt.semilogy(time_v, classical_relative_entropy_original, label=r'$S(\mathcal{G}(\hat{\rho}_s)|| \hat{\pi})$, ' + r'$\hat\rho_1$', color='green')
# plt.semilogy(time_v, classical_relative_entropy_diagonalized, label=r'$S(\mathcal{G}(\hat{\rho}_s)|| \hat{\pi})$, ' + r'$\hat\rho_2$', color='green', linestyle='dashed')
# # plt.ylabel(r'$S(\mathcal{G}(\hat{\rho}_s)|| \hat{\pi})$')
# plt.xlabel('time (1/J)')
# plt.legend()
# plt.title("$\gamma = $" + str(noise_state))
# plt.tight_layout()
# plt.show()

# np.save("state_unrotated_vectorized_every_t", state_unrotated_vectorized_every_t)
# np.save("state_rotated_vectorized_every_t", state_rotated_vectorized_every_t)
# np.save("ham", ham)

#Renyi divergence
# plt.figure(4)
# plt.rcParams.update({'font.size': 20})
# plt.semilogy(time_v, renyi_divergence_original_single_alpha, label=r'$D_{\alpha}(\hat{\rho}_s || \hat{\pi})$, ' + r'$\hat\rho_1$', color='purple') #+ 0.05\rho_{\text{rand}}
# plt.semilogy(time_v, renyi_divergence_rotated_single_alpha, label=r'$D_{\alpha}(\hat{\rho}_s || \hat{\pi})$, ' + r'$\hat\rho_2$', color='purple', linestyle='dashed')
# #compare with classical + quantum relative entropy
# plt.semilogy(time_v, quantum_relative_entropy_original, label=r'$S(\hat{\rho}_s || \hat{\pi})$, ' + r'$\hat\rho_1$', color='orange', alpha=0.5)
# plt.semilogy(time_v, quantum_relative_entropy_rotated , label=r'$S(\hat{\rho}_s || \hat{\pi})$, ' + r'$\hat\rho_2$', color='orange', linestyle='dashed', alpha=0.5)

# plt.xlabel('time (1/J)')
# plt.ylabel(r'$D_{\alpha}(\hat{\rho}_s || \hat{\pi})$')
# plt.legend()
# plt.title(r"$\alpha = $" + str(alpha))
# plt.tight_layout()
# plt.show()

alpha_array =  np.linspace(0.1, 2, 50) #40
renyi_divergence_original = []
renyi_divergence_rotated = []

#loop over the array of alphas and over all times and compute the reny entropies for the original and diagonalized state
for alpha in alpha_array:
    renyi_divergence_original_alpha = []
    renyi_divergence_rotated_alpha = []
    print("Computing Renyi divergence for alpha = ", alpha)
    for i in range(len(time_v)):
        state_unrotated_i = np.reshape(state_unrotated_vectorized_every_t[:,i], (2**n_sites,2**n_sites))
        state_unrotated_i /= np.trace(state_unrotated_i)
        # print("Trace of unrotated state at time ", time_v[i], " = ", np.trace(state_unrotated_i))
        state_rotated_i = np.reshape(state_rotated_vectorized_every_t[:,i], (2**n_sites,2**n_sites))
        state_rotated_i /= np.trace(state_rotated_i)
        # print("Trace of rotated state at time ", time_v[i], " = ", np.trace(state_rotated_i))
        renyi_divergence_original_alpha.append( renyi_divergence(state_unrotated_i, np.diag(np.diag(state_unrotated_i)), alpha) )
        renyi_divergence_rotated_alpha.append( renyi_divergence(state_rotated_i, np.diag(np.diag(state_rotated_i)), alpha) )
    renyi_divergence_original.append(renyi_divergence_original_alpha)
    renyi_divergence_rotated.append(renyi_divergence_rotated_alpha)   
  
    
def find_closest_crossing(t_list, A, B):
    diff = np.array(A) - np.array(B)
    sign_changes = np.where(np.diff(np.sign(diff)))[0]
    if len(sign_changes) == 0:
        return None
    # Assuming we want the crossing from the 'down' side, we can pick the last crossing:
    idx = sign_changes[-1]
    return t_list[idx]

#for every alpha, find the crossing time
crossing_times = []
for i in range(len(alpha_array)):
    crossing_times.append( find_closest_crossing(time_v, renyi_divergence_original[i], renyi_divergence_rotated[i]) ) 

#plot the crossing times 
plt.figure()
plt.rcParams.update({'font.size': 20})
plt.plot(alpha_array, crossing_times, label='crossing time', color='black')
plt.xlabel(r'$\alpha$')
plt.ylabel('crossing time (1/J)')
plt.title(r'crossing time vs $\alpha$')
plt.axhline(y=0, color='gray', linestyle='--')
plt.axvline(x=1, color='gray', linestyle='--')
plt.legend()
plt.tight_layout()
plt.show()           