In [46]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
# import pywt
from pydmd import DMD
import json
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display

class Field:
    def __init__(self, num_points, initial_conditions, dx, dt):
        self.num_points = num_points
        self.values = np.array(initial_conditions)
        self.dx = dx
        self.dt = dt
        self.prev_values = np.copy(self.values)

    def store_prev_values(self):
        self.prev_values = np.copy(self.values)

    def r(self, shift=1):
        return np.roll(self.values, -shift)

    def l(self, shift=1):
        return np.roll(self.values, shift)

    def ddx(self, direction='r'):
        if direction == 'l':
            return (self.values - self.l()) / self.dx
        else:
            return (self.r() - self.values) / self.dx

    def update(self):
        raise NotImplementedError("Subclasses should implement this method.")

class Species:
    def __init__(self, density_initial, vel_initial, temp_initial, dx, dt, mass=1.0, gamma=1.0, thermal_diffusivity=0.01):
        self.density = Field(len(density_initial), density_initial, dx, dt)
        self.vel = Field(len(vel_initial), vel_initial, dx, dt)
        self.temp = Field(len(temp_initial), temp_initial, dx, dt)
        self.potential = Field(len(density_initial), np.zeros_like(density_initial), dx, dt)
        self.correlations = Field(len(density_initial), np.zeros_like(density_initial), dx, dt) 
        self.mass = mass
        self.gamma = gamma
        self.thermal_diffusivity = thermal_diffusivity

    def solve_poisson_periodic(self):
        # Implement the periodic Poisson solver here
        a, b, c = -1.0, 2.0, -1.0
        d = -4 * np.pi * self.density.dx**2 * (self.density.values - np.mean(self.density.values))
        beta_modified = np.full(self.density.num_points, b)
        r_modified = d.copy()
        beta_modified[0] -= a
        beta_modified[-1] -= c
        r_modified[0] -= a * d[-1]
        r_modified[-1] -= c * d[0]

        for i in range(1, self.density.num_points):
            m = a / beta_modified[i-1]
            beta_modified[i] -= m * c
            r_modified[i] -= m * r_modified[i-1]

        # Initialize an array to store the solution of the Poisson equation
        potential_solution = np.zeros(self.density.num_points)
        potential_solution[-1] = r_modified[-1] / beta_modified[-1]
        for i in range(self.density.num_points - 2, -1, -1):
            potential_solution[i] = (r_modified[i] - c * potential_solution[i+1]) / beta_modified[i]
        potential_solution[0] = potential_solution[-1]  # Enforce periodicity

        # Update the potential field values with the solution
        self.potential.values = potential_solution

    ###########
    ## TODO: ##
    ###########
    def calculate_correlations(self):
        """
        Placeholder method for calculating correlations.
        Currently returns a zero array of the same shape as the density.
        This should be replaced with the actual correlation calculation logic.
        """
        # Return zeros for now, as we don't have the actual correlation logic
        self.correlations.values = np.zeros_like(self.density.values)

    def update_density(self):
        # Density update using upwind scheme for advection term
        flux_n = np.where(self.vel.values >= 0,
                          (self.density.values - self.density.l()) / self.density.dx,
                          (self.density.r() - self.density.values) / self.density.dx) * self.vel.values
        self.density.values -= self.density.dt * flux_n

    def update_velocity(self):
        # Velocity update using upwind scheme for the nonlinear term and central differencing for potential gradient
        du_dx = np.where(self.vel.values >= 0,
                         (self.vel.values - self.vel.l()) / self.vel.dx,
                         (self.vel.r() - self.vel.values) / self.vel.dx) * self.vel.values
        dlogn_dx = (np.log(self.density.r()) - np.log(self.density.l())) / (2 * self.density.dx)
        dpotential_dx = (self.potential.r() - self.potential.l()) / (2 * self.vel.dx)

        # Calculate correlations and their gradient
        self.calculate_correlations()
        dC_dx = (self.correlations.r() - self.correlations.l()) / (2 * self.vel.dx)

        self.vel.values -= self.vel.dt * (du_dx - self.temp.values * dlogn_dx + self.gamma * dpotential_dx - self.temp.values * dC_dx)

    def update_temperature(self):
        # Temperature update using upwind scheme for advection and central differencing for diffusion
        # Advection term T*u*dT/dx
        advective_term_T = np.where(self.vel.values >= 0,
                                    self.vel.values * (self.temp.values - self.temp.l()) / self.temp.dx,
                                    self.vel.values * (self.temp.r() - self.temp.values) / self.temp.dx) * self.temp.values
        # Diffusion term - lambda*d^2T/dx^2
        diffusive_term = (self.temp.r() - 2 * self.temp.values + self.temp.l()) / self.temp.dx**2
        self.temp.values -= self.temp.dt * (advective_term_T - self.thermal_diffusivity * diffusive_term)

    def update(self):
        self.density.store_prev_values()
        self.vel.store_prev_values()
        self.temp.store_prev_values()
        
        self.solve_poisson_periodic()
        self.update_density()
        self.update_velocity()
        self.update_temperature()
        

class Orchestrator:
    def __init__(self, species, num_snapshots):
        self.species = species
        self.num_snapshots = num_snapshots
        # Initialize snapshot arrays with zeros based on the species fields
        self.density_snapshots = np.zeros((num_snapshots, species.density.num_points))
        self.vel_snapshots = np.zeros((num_snapshots, species.vel.num_points))
        self.temp_snapshots = np.zeros((num_snapshots, species.temp.num_points))
        self.potential_snapshots = np.zeros((num_snapshots, species.potential.num_points))
        self.snapshot_counter = 0


    def take_snapshot(self):
        if self.snapshot_counter < self.num_snapshots:
            self.density_snapshots[self.snapshot_counter] = self.species.density.values
            self.vel_snapshots[self.snapshot_counter] = self.species.vel.values
            self.temp_snapshots[self.snapshot_counter] = self.species.temp.values
            self.potential_snapshots[self.snapshot_counter] = self.species.potential.values
            self.snapshot_counter += 1

    def run_sim(self, num_steps):
        snapshot_interval = max(1, num_steps // self.num_snapshots)
        
        for step in range(num_steps):
            self.species.update()
            
            # Take snapshots at the calculated intervals
            if step % snapshot_interval == 0 or step == num_steps - 1:
                self.take_snapshot()

class Plotter:
    @staticmethod
    def plot_heatmap(data, x_domain, y_domain, title, cmap_type='hot'):
        fig, ax = plt.subplots()
        cax = ax.imshow(data, cmap=cmap_type, interpolation='nearest', aspect='auto', 
                        extent=[x_domain[0], x_domain[-1], y_domain[0], y_domain[-1]], origin='lower')
        fig.colorbar(cax)
        ax.set_title(title)
        ax.set_xlabel('Time')
        ax.set_ylabel('Space')
        plt.show(block=False)
        return fig, ax

    @staticmethod
    def plot_line(data_array, x_domain, title):
        fig, ax = plt.subplots()
        for data in data_array:
            ax.plot(x_domain, data)
        ax.set_title(title)
        ax.set_xlabel('Domain')
        ax.set_ylabel('Value')
        ax.grid(True)
        plt.show(block=False)
        return fig, ax
    
    @staticmethod
    def plot_fourier_transform(signal, sampling_rate, title='Fourier Transform'):
        N = len(signal)
        T = 1.0 / sampling_rate
        yf = np.fft.fft(signal)
        xf = np.fft.fftfreq(N, T)[:N//2]
        fig, ax = plt.subplots()
        ax.plot(xf, 2.0/N * np.abs(yf[0:N//2]))
        ax.set_title(title)
        ax.set_xlabel('Frequency')
        ax.set_ylabel('Amplitude')
        plt.show(block=False)
        return fig, ax

    @staticmethod
    def plot_2d_fourier_transform(data, title='2D Fourier Transform'):
        """
        Plots the 2D Fourier Transform of the provided data.

        :param data: The 2D data array to transform and plot.
        :param title: The title of the plot.
        """
        # Compute the 2D Fourier Transform
        fourier_transform = np.fft.fft2(data)
        # Shift the zero frequency component to the center of the spectrum
        fshift = np.fft.fftshift(fourier_transform)
        # Calculate the magnitude spectrum
        magnitude_spectrum = np.abs(fshift)
        
        fig, ax = plt.subplots()
        # Use logarithmic scaling to better visualize the spectrum
        ax.imshow(magnitude_spectrum, norm=LogNorm(vmin=1), cmap='hot', aspect='equal')
        ax.set_title(title)
        plt.colorbar(ax.imshow(magnitude_spectrum, norm=LogNorm(vmin=1), cmap='hot'), ax=ax)
        plt.show(block=False)
        return fig, ax
    
    @staticmethod
    def plot_wavelet_transform(signal, scales, dt, waveletname='cmor', title='Wavelet Transform'):
        import pywt

        duration = len(signal) * dt
        sampling_rate = 1 / dt

        # Perform the Continuous Wavelet Transform (CWT)
        coefficients, frequencies = pywt.cwt(signal, scales, waveletname, 1 / sampling_rate)
        
        # Plot the wavelet power spectrum
        fig, ax = plt.subplots(figsize=(10, 4))
        
        # Determine the extent of the plot
        extent = [0, duration, 0, len(frequencies) - 1]
        
        # Plot the coefficients with an image plot
        im = ax.imshow(np.abs(coefficients), extent=extent, cmap='jet', aspect='auto', origin='lower', vmax=abs(coefficients).max(), vmin=-abs(coefficients).max())
        
        # Create an array of y positions from 0 to the number of frequencies, this will be the new y-axis
        y_positions = np.linspace(start=0, stop=len(frequencies) - 1, num=len(frequencies))
        
        # Set the y-ticks to correspond to the positions we just created
        ax.set_yticks(y_positions[::len(y_positions) // 10])
        
        # Set the y-tick labels to show the frequency values
        ax.set_yticklabels(np.round(frequencies, decimals=2)[::len(y_positions) // 10])
        
        # Add the plot details
        ax.set_title(title)
        ax.set_xlabel('Time (seconds)')
        ax.set_ylabel('Frequency (Hz)')
        
        # Add a colorbar for the magnitude
        fig.colorbar(im, ax=ax, label='Magnitude')
        
        plt.show(block=False)
        return fig, ax

    @staticmethod
    def animate_solution(data, x_domain, y_label='Value', title='Solution Evolution', interval=200, cmap_type='hot'):
        """
        Creates an animation of the solution's evolution over time.

        :param data: The data to animate, expected shape is (time_steps, spatial_domain).
        :param x_domain: The spatial domain or x-axis values for the plot.
        :param y_label: Label for the y-axis.
        :param title: The title of the plot.
        :param interval: Time interval between frames in milliseconds.
        :param cmap_type: Colormap for the heatmap.
        """
        fig, ax = plt.subplots()
        ax.set_title(title)
        ax.set_xlabel('Space')
        ax.set_ylabel(y_label)
        
        # Setting the limits for x and y axes
        ax.set_xlim(x_domain[0], x_domain[-1])
        ax.set_ylim(np.min(data), np.max(data))
          
        line, = ax.plot([], [], lw=2)

        def init():
            line.set_data([], [])
            return line,

        def animate(i):
            y = data[i]
            line.set_data(x_domain, y)
            return line,

        anim = FuncAnimation(fig, animate, init_func=init, frames=len(data), interval=interval, blit=True)

        plt.show(block=True)
        return anim
    
    @staticmethod
    def animate_solution_ipynb(data, x_domain, y_label, title, interval, cmap_type='viridis'):
        fig, ax = plt.subplots()
        ax.set_title(title)
        ax.set_xlabel('Space')
        ax.set_ylabel(y_label)
        ax.set_xlim(x_domain[0], x_domain[-1])
        ax.set_ylim(np.nanmin(data), np.nanmax(data))
        
        line, = ax.plot([], [], lw=2)
        
        def init():
            line.set_data([], [])
            return line,
        
        def animate(i):
            y = data[i]
            line.set_data(x_domain, y)
            return line,
        
        anim = FuncAnimation(fig, animate, init_func=init, frames=len(data), interval=interval, blit=True)
        
        plt.close(fig)  # Close the figure to prevent it from displaying twice
        return HTML(anim.to_jshtml())

if __name__ == "__main__":
    ### Define Parameters & Initial Conditions ###
    # Defined below. JSON file is wrong for initial conditions
    # config_file_path = 'config.json'  # replace with the actual path to your JSON config file
    # orchestrator.set_up(config_file_path)

    # Parameters
    gamma = 1  # Coulomb coupling parameter
    mass = 1  # Mass of particles
    thermal_diffusivity = 0.01  # Thermal diffusivity
    
    # Spatial and temporal resolution
    grid_size = 100
    dx = 2 * np.pi / grid_size
    dt = 0.01

    # Number of snapshots to record
    num_snapshots = 50

    # Initial conditions
    x = np.linspace(0, 2 * np.pi, grid_size, endpoint=False)
    initial_density = 3/(4*np.pi) + 0.1 * np.sin(3/4 * np.pi * x)
    initial_vel = np.sin(x)
    initial_temp = 1 + 0.1 * np.cos(x)

    # Create species object with explicit parameters
    species = Species(initial_density, initial_vel, initial_temp, dx, dt, mass, gamma, thermal_diffusivity)

    # Initialize the orchestrator with species
    orchestrator = Orchestrator(species, num_snapshots)
    num_steps = 1000  # number of simulation steps
    orchestrator.run_sim(num_steps)
    
    ####################
    ### Check Values ###
    ####################
    
    # Check for NaN or Inf values
    print("Contains NaN:", np.isnan(orchestrator.density_snapshots).any())
    print("Contains Inf:", np.isinf(orchestrator.density_snapshots).any())

    ##########
    ## Plot ##
    ##########
    # Display animation in the Jupyter Notebook
    animation_html = Plotter.animate_solution_ipynb(orchestrator.density_snapshots, x, 'Density', 'Solution Evolution', 200, 'hot')
    display(animation_html)
    
    animation_html = Plotter.animate_solution_ipynb(orchestrator.vel_snapshots, x, 'Velocity', 'Solution Evolution', 200, 'hot')
    display(animation_html)
    
    animation_html = Plotter.animate_solution_ipynb(orchestrator.temp_snapshots, x, 'Temperature', 'Solution Evolution', 200, 'hot')
    display(animation_html)
    
    animation_html = Plotter.animate_solution_ipynb(orchestrator.potential_snapshots, x, 'Potential', 'Solution Evolution', 200, 'hot')
    display(animation_html)

Contains NaN: False
Contains Inf: False
