In [None]:
def flip_metropolis(s, J, h, beta, flip_two_spins=False):
    """Flipping algorithm using sequential dynamics and metropolis hashting"""
    s_prime = np.copy(s)
    
    #select spins to be flipped
    if flip_two_spins and np.random.rand() < 0.5:
        flip_indices = np.random.permutation(len(s))[:2]     # selects two spin indices to be flipped
    else:
        flip_indices = np.array([np.random.randint(len(s))]) # selects a single spin index to be flipped
    
    # flip the spins and calculate energy difference
    s_prime[flip_indices] *= -1
    delta_E = energy(s_prime, J, h) - energy(s, J, h)
    
    # accept or reject new state based on metropolis update rule
    if delta_E < 0 or np.random.rand() < np.exp(-beta * delta_E):
        s = s_prime
    
    return s


###CODE TO CHECK CONVERGENCE TO BOLTZMANN
def boltzmann_distribution(J, h):
    n_spins = len(h)
    all_states = np.array([[1 if digit=='0' else -1 for digit in f"{i:0{n_spins}b}"] for i in range(2**n_spins)])
    energies = np.array([energy(state, J, h) for state in all_states])
    probabilities = np.exp(-energies)
    Z = np.sum(probabilities)
    return probabilities / Z

def plot_convergence_to_boltzmann(J, h, trajectory):
    n_spins = len(h)
    boltzmann_probs = boltzmann_distribution(J, h).flatten()

    # count the occurrences of each unique spin state in the trajectory
    state_counts = {}
    for state in trajectory:
        state_tuple = tuple(state)
        if state_tuple in state_counts:
            state_counts[state_tuple] += 1
        else:
            state_counts[state_tuple] = 1

    # calculate the empirical probabilities from the trajectory
    empirical_probs = np.zeros(2**n_spins)
    for idx, state in enumerate(np.array([[1 if digit=='0' else -1 for digit in f"{i:0{n_spins}b}"] for i in range(2**n_spins)])):
        state_tuple = tuple(state)
        if state_tuple in state_counts:
            empirical_probs[idx] = state_counts[state_tuple] / len(trajectory)


    # plot the Boltzmann probabilities and the empirical probabilities
    state_labels = [system_state(state) for state in np.array([[1 if digit=='0' else -1 for digit in f"{i:0{n_spins}b}"] for i in range(2**n_spins)])]
    x = np.arange(len(state_labels))
    width = 0.35

    fig, ax = plt.subplots(figsize=(6, 3))
    rects1 = ax.bar(x - width / 2, boltzmann_probs, width, label='Boltzmann', color='SteelBlue', alpha=0.8)
    rects2 = ax.bar(x + width / 2, empirical_probs, width, label='Empirical', color='Coral', alpha=0.8)

    ax.set_ylabel('Probability')
    ax.set_xlabel('Spin states')
    ax.set_title('Boltzmann vs Empirical Probabilities')
    ax.set_xticks(x)
    ax.set_xticklabels(state_labels)
    ax.legend()

    plt.show()


###OLD TRANSITION MATRIX FUNCTIONS THAT CALCULATE W SEPARATELY. NOW INCORPORATED IN simulate_and_infer_dynamics(). NOTE THAT THESE STILL HAVE INCORRECT NORMALIZATION
#----------------------------
def transition_probabilities(s, w, h, J):
    """Calculates the row in the transition matrix corresponding to spin s"""
    N = len(s)
    transition_probs = np.zeros(2**N)
    
    for state_idx in range(2**N):
        s_prime = [(state_idx >> i) & 1 for i in reversed(range(N))] # binary representation with reversed order
        s_prime = (1 - np.array(s_prime)) * 2 - 1 # conversion to spin representation
        
        flipped_spins = np.where(s != s_prime)[0] # returns array with the indices of the differing spins 
        
        #calculate and store the transition probability associated with s'
        if len(flipped_spins) == 0:   #no spin flips
            transition_probs[state_idx] = 1.0
        elif len(flipped_spins) == 1: # single spin flip
            i = flipped_spins[0]
            transition_probs[state_idx] = g_single(i, s, w, h, J)
        elif len(flipped_spins) == 2: # double spin flip
            i, j = flipped_spins
            transition_probs[state_idx] = g_double(i, j, s, w, h, J)
        else:
            transition_probs[state_idx] = 0.0
            
    transition_probs /= np.sum(transition_probs)
    
    return transition_probs

def model_transition_matrix(w, h, J):
    """Calculates the entire transition matrix"""
    N = len(w)
    states = [np.array([(i >> j) & 1 for j in reversed(range(N))]) for i in range(2**N)] # binary representation with reversed order
    states = [(1 - state) * 2 - 1 for state in states] # conversion to spin representation
    
    transition_matrix = np.zeros((2**N, 2**N))
    
    for i, s in enumerate(states):
        probs = transition_probabilities(s, w, h, J)
        transition_matrix[i, :] = probs
    
    return transition_matrix

# @njit
def no_flip_prob(from_idx, w, h, J, N):
        flip_probs = np.zeros((2**N), dtype=np.float64)  # Initialize with zeros
        s = index_to_spin_state(from_idx, N)

        for i in range(N):
            flip_index = from_idx ^ (1 << i)
            i_flipped = bits_flipped_indices(from_idx, flip_index, N)[0]
            g = g_single(i_flipped, s, w, h, J)  # single flip probabilities
            flip_probs[flip_index] = g

        for i in range(N):
            for j in range(i+1, N):
                flip_index = from_idx ^ (1 << i) ^ (1 << j)
                i_flipped, j_flipped = bits_flipped_indices(from_idx, flip_index, N)
                flip_probs[flip_index] = g_double(i_flipped, j_flipped, s, w, h, J)  # double flip probabilities

        # add the no-flip probability and store the flipping probabilities
        no_flip_prob = 1 - sum(flip_probs)
        return no_flip_prob


def check_density_matrices(eta, rho):
    """
    Checks if the density matrices `eta` and `rho` are positive definite 
    and that their trace is equal to 1. If not, return an error message and 
    the violating density matrix.
    """
    # Check if eta and rho are positive definite
    eta_eigenvalues = np.linalg.eigvals(eta)
    rho_eigenvalues = np.linalg.eigvals(rho)
    
    if not np.all(eta_eigenvalues > 0):
        return "Error: Density matrix eta is not positive definite", eta
    
    if not np.all(rho_eigenvalues > 0):
        return "Error: Density matrix rho is not positive definite", rho
    
    # Check if the trace of eta and rho is equal to 1
    eta_trace = np.real(np.trace(eta))
    rho_trace = np.real(np.trace(rho))
    
    if not np.isclose(eta_trace, 1):
        return f"Error: Trace of density matrix eta is not equal to 1, but {eta_trace}", eta
    
    if not np.isclose(rho_trace, 1):
        return f"Error: Trace of density matrix rho is not equal to 1, but {rho_trace}", rho
    
    # If eta and rho passed all checks, return None
    return None, None

@njit
def find_min_w(w, h, J, N, step_size, no_flip_prob):
    """Finds the smallest offset to add to the w parameter such that diagonal elements of W become at least no_flip_prob"""
    min_scalar = 0
    while True:
        # Update w with the current scalar value
        new_w = w + np.ones(w.shape) * min_scalar
        
        # Calculate the transition matrix
        _, transition_matrix = simulate_and_infer_dynamics(new_w, J, h, steps=1, N=N, fill_missing_entries=True)
        
        # Calculate the flip probabilities for each row in the transition matrix
        flip_probs = np.sum(transition_matrix, axis=1) - np.diag(transition_matrix)

        # Check if the sum of flip probabilities for each row is smaller than (1 - no_flip_prob)
        if np.all(flip_probs < (1 - no_flip_prob)):
            break

        # If the condition is not met, increase the scalar value and try again
        min_scalar += step_size

    return new_w




def negative_log_likelihood(params, trajectory, N):
    """Compute the negative log-likelihood of the observed trajectory given the model parameters."""

    # Unflatten the parameters
    w, h, J = unflatten_parameters(params, N)

    # Calculate the transition matrix with the current values of w, h, and J
    _, transition_matrix = simulate_and_infer_dynamics(w, J, h, steps=1, N=N, fill_missing_entries=True)

    # Compute the negative log-likelihood
    nll = 0
    for i in range(trajectory.shape[0] - 1):
        s = trajectory[i]
        s_next = trajectory[i + 1]
        s_idx = spin_state_to_index(s)
        s_next_idx = spin_state_to_index(s_next)

        # Probability of transitioning from s to s_next
        prob = transition_matrix[s_idx, s_next_idx]

        if prob > 0:
            nll -= np.log(prob)

    return nll

def infer_parameters(trajectory, N, initial_params=None):
    """Infer the parameter matrices w, h, and J by minimizing the likelihood from a time series data of the states the system went through."""

    if initial_params is None:
        # Create initial guess for the parameters
        w_init = np.random.rand(N, N)
        h_init = np.random.rand(N)
        J_init = np.random.rand(N, N)

        initial_params = flatten_parameters(w_init, h_init, J_init)

    # Minimize the negative log-likelihood using the BFGS algorithm
    result = minimize(negative_log_likelihood, initial_params, args=(trajectory,N), method='BFGS')

    # Unflatten the optimized parameters back into matrices
    w_opt, h_opt, J_opt = unflatten_parameters(result.x, N)

    return w_opt, h_opt, J_opt

def boltzmann_distribution(J, h):
    N = len(h)
    all_states = np.array([index_to_spin_state(s_idx, N) for s_idx in range(2**N)])
    energies = np.array([energy(state, J, h) for state in all_states])
    probabilities = np.exp(-energies)
    Z = np.sum(probabilities)
    return probabilities / Z


def energy(s, J, h):
    """
    Compute the energy of the system in a particular spin state.
    """
    s = s.astype(np.float64)                              # convert the spin configuratio to float64 for numba
    pairwise_energy = np.dot(s, np.dot(J, s))
    local_field_energy = h * s                            # simple multiplication for scalar-array operation
    return -pairwise_energy - np.sum(local_field_energy)  # sum over local field energy



def tweak_w_value(W, w, h, J, N, from_idx ,to_idx, w_bottom = -100, w_top = 100, tolerance=1e-6, max_iterations=2000):
    '''Adjusts the value of w until the discrepancy with the target value is within the specified tolerance.'''
    W_clamped = W[from_idx,to_idx]
    w_low = w_bottom
    w_high = w_top
    for _ in range(max_iterations):
        w_mid = (w_low + w_high) / 2.0
        s = index_to_spin_state(from_idx, N)
        flipped_indices = bits_flipped_indices(from_idx, to_idx, N)

        if len(flipped_indices) == 1:  # single flip
            i = flipped_indices[0]
            w[i,i] = w_mid
            W_free = g_single(i, s, w, h, J) 
        elif len(flipped_indices) == 2: # double flip
            i, j = flipped_indices
            w[i,j] = w_mid
            w[j,i] = w_mid
            W_free = g_double(i,j,s,w,h,J)
        else:  #no flip
            W_free = 1 #arbitrary value above tolerance

        diff = W_free - W_clamped
        if abs(diff) < tolerance:
            return w_mid, W_free
        elif diff < 0:
            w_high = w_mid
        else:
            w_low = w_mid
    raise ValueError(f"Failed to find suitable w value within {max_iterations} iterations.")
