In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import os


In [3]:
from src.utils.logger import Logging
from src.utils.fsi_visualization import create_animations_from_existing_frames


In [4]:
TEST_CHECKPOINT_PATH = os.path.join("result/simple_spectral_test")

logger = Logging(TEST_CHECKPOINT_PATH)
result_dir = logger.get_output_dir()
print(f"{result_dir=}")

result_dir='result/simple_spectral_test/2025-06-16_17-28-46-662299'


In [61]:
def load_model_class(solver_name):
    solver_to_module = {
        "grbf": "src.nn.grbf",
        "bspline": "src.nn.bspline",
        "jacobi": "src.nn.jacobi",
        "chebyshev": "src.nn.chebyshev",
        "param_tanh": "src.nn.tanh_parameterized",
        "tanh": "src.nn.tanh",
        "fourier": "src.nn.fourier",
    }
    module = __import__(solver_to_module[solver_name], fromlist=["PINNKAN"])
    return getattr(module, "PINNKAN")


class DNN(nn.Module):
    """Deep Neural Network for function approximation"""

    def __init__(self, input_dim=1, hidden_dims=[100, 100, 100], output_dim=1):
        super(DNN, self).__init__()

        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.Tanh())
            prev_dim = hidden_dim

        layers.append(nn.Linear(prev_dim, output_dim))

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


def target_function(x):
    """Target function: sin(x) + sin(5x)"""
    return np.sin(x) + x * np.sin(np.exp(x))


def generate_training_data(x_range, n_samples=1000):
    """Generate training data sampled from the target function"""
    x = np.random.uniform(x_range[0], x_range[1], n_samples)
    y = target_function(x)
    return x, y


In [65]:
def train_and_visualize(model_str, model, x_range, device):
    torch.manual_seed(42)
    np.random.seed(42)

    x_train, y_train = generate_training_data(x_range, n_samples=1000)
    x_train_tensor = torch.FloatTensor(x_train).reshape(-1, 1)
    y_train_tensor = torch.FloatTensor(y_train).reshape(-1, 1)

    x_eval = np.linspace(x_range[0], x_range[1], 200)
    y_eval_true = target_function(x_eval)
    x_eval_tensor = torch.FloatTensor(x_eval).reshape(-1, 1)

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.1, patience=5000
    )
    epochs_to_plot = [0, 100, 1000, 10000, 100000]
    total_epochs = 100000
    plot_every = 10000

    for epoch in range(total_epochs + 1):
        if epoch > 0:
            y_pred = model(x_train_tensor.to(device))
            loss = criterion(y_pred, y_train_tensor.to(device))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step(loss)

        if epoch % plot_every == 0:
            fig, axes = plt.subplots(1, 1, figsize=(5, 3))
            with torch.no_grad():
                y_pred_eval = model(x_eval_tensor.to(device)).cpu().numpy().flatten()

                l2_error = np.linalg.norm(y_pred_eval - y_eval_true) / np.linalg.norm(
                    y_eval_true
                )
                print(
                    f"Epoch {epoch:,}, L2 Error: {l2_error:.2e} , learning rate: {scheduler.get_last_lr()[0]:.2e}"
                )

            axes.plot(
                x_eval,
                np.sin(x_eval),
                "gray",
                linewidth=1,
                linestyle="--",
                alpha=0.5,
                label="sin(x)",
            )
            axes.plot(
                x_eval,
                y_eval_true,
                "green",
                linewidth=1,
                alpha=1,
                label=r"Target fn: $sin(x) + xsin(e^x)$",
            )
            axes.plot(
                x_eval,
                y_pred_eval,
                "red",
                linewidth=1,
                alpha=0.8,
                label=str("NN model:" + model_str),
            )

            axes.set_title(f"Epoch: {epoch:,}", fontsize=FONT_SIZE)
            axes.grid(True, alpha=0.3)
            # axes.set_ylim(-2.5, 2.5)
            axes.set_xlim(x_range)

            axes.set_xlabel("x", fontsize=FONT_SIZE)
            axes.set_ylabel("y", fontsize=FONT_SIZE)

            legend_labels = []
            legend_colors = []

            for line in axes.get_lines():
                legend_labels.append(line.get_label())
                legend_colors.append(line.get_color())

            legend = fig.legend(
                legend_labels,
                loc="upper left",
                bbox_to_anchor=(0.965, 0.7),
                ncol=1,
                prop={"size": FONT_SIZE},
                handlelength=0,
            )
            for text, color in zip(legend.get_texts(), legend_colors):
                text.set_color(color)
            legend.get_frame().set_linewidth(1)
            # legend.get_frame().set_facecolor("none")

            # Print progress
            if epoch in epochs_to_plot:
                with torch.no_grad():
                    y_pred = model(x_train_tensor.to(device))
                    loss = criterion(y_pred, y_train_tensor.to(device))
                    print(f"Epoch {epoch:,}, Loss: {loss.item():.6f}")

            plt.tight_layout()
            save_path = os.path.join(
                result_dir, str(epoch) + "_" + model_str + "_simple_spectral_test.png"
            )
            plt.savefig(save_path, bbox_extra_artists=(legend,), bbox_inches="tight")
            plt.close(
                "all",
            )
    return model


if __name__ == "__main__":
    FONT_SIZE = 10
    x_range = (-10, 10)
    print("Training DNN to approximate sin(x) + sin(e^x)...")

    models_strs = [
        "tanh",
    ]
    network = [1, 30, 30, 30, 1]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    for model_str in models_strs:
        print(f"Using model: {model_str}")
        module = load_model_class(model_str)
        model = module(network)
        model.to(device)
        print(
            f"number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
        )
        trained_model = train_and_visualize(model_str, model, x_range, device)


Training DNN to approximate sin(x) + sin(e^x)...
Using model: tanh
number of trainable parameters: 1951
Epoch 0, L2 Error: 1.12e+00 , learning rate: 1.00e-03
Epoch 0, Loss: 10.017882
Epoch 10,000, L2 Error: 1.04e+00 , learning rate: 1.00e-03
Epoch 10,000, Loss: 6.972376


KeyboardInterrupt: 

In [None]:
animations_pred_dir = "./"

create_animations_from_existing_frames(
    frames_dirs=[
        os.path.join("result/simple_spectral_test/2025-06-13_22-57-28-709230"),
    ],
    output_dir=os.path.join(animations_pred_dir, "gif"),
)


Found 51 frames for 2025-06-13_22-57-28-709230
First few files in order:
  0: 0_tanh_simple_spectral_test.png
  1: 1000_tanh_simple_spectral_test.png
  2: 2000_tanh_simple_spectral_test.png
  3: 3000_tanh_simple_spectral_test.png
  4: 4000_tanh_simple_spectral_test.png
  ...
  49: 49000_tanh_simple_spectral_test.png
  50: 50000_tanh_simple_spectral_test.png
Creating animation for 2025-06-13_22-57-28-709230...
Saving animation to ./gif/2025-06-13_22-57-28-709230_animation.gif...
Saving frame 51/51
Animation saved to ./gif/2025-06-13_22-57-28-709230_animation.gif


In [12]:
get_output_dir = "result/simple_spectral_test/2025-06-15_18-08-31-196415"

animations_pred_dir = "./gif"
create_animations_from_existing_frames(
    frames_dirs=[
        os.path.join(get_output_dir),
    ],
    output_dir=os.path.join(animations_pred_dir, "gif"),
)

# remove the gif images from the log directory
for file in os.listdir(logger.get_output_dir()):
    if file.endswith(".png"):
        os.remove(os.path.join(logger.get_output_dir(), file))


logger.print(
    "Training completed. the gif file is saved in " + animations_pred_dir + "gif"
)


Found 5001 frames for 2025-06-15_18-08-31-196415
Creating animation for 2025-06-15_18-08-31-196415...
Saving animation to ./gif/gif/2025-06-15_18-08-31-196415_animation.mp4...
Saving frame 5000/5001

INFO:src.utils.logger:Training completed. the gif file is saved in ./gifgif


Saving frame 5001/5001
Animation saved to ./gif/gif/2025-06-15_18-08-31-196415_animation.mp4


In [18]:
from src.utils.combine_mp4_videos import combine_mp4_videos_grid


animations_combined = "./gif"
combine_mp4_videos_grid(
    input_dir=animations_combined,
    output_path="./gif/all_animations_grid.mp4",
    grid_cols=3,
    scale_width=640,
    scale_height=480,
)

Found 9 MP4 files to combine
Creating 3x3 grid with 0px spacing
Running FFmpeg command...
ffmpeg -y -ss 0 -i ./gif/2025-06-16_15-23-56-550055_animation.mp4 -ss 0 -i ./gif/2025-06-16_15-33-30-611314_animation.mp4 -ss 0 -i ./gif/2025-06-16_15-34-53-200596_animation.mp4 -ss 0 -i ./gif/2025-06-16_15-35-47-393683_animation.mp4 -ss 0 -i ./gif/2025-06-16_15-37-13-361307_animation.mp4 -ss 0 -i ./gif/2025-06-16_16-09-33-188557_animation.mp4 -ss 0 -i ./gif/2025-06-16_16-12-53-067559_animation.mp4 -ss 0 -i ./gif/2025-06-16_16-13-54-042695_animation.mp4 -ss 0 -i ./gif/2025-06-16_16-34-20-361763_animation.mp4 -filter_complex [0:v]scale=640:480:flags=lanczos[v0];[1:v]scale=640:480:flags=lanczos[v1];[2:v]scale=640:480:flags=lanczos[v2];[3:v]scale=640:480:flags=lanczos[v3];[4:v]scale=640:480:flags=lanczos[v4];[5:v]scale=640:480:flags=lanczos[v5];[6:v]scale=640:480:flags=lanczos[v6];[7:v]scale=640:480:flags=lanczos[v7];[8:v]scale=640:480:flags=lanczos[v8];[v0][v1][v2]hstack=inputs=3:shortest=1[row0];[v

ffmpeg version 7.1 Copyright (c) 2000-2024 the FFmpeg developers
  built with gcc 13.3.0 (conda-forge gcc 13.3.0-1)
  configuration: --prefix=/home/vlq26735/anaconda3/envs/pinn_ibm4fsi --cc=/home/conda/feedstock_root/build_artifacts/ffmpeg_1732155191655/_build_env/bin/x86_64-conda-linux-gnu-cc --cxx=/home/conda/feedstock_root/build_artifacts/ffmpeg_1732155191655/_build_env/bin/x86_64-conda-linux-gnu-c++ --nm=/home/conda/feedstock_root/build_artifacts/ffmpeg_1732155191655/_build_env/bin/x86_64-conda-linux-gnu-nm --ar=/home/conda/feedstock_root/build_artifacts/ffmpeg_1732155191655/_build_env/bin/x86_64-conda-linux-gnu-ar --disable-doc --enable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libharfbuzz --enable-libfontconfig --enable-libopenh264 --enable-libdav1d --disable-gnutls --enable-libmp3lame --enable-libvpx --enable-libass --enable-pthreads --enable-vaapi --enable-libopenvino --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --en

Successfully created combined video: ./gif/all_animations_grid.mp4
Output file size: 31.45 MB


[mp4 @ 0x562a92210280] Starting second pass: moving the moov atom to the beginning of the file
[out#0/mp4 @ 0x562a92795400] video:32196KiB audio:0KiB subtitle:0KiB other streams:0KiB global headers:0KiB muxing overhead: 0.040596%
frame= 3001 fps= 28 q=-1.0 Lsize=   32209KiB time=00:01:40.03 bitrate=2637.7kbits/s speed=0.943x    
[libx264 @ 0x562a92236cc0] frame I:13    Avg QP:22.85  size:133994
[libx264 @ 0x562a92236cc0] frame P:2988  Avg QP:33.56  size: 10450
[libx264 @ 0x562a92236cc0] mb I  I16..4: 70.8%  0.0% 29.2%
[libx264 @ 0x562a92236cc0] mb P  I16..4:  0.3%  0.0%  0.7%  P16..4:  5.0%  2.5%  1.0%  0.4%  0.1%    skip:90.0%
[libx264 @ 0x562a92236cc0] coded y,uvDC,uvAC intra: 50.1% 55.1% 50.7% inter: 2.2% 3.6% 2.7%
[libx264 @ 0x562a92236cc0] i16 v,h,dc,p: 73% 24%  3%  0%
[libx264 @ 0x562a92236cc0] i4 v,h,dc,ddl,ddr,vr,hd,vl,hu: 73%  5% 11%  2%  1%  4%  1%  3%  1%
[libx264 @ 0x562a92236cc0] i8c dc,h,v,p: 46%  5% 47%  2%
[libx264 @ 0x562a92236cc0] kb/s:2636.55


In [9]:
ls ./gif

2025-06-16_15-23-56-550055_animation.mp4
2025-06-16_15-33-30-611314_animation.mp4
2025-06-16_15-34-53-200596_animation.mp4
2025-06-16_15-35-47-393683_animation.mp4
2025-06-16_15-37-13-361307_animation.mp4
2025-06-16_15-38-34-697404_animation.mp4
2025-06-16_15-39-33-703083_animation.mp4
2025-06-16_15-40-38-871683_animation.mp4
2025-06-16_16-09-33-188557_animation.mp4
2025-06-16_16-12-53-067559_animation.mp4
2025-06-16_16-13-54-042695_animation.mp4
2025-06-16_16-34-20-361763_animation.mp4
2025-06-16_16-52-05-897785_animation.mp4
[0m[01;34mgif[0m/
