In [8]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
# import pywt
# from pydmd import DMD
import json
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display

from ml_functions import *

In [9]:
class Plotter:
    @staticmethod
    def plot_heatmap(data, x_domain, y_domain, title, cmap_type='hot', figsize=(10,6)):
        fig, ax = plt.subplots(figsize=figsize)
        cax = ax.imshow(data, cmap=cmap_type, interpolation='nearest', aspect='auto', 
                        extent=[x_domain[0], x_domain[-1], y_domain[0], y_domain[-1]], origin='lower')
        fig.colorbar(cax)
        ax.set_title(title)
        ax.set_xlabel('Space')
        ax.set_ylabel('Time')
        plt.show(block=False)
        return fig, ax

    @staticmethod
    def plot_line(data_array, x_domain, title, figsize=(10,6)):
        fig, ax = plt.subplots(figsize=figsize)
        for data in data_array:
            ax.plot(x_domain, data)
        ax.set_title(title)
        ax.set_xlabel('Space')
        ax.set_ylabel('Value')
        ax.grid(True)
        plt.show(block=False)
        return fig, ax
    
    @staticmethod
    def plot_fourier_transform(signal, sampling_rate, title='Fourier Transform'):
        N = len(signal)
        T = 1.0 / sampling_rate
        yf = np.fft.fft(signal)
        xf = np.fft.fftfreq(N, T)[:N//2]
        fig, ax = plt.subplots()
        ax.plot(xf, 2.0/N * np.abs(yf[0:N//2]))
        ax.set_title(title)
        ax.set_xlabel('Frequency')
        ax.set_ylabel('Amplitude')
        plt.show(block=False)
        return fig, ax

    @staticmethod
    def plot_2d_fourier_transform(data, title='2D Fourier Transform'):
        """
        Plots the 2D Fourier Transform of the provided data.

        :param data: The 2D data array to transform and plot.
        :param title: The title of the plot.
        """
        # Compute the 2D Fourier Transform
        fourier_transform = np.fft.fft2(data)
        # Shift the zero frequency component to the center of the spectrum
        fshift = np.fft.fftshift(fourier_transform)
        
        # Calculate the magnitude spectrum
        magnitude_spectrum = np.abs(fshift)
        
        fig, ax = plt.subplots()
        # Use logarithmic scaling to better visualize the spectrum
        ax.imshow(magnitude_spectrum, norm=LogNorm(vmin=1), cmap='hot', aspect='equal')
        ax.set_title(title)
        plt.colorbar(ax.imshow(magnitude_spectrum, norm=LogNorm(vmin=1), cmap='hot'), ax=ax)
        plt.show(block=False)
        return fig, ax
    
    @staticmethod
    def plot_wavelet_transform(signal, scales, dt, waveletname='cmor', title='Wavelet Transform'):
        import pywt

        duration = len(signal) * dt
        sampling_rate = 1 / dt

        # Perform the Continuous Wavelet Transform (CWT)
        coefficients, frequencies = pywt.cwt(signal, scales, waveletname, 1 / sampling_rate)
        
        # Plot the wavelet power spectrum
        fig, ax = plt.subplots(figsize=(10, 4))
        
        # Determine the extent of the plot
        extent = [0, duration, 0, len(frequencies) - 1]
        
        # Plot the coefficients with an image plot
        im = ax.imshow(np.abs(coefficients), extent=extent, cmap='jet', aspect='auto', origin='lower', vmax=abs(coefficients).max(), vmin=-abs(coefficients).max())
        
        # Create an array of y positions from 0 to the number of frequencies, this will be the new y-axis
        y_positions = np.linspace(start=0, stop=len(frequencies) - 1, num=len(frequencies))
        
        # Set the y-ticks to correspond to the positions we just created
        ax.set_yticks(y_positions[::len(y_positions) // 10])
        
        # Set the y-tick labels to show the frequency values
        ax.set_yticklabels(np.round(frequencies, decimals=2)[::len(y_positions) // 10])
        
        # Add the plot details
        ax.set_title(title)
        ax.set_xlabel('Time (seconds)')
        ax.set_ylabel('Frequency (Hz)')
        
        # Add a colorbar for the magnitude
        fig.colorbar(im, ax=ax, label='Magnitude')
        
        plt.show(block=False)
        return fig, ax

    @staticmethod
    def animate_solution(data, x_domain, y_label='Value', title='Solution Evolution', interval=200, cmap_type='hot'):
        """
        Creates an animation of the solution's evolution over time.

        :param data: The data to animate, expected shape is (time_steps, spatial_domain).
        :param x_domain: The spatial domain or x-axis values for the plot.
        :param y_label: Label for the y-axis.
        :param title: The title of the plot.
        :param interval: Time interval between frames in milliseconds.
        :param cmap_type: Colormap for the heatmap.
        """
        fig, ax = plt.subplots()
        ax.set_title(title)
        ax.set_xlabel('Space')
        ax.set_ylabel(y_label)
        
        # Setting the limits for x and y axes
        ax.set_xlim(x_domain[0], x_domain[-1])
        ax.set_ylim(np.min(data), np.max(data))
          
        line, = ax.plot([], [], lw=2)

        def init():
            line.set_data([], [])
            return line,

        def animate(i):
            y = data[i]
            line.set_data(x_domain, y)
            return line,

        anim = FuncAnimation(fig, animate, init_func=init, frames=len(data), interval=interval, blit=True)

        plt.show(block=True)
        return anim
    
    @staticmethod
    def animate_solution_ipynb(data, x_domain, y_label, title, interval, cmap_type='viridis'):
        fig, ax = plt.subplots()
        ax.set_title(title)
        ax.set_xlabel('Space')
        ax.set_ylabel(y_label)
        ax.set_xlim(x_domain[0], x_domain[-1])
        ax.set_ylim(np.nanmin(data), np.nanmax(data))

        # Set the face colors to be visible on dark backgrounds
        # ax.set_facecolor('white')
       
        # Create a border around the figure for visibility
        fig.patch.set_linewidth(2)
        fig.patch.set_facecolor('white')
        
        line, = ax.plot([], [], lw=2)
        
        def init():
            line.set_data([], [])
            return line,
        
        def animate(i):
            y = data[i]
            line.set_data(x_domain, y)
            return line,
        
        anim = FuncAnimation(fig, animate, init_func=init, frames=len(data), interval=interval, blit=True)
        
        plt.close(fig)  # Close the figure to prevent it from displaying twice
        return HTML(anim.to_jshtml())

In [13]:
from tensorflow.keras.models import load_model

# Paths to files
simulation_data_path = 'simulation_data.npz'
model_path = 'model.h5'
history_path = 'history.pkl'
predictions_path = 'predictions.pkl'
metrics_path = 'metrics.pkl'

# Load Information
data = np.load(simulation_data_path)
training_data = data['training_data']
testing_data = data['testing_data']

model = load_model(model_path)

# Load History
with open(history_path, 'rb') as f:
    history = pickle.load(f)

# Load Predictions
with open(predictions_path, 'rb') as f:
    predictions = pickle.load(f)

# Load Metrics
with open(metrics_path, 'rb') as f:
    metrics = pickle.load(f)


FileNotFoundError: [Errno 2] No such file or directory: 'predictions.pkl'

In [None]:
import matplotlib.pyplot as plt

# Function to plot training and validation loss
def plot_training_loss(history):
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(history['loss']) + 1)  # assuming 'loss' and 'val_loss' have the same length
    plt.plot(epochs, history['loss'], label='Training Loss')
    plt.plot(epochs, history['val_loss'], label='Validation Loss')
    plt.title('Model Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(loc='upper right')
    plt.grid(True)
    plt.show()

plot_training_loss(history)