In [2]:
import numpy as np
from sklearn.metrics import pairwise_distances
from scipy.optimize import fsolve, newton_krylov, broyden1, leastsq
from scipy.optimize import approx_fprime
from scipy.integrate import solve_ivp



def psi_function(row_1, row_2):
    psi_i = row_1[0]
    psi_j = row_2[0]
    if psi_i + psi_j == 0:
        return 0
    return (psi_i - psi_j) *np.abs(psi_i - psi_j)/(psi_i + psi_j)


class Simulator:
    def __init__(self, p, tstep = 0.1) -> None:
        """ ANY changes here must also go to reset method below"""
        self.secant_used = False
        self.t = 0
        self.tstep = tstep
        self.p = p
        # calculate T
        self.N = self.p['coords'].shape[0]
        self._calculate_T()
        self.u = p['u0']
        self.q = p['q0']
        self.num_psis = p['mu'].shape[1]
        # init x with only the S population
        self.x = np.zeros([self.N,self.num_psis])
        self.x[:,0] = self.p['S0']
        self.dead_tracker = [np.zeros([self.N])]
        self.source_sum = np.zeros(self.N)

        # one time calc
        self.T_expanded = np.tile(self.p['T'][...,np.newaxis], [1,1,self.num_psis])
        self.P_expanded = np.tile(self.p['P'][...,np.newaxis], [1,1,self.num_psis])
        self.M_expanded = np.tile(self.p['M'][np.newaxis,...],[self.N,1,1])

    def reset(self, p = None, tstep = None):
        self.t = 0
        self.tstep = tstep
        self.p = p
        # calculate T
        self.N = self.p['coords'].shape[0]
        self._calculate_T()
        self.u = p['u0']
        self.q = p['q0']
        self.num_psis = p['mu'].shape[1]
        self.x = np.zeros([self.N, self.num_psis])
        # init x with only the S population
        self.x = np.zeros([self.N,self.num_psis])
        self.x[:,0] = self.p['S0']
        self.dead_tracker = [np.zeros([self.N])]

        # one time calc
        self.T_expanded = np.tile(self.p['T'][...,np.newaxis], [1,1,self.num_psis])
        self.P_expanded = np.tile(self.p['P'][...,np.newaxis], [1,1,self.num_psis])
        self.M_expanded = np.tile(self.p['M'][np.newaxis,...],[self.N,1,1])

    def _calculate_T(self):
        # T is an N by N matrix
        coords = self.p['coords'] # N x 2
        deltas = pairwise_distances(coords)
        exp_deltas = np.exp(-1 * deltas)
        self.p['T'] = exp_deltas / ((np.sum(exp_deltas) - self.N)/2)
        np.fill_diagonal(self.p['T'], 0)


    def _calculate_tau_matrix(self):
        # tau is N x 4 (SIRD)
        # self.x.shape is N x 4
        psi_matrix = np.array([pairwise_distances(self.x[:,i:i+1], metric = psi_function) for i in range(self.num_psis)])
        psi_matrix = np.transpose(psi_matrix, (1, 2, 0))
        for i in range(self.num_psis):
          psi_matrix[(*np.tril_indices(self.N, -1),i)] *= -1
        result = psi_matrix * self.T_expanded * self.P_expanded * self.M_expanded
        # result = np.floor(psi_matrix * self.T_expanded * self.P_expanded * self.M_expanded)
        return result

    def u_sources(self):
        if self.p['mode'] == 'exp':
            self.u = np.apply_along_axis(lambda x :np.floor(x[0]*np.exp(-1*x[1]*self.t)) , 1, np.vstack([self.p['u0'], self.p['gamma']]).T)
        elif self.p['mode'] == 'lin':
            self.u = np.apply_along_axis(lambda x : np.max([0,np.floor(x[0] - x[1]*self.t)]) , 1, np.vstack([self.p['u0'], self.p['zeta']]).T)

    def q_intra_sources(self):
        if self.p['mode'] == 'exp':
            self.q = np.apply_along_axis(lambda x :np.floor(x[0]*np.exp(-1*x[1]*self.t)) , 1, np.vstack([self.p['q0'], self.p['gamma']]).T)
        elif self.p['mode'] == 'lin':
            self.q = np.apply_along_axis(lambda x : np.max([0,np.floor(x[0] - x[1]*self.t)]) , 1, np.vstack([self.p['q0'], self.p['zeta']]).T)

    def eval_f(self, custom_x = None, active_step = False, dt = None):
        # extract params
        if custom_x is None:
          x = self.x
        else:
          x = custom_x

        if active_step:
            assert dt is not None

        if x.shape[0] == x.size:
            x = np.reshape(x, (self.N, self.num_psis))
        p = self.p
        u = self.u
        q = self.q
        N = p['N']
        mu = p['mu']
        num_psis = self.num_psis
        v = p['v']
        alpha = p['alpha']
        beta = p['beta']
        kappa = p['kappa']
        # init f matrix
        f = np.zeros([N,num_psis])
        # get tau
        tau = self._calculate_tau_matrix()
        # sum tau over AXIS = 1
        tau = np.sum(tau, axis=1) # N x num_psi(4)

        q = np.apply_along_axis(lambda x : x[2] if x[0] - x[1] > x[2] else x[0], 1, np.vstack([q,u, x[:,0]]).T)

        # number of people per node
        sigma = np.sum(x, axis = 1)# + self.dead_tracker[-1]
        # ensure no sigma is zero
        sigma[np.where(sigma == 0)] = 1e-5
        mu[np.where(mu == 0)] = 1e-5
        # evaluate f

        vaccination_rate = (v * (x[:,0]**2)/ (sigma))
        max_vaccination_rate = np.inf  # Set this to a reasonable value
        vaccination_rate = np.minimum(vaccination_rate, max_vaccination_rate)

        # POSSIBLE FIX: do Guarding of u - q AFTER calculating the other terms
        f[:,0] = (-alpha * x[:,1] * x[:,0] / sigma) - mu[:,0] * x[:,0] - q - tau[:,0] - vaccination_rate + u
        f[:,1] = (alpha * x[:,1] * x[:,0] / sigma) - (mu[:,1] + beta + kappa) * x[:,1] + q - tau[:,1]
        f[:,2] = beta * x[:,1] - mu[:,2] * x[:,2] +vaccination_rate - tau[:,2]

        if active_step:
            self.dead_tracker.append(self.dead_tracker[-1] + dt * (kappa * x[:,1] + mu[:,0] * x[:,0] + mu[:,1] * x[:,1]  + mu[:,2] * x[:,2]))
            self.source_sum += dt * u
            self.u_sources()
            self.q_intra_sources()

        return f
    
    def solve_steady_state(self, guess = None):
        if guess is None:
            guess = np.reshape(self.x, (self.N * self.num_psis))
        else:
            guess = np.reshape(guess, (self.N * self.num_psis))
        root = fsolve(self.eval_f, guess, True, xtol=1)
        return root
    
    def Jac_test(self): 
        def _eval_f_wrapper(x_flat):
            x_reshaped = x_flat.reshape(self.x.shape)
            f_val = self.eval_f(x_reshaped)
            return f_val.flatten()

        x = self.x.flatten()
        epsilon = np.sqrt(np.finfo(float).eps)
        jacobian = approx_fprime(x, _eval_f_wrapper, epsilon)
        return jacobian.reshape(self.x.shape[0] * self.x.shape[1], -1)


    def FiniteDifference_JacobianMatrix(self, x_custom=None):
        if x_custom is None:
            x_col = self.x.flatten().reshape(-1, 1)
        else:
            x_col = x_custom.flatten().reshape(-1, 1)

        total_states = len(x_col)
        Jf = np.zeros((total_states, total_states))

        # Adaptive epsilon based on the norm of x_col
        norm_x = np.linalg.norm(x_col, np.inf)
        base_epsilon = 2 * np.sqrt(np.finfo(float).eps) * np.sqrt(1 + norm_x)

        for i in range(total_states):
            epsilon_i = base_epsilon * (1 + abs(x_col[i, 0]))
            e_i = np.zeros((total_states, 1))
            e_i[i, 0] = epsilon_i

            x_perturbed = (x_col + e_i).reshape(self.x.shape)
            delta_f = (self.eval_f(x_perturbed) - self.eval_f(x_perturbed - e_i.reshape(self.x.shape))).flatten()

            Jf[:, i] = delta_f / epsilon_i

        return Jf

    def forward_euler(self, tspan, dt, return_tau = False):
        """
        Uses Forward Euler to simulate states model dx/dt=f(x,p,u)
        starting from state vector x_start at time t_start
        until time t_stop, with time intervals timestep.

        Parameters:
        - f: function that evaluates f(x,p,u)
        - x: initial state vector
        - tspan: tuple of (start, end) time
        - dt: time step for the integration

        Returns:
        - t: array of time points
        - y: array of state vectors at each time point
        """

        t_start = tspan[0]
        t_stop = tspan[1]

        num_steps = len(np.arange(t_start, t_stop, dt))
        t = np.zeros(num_steps + 1)

        # Initialize the state trajectory array
        y = np.zeros((num_steps + 1,) + self.x.shape)
        y[0] = self.x

        self.t = t_start
        t[0] = self.t

        taus = []

        for n in range(num_steps):
            #if np.any(self.x <0):
              #break
            dt = min(dt, (t_stop - t[n]))
            f_val = self.eval_f(dt=dt, active_step=True)
            self.x += dt * f_val
            y[n+1] = self.x
            self.t += dt
            t[n+1] = self.t
            # print(f"Time step {t[n+1]}: Sum of x = {np.sum(self.x)}")
            #print(f"Time step {t[n+1]}: Sum of x = {np.sum(self.x) + np.sum(self.dead_tracker[-1])}")
            # print(f"Time step {t[n+1]}: New Source Sum (per node/total) = {self.source_sum}, {np.sum(self.source_sum)}")
            # if self.t % 10 == 0:
            #     print("Time: ", self.t)
            if return_tau:
              tau = self._calculate_tau_matrix()
              taus.append(tau)
        if return_tau:
          return t, y, taus

        return t, y

    def trapezoidal_equation(self, x_new_flat, x_old_flat, dt):
        """
        Calculate the trapezoidal rule result.

        Parameters:
        - x_new_flat: flattened state at the next time step
        - x_old_flat: flattened state at the current time step
        - dt: time step size

        Returns:
        - result: result of the trapezoidal rule
        """
        # Reshape the flattened states back to their original multi-dimensional structure
        x_new = x_new_flat.reshape(self.x.shape)
        x_old = x_old_flat.reshape(self.x.shape)

        # Calculate f for current and next time steps
        f_current = self.eval_f(custom_x=x_old).flatten()  # Flatten the result
        f_next = self.eval_f(custom_x=x_new).flatten()  # Flatten the result

        # Compute the trapezoidal rule result
        result = x_new_flat - x_old_flat - (dt / 2) * (f_current + f_next)
        return result

    from scipy.optimize import approx_fprime

########################## GCR CODE ####################################

    def trapezoidal_method_GCRHadpt(self, tspan, initial_dt, tolerance, max_dt=1, max_iters_newton=50):
        t_start, t_stop = tspan
        max_estimated_steps = int((t_stop - t_start) / initial_dt) + 1
        t = np.zeros(max_estimated_steps)
        y = np.zeros((max_estimated_steps,) + self.x.shape)
        y[0] = self.x.copy()
        t[0] = t_start

        self.t = t_start
        dt = initial_dt
        n = 0

        while self.t < t_stop and n < max_estimated_steps - 1:
            x_current = self.x.copy()
            f_current = self.eval_f(x_current, active_step=True, dt=dt)

            def trapezoidal_implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                f_next = self.eval_f(x_next)
                return (x_next - x_current - (dt / 2) * (f_current + f_next)).flatten()

            x_next_flat_guess = x_current.flatten()
            x_next_flat, converged = self.HybridSolver_GCR(
                x_next_flat_guess,
                trapezoidal_implicit_function,
                max_iters_newton = max_iters_newton
            )

            if not converged:
                print(f"Newton-GCR method failed to converge at time step {n + 1}. Decreasing dt.")
                dt = max(initial_dt, dt / 1.5)
                continue

            self.x = x_next_flat.reshape(self.x.shape)
            y[n + 1] = self.x.copy()

            # Calculate time remaining and adjust dt
            time_remaining = t_stop - self.t
            dt = min(dt, time_remaining)  # Ensure dt does not exceed time remaining

            self.t += dt
            t[n + 1] = self.t
            n += 1

            
            change = np.linalg.norm(self.x - x_current, np.inf)
            gradual_reduction_threshold = 0 * (t_stop - t_start)  

            if self.t >= t_stop - gradual_reduction_threshold:
                dt = max(initial_dt, dt / 1.1)  
            elif change < tolerance / 2 and dt < max_dt:
                dt = min(dt * 1.1, max_dt, time_remaining)  
            #elif change > tolerance:
                #dt = max(initial_dt, dt / 1.5)  # Decrease dt if change is large

            print(f"Time step {n + 1}: dt = {dt}, change = {change}")

        t = t[:n + 1]
        y = y[:n + 1]

        return t, y
    

    def trapezoidal_method_GCRadpt(self, tspan, initial_dt, tolerance, max_dt=1, max_iters_newton=50):
        t_start, t_stop = tspan
        max_estimated_steps = int((t_stop - t_start) / initial_dt) + 1
        t = np.zeros(max_estimated_steps)
        y = np.zeros((max_estimated_steps,) + self.x.shape)
        y[0] = self.x.copy()
        t[0] = t_start

        self.t = t_start
        dt = initial_dt  # Start with the minimum dt
        n = 0

        while self.t < t_stop and n < max_estimated_steps - 1:
            x_current = self.x.copy()
            f_current = self.eval_f(x_current, active_step=True, dt=dt)

            def trapezoidal_implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                f_next = self.eval_f(x_next)
                return (x_next - x_current - (dt / 2) * (f_current + f_next)).flatten()

            x_next_flat_guess = x_current.flatten()
            x_next_flat, converged = self.NewtonNd_GCR(
                x_next_flat_guess,
                trapezoidal_implicit_function,
                tol_f=tolerance,
                max_iter=max_iters_newton
            )

            if not converged:
                print(f"Newton-GCR method failed to converge at time step {n + 1}. Decreasing dt.")
                dt = max(initial_dt, dt / 1.1)
                continue

            self.x = x_next_flat.reshape(self.x.shape)
            y[n + 1] = self.x.copy()

            time_remaining = t_stop - self.t
            dt = min(dt, time_remaining)  

            self.t += dt
            t[n + 1] = self.t
            n += 1

            # Adjust dt based on solution change and phase of simulation
            change = np.linalg.norm(self.x - x_current, np.inf)
            gradual_reduction_threshold = 0.005 * (t_stop - t_start)  

            if self.t >= t_stop - gradual_reduction_threshold:
                dt = max(initial_dt, dt / 1.25)  
            elif change < tolerance / 2 and dt < max_dt:
                dt = min(dt * 1.25, max_dt, time_remaining)  
            #elif change > tolerance:
                #dt = max(initial_dt, dt / 1.5)  # Decrease dt if change is large

            print(f"Time step {n + 1}: dt = {dt}, change = {change}")

        # Trim the arrays to the actual number of steps taken
        t = t[:n + 1]
        y = y[:n + 1]

        return t, y

    def trapezoidal_method_GCR(self, tspan, dt):
            t_start, t_stop = tspan
            num_steps = int((t_stop - t_start) / dt)
            t = np.zeros(num_steps + 1)
            y = np.zeros((num_steps + 1,) + self.x.shape)
            y[0] = self.x

            self.t = t_start
            t[0] = self.t
            for n in range(num_steps):
                x_current = self.x
                f_current = self.eval_f(x_current, active_step=True, dt=dt)
                def trapezoidal_implicit_function(x_next_flat, f_current):
                    x_next = x_next_flat.reshape(self.x.shape)
                    f_next = self.eval_f(x_next)
                    return (x_next - x_current - (dt / 2) * (f_current + f_next)).flatten()

                x_next_flat_guess = x_current.flatten()
                # Use newtonNd_GCR for solving the implicit equation
                x_next_flat, converged= self.NewtonNd_GCR(
                    x_next_flat_guess,
                    trapezoidal_implicit_function,
                    tol_f=1e-9,  
                    max_iter=20,  
                )

                if not converged:
                    print(f"Newton-GCR method failed to converge at time step {n + 1}")

                self.x = x_next_flat.reshape(self.x.shape)
                y[n + 1] = self.x

                self.t += dt
                t[n + 1] = self.t

            return t, y

    def tgcr_matrixfree(self, xk, b, eps, tolrGCR, MaxItersGCR):
            x = np.zeros_like(b)
            r = b.copy()
            r_norms = [np.linalg.norm(r, 2)]
            P_matrix = np.zeros((len(b), MaxItersGCR))
            gcr_converged = False  

            for k in range(MaxItersGCR):
                P_matrix[:, k] = r / (np.linalg.norm(r) + 1e-15)  
                Ap = (self.eval_f(xk + eps * P_matrix[:, k]) - self.eval_f(xk)) / eps 
                Ap = Ap.flatten()

                # Orthogonalize Ap against previous directions in P_matrix
                for j in range(k):
                    Ap -= np.dot(Ap, P_matrix[:, j]) * P_matrix[:, j]

                Ap_norm = np.linalg.norm(Ap)
                if Ap_norm < 1e-15:  
                    break

                Ap /= Ap_norm
                P_matrix[:, k] = Ap

                
                alpha = np.dot(r, Ap)
                x += alpha * Ap
                r -= alpha * Ap
                new_r_norm = np.linalg.norm(r)
                r_norms.append(new_r_norm)

                
                if new_r_norm <= tolrGCR * r_norms[0]:
                    gcr_converged = True
                    break

            return x, gcr_converged, r_norms


    def NewtonNd_GCR(self, x0_flat, eval_f, tol_f=1e-9, max_iter=1000, custom_J = None, return_res = False):
            k = 0
            x_flat = x0_flat
            f = eval_f(x_flat)
            err_f = np.linalg.norm(f, np.inf)
            tolrGCR = 1e-8
            MaxItersGCR = 10000
            eps = 1e-4; 
            eps_a = np.sqrt(np.finfo(float).eps * eps)

            while k < max_iter:

                delta_x, _, _ = self.tgcr_matrixfree(x_flat, -f.flatten(), eps_a, tolrGCR, MaxItersGCR)
                
                step_size = 1 # Adjust based on your problem scale

                x_new_flat = x_flat + step_size * delta_x
                x_flat = x_new_flat
                k += 1

                f = eval_f(x_flat)
                err_f = np.linalg.norm(f, np.inf)

                #print(f"Iteration {k}: Residual norm = {err_f}")

                if err_f <= tol_f:
                    #print(f"Newton's method converged in {k} iterations with residual norm: {err_f}")
                    if return_res:
                        return x_flat, True, err_f
                    else:
                        return x_flat, True

            print("Newton did NOT converge! Maximum Number of Iterations reached")
            if return_res:
                return x_flat, True, err_f
            else:
                return x_flat, True
        
    def HybridSolver_GCR(self, x0_flat, func, max_iters_newton):
        if not self.secant_used:
            # Use the Secant method only if it hasn't been used before
            x1_flat = x0_flat + np.random.randn(*x0_flat.shape) * 1e-4
            x_initial_guess, secant_converged = self.secant_method_solver(x0_flat, x1_flat, func)
            self.secant_used = True
        else:
            # If Secant has been used before, just use the current guess
            x_initial_guess = x0_flat

        # Proceed with Newton-GCR method
        x_newton_GCR, newton_GCR_converged, final_residual_norm = self.NewtonNd_GCR(
            x_initial_guess,
            func,
            tol_f=1e-9,
            max_iter=max_iters_newton,
            return_res=True
        )

        return x_newton_GCR, newton_GCR_converged

################### NEWTON NORMAL TRAP CODE ####################################

    def trapezoidal_method(self, tspan, dt):
        t_start, t_stop = tspan
        num_steps = int((t_stop - t_start) / dt)
        t = np.zeros(num_steps + 1)
        y = np.zeros((num_steps + 1,) + self.x.shape)
        y[0] = self.x

        self.t = t_start
        t[0] = self.t

        for n in range(num_steps):
            x_current = self.x
            f_current = self.eval_f(x_current, active_step=True, dt=dt)

            def trapezoidal_implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                f_next = self.eval_f(x_next)
                return (x_next - x_current - (dt / 2) * (f_current + f_next)).flatten()
            
            def J_trapezoidal_implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                return np.eye(x_next.size) - dt/2 * self.FiniteDifference_JacobianMatrix(x_next)        

            x_next_flat_guess = x_current.flatten()
            x_next_flat, converged = self.NewtonNd(x_next_flat_guess, trapezoidal_implicit_function,
                                                   custom_J = J_trapezoidal_implicit_function)

            if not converged:
                print(f"Newton's method failed to converge at time step {n + 1}")

            self.x = x_next_flat.reshape(self.x.shape)
            y[n + 1] = self.x

            self.t += dt
            t[n + 1] = self.t

        return t, y

    def NewtonNd(self, x0_flat, eval_f, tol_f=1e-4, max_iter=1000, custom_J = None, return_res = False):
        k = 0
        x_flat = x0_flat
        f = eval_f(x_flat)
        err_f = np.linalg.norm(f, np.inf)

        while k < max_iter:
            if custom_J is None:
                Jf = self.FiniteDifference_JacobianMatrix(x_flat)
            else:
                Jf = custom_J(x_flat)

            delta_x = np.linalg.solve(Jf, -f.flatten())

            # Fixed step size
            step_size = 1 # Adjust based on your problem scale

            x_new_flat = x_flat + step_size * delta_x
            x_flat = x_new_flat
            k += 1

            f = eval_f(x_flat)
            err_f = np.linalg.norm(f, np.inf)

            #print(f"Iteration {k}: Residual norm = {err_f}")

            if err_f <= tol_f:
                #print(f"Newton's method converged in {k} iterations with residual norm: {err_f}")
                if return_res:
                    return x_flat, True, err_f
                else:
                    return x_flat, True

        print("Newton did NOT converge! Maximum Number of Iterations reached")
        if return_res:
            return x_flat, True, err_f
        else:
            return x_flat, True

    def secant_method_solver(self, x0_flat, x1_flat, func, epsilon=1e-2, max_iters=1000):
        for _ in range(max_iters):
            f_x0 = func(x0_flat)
            f_x1 = func(x1_flat)

            if np.linalg.norm(f_x1 - f_x0) < 1e-8:
                print("Denominator too small. Secant method failed.")
                return x1_flat, False

            x2_flat = x1_flat - f_x1 * (x1_flat - x0_flat) / (f_x1 - f_x0)
            if np.linalg.norm(x2_flat - x1_flat) < epsilon:
                return x2_flat, True  # Converged

            x0_flat, x1_flat = x1_flat, x2_flat

        print("Secant method did not converge.")
        return x1_flat, False

    def HybridSolver_s(self, x0_flat, func, J_func, convergence_threshold, max_iters_newton):
        # Always use the Secant method to generate an initial guess
        x1_flat = x0_flat + np.random.randn(*x0_flat.shape) * 1e-2
        x_initial_guess, secant_converged = self.secant_method_solver(x0_flat, x1_flat, func)

        # Proceed with Newton's method using the initial guess from the Secant method
        x_newton, newton_converged, final_residual_norm = self.NewtonNd(x_initial_guess, func, max_iters_newton, custom_J=J_func, return_res=True)

        return x_newton, newton_converged


    def HybridSolver(self, x0_flat, func, J_func, convergence_threshold, max_iters_newton):
        if not self.secant_used:
            # Use the Secant method only if it hasn't been used before
            x1_flat = x0_flat + np.random.randn(*x0_flat.shape) * 1e-4
            x_initial_guess, secant_converged = self.secant_method_solver(x0_flat, x1_flat, func)
            self.secant_used = True  
            #print("Secant method used for initial guess.")
        else:
            # If Secant has been used before, just use the current guess
            x_initial_guess = x0_flat
            #print("Using current guess for Newton's method.")

        # Proceed with Newton's method
        x_newton, newton_converged, final_residual_norm = self.NewtonNd(x_initial_guess, func, max_iters_newton, custom_J=J_func, return_res = True)

        #if newton_converged:
            #print(f"Newton's method converged with final residual norm: {final_residual_norm}")
        #else:
            #print("Newton's method did not converge.")

        return x_newton, newton_converged
    
    def trapezoidal_method_sec(self, tspan, dt):
        t_start, t_stop = tspan
        num_steps = int((t_stop - t_start) / dt)
        t = np.zeros(num_steps + 1)
        y = np.zeros((num_steps + 1,) + self.x.shape)
        y[0] = self.x

        self.t = t_start
        t[0] = self.t

        for n in range(num_steps):
            x_current = self.x
            f_current = self.eval_f(x_current, active_step=True, dt=dt)

            def trapezoidal_implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                f_next = self.eval_f(x_next)
                return (x_next - x_current - (dt / 2) * (f_current + f_next)).flatten()
            

            x_next_flat_guess = x_current.flatten()
            x_next_flat_guess_prev = x_next_flat_guess + 1e-4  # Slightly perturbed guess

            x_next_flat, converged = self.secant_method_solver(
                x_next_flat_guess_prev, x_next_flat_guess, trapezoidal_implicit_function)

            if not converged:
                print(f"Secant method failed to converge at time step {n + 1}")

            self.x = x_next_flat.reshape(self.x.shape)
            y[n + 1] = self.x

            self.t += dt
            t[n + 1] = self.t

        return t, y
    
    def trapezoidal_method_hybrid(self, tspan, dt, convergence_threshold=1e-1, max_iters_newton=50):
        t_start, t_stop = tspan
        num_steps = int((t_stop - t_start) / dt)
        t = np.zeros(num_steps + 1)
        y = np.zeros((num_steps + 1,) + self.x.shape)
        y[0] = self.x

        self.t = t_start
        t[0] = self.t

        for n in range(num_steps):
            x_current = self.x
            f_current = self.eval_f(x_current, active_step=True, dt=dt)

            def trapezoidal_implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                f_next = self.eval_f(x_next)
                return (x_next - x_current - (dt / 2) * (f_current + f_next)).flatten()

            def J_trapezoidal_implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                return np.eye(x_next.size) - dt/2 * self.FiniteDifference_JacobianMatrix(x_next)   

            x_next_flat_guess = x_current.flatten()

            # Using Hybrid Solver
            x_next_flat, converged = self.HybridSolver(x_next_flat_guess, trapezoidal_implicit_function, J_trapezoidal_implicit_function, convergence_threshold, max_iters_newton)

            if not converged:
                print(f"Solver failed to converge at time step {n + 1}")

            self.x = x_next_flat.reshape(self.x.shape)
            y[n + 1] = self.x

            self.t += dt
            t[n + 1] = self.t

        return t, y
    
    def trapezoidal_method_adapt(self, tspan, initial_dt, tolerance, convergence_threshold=1e-1, max_dt=1, max_iters_newton=50):
        t_start, t_stop = tspan
        max_estimated_steps = int((t_stop - t_start) / initial_dt) + 1  # Estimate based on initial_dt for maximum steps
        t = np.zeros(max_estimated_steps)
        y = np.zeros((max_estimated_steps,) + self.x.shape)
        y[0] = self.x.copy()
        t[0] = t_start

        self.t = t_start
        dt = initial_dt  # Start with initial_dt
        n = 0

        while self.t < t_stop and n < max_estimated_steps - 1:
            x_current = self.x.copy()
            f_current = self.eval_f(x_current, active_step=True, dt=dt)
            
            def trapezoidal_implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                f_next = self.eval_f(x_next)
                return (x_next - x_current - (dt / 2) * (f_current + f_next)).flatten()

            def J_trapezoidal_implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                return np.eye(x_next.size) - dt/2 * self.FiniteDifference_JacobianMatrix(x_next)

            x_next_flat_guess = x_current.flatten()
            x_next_flat, converged = self.HybridSolver(x_next_flat_guess, trapezoidal_implicit_function, J_trapezoidal_implicit_function, convergence_threshold, max_iters_newton)

            if not converged:
                print(f"Newton-GCR method failed to converge at time step {n + 1}. Decreasing dt.")
                dt = max(initial_dt, dt / 1.1)
                continue

            self.x = x_next_flat.reshape(self.x.shape)
            y[n + 1] = self.x.copy()

            # Calculate time remaining and adjust dt
            time_remaining = t_stop - self.t
            dt = min(dt, time_remaining)  # Ensure dt does not exceed time remaining

            self.t += dt
            t[n + 1] = self.t
            n += 1

            # Adjust dt based on solution change and phase of simulation
            change = np.linalg.norm(self.x - x_current, np.inf)
            gradual_reduction_threshold = 0.025 * (t_stop - t_start)  

            if self.t >= t_stop - gradual_reduction_threshold:
                dt = max(initial_dt, dt / 1.1)  
            elif change < tolerance / 2 and dt < max_dt:
                dt = min(dt * 1.1, max_dt, time_remaining)  
            #elif change > tolerance:
                #dt = max(initial_dt, dt / 1.5)  # Decrease dt if change is large

            print(f"Time step {n + 1}: dt = {dt}, change = {change}")

        # Trim the arrays to the actual number of steps taken
        t = t[:n + 1]
        y = y[:n + 1]

        return t, y

    # Wrapped functions
    def wrapped_eval_f(self, x_flat):
        x_reshaped = x_flat.reshape((self.N, self.num_psis))
        f_val = self.eval_f(x_reshaped)
        return f_val.flatten()

    def wrapped_J(self, x_flat):
        x_reshaped = x_flat.reshape((self.N, self.num_psis))
        J = self.FiniteDifference_JacobianMatrix(x_reshaped)
        return J

    def backward_euler(self, tspan, dt):
        t_start, t_stop = tspan
        num_steps = int((t_stop - t_start) / dt)
        t = np.zeros(num_steps + 1)
        y = np.zeros((num_steps + 1,) + self.x.shape)
        y[0] = self.x

        self.t = t_start
        t[0] = self.t

        epsilon_f = 1e-1  # Tolerance for function convergence
        epsilon_deltax = 1e-1  # Tolerance for change in x
        epsilon_xrel = np.inf  # Relative tolerance for x
        max_iters = 100  # Maximum iterations for Newton's method

        for n in range(num_steps):
            # run active step for dead tracker and sources
            _ = self.eval_f(active_step=True, dt=dt)
            x_next_flat_guess = self.x.flatten()

            def implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                return (x_next - self.x - dt * self.eval_f(x_next)).flatten()
            
            def J_backward_implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                return np.eye(x_next.size) - dt* self.FiniteDifference_JacobianMatrix(x_next) 

            # Extract only the first element of the tuple (the state x) from the solver
            x_next_flat, converged = self.NewtonNd(x_next_flat_guess, implicit_function,
                                                   custom_J = J_backward_implicit_function)

            self.x = x_next_flat.reshape(self.x.shape)
            print(f"Time step {n + 1}: Sum of x = {np.sum(self.x)}")
            y[n + 1] = self.x
            self.t += dt
            t[n + 1] = self.t
            if self.t % 10 == 0:
                print("Time: ", self.t)

        return t, y
    

    def backward_euler_f(self, tspan, dt):
        t_start, t_stop = tspan
        num_steps = int((t_stop - t_start) / dt)
        t = np.zeros(num_steps + 1)
        y = np.zeros((num_steps + 1,) + self.x.shape)
        y[0] = self.x

        self.t = t_start
        t[0] = self.t

        for n in range(num_steps):
            # run active step for dead tracker and sources
            _ = self.eval_f(active_step=True, dt=dt)
            x_next_flat_guess = self.x.flatten()

            def implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                return (x_next - self.x - dt * self.eval_f(x_next)).flatten()

            # Solve using fsolve
            x_next_flat = fsolve(implicit_function, x_next_flat_guess)

            self.x = x_next_flat.reshape(self.x.shape)
            print(f"Time step {n + 1}: Sum of x = {np.sum(self.x)}")
            y[n + 1] = self.x
            self.t += dt
            t[n + 1] = self.t
            if self.t % 10 == 0:
                print("Time: ", self.t)

        return t, y

    def trapezoidal_method_f(self, tspan, dt):
        t_start, t_stop = tspan
        num_steps = int((t_stop - t_start) / dt)
        t = np.zeros(num_steps + 1)
        y = np.zeros((num_steps + 1,) + self.x.shape)
        y[0] = self.x

        self.t = t_start
        t[0] = self.t

        for n in range(num_steps):
            x_current = self.x
            f_current = self.eval_f(x_current, active_step=True, dt=dt)

            def trapezoidal_implicit_function(x_next_flat):
                x_next = x_next_flat.reshape(self.x.shape)
                f_next = self.eval_f(x_next)
                res = x_next - x_current - (dt / 2) * (f_current + f_next)
                return res.flatten()

            x_next_flat_guess = x_current.flatten()
            x_next_flat = fsolve(trapezoidal_implicit_function, x_next_flat_guess)
            print(f"Time step {n + 1}: before reshaping  = {x_next_flat}")
            self.x = x_next_flat.reshape(self.x.shape)
            y[n + 1] = self.x

            self.t += dt
            t[n + 1] = self.t
            if self.t % 10 == 0:
                print("Time: ", self.t)

        return t, y

In [4]:
import dash
from dash import dcc, html, callback_context
from dash.exceptions import PreventUpdate
from dash.dependencies import Input, Output, State
import plotly.graph_objs as go
import numpy as np
import dash_bootstrap_components as dbc
import base64
import io
import plotly.express as px
import pandas as pd
import dash_ag_grid as dag
from dash import dash_table
import json
from dash.exceptions import PreventUpdate
import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import openai
from openai import OpenAI
import os

with open('openai_key.json', 'r') as file:
    data = json.load(file)
api_key = data['API_KEY']
print(api_key)
client = OpenAI(api_key=api_key)

if not nltk.download('stopwords'):
    nltk.download('stopwords')

if not nltk.download('punkt'):
    nltk.download('punkt')


np.random.seed(42)


css_style = """
body {
    font-family: Arial, sans-serif !important;
}

/* DataTable Styling for Dark Theme */
.dash-spreadsheet, .dash-spreadsheet-container, .dash-header, .dash-cell div {
    color: #FFFFFF; /* Light text color for better readability */
}

.dash-spreadsheet-container .Select-menu-outer {
    background-color: #2C3E50; /* Dark background for dropdowns */
}

.dash-spreadsheet tr th, .dash-spreadsheet tr td {
    background-color: #34495E; /* Dark background color for table cells */
    border: 1px solid #2C3E50; /* Slightly darker border color */
}

.dash-spreadsheet tr th {
    background-color: #2C3E50; /* Slightly darker background color for headers */
}

/* Adjusting filter input box for dark theme */
.dash-filter input {
    background-color: #2C3E50;
    color: #FFF;
}

/* Adjusting pagination for dark theme */
.dash-pagination-container .pagination {
    background-color: #34495E;
    color: #FFF;
}

.dash-pagination-container .pagination .page-item.active .page-link {
    background-color: #18BC9C;
    border-color: #18BC9C;
}

.dash-pagination-container .pagination .page-item.disabled .page-link {
    color: #777;
}

"""

app = dash.Dash(__name__, external_stylesheets=[dbc.themes.CYBORG, css_style], suppress_callback_exceptions=True)

initial_figure = go.Figure(
    go.Scattermapbox(
        lat=['37.04358'],  
        lon=['-97.23412'],  
        mode='markers'
    )
)

initial_figure.update_layout(
    mapbox_style="light",  
    mapbox=dict(
        center=dict(lat=37.04358, lon=-97.23412), 
        zoom=2  
    ),
    margin={'l': 0, 'r': 0, 't': 0, 'b': 0}
)

is_simulation_aborted = False

@app.callback(
    Output('abort-flag-store', 'data'),
    [Input('abort-button', 'n_clicks')],
    prevent_initial_call=True
)
def abort_simulation(n_clicks):
    return {'is_aborted': True}

@app.callback(
    [
        Output('reset-upload-store', 'data'),
        Output('input-location-name', 'value'),
        Output('input-x-coordinate', 'value'),
        Output('input-y-coordinate', 'value'),
        Output('input-population', 'value'),
        Output('input-infection-source', 'value'),
        Output('input-vaccination-rate', 'value'),
        Output('input-mobility-toggle', 'value'),
    ],
    Input('clear-locations-btn', 'n_clicks'),
    State('reset-upload-store', 'data'),
    prevent_initial_call=True
)
def trigger_reset(n_clicks, reset_counter):
    return reset_counter + 1, '', None, None, None, 50, 0.5, False

@app.callback(
    Output('upload-data-container', 'children'),  
    Input('reset-upload-store', 'data')
)
def update_upload_component(reset_counter):
    return dcc.Upload(
        id='upload-data',
        children=html.Button('Upload Data', id='upload-button', className='btn btn-info', n_clicks=0),
        multiple=True
    )

@app.callback(
    [Output('locations-table', 'data'),
     Output('status-message', 'children')],
    [Input('location-data-store', 'data')]
)
def update_locations_and_message(stored_data):
    if not stored_data:
        return [], "No locations added yet."
    
    message = "Locations updated successfully."
    return stored_data, message
@app.callback(
    Output('simulation-trigger', 'data'),
    Input('locations-table', 'data'),
    State('simulation-trigger', 'data')
)
def update_simulation_trigger(table_data, trigger_data):
    new_version = trigger_data['version'] + 1
    return {'version': new_version}

location_data = []

def create_initial_map_layout():
    fig = go.Figure()

    fig.update_layout(
        mapbox_style="carto-darkmatter",
        mapbox=dict(center=dict(lat=37.04358, lon=-97.23412), zoom=3),
        margin={"l": 0, "r": 0, "t": 0, "b": 0},
        legend=dict(x=0.01, y=0.99, bgcolor="rgba(0,0,0,0)", bordercolor="rgba(0,0,0,0)"),
        paper_bgcolor="lightgray",
        plot_bgcolor="lightgray",
    )
    return fig

def standardize_data_keys(data):
    """
    Standardize the keys of dictionaries in the list `data` to match the expected
    keys by the DataTable, handling different naming conventions and case sensitivity.
    """
    standardized_data = []
    key_mapping = {
        'name': ['Name', 'location', 'Location', 'place', 'Place', 'locations', 'Locations', 'places', 'Places'],
        'x': ['X Coordinate', 'x coordinate', 'X', 'x'],
        'y': ['Y Coordinate', 'y coordinate', 'Y', 'y'],
        'population': ['Population', 'population'],
        'infection_source': ['Infection Source', 'infection source', 'InfectionSource'],
        'vaccination_rate': ['Vaccination Rate', 'vaccination rate', 'VaccinationRate'],
        'mobility': ['Mobility', 'mobility']
    }
    
    inverted_mapping = {v.lower(): k for k, vals in key_mapping.items() for v in vals}
    
    for row in data:
        standardized_row = {}
        for key, value in row.items():
            standardized_key = inverted_mapping.get(key.lower(), key)  # Default to original key if no mapping found
            standardized_row[standardized_key] = value
        standardized_data.append(standardized_row)
    
    return standardized_data

def add_detailed_descriptions(df):
    previous_row = None
    previous_node = None
    
    def get_detailed_description(current_row, previous_row):
        if previous_row is None or current_row['Node'] != previous_row['Node']:
            return "Initial data point for node."
        
        desc = []
        
        if current_row['Infected'] > previous_row['Infected']:
            if current_row['Infected'] > previous_row['Recovered']:
                desc.append("Infection rising sharply.")
            else:
                desc.append("Infection increasing but recovery is higher.")
        elif current_row['Infected'] < previous_row['Infected']:
            desc.append("Infection declining.")
        
        if current_row['Recovered'] > previous_row['Recovered']:
            desc.append("Recovery numbers improving.")
        
        if current_row['Susceptible'] < previous_row['Susceptible']:
            desc.append("Susceptible population decreasing.")
        
        if not desc: 
            return "Stable condition with no significant changes from previous timestep."
        
        return ' '.join(desc)
    
    descriptions = []
    for index, row in df.iterrows():
        description = get_detailed_description(row, previous_row if previous_node == row['Node'] else None)
        descriptions.append(description)
        previous_row = row
        previous_node = row['Node']
    
    df['Description'] = descriptions
    return df

@app.callback(
    Output('location-data-store', 'data'),
    [Input('upload-data', 'contents'),
     Input('add-location-btn', 'n_clicks'),
     Input('clear-locations-btn', 'n_clicks'),
     Input('locations-table', 'data_timestamp')],
    [State('upload-data', 'filename'),
     State('input-location-name', 'value'),
     State('input-x-coordinate', 'value'),
     State('input-y-coordinate', 'value'),
     State('input-population', 'value'),
     State('input-infection-source', 'value'),
     State('input-vaccination-rate', 'value'),
     State('input-mobility-toggle', 'value'),
     State('location-data-store', 'data'),
     State('locations-table', 'data')]
)
def unified_data_handler(contents, add_clicks, clear_clicks, table_timestamp,
                         filenames, name, x, y, population, infection_source,
                         vaccination_rate, mobility, existing_data, table_data):
    ctx = dash.callback_context
    if not ctx.triggered:
        raise PreventUpdate

    triggered_id = ctx.triggered[0]['prop_id'].split('.')[0]

    if triggered_id == 'upload-data' and contents:
        combined_data = existing_data[:]
        for content, filename in zip(contents, filenames):
            content_type, content_string = content.split(',')
            decoded = base64.b64decode(content_string)
            if 'xlsx' in filename:
                df = pd.read_excel(io.BytesIO(decoded))
                new_data = standardize_data_keys(df.to_dict('records'))
                combined_data += new_data
        return combined_data

    elif triggered_id == 'add-location-btn' and name:
        mobility_text = "Yes" if mobility else "No"
        
        new_entry = standardize_data_keys([{
            'Name': name,
            'X Coordinate': x,
            'Y Coordinate': y,
            'Population': population,
            'Infection Source': infection_source,
            'Vaccination Rate': vaccination_rate,
            'Mobility': mobility_text 
        }])[0]
        
        updated_data = existing_data[:]
        updated_data.append(new_entry)
        return updated_data

    elif triggered_id == 'clear-locations-btn':
        return []

    return table_data or existing_data

@app.callback(
    Output('simulation-plot', 'figure'),
    [Input('run-button', 'n_clicks')],
    [
        State('abort-flag-store', 'data'), 
        State('alpha', 'value'), 
        State('beta', 'value'),
        State('kappa', 'value'),
        State('duration-slider', 'value'), 
        State('time-step-slider', 'value'), 
        State('location-data-store', 'data')
    ],
    #prevent_initial_call=True
)

@app.callback(
    [Output('map-plot', 'figure'), Output('3d-map-plot', 'figure'), Output('statistics-plot', 'figure')],
    Input('run-button', 'n_clicks'),
    State('abort-flag-store', 'data'),
    State('alpha', 'value'), 
    State('beta', 'value'),
    State('kappa', 'value'), 
    State('duration-slider', 'value'), 
    State('time-step-slider', 'value'), 
    State('location-data-store', 'data')
)
def update_simulation(n_clicks, abort_flag_data, alpha, beta, kappa, duration, time_step, location_data):
    if n_clicks is None:
        return go.Figure(), go.Figure(), go.Figure()

    if abort_flag_data.get('is_aborted', False):
        return go.Figure(), go.Figure(), go.Figure()

    sim_params = setup_simulation_parameters(alpha, beta, kappa, location_data)
    sim = Simulator(sim_params)
    tspan = (0, duration if duration is not None else 350)
    dt = time_step if time_step is not None else 0.5
    timesteps, state_trajectories, tau_values = sim.forward_euler(tspan, dt, return_tau=True)

    results = []
    node_names = sim_params['node_names']
    for node_index, node_name in enumerate(node_names):
        for time_index, time in enumerate(timesteps):
            results.append({
                'Time': time,
                'Node': node_name,
                'Susceptible': state_trajectories[time_index, node_index, 0],
                'Infected': state_trajectories[time_index, node_index, 1],
                'Recovered': state_trajectories[time_index, node_index, 2]
            })
    
    df = pd.DataFrame(results)
    df.to_excel('raw_output.xlsx', index=False)

    df_descriptions = add_detailed_descriptions(df)

    df_descriptions.to_excel('output_with_descriptions.xlsx', index=False)

    map_fig = setup_map_figure(sim_params, timesteps, state_trajectories, tau_values)
    map_fig3d = setup_3dmap_figure(sim_params, timesteps, state_trajectories, tau_values)
    stats_fig = setup_stats_figure(sim_params, state_trajectories, timesteps)

    return map_fig, map_fig3d, stats_fig

def setup_simulation_parameters(alpha, beta, kappa, location_data):

    p = {}
    original_coords = np.array([
    [29.7604, -95.3698],  # Houston, Texas
    [33.4484, -112.0740],  # Phoenix, Arizona
    [39.7392, -104.9903],  # Denver, Colorado
    [32.7157, -117.1611],  # San Diego, California
])
    
    additional_coords = np.array([
    [
        loc.get('x', loc.get('X Coordinate')),  
        loc.get('y', loc.get('Y Coordinate')) 
    ] 
    for loc in location_data
])

    updated_coords = additional_coords if additional_coords.size > 0 else original_coords
    new_N = updated_coords.shape[0]

    populations = [
    loc.get('population', loc.get('Population')) 
    for loc in location_data
]

    infection_sources = [
        loc.get('infection_source', loc.get('Infection Source'))  
        for loc in location_data
    ]

    vaccination_rates = [
        loc.get('vaccination_rate', loc.get('Vaccination Rate'))  
        for loc in location_data
    ]

    p['mode'] = 'exp'
    p['coords'] = updated_coords
    p['N'] = new_N
    p['P'] = np.ones([new_N, new_N])

    if location_data:
        node_names = [loc.get('name', f'Node {i}') for i, loc in enumerate(location_data)]
    else:
        node_names = [f'Node {i}' for i in range(p['N'])]
    p['node_names'] = node_names

    #print("Location Data:", location_data)

    mobility_matrix = []
    mobility_matrix = [
    [0.7, 0.3, 1] if loc.get('mobility', loc.get('Mobility')) == 'Yes' else [0, 0, 0]
    for loc in location_data
    ]

    if mobility_matrix: 
        print("Using new mobility values")
        p['M'] = np.array(mobility_matrix)
    else: 
        print("Using default mobility values")
        p['M'] = np.tile(np.array([0.7, 0.3, 1])[np.newaxis, :], [new_N, 1])
    
    for i, mobility_vals in enumerate(mobility_matrix):
        if all(val == 0 for val in mobility_vals):
            p['P'][i, :] = 0
            p['P'][:, i] = 0
        #print("Mobility Matrix p['M']:", p['M'])
    
    beta= 0.5*beta

    p['alpha'] = np.full(new_N, float(alpha))
    p['beta'] = np.full(new_N, float(beta))
    p['kappa'] = np.full(new_N, float(kappa))
    p['mu'] = np.tile(0.00001 * np.array([1.00, 1.00, 1.00])[np.newaxis, :], [new_N, 1])
    p['gamma'] = np.full(new_N, 0.1)
    p['zeta'] = np.full(new_N, 0.1)
    p['u0'] = np.full(new_N, 0)
    p['q0'] = np.array([10] + [0]*(new_N-1)) if not infection_sources else np.array(infection_sources)
    p['v'] = np.array([1]*new_N) if not vaccination_rates else np.array(vaccination_rates)
    p['S0'] = np.random.randint(10000, 30000, size=new_N) if not populations else np.array(populations)
    return p

def setup_map_figure(params, timesteps, state_trajectories, tau_values):
    map_fig = go.Figure()

    step_size = 5  # You can adjust this value to control frame generation
    frames = []
    for idx, (t, state, tau) in enumerate(zip(timesteps, state_trajectories, tau_values)):
        if idx % step_size == 0:  # Only build frames at each step
            frame_traces = build_map_traces(state, params['coords'], tau, map_type='mapbox')
            frames.append(go.Frame(data=frame_traces, name=str(idx)))

    slider_steps = setup_slider_controls(frames, timesteps)
    map_fig.frames = frames
    if frames:
        map_fig.add_traces(frames[0].data)

    map_fig.update_layout(
        mapbox_style="carto-darkmatter",
        mapbox=dict(center=dict(lat=37.04358, lon=-97.23412), zoom=3),
        margin={'l': 0, 't': 0, 'b': 0, 'r': 0},
        sliders=[slider_steps],
        updatemenus=[{
        'type': 'buttons',
        'buttons': [
            {'label': 'Play', 'method': 'animate', 'args': [None, {'frame': {'duration': 500, 'redraw': True}, 'fromcurrent': True}]},
            {'label': 'Pause', 'method': 'animate', 'args': [[None], {'frame': {'duration': 0, 'redraw': False}, 'mode': 'immediate'}]}
        ],
        'direction': 'left',
        'pad': {'r': 10, 't': 87},
        'showactive': False,
        'x': 0.1,
        'xanchor': 'right',
        'y': 0,
        'yanchor': 'top'
    }],
        legend=dict(x=0, y=1, orientation='h', bgcolor='rgba(0,0,0,0.7)'),
        font=dict(family="Arial, sans-serif", size=12, color="white"),
        paper_bgcolor='black',
        plot_bgcolor='black'
    )

    print("Map figure setup with frames:", len(frames))

    return map_fig

def setup_3dmap_figure(params, timesteps, state_trajectories, tau_values):
    map_fig3d = go.Figure()

    step_size = 5
    frames = []
    for idx, (t, state, tau) in enumerate(zip(timesteps, state_trajectories, tau_values)):
        if idx % step_size == 0:
            frame_traces = build_map_traces(state, params['coords'], tau, map_type='geo')
            frames.append(go.Frame(data=frame_traces, name=str(idx)))

    slider_steps = setup_slider_controls(frames, timesteps)
    map_fig3d.frames = frames
    if frames:
        map_fig3d.add_traces(frames[0].data)

    map_fig3d.update_layout(
        geo=dict(
            projection_type="orthographic",
            showland=True,
            landcolor="rgb(190, 173, 150)",  
            showocean=True,
            oceancolor="rgb(29,162,216)",
            showlakes=True,
            lakecolor="rgb(127,205,255)",  
            bgcolor='rgb(10,10,10)',  
            showcountries=True,  
            countrycolor="rgb(0, 4, 8)", 
            showsubunits=True,  
            subunitcolor="rgb(0, 4, 8)",
            resolution=110  
        ),
        margin={'l': 0, 't': 0, 'b': 0, 'r': 0},
        sliders=[slider_steps],
        updatemenus=[{
            'type': 'buttons',
            'buttons': [
                {'label': 'Play', 'method': 'animate', 'args': [None, {'frame': {'duration': 500, 'redraw': True}, 'fromcurrent': True}]},
                {'label': 'Pause', 'method': 'animate', 'args': [[None], {'frame': {'duration': 0, 'redraw': False}, 'mode': 'immediate'}]}
            ],
            'direction': 'left',
            'pad': {'r': 10, 't': 87},
            'showactive': False,
            'x': 0.1,
            'xanchor': 'right',
            'y': 0,
            'yanchor': 'top'
        }],
        legend=dict(x=0, y=1, orientation='h', bgcolor='rgba(0,0,0,0.7)'),
        font=dict(family="Arial, sans-serif", size=12, color="white"),
        paper_bgcolor='black',  
        plot_bgcolor='black'  
    )

    print("3-D Map figure setup with frames:", len(frames))

    return map_fig3d


def build_map_traces(state, coords, tau, map_type='mapbox'):
    S, I, R = state[:, 0], state[:, 1], state[:, 2]
    Total = S + I + R
    
    max_total = Total.max()
    if max_total == 0:
        max_total = 1  # Avoid division by zero

    T_normalized = Total / max_total

    S_scale = np.where(Total > 0, S / Total, 0)
    I_scale = np.where(Total > 0, I / Total, 0)
    R_scale = np.where(Total > 0, R / Total, 0)

    max_size = 50
    min_size = 2

    S_sizes = np.clip(T_normalized * max_size * S_scale, min_size, max_size)
    I_sizes = np.clip(T_normalized * max_size * I_scale + S_sizes, min_size, max_size)
    R_sizes = np.clip(T_normalized * max_size * R_scale + I_sizes, min_size, max_size)

    scatter_class = go.Scattermapbox if map_type == 'mapbox' else go.Scattergeo
    line_color = 'yellow' if map_type == 'mapbox' else 'black'  

    traces = [
        scatter_class(
            lat=coords[:, 0], lon=coords[:, 1], mode='markers',
            marker=dict(size=[s+2 for s in R_sizes], color='black'),
            name='Recovered_border',
            showlegend=False
        ),
        scatter_class(
            lat=coords[:, 0], lon=coords[:, 1], mode='markers',
            marker=dict(size=R_sizes, color='green', opacity=0.7), name='Recovered',
            hoverinfo='text',
            hovertext=['Recovered: ' + '{:0.0f}'.format(r) for r in R]
        ),
        scatter_class(
            lat=coords[:, 0], lon=coords[:, 1], mode='markers',
            marker=dict(size=[s+2 for s in I_sizes], color='black'),
            name='Infected_border',
            showlegend=False
        ),
        scatter_class(
            lat=coords[:, 0], lon=coords[:, 1], mode='markers',
            marker=dict(size=I_sizes, color='red', opacity=0.7), name='Infected',
            hoverinfo='text',
            hovertext=['Infected: ' + '{:0.0f}'.format(inff) for inff in I]
        ),
        scatter_class(
            lat=coords[:, 0], lon=coords[:, 1], mode='markers',
            marker=dict(size=[s+2 for s in S_sizes], color='black'),
            name='Susceptible_border',
            showlegend=False
        ),
        scatter_class(
            lat=coords[:, 0], lon=coords[:, 1], mode='markers',
            marker=dict(size=S_sizes, color='blue', opacity=0.7), name='Susceptible',
            hoverinfo='text',
            hovertext=['Susceptible: ' + '{:0.0f}'.format(s) for s in S]
        )
    ]

    line_base_width = 1 
    for i in range(len(coords)):
        for j in range(i + 1, len(coords)):
            tau_sum = tau[i, j].sum()
            if tau_sum > 0:
                width = max(line_base_width, 2 * tau_sum)  
                traces.append(scatter_class(
                    lat=[coords[i, 0], coords[j, 0]],
                    lon=[coords[i, 1], coords[j, 1]],
                    mode='lines',
                    line=dict(width=width, color=line_color),
                    hoverinfo='none', showlegend=False
                ))

    return traces

def setup_slider_controls(frames, timesteps, duration=100):
    steps = []
    for frame in frames:
        step_time = int(timesteps[int(frame.name)])
        step = {
            "args": [
                [frame.name],
                {"frame": {"duration": duration, "redraw": True}, "mode": "immediate", "transition": {"duration": duration}}
            ],
            "label": str(step_time),
            "method": "animate"
        }
        steps.append(step)
    return {
        'steps': steps,
        'transition': {'duration': 300},
        'x': 0.1,
        'y': 0,
        'currentvalue': {
            'visible': True,
            'prefix': 'Time: ',
            'xanchor': 'right'
        },
        'pad': {'b': 10, 't': 10},
        'len': 0.9,
        'xanchor': 'left',
        'yanchor': 'top'
    }

def setup_stats_figure(params, state_trajectories, timesteps):
    num_nodes = params['N']
    num_states = state_trajectories.shape[2]
    state_dict = {0: 'S', 1: 'I', 2: 'R'}

    node_names = params.get('node_names', [f'Node {i}' for i in range(num_nodes)])

    stats_fig = go.Figure()

    for node in range(num_nodes):
        for state in range(num_states):
            stats_fig.add_trace(go.Scatter(
                x=[],
                y=[],
                mode='lines+markers',
                name=f'{node_names[node]}, State {state_dict[state]}',
                visible=True if node == 0 else 'legendonly'
            ))

    frames = []
    step_size = 5
    for idx, t in enumerate(timesteps):
        if idx % step_size == 0:
            frame_traces = []
            for node in range(num_nodes):
                for state in range(num_states):
                    x_data = timesteps[:idx + 1]
                    y_data = state_trajectories[:idx + 1, node, state]
                    frame_traces.append(go.Scatter(
                        x=x_data,
                        y=y_data,
                        mode='lines+markers',
                        name=f'{node_names[node]}, State {state_dict[state]}'
                    ))
            frames.append(go.Frame(data=frame_traces, name=str(idx)))

    stats_fig.frames = frames
    print("Stat figure setup with frames:", len(frames))

    buttons = []
    for node in range(num_nodes):
        visibility = [(i == node) for i in range(num_nodes) for _ in range(num_states)]
        buttons.append(dict(
            label=node_names[node], 
            method='update',
            args=[{'visible': visibility}]
        ))

    slider_settings = setup_slider_controls(frames, timesteps)
    
    stats_fig.update_layout(
        updatemenus=[
            {
                'buttons': buttons,
                'direction': 'down',
                'showactive': True,
                'x': 0.1,
                'xanchor': 'right',
                'y': 1.2,
                'yanchor': 'top'
            },
            {
                'type': 'buttons',
                'buttons': [
                    {'label': 'Play', 'method': 'animate', 'args': [None, {'frame': {'duration': 500, 'redraw': True}, 'fromcurrent': True}]},
                    {'label': 'Pause', 'method': 'animate', 'args': [[None], {'frame': {'duration': 0, 'redraw': False}, 'mode': 'immediate'}]}
                ],
                'direction': 'left',
                'pad': {'r': 10, 't': 10},
                'showactive': True,
                'x': 0.1,
                'xanchor': 'right',
                'y': 0,
                'yanchor': 'top'
            }
        ],
        sliders=[slider_settings],
        paper_bgcolor='rgba(0,0,0,0.8)',
        plot_bgcolor='rgba(0,0,0,0.8)',
        font=dict(color='white'),
        xaxis=dict(
            showgrid=True,
            gridcolor='gray',
            linecolor='white', 
            range=[0, max(timesteps)]
        ),
        yaxis=dict(
            showgrid=True,
            gridcolor='gray',
            linecolor='white'
        )
    )

    return stats_fig

@app.callback(
    Output('duration-value', 'children'),
    [Input('duration-slider', 'value')]
    )
def update_duration_display(value):
        return f"{value} days"

@app.callback(
        Output('time-step-value', 'children'),
        [Input('time-step-slider', 'value')]
    )
def update_time_step_display(value):
        return f"Time step: {value}"
    
@app.callback(
    Output('infection-source-value', 'children'),
    [Input('input-infection-source', 'value')]
)
def update_infection_source_display(value):
    return f"Infection source: {value}"

@app.callback(
    Output('vaccination-rate-value', 'children'),
    [Input('input-vaccination-rate', 'value')]
)
def update_vaccination_rate_display(value):
    return f"Vaccination rate: {value:.2f}"
    
@app.callback(
    [Output('alpha', 'value'), Output('beta', 'value'),
     Output('kappa', 'value'),
     Output('alpha-value', 'children'), Output('beta-value', 'children'),
     Output('kappa-value', 'children')],
    [Input('preset-1', 'n_clicks'), Input('preset-2', 'n_clicks'),
     Input('alpha', 'value'), Input('beta', 'value'),
     Input('kappa', 'value')],
    prevent_initial_call=True
)
def update_values_and_displays(preset1, preset2, alpha, beta, kappa):
    ctx = callback_context
    if not ctx.triggered:
        triggered_id = 'No clicks yet'
    else:
        triggered_id = ctx.triggered[0]['prop_id'].split('.')[0]

    if triggered_id == 'preset-1':
        return 0.35, 0.05, 0.03, "0.35", "0.05", "0.03"
    elif triggered_id == 'preset-2':
        return 0.45, 0.08, 0.06, "0.45", "0.08", "0.06"
    else:
        return alpha, beta, kappa, f"{alpha:.2f}", f"{beta:.2f}", f"{kappa:.2f}"
        
app.layout = dbc.Container(
    [
        dcc.Store(id='abort-flag-store', data={'is_aborted': False}),
        dcc.Store(id='location-data-store', storage_type='session', data=[]),
        dcc.Store(id='upload-key-store', data={'key': 'upload-data-1'}),
        dcc.Store(id='reset-upload-store', data=0),
        dcc.Store(id='simulation-trigger', data={'version': 0}),
        dbc.Row(
            dbc.Col(
                html.H4(
                    "Simufection (Computational Engineering & Data Analytics)",
                    className="text-white bg-primary p-2 mb-2 text-center"
                ),
                width=12
            )
        ),
        dbc.Row(
            [
                dbc.Col(
                    [
                        dbc.Card(
                            [
                                html.H5("Input Parameters", className="card-title"),
                                html.Div(
                                    [
                                        html.Label("α (Infection Rate)", className="mb-1"),
                                        dcc.Slider(
                                            id='alpha',
                                            min=0,
                                            max=1,
                                            step=0.01,
                                            value=0.42,
                                             marks={0: '0', 0.25: '0.25', 0.5: '0.5', 0.75: '0.75', 1: '1'},
                                        ),
                                        html.Div(id='alpha-value', children="α: 0.42"),
                                    ],
                                    className="mb-4"
                                ),
                                html.Div(
                                    [
                                        html.Label("β (Recovery Rate)", className="mb-1"),
                                        dcc.Slider(
                                            id='beta',
                                            min=0,
                                            max=1,
                                            step=0.01,
                                            value=0.07,
                                             marks={0: '0', 0.25: '0.25', 0.5: '0.5', 0.75: '0.75', 1: '1'},
                                        ),
                                        html.Div(id='beta-value', children="β: 0.07"),
                                    ],
                                    className="mb-4"
                                ),
                                html.Div(
                                    [
                                        html.Label("κ (Mortality Rate)", className="mb-1"),
                                        dcc.Slider(
                                            id='kappa',
                                            min=0,
                                            max=1,
                                            step=0.01,
                                            value=0.05,
                                             marks={0: '0', 0.25: '0.25', 0.5: '0.5', 0.75: '0.75', 1: '1'},
                                        ),
                                        html.Div(id='kappa-value', children="κ: 0.05"),
                                    ],
                                    className="mb-4"
                                ),
                            ],
                            body=True,
                        ),
                        dbc.Card(
                            [
                                html.H5("Dynamic Parameters", className="card-title"),
                                html.Div(
                                    [
                                        html.Label("Duration (days)", className="mb-1"),
                                        dcc.Slider(
                                            id='duration-slider',
                                            min=0,
                                            max=2000,
                                            step=1,
                                            value=350,
                                            marks={i: str(i) for i in range(0, 2001, 500)},
                                        ),
                                        html.Div(id='duration-value', children="Duration: 350 days"),
                                    ],
                                    className="mb-4"
                                ),
                                html.Div(
                                    [
                                        html.Label("Time Step", className="mb-1"),
                                        dcc.Slider(
                                            id='time-step-slider',
                                            min=0,
                                            max=1,
                                            step=0.001,
                                            value=0.5,
                                             marks={0: '0', 0.25: '0.25', 0.5: '0.5', 0.75: '0.75', 1: '1'},
                                        ),
                                        html.Div(id='time-step-value', children="Time step: 0.5"),
                                    ],
                                    className="mb-4"
                                ),
                            ],
                            body=True,
                            className="mt-4"
                        ),
                    ],
                    width=3
                ),
                dbc.Col(
                    [
                        dbc.Card(
                            [
                                html.H5("Locations", className="card-title"),
                                dbc.InputGroup(
                                    [
                                        dbc.InputGroupText("Name"),
                                        dbc.Input(id="input-location-name", placeholder="Enter a name", type="text"),
                                    ],
                                    className="mb-3"
                                ),
                                dbc.InputGroup(
                                    [
                                        dbc.InputGroupText("X"),
                                        dbc.Input(id="input-x-coordinate", placeholder="Enter X-coords", type="number"),
                                    ],
                                    className="mb-3"
                                ),
                                dbc.InputGroup(
                                    [
                                        dbc.InputGroupText("Y"),
                                        dbc.Input(id="input-y-coordinate", placeholder="Enter Y-coords", type="number"),
                                    ],
                                    className="mb-3"
                                ),
                                dbc.InputGroup(
                                    [
                                        dbc.InputGroupText("Population"),
                                        dbc.Input(id="input-population", placeholder="Population", type="number"),
                                    ],
                                    className="mb-3"
                                ),
                                dbc.Row(
                                    [
                                        dbc.Col(dbc.Switch(id="input-mobility-toggle", label="Mobility", value=False), width=10),
                                    ],
                                    className="mb-3"
                                ),
                                html.Div(
                                    [
                                        html.Label("Infection Source", className="mb-1"),
                                        dcc.Slider(
                                            id='input-infection-source',
                                            min=0,
                                            max=100,
                                            step=1,
                                            value=50,
                                            marks={i: str(i) for i in range(0, 101, 25)},
                                        ),
                                        html.Div(id='infection-source-value', children="Infection source: 50"),
                                    ],
                                    className="mb-4"
                                ),
                                html.Div(
                                    [
                                        html.Label("Vaccination Rate", className="mb-1"),
                                        dcc.Slider(
                                            id='input-vaccination-rate',
                                            min=0,
                                            max=1,
                                            step=0.01,
                                            value=0.5,
                                             marks={0: '0', 0.25: '0.25', 0.5: '0.5', 0.75: '0.75', 1: '1'},
                                        ),
                                        html.Div(id='vaccination-rate-value', children="Vaccination rate: 0.5"),
                                    ],
                                    className="mb-4"
                                ),
                                dbc.Row(
                                    [
                                        dbc.Col(
                                            html.Div(id='upload-data-container', children=[
                                                dcc.Upload(
                                                    id='upload-data',
                                                    children=html.Button('Upload Data', id='upload-button', className='btn btn-info', n_clicks=0),
                                                    style={'display': 'block'}, 
                                                    multiple=True
                                                )
                                            ]),
                                            width=12
                                        ),
                                    ],
                                    className="mb-4"
                                ),
                                dbc.Row(
                                    [
                                        dbc.Col(html.Button('Add Location', id='add-location-btn', className='btn btn-primary', n_clicks=0), width=4),
                                        dbc.Col(html.Button('Clear Selected', id='clear-selected-btn', className='btn btn-warning', n_clicks=0), width=4),
                                        dbc.Col(html.Button('Clear All', id='clear-locations-btn', className='btn btn-danger', n_clicks=0), width=4),
                                    ],
                                    className="mb-3"
                                ),
                                html.Div(id='status-message', className="text-center"),
                            ],
                            body=True
                        ),
                    ],
                    width=3
                ),
                dbc.Col(
                    [
                        dbc.Tabs(
                            [
                                dbc.Tab(label='Locations of Interest', children=[
                                    html.Div(id='locations-content', children=[
                                        dash_table.DataTable(
                                            id='locations-table',
                                            columns=[
                                                {'name': 'Name', 'id': 'name', 'editable': False},
                                                {'name': 'X-coords', 'id': 'x', 'type': 'numeric', 'editable': True},
                                                {'name': 'Y-coords', 'id': 'y', 'type': 'numeric', 'editable': True},
                                                {'name': 'Population', 'id': 'population', 'type': 'numeric', 'editable': True},
                                                {'name': 'Infection Source', 'id': 'infection_source', 'type': 'numeric', 'editable': True},
                                                {'name': 'Vaccination Rate', 'id': 'vaccination_rate', 'type': 'numeric', 'editable': True},
                                                {'name': 'Mobility', 'id': 'mobility', 'editable': True}
                                            ],
                                            editable=True,
                                            row_deletable=True,
                                            style_table={'overflowX': 'auto'},
                                            style_header={
                                                'backgroundColor': '#222831',
                                                'color': '#EEEEEE',
                                                'fontWeight': 'bold',
                                                'borderBottom': '3px solid #00ADB5',
                                                'textAlign': 'center',
                                            },
                                            style_cell={
                                                'backgroundColor': '#393E46',
                                                'color': '#EEEEEE',
                                                'border': '1px solid #222831',
                                                'textAlign': 'center',  
                                                'padding': '10px',
                                            },
                                            style_data_conditional=[
                                                {
                                                    'if': {'row_index': 'odd'},
                                                    'backgroundColor': '#222831'
                                                }
                                            ],
                                            style_as_list_view=True,
                                            style_filter={
                                                'backgroundColor': '#393E46',
                                                'border': '1px solid #00ADB5',
                                            },
                                            style_data={
                                                'whiteSpace': 'normal',
                                                'height': 'auto',
                                            },
                                        )
                                    ])
                                ], tab_id='tab-locations'),
                                dbc.Tab(label='Nodal Analysis', children=[
                                    dcc.Loading(
                                        id="loading-stats",
                                        type="cube",
                                        children=dcc.Graph(
                                            id='statistics-plot',
                                        )
                                    )
                                ], tab_id='tab-stats'),
                               dbc.Tab(label='Map', children=[
                                    dcc.Loading(
                                        id="loading-2d-map",
                                        type="cube",
                                        children=dcc.Graph(
                                            id='map-plot', 
                                        )
                                    )
                                ], tab_id='tab-map'),

                                dbc.Tab(label='3-D Map', children=[
                                    dcc.Loading(
                                        id="loading-3d-map",
                                        type="cube",
                                        children=dcc.Graph(
                                            id='3d-map-plot', 
                                        )
                                    )
                                ], tab_id='tab-3d-map'),
                                dbc.Tab(label="Chat with Sim", children=[
                                    html.Div([
                                        dcc.Textarea(
                                            id='user-query', 
                                            placeholder='Ask a question about the simulation results...',
                                            style={
                                                'width': '100%',
                                                'height': '150px',
                                                'color': '#fff',
                                                'backgroundColor': '#2C3E50',
                                                'border': '1px solid #00ADB5',
                                                'borderRadius': '5px',
                                                'padding': '10px',  
                                                'marginBottom': '10px'  
                                            }
                                        ),
                                        html.Button(
                                            'Submit Query', 
                                            id='submit-query', 
                                            n_clicks=0,
                                            className='btn btn-primary btn-lg',
                                            style={
                                                'width': '100%',
                                                'color': '#fff',
                                                'backgroundColor': '#17a2b8',
                                                'border': 'none',
                                                'marginBottom': '10px'
                                            }
                                        ),
                                        html.Div(id='openai-response', style={
                                            'color': '#fff',  
                                            'backgroundColor': '#394E60',  
                                            'padding': '20px',  
                                            'borderRadius': '5px', 
                                            'border': '1px solid #00ADB5'  
                                        })
                                    ], style={'padding': '20px'})  
                                ], tab_id='chat-with-simulator'),
                            ],
                            id='tabs',
                            active_tab='tab-locations'
                        ),
                        dbc.Card(
                            dbc.CardBody(
                                [
                                    dbc.Row(
                                        [
                                            dbc.Col(html.Button('Run Sim', id='run-button', n_clicks=0, className='btn btn-primary'), width=3),
                                            dbc.Col(html.Button('Abort Sim', id='abort-button', n_clicks=0, className='btn btn-danger'), width=3),
                                            dbc.Col(html.Button('Preset 1', id='preset-1', n_clicks=0, className='btn btn-dark'), width=3),
                                            dbc.Col(html.Button('Preset 2', id='preset-2', n_clicks=0, className='btn btn-dark'), width=3),
                                        ],
                                        className='mt-4'
                                    ),
                                ]
                            ),
                            className='mt-4'
                        ),
                    ],
                    md=6
                ),
            ]
        ),
    ],
    fluid=True,
    className="dbc"
)

def advanced_filter_data_based_on_query(query):
    import pandas as pd
    import re
    from nltk.corpus import stopwords
    from nltk.tokenize import word_tokenize

    df_descriptions = pd.read_excel('output_with_descriptions.xlsx')
    stop_words = set(stopwords.words('english'))
    words = word_tokenize(query.lower())
    filtered_words = [word for word in words if word not in stop_words]

    filter_mask = pd.Series([False] * len(df_descriptions))

    node_regex = '|'.join(map(re.escape, filtered_words))
    filter_mask |= df_descriptions['Node'].str.contains(node_regex, case=False, regex=True)

    conditions = ['rising', 'declining', 'improving', 'decreasing', 'stable']
    matched_conditions = [word for word in filtered_words if word in conditions]
    for condition in matched_conditions:
        if condition in ['rising', 'increasing']:
            filter_mask |= df_descriptions['Description'].str.contains('rising|increasing', case=False)
        elif condition in ['declining', 'decreasing']:
            filter_mask |= df_descriptions['Description'].str.contains('declining|decreasing', case=False)
        elif condition == 'improving':
            filter_mask |= df_descriptions['Description'].str.contains('improving', case=False)
        elif condition == 'stable':
            filter_mask |= df_descriptions['Description'].str.contains('stable', case=False)

    filtered_df = df_descriptions[filter_mask]
    if filtered_df.empty:
        return "No data matching your query was found."
    return filtered_df


def convert_data_to_text(data_df):
    if isinstance(data_df, str):  
        return data_df
    
    summary_texts = []
    for _, row in data_df.iterrows():
        summary = f"At time {row['Time']}, in node {row['Node']}, there were {row['Susceptible']} susceptible, {row['Infected']} infected, and {row['Recovered']} recovered. Description: {row['Description']}"
        summary_texts.append(summary)
    return " ".join(summary_texts)

import openai

def ask_openai(query):
    filtered_df = advanced_filter_data_based_on_query(query)
    text_for_ai = convert_data_to_text(filtered_df)
    print("Text for AI:", text_for_ai)
    
    messages = [
        {"role": "system", "content": "You are an AI that assists with understanding simulation data. Answer questions based on the data."},
        {"role": "user", "content": text_for_ai},
        {"role": "user", "content": query}
    ]

    response = client.chat.completions.create(
        model="gpt-4",
        messages=messages,
        max_tokens=150,
        temperature=0.7
    )
    
    ai_response = response['choices'][0]['message']['content'].strip()
    return ai_response

@app.callback(
    Output('openai-response', 'children'),
    [Input('submit-query', 'n_clicks')],
    [State('user-query', 'value'), State('tabs', 'active_tab')]
)
def handle_query(n_clicks, query, active_tab):
    if active_tab != 'chat-with-simulator' or n_clicks < 1 or not query:
        raise PreventUpdate
    
    response = ask_openai(query)
    return response

if __name__ == '__main__':
    app.run_server(debug=True)

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


[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\a2nem\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\a2nem\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Using default mobility values
Map figure setup with frames: 140
3-D Map figure setup with frames: 140
Stat figure setup with frames: 141
Using new mobility values
Map figure setup with frames: 140
3-D Map figure setup with frames: 140
Stat figure setup with frames: 141
Text for AI: At time 0.0, in node Boston, MA, there were 20000.0 susceptible, 0.0 infected, and 0.0 recovered. Description: Initial data point for node. At time 0.5, in node Boston, MA, there were 19948.42164435532 susceptible, 50.0 infected, and 0.0 recovered. Description: Infection rising sharply. Susceptible population decreasing. At time 1.0, in node Boston, MA, there were 19886.41981312326 susceptible, 108.2171724665536 infected, and 0.8750000000000001 recovered. Description: Infection rising sharply. Recovery numbers improving. Susceptible population decreasing. At time 1.5, in node Boston, MA, there were 19814.84916050824 susceptible, 173.4358198742998 infected, and 2.761135491232705 recovered. Description: Infect

In [50]:
from urllib.request import urlopen
import json
with urlopen('https://raw.githubusercontent.com/plotly/datasets/master/geojson-counties-fips.json') as response:
    counties = json.load(response)

import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/fips-unemp-16.csv",
                   dtype={"fips": str})

import plotly.graph_objects as go

fig = go.Figure(go.Choroplethmapbox(geojson=counties, locations=df.fips, z=df.unemp,
                                    colorscale="Viridis", zmin=0, zmax=12,
                                    marker_opacity=0.5, marker_line_width=0))
fig.update_layout(mapbox_style="carto-positron",
                  mapbox_zoom=3, mapbox_center = {"lat": 37.0902, "lon": -95.7129})
fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})
fig.show()

ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed

In [5]:
import pandas as pd
import numpy as np

cities_data = {
    "City": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix", "Philadelphia", "San Antonio", "San Diego", 
             "Dallas", "San Jose", "Austin", "Jacksonville", "Fort Worth", "Columbus", "Charlotte", "San Francisco", 
             "Indianapolis", "Seattle", "Denver", "Washington", "Boston", "El Paso", "Detroit", "Nashville", 
             "Portland", "Memphis", "Oklahoma City", "Las Vegas", "Louisville", "Baltimore", "Milwaukee", 
             "Albuquerque", "Tucson", "Fresno", "Mesa", "Sacramento", "Atlanta", "Kansas City", 
             "Colorado Springs", "Miami", "Raleigh", "Omaha", "Long Beach", "Virginia Beach", "Oakland", 
             "Minneapolis", "Tulsa", "Arlington", "Tampa", "New Orleans"],
    "State": ["NY", "CA", "IL", "TX", "AZ", "PA", "TX", "CA", "TX", "CA", "TX", "FL", "TX", "OH", "NC", "CA", 
              "IN", "WA", "CO", "DC", "MA", "TX", "MI", "TN", "OR", "TN", "OK", "NV", "KY", "MD", "WI", 
              "NM", "AZ", "CA", "AZ", "CA", "GA", "MO", "CO", "FL", "NC", "NE", "CA", "VA", "CA", "MN", 
              "OK", "TX", "FL", "LA"],
    "Latitude": [40.7128, 34.0522, 41.8781, 29.7604, 33.4484, 39.9526, 29.4241, 32.7157, 32.7767, 37.3382, 
                 30.2672, 30.3322, 32.7555, 39.9612, 35.2271, 37.7749, 39.7684, 47.6062, 39.7392, 38.9072, 
                 42.3601, 31.7619, 42.3314, 36.1627, 45.5152, 35.1495, 35.4676, 36.1699, 38.2527, 39.2904, 
                 43.0389, 35.0844, 32.2226, 36.7378, 33.4223, 38.5816, 33.7490, 39.0997, 38.8339, 25.7617, 
                 35.7796, 41.2565, 33.7701, 36.8529, 37.8044, 44.9778, 36.1539, 32.7357, 27.9506, 29.9511],
    "Longitude": [-74.0060, -118.2437, -87.6298, -95.3698, -112.0740, -75.1652, -98.4936, -117.1611, -96.7970, 
                  -121.8863, -97.7431, -81.6557, -97.3308, -82.9988, -80.8431, -122.4194, -86.1581, -122.3321, 
                  -104.9903, -77.0369, -71.0589, -106.4850, -83.0458, -86.7816, -122.6765, -90.0489, -97.5164, 
                  -115.1398, -85.7585, -76.6122, -87.9065, -106.6504, -110.9747, -119.7871, -111.8315, -121.4944, 
                  -84.3880, -94.5786, -104.8214, -80.1918, -78.6382, -95.9345, -118.1892, -75.9780, -122.2711, 
                  -93.2650, -95.9928, -97.1081, -82.4572, -90.0715]
}

us_cities_df = pd.DataFrame(cities_data)

def generate_us_place_names(us_cities_df, n):
    selected_indices = np.random.choice(us_cities_df.index, size=n, replace=False)
    selected_cities = us_cities_df.loc[selected_indices]
    return selected_cities


size = 20
np.random.seed(0)  

selected_cities = generate_us_place_names(us_cities_df, size)

populations = np.random.randint(10000, 20000, size=size)
infection_sources = np.random.randint(10, 100, size=size)
vaccination_rates = np.round(np.random.uniform(0, 0.2, size=size), 2)
mobility = np.random.choice(['Yes', 'No'], size=size)

data = {
    'Place': selected_cities['City'] + ", " + selected_cities['State'],
    'X Coordinate': selected_cities['Latitude'],
    'Y Coordinate': selected_cities['Longitude'],
    'Population': populations,
    'Infection Source': infection_sources,
    'Vaccination Rate': vaccination_rates,
    'Mobility': mobility
}

df = pd.DataFrame(data)
filename = 'random_us_places.xlsx'

df.to_excel(filename, index=False)

filename

'random_us_places.xlsx'

Using default mobility values
{'mode': 'exp', 'coords': array([[  29.7604,  -95.3698],
       [  33.4484, -112.074 ],
       [  39.7392, -104.9903],
       [  32.7157, -117.1611]]), 'N': 4, 'P': array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]]), 'M': array([[0.7, 0.3, 1. ],
       [0.7, 0.3, 1. ],
       [0.7, 0.3, 1. ],
       [0.7, 0.3, 1. ]]), 'alpha': array([0.25, 0.25, 0.25, 0.25]), 'beta': array([0.07, 0.07, 0.07, 0.07]), 'kappa': array([0., 0., 0., 0.]), 'mu': array([[1.e-05, 1.e-05, 1.e-05],
       [1.e-05, 1.e-05, 1.e-05],
       [1.e-05, 1.e-05, 1.e-05],
       [1.e-05, 1.e-05, 1.e-05]]), 'gamma': array([0.1, 0.1, 0.1, 0.1]), 'zeta': array([0.1, 0.1, 0.1, 0.1]), 'u0': array([0, 0, 0, 0]), 'q0': array([10,  0,  0,  0]), 'v': array([1, 1, 1, 1]), 'S0': array([23592, 10770, 25875, 18286])}
[[[2.35920000e+04 0.00000000e+00 0.00000000e+00]
  [1.07700000e+04 0.00000000e+00 0.00000000e+00]
  [2.58750000e+04 0.00000000e+00 0.00000000

In [3]:
pip install dash dash-bootstrap-components plotly numpy scipy scikit-learn dash_ag_grid 


Collecting dash_ag_grid
  Downloading dash_ag_grid-31.0.1-py3-none-any.whl.metadata (4.4 kB)
Downloading dash_ag_grid-31.0.1-py3-none-any.whl (4.9 MB)
   ---------------------------------------- 0.0/4.9 MB ? eta -:--:--
   ---------------------------------------- 0.0/4.9 MB 682.7 kB/s eta 0:00:08
   - -------------------------------------- 0.1/4.9 MB 1.6 MB/s eta 0:00:04
   ---- ----------------------------------- 0.5/4.9 MB 4.3 MB/s eta 0:00:02
   --------- ------------------------------ 1.1/4.9 MB 6.5 MB/s eta 0:00:01
   ---------------- ----------------------- 2.0/4.9 MB 9.1 MB/s eta 0:00:01
   ------------------------- -------------- 3.1/4.9 MB 11.8 MB/s eta 0:00:01
   ---------------------------------------  4.8/4.9 MB 15.4 MB/s eta 0:00:01
   ---------------------------------------  4.9/4.9 MB 15.6 MB/s eta 0:00:01
   ---------------------------------------- 4.9/4.9 MB 13.5 MB/s eta 0:00:00
Installing collected packages: dash_ag_grid
Successfully installed dash_ag_grid-31.0.1



[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [7]:
import dash
from dash import dcc, html
from dash.dependencies import Input, Output, State
import plotly.graph_objs as go
import numpy as np
import dash_bootstrap_components as dbc
import json

# Initialize parameters and settings
np.random.seed(42)

css_style = """
body {
    font-family: Arial, sans-serif !important;
}
"""

app = dash.Dash(__name__, external_stylesheets=[dbc.themes.CYBORG, css_style])

@app.callback(
    [Output('preset_id', 'data')],
    [Input('preset-1', 'n_clicks'), Input('preset-2', 'n_clicks'), Input('preset-3', 'n_clicks')]
)
def update_presets(preset1_clicks, preset2_clicks, preset3_clicks):
    ctx = dash.callback_context
    if not ctx.triggered:
        return [[0]]
    else:
        button_id = ctx.triggered[0]['prop_id'].split('.')[0]
        if button_id == 'preset-1':
            return [[1]]
        elif button_id == 'preset-2':
            return [[2]]
        elif button_id == 'preset-3':
            return [[3]]

@app.callback(
    Output('simulation-plot', 'figure'),
    [Input('run-button', 'n_clicks'), Input('preset_id', 'data')],
    [State('preset_id', 'data')]
)
def update_simulation(n_clicks,_, preset_id):
    if n_clicks is None:
        raise dash.exceptions.PreventUpdate

    # alpha_array = np.array([float(a) for a in alpha.split(',')])
    # beta_array = np.array([float(b) for b in beta.split(',')])
    # kappa_array = np.array([float(k) for k in kappa.split(',')])
    # v_array = np.array([float(vv) for vv in v.split(',')])

    # p['alpha'] = alpha_array
    # p['beta'] = beta_array
    # p['kappa'] = kappa_array
    # p['v'] = v_array
    if preset_id == None:
        p = p0
    else:
        if preset_id[0] == 0:
            p = p0
        if preset_id[0] == 1:
            p = p1
        if preset_id[0] == 2:
            p = p2
        if preset_id[0] == 3:
            p = p3


    # Assume the Simulator class and forward_euler method are defined elsewhere
    sim = Simulator(p)
    tspan = (0, 200)
    dt = 0.1
    timesteps, state_trajectories, tau_values = sim.forward_euler(tspan, dt, return_tau=True)
    fig = go.Figure()
    fig.update_layout(
        mapbox_style="carto-darkmatter",
        mapbox=dict(center=dict(lat=37.04358, lon=-97.23412), zoom=2),
        margin={'l': 0, 't': 0, 'b': 0, 'r': 0},
        updatemenus=[{
            'type': 'buttons',
            'buttons': [
                {
                    'label': 'Play',
                    'method': 'animate',
                    'args': [None, {'frame': {'duration': 100, 'redraw': True}, 'fromcurrent': True}],
                },
                {
                    'label': 'Pause',
                    'method': 'animate',
                    'args': [[None], {'frame': {'duration': 0, 'redraw': False}, 'mode': 'immediate'}],
                }
            ],
            'direction': 'left',
            'pad': {'r': 10, 't': 87},
            'showactive': False,
            'x': 0.1,
            'xanchor': 'right',
            'y': 0,
            'yanchor': 'top'
        }]
    )

    frames = []

    for idx, (t, state, tau) in enumerate(zip(timesteps, state_trajectories, tau_values)):
        if idx % 10 != 0:
            continue

        S = state[:, 0]
        I = state[:, 1]
        R = state[:, 2]
        Total = S + I + R
        S_range = S.max() - S.min() if S.max() - S.min() != 0 else np.inf
        I_range = I.max() - I.min() if I.max() - I.min() != 0 else np.inf
        R_range = R.max() - R.min() if R.max() - R.min() != 0 else np.inf
        T_range = Total.max() - Total.min() if Total.max() - Total.min() != 0 else np.inf

        S_normalized = (S - S.min()) / S_range
        I_normalized = (I - I.min()) / I_range
        R_normalized = (R - R.min()) / R_range
        # T_normalized = 0.5 + (Total - Total.min()) / (2 * T_range)
        T_normalized = Total/Total.max()

        S_scale = S / Total
        I_scale = I / Total
        R_scale = R / Total

        max_size = 50
        S_sizes = T_normalized * max_size * S_scale
        I_sizes = T_normalized * max_size * I_scale + S_sizes
        R_sizes = T_normalized * max_size * R_scale + I_sizes


        frame_traces = []


        frame_traces.append(go.Scattermapbox(
            lat=p['coords'][:, 0],
            lon=p['coords'][:, 1],
            mode='markers',
            marker=dict(size=[s+2 for s in R_sizes], color='black'),  # Adjust the size as needed for border width
            name='Recovered_border',
            showlegend=False
        ))

        frame_traces.append(go.Scattermapbox(
            lat=p['coords'][:, 0],
            lon=p['coords'][:, 1],
            mode='markers',
            marker=dict(size=R_sizes, color='green', opacity=1),
            name='Recovered',
            hoverinfo='text',
            hovertext=['Recovered: ' + '{:0.0f}'.format(r) for r in R]
        ))

        frame_traces.append(go.Scattermapbox(
            lat=p['coords'][:, 0],
            lon=p['coords'][:, 1],
            mode='markers',
            marker=dict(size=[s+2 for s in I_sizes], color='black'),  # Adjust the size as needed for border width
            name='I_border',
            showlegend=False
        ))

        frame_traces.append(go.Scattermapbox(
            lat=p['coords'][:, 0],
            lon=p['coords'][:, 1],
            mode='markers',
            marker=dict(size=I_sizes, color='#E62020', opacity=1),
            name='Infected',
            hoverinfo='text',
            hovertext=['Infected: ' + '{:0.0f}'.format(inff) for inff in I]
        ))

        frame_traces.append(go.Scattermapbox(
            lat=p['coords'][:, 0],
            lon=p['coords'][:, 1],
            mode='markers',
            marker=dict(size=[s+2 for s in S_sizes], color='black'),  # Adjust the size as needed for border width
            name='Sus_border',
            showlegend=False
        ))

        frame_traces.append(go.Scattermapbox(
            lat=p['coords'][:, 0],
            lon=p['coords'][:, 1],
            mode='markers',
            marker=dict(size=S_sizes, color='orange', opacity=1),
            name='Susceptible',
            hoverinfo='text',
            hovertext=['Susceptible: ' + '{:0.0f}'.format(s) for s in S]
        ))

        for i in range(p['N']):
            for j in range(i+1, p['N']):
                if np.sum(tau[i, j, :]) > 0:
                    frame_traces.append(go.Scattermapbox(
                        lat=[p['coords'][i, 0], p['coords'][j, 0]],
                        lon=[p['coords'][i, 1], p['coords'][j, 1]],
                        mode='lines',
                        line=dict(width=min(2 * np.sum(tau[i, j, :]),5), color='yellow'),
                        hoverinfo='none',
                        showlegend=False
                    ))

        frames.append(go.Frame(data=frame_traces, name=str(idx)))

    slider_steps = []
    for idx, frame in enumerate(frames):
        slider_step = {
            "args": [
                [frame.name],
                {"frame": {"duration": 100, "redraw": True}, "mode": "immediate", "transition": {"duration": 100}}
            ],
            "label": f"{int(dt*float(frame.name))}",
            "method": "animate"
        }
        slider_steps.append(slider_step)

    fig.update_layout(
        sliders=[{
            "active": 0,
            "yanchor": "top",
            "xanchor": "left",
            "currentvalue": {
                "font": {"size": 16, "color": "white"},
                "prefix": "Time: ",
                "visible": True,
                "xanchor": "right"
            },
            "transition": {"duration": 300, "easing": "cubic-in-out"},
            "pad": {"b": 10, "t": 10},
            "len": 0.9,
            "x": 0.1,
            "y": 0,
            "steps": slider_steps,
            "bgcolor": "#000000",  # Black background for slider
            "bordercolor": "#666666",
            "borderwidth": 1,
            "font": {"color": "white"}
        }],
        mapbox_style="carto-darkmatter",
        mapbox=dict(center=dict(lat=37.04358, lon=-97.23412), zoom=3),
        updatemenus=[{
            'type': 'buttons',
            'buttons': [
                {
                    'label': 'Play',
                    'method': 'animate',
                    'args': [None, {'frame': {'duration': 100, 'redraw': True}, 'fromcurrent': True}],
                },
                {
                    'label': 'Pause',
                    'method': 'animate',
                    'args': [[None], {'frame': {'duration': 0, 'redraw': False}, 'mode': 'immediate'}],
                }
            ],
            'direction': 'left',
            'pad': {'r': 0, 't': 87},
            'showactive': False,
            'x': 0.1,
            'xanchor': 'right',
            'y': 0,
            'yanchor': 'top',
            'bgcolor': '#000000'  # Black background for buttons
        }],
        margin={'l': 0, 't': 0, 'b': 0, 'r': 0},
        legend=dict(x=0, y=1, orientation='h', bgcolor='rgba(0,0,0,0.7)'),
        font=dict(family="Arial, sans-serif", size=12, color="white"),
        paper_bgcolor='black',  # Set paper background color to black
        plot_bgcolor='black'  # Set plot background color to black
        )
    fig.frames = frames
    fig.add_traces(frames[0].data)
    return fig


app.layout = dbc.Container(
    [
        
        dbc.Row(
            dbc.Col(
                html.H4("Simufection (6.7300)", className="text-black p-2 mb-2 text-center", style={'backgroundColor': 'rgba(255, 255, 255, 0.8)'}),
                width=12
            )
        ),
        dbc.Row(
            [
                dbc.Col(
                    [
                        # dbc.Row([
                        #     dbc.Col(html.Label("α", className="mr-2"), width=3),
                        #     dbc.Col(dcc.Input(id='alpha', type='text', value=','.join(['0.42'] * p['N']), className='form-control'), width=10)
                        # ], className="mb-3"),
                        # dbc.Row([
                        #     dbc.Col(html.Label("β", className="mr-2"), width=3),
                        #     dbc.Col(dcc.Input(id='beta', type='text', value=','.join(['0.07'] * p['N']), className='form-control'), width=10)
                        # ], className="mb-3"),
                        # dbc.Row([
                        #     dbc.Col(html.Label("κ", className="mr-2"), width=3),
                        #     dbc.Col(dcc.Input(id='kappa', type='text', value=','.join(['0.05'] * p['N']), className='form-control'), width=10)
                        # ], className="mb-3"),
                        # dbc.Row([
                        #     dbc.Col(html.Label("ν", className="mr-2"), width=3),
                        #     dbc.Col(dcc.Input(id='v', type='text', value=','.join(['0.2'] * p['N']), className='form-control'), width=10)
                        # ], className="mb-3"),
                        dcc.Store(id='preset_id', data=[0]),
                        html.Button('Run Simulator', id='run-button', n_clicks=1, className='btn btn-warning btn-block my-2'),
                        html.Button('High Infection Case', id='preset-1', n_clicks=0, className='btn btn-info btn-block my-2'),
                        html.Button('Vaccination Case', id='preset-2', n_clicks=0, className='btn btn-info btn-block my-2'),
                        html.Button('Lockdown Case', id='preset-3', n_clicks=0, className='btn btn-info btn-block my-2')
                    ],
                    width=2
                ),
                dbc.Col(
                    dcc.Loading(
                        id="loading-1",
                        type="cube",
                        children=dcc.Graph(id='simulation-plot')
                    ),
                    width=10
                )
            ]
        )
    ],
    fluid=True
)

if __name__ == '__main__':
    app.run_server(debug=True)

us_coords = np.array([
    [40.6892, -74.0445],  # Statue of Liberty, New York
    [31.9686, -99.9018], # Texas
    [28.3852, -81.5639],  # Walt Disney World Resort, Orlando, Florida
    [38.8977, -77.0365],  # The White House, Washington D.C.
    [36.1147, -115.1728], # Las Vegas Strip, Las Vegas, Nevada
    [43.8791, -103.4591], # Mount Rushmore, South Dakota
    [37.8267, -122.4233], # Alcatraz Island, San Francisco, California
    [29.9584, -90.0644]   # French Quarter, New Orleans, Louisiana
])
p0 = {}
p0['coords'] = np.array(us_coords)
p0['N'] = us_coords.shape[0]
p0['P'] = np.ones([p0['N'], p0['N']])
p0['M'] = np.tile(np.array([0.7, 0.3, 1])[np.newaxis, :], [p0['N'], 1])
p0['S0'] = np.random.randint(10000, 20000, size=p0['N'])
p0['mu'] = np.tile(0.00001 * np.array([1.00, 1.00, 1.00])[np.newaxis, :], [p0['N'], 1])
p0['gamma'] = np.repeat(0.2, p0['N'])
p0['zeta'] = np.repeat(0.2, p0['N'])
p0['u0'] = np.repeat(10, p0['N'])
p0['q0'] = np.array([10] + [0]*7)
p0['mode'] = 'exp'
p0['alpha'] = np.repeat(0.31, p0['N'])
p0['beta'] = np.repeat(0.01, p0['N'])
p0['kappa'] = np.repeat(0.00, p0['N'])
p0['v'] = np.array(8*[0])