In [None]:
from the_well.data import WellDataset
from torch.utils.data import DataLoader
from the_well.utils.download import well_download
import torch
import os
import numpy as np
import matplotlib.pyplot as plt

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BASE_PATH = './datasets'
DATASET_NAME = 'helmholtz_staircase'

In [None]:
# is BASE_PATH + DATASET_NAME exists, then skip download
if os.path.exists(os.path.join(BASE_PATH, DATASET_NAME)):
    print(f"Dataset {DATASET_NAME} already exists at {BASE_PATH}. Skipping download.")
else:
    # Download the dataset
    print(f"Downloading {DATASET_NAME} dataset...")
    well_download(base_path=BASE_PATH, dataset=DATASET_NAME)
    print(f"Dataset {DATASET_NAME} downloaded to {BASE_PATH}.")


In [None]:
dataset = WellDataset(
    well_base_path=BASE_PATH,
    well_dataset_name=DATASET_NAME,
    well_split_name='train',
    n_steps_input=50,
    n_steps_output=0,

)

# Quickly checking what data changes from one trrajectory to another

In [None]:
# Compare the outputs of dataset[0] and dataset[1]
output_0 = dataset[99]
output_1 = dataset[100]

# Find keys with non-zero differences
keys_with_differences = []
for key in output_0.keys():
    if not torch.equal(output_0[key], output_1[key]):
        keys_with_differences.append(key)

        # Print the non-zero differences
        diff = output_0[key] - output_1[key]
        print(f"Key: {key}, Difference: {diff.sum().item()}")


print("Keys with non-zero differences:", keys_with_differences)

## Plot Picture

Create Gif

In [None]:
from matplotlib.animation import PillowWriter, FuncAnimation

def plot_step(dataset, trajectory_idx, field_idx, time_idx, save_dir, vrange='auto', ax=None):
    # Get the trajectory data
    images = dataset[trajectory_idx]['input_fields'][:, 256+128:(256*3)+1+128, 128:, field_idx]
    field_name = dataset.metadata.field_names[0][field_idx]
    abs_max = images.abs().max()

    if vrange == 'auto':
        vmin, vmax = -abs_max, abs_max
    elif vrange is None:
        vmin, vmax = None, None
    else:
        vmin, vmax = vrange


    # Create a figure and axis
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 12))

    # Plot the first frame
    im = ax.imshow(images[time_idx].T, cmap='bwr', vmin=vmin, vmax=vmax)
    ax.set_title(f"Trajectory {trajectory_idx}, Field {field_name}, Time {time_idx}")

    #Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.01, pad=0.04)
    cbar.ax.set_ylabel(field_name, rotation=270, labelpad=15)
    ax.axis('off')

    if save_dir is not None:
        # Save the figure
        output_path = os.path.join(save_dir, f"T{trajectory_idx}_{field_name}_time_{time_idx}.png")
        plt.savefig(output_path)
        print(f"Image saved at {output_path}")
    return ax

def generate_gif(dataset, trajectory_idx: int, field_idx: int, save_dir):

    images = dataset[trajectory_idx]['input_fields'][0, :, :, field_idx]
    field_name = dataset.metadata.field_names[0][field_idx]
    omega = dataset[trajectory_idx]['constant_scalars'].item()
    vmin, vmax = images.min(), images.max()

    # Prepare the figure
    fig, ax = plt.subplots(figsize=(12, 12))
    im = ax.imshow(images.T, cmap='bwr', vmin=vmin, vmax=vmax)
    #Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.ax.set_ylabel(field_name, rotation=270, labelpad=15)

    # Reduce size of colorbar
    cbar.ax.yaxis.set_label_coords(1.5, 0.5)
    ax.set_title(f"Frame 0")

    ax.axis('off')

    # Update function for the animation
    def update(frame):
        im.set_data(dataset[trajectory_idx]['input_fields'][frame, :, :, field_idx].T)
        ax.set_title(f"Frame {frame}")
        return [im]

    # Create the animation
    num_frames = dataset[trajectory_idx]['input_fields'].shape[0]
    ani = FuncAnimation(fig, update, frames=num_frames, blit=True)

    # Save the animation as a GIF
    output_path = os.path.join(save_dir, f"T{trajectory_idx}_{field_name}_omega_{omega:.3f}.gif")
    ani.save(output_path, writer=PillowWriter(fps=1))
    print(f"GIF saved at {output_path}")



In [None]:
fig, axes = plt.subplots(8, 4, figsize=(40, 40))
for i in range(len(axes.flat)):
    traj_idx = 352 + i
    plot_step(dataset, trajectory_idx=traj_idx, field_idx=1, time_idx=40, save_dir=BASE_PATH, vrange=(-0.6, 0.6), ax=axes.flat[i])
plt.tight_layout()


In [None]:
cnt = 0
for i in range(0, len(dataset), 32):

    save_dir = os.path.join(BASE_PATH, DATASET_NAME, 'gifs', f'trajectory_{i}')
    if not os.path.exists(save_dir):
        print(f"Creating directory {save_dir}")
        os.makedirs(save_dir)

    generate_gif(dataset, trajectory_idx=i, field_idx=0, save_dir=save_dir)
    generate_gif(dataset, trajectory_idx=i, field_idx=1, save_dir=save_dir)

    cnt+=1
    if cnt > 10:
        break