In [23]:
import sys
import warnings

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import matplotlib.patches as mpatches

# modify the plotter
matplotlib.rcParams["font.family"] = ["serif"]

# import the colors
from helpers import *

In [24]:

class PlayBack:
    def __init__(
        self, fp: list[tuple[str, str]], cns: list[str] | None, grid: list[int],
        plot_data_range = -5,
    ):
        self.folder_pairs = fp
        self.folder_plot_data: dict[str, dict] = curate_plot_data(self.folder_pairs, plot_data_range) # get the plot at time 5 steps from the end
        self.paused = False

        self.grid = grid
        self.case_names = (
            np.array(["" for _ in range(int(grid[0] * grid[1]))])
            .reshape((grid[0], grid[1]))
            .tolist()
        )
        self.folder_pair_gird = []
        # k = 0
        _range = range(0, len(fp), 2)
        for i in _range:
            a = []
            for j in range(i, i+1):
                if j in _range:
                    a.append(fp[j])
            # l = fp[i]
            # a.append(fp[i])
            # if i + 1 in _range:
            #     r = fp[i + 1]
            #     a.append(r)

        for i in range(0, len(fp), 2):
            l = fp[i]
            r = fp[i + 1]
            self.folder_pair_gird.append([l, r])
        self.n_images = int(len(self.folder_plot_data))
        self.fig, self.axes = plt.subplots(grid[0], grid[1], figsize=(10, 10))
        # self.fig.subplots_adjust(hspace=20, left=20, right=40)

        if self.n_images == 1:
            self.axes = [self.axes]  # Ensure axes is a list for consistent indexing

        self.contours = []
        self.images = []

        for i in range(self.grid[0]):
            ctr = []
            imgs = []
            for j in range(self.grid[1]):
                ax = self.axes[i][j]
                l = self.folder_pair_gird[i]
                set_dir = self.folder_pair_gird[i][j][0]

                ax.set_ylabel(set_dir.split("\\")[-1])  # set the title of the plot
                opts = self.folder_plot_data[set_dir]

                ann_img = ax.imshow(
                    opts["annulus_images"][0],
                    aspect="equal",
                )
                ann_contour = ax.contour(
                    opts["annulus_streamlines_x"],
                    opts["annulus_streamlines_y"],
                    opts["annulus_streamlines"][0],
                    levels=10,
                    linewidths=(0.5,),
                    linestyles=("solid",),
                    colors="white",
                )

                ax.set_title(self.case_names[i][j])
                ax.set_ylabel("")
                ax.set_xlabel("Measured Depth (m)")

                ax.set_xticks(
                    opts["ticks"]["annulus"]["x_ticks"],
                    labels=opts["ticks"]["annulus"]["x_ticks_labels"],
                )
                ax.set_yticks(
                    opts["ticks"]["annulus"]["y_ticks"],
                    labels=opts["ticks"]["annulus"]["y_ticks_labels"],
                )
                ctr.append(ann_contour)
                imgs.append(ann_img)

            self.contours.append(ctr)
            self.images.append(imgs)

        # obtain the max number of frames
        self.n_frames = 10
        for i in range(self.grid[0]):
            for j in range(self.grid[1]):
                l = len(
                    self.folder_plot_data[self.folder_pair_gird[i][j][0]][
                        "annulus_images"
                    ]
                )
                if self.n_frames < l:
                    self.n_frames = l

        # Create the animation
        # n_frames = len(folder_plot_data[folder_pairs[0][0]]['annulus_images'])  # Assuming equal frame count across all
        self.anim = FuncAnimation(
            self.fig, self.update, frames=self.n_frames, interval=50, blit=False
        )

        self.fig.canvas.mpl_connect("button_press_event", self.toggle_pause)

        legend_patches = [
            mpatches.Patch(color=color, label=name)
            for color, name in zip(COLORS, COLOR_NAMES)
        ]
        self.fig.legend(
            handles=legend_patches,
            loc="upper center",
            ncol=5,
            bbox_to_anchor=(0.5, 1.0),
        )

    # Define the update function for animation
    def update(self, frame):

        if frame + 1 == self.n_frames:
            self.toggle_pause()

        for i in range(self.grid[0]):
            for j in range(self.grid[1]):
                img = self.images[i][j]
                contour = self.contours[i][j]

                opts = self.folder_plot_data[self.folder_pair_gird[i][j][0]]

                if frame not in range(len(opts["annulus_images"])):
                    # skip areas where the frames are not consistent
                    continue

                # Update the image data
                img.set_data(opts["annulus_images"][frame])

                # Clear the old contours
                for c in contour.collections:
                    c.remove()

                # Add new contours
                new_contour = self.axes[i][j].contour(
                    opts["annulus_streamlines_x"],
                    opts["annulus_streamlines_y"],
                    opts["annulus_streamlines"][frame],
                    levels=10,
                    linewidths=(0.5,),
                    linestyles=("solid",),
                    colors="white",
                )
                self.contours[i][j] = new_contour  # Update the contour reference

        # return self.images + [
        #     c for contour in self.contours for c in contour.collections
        # ]

    def toggle_pause(self, *args, **kwargs):
        if self.paused:
            self.anim.pause()
        else:
            self.anim.resume()

        self.paused = not self.paused
        

In [25]:
def create_folder_pairs() -> dict[str, list[tuple[str, str]]]:
    base_pairs_list: dict[str, list[tuple[str, str]]] = {}

    grids = [
        "grid_1",
        # "grid_2",
        # "grid_3",
        # "grid_4",
    ]
    cases = [
        # "case_1"
        # "case_2"
        # "case_3"
        "case_4"
    ]

    for grid in grids:
        arr = []

        for case in cases:
            folder = f"{grid}-{case}"
            arr.append(
                (
                    os.path.join(os.getcwd(), folder, "params"),
                    os.path.join(os.getcwd(), folder, "results"),
                )
            )

        base_pairs_list.update({grid: arr})

    return base_pairs_list


In [26]:

FOLDER_PAIRS_LIST: dict[str, list[tuple[str, str]]] = create_folder_pairs()

In [27]:

target = "grid_1"

playback = PlayBack(fp=FOLDER_PAIRS_LIST[target], cns=None, grid=[1, 1], plot_data_range=-5)
playback.fig.suptitle(target)

plt.tight_layout()
plt.show()

IndexError: list index out of range