In [1]:
import os, sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch

from torch.utils.data import DataLoader

In [2]:
sys.path.append('..')
from src.data.dataset import GWDataset, GWGridDataset, Normalize
from src.model.handler import ModelHandler
from src.model.neuralop.fno import FNO
from src.model.neuralop.losses import LpLoss, H1Loss

In [3]:
# base_data_dir = '/srv/scratch/z5370003/projects/data/groundwater/FEFLOW/coastal/variable_density/'
base_data_dir = '/Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Projects/01_PhD/05_groundwater/data/FEFLOW/variable_density/'
interpolated_data_dir = os.path.join(base_data_dir, 'interpolated')

In [4]:
_mean = np.array([-0.5474])
_std = np.array([0.6562])
input_transform = Normalize(mean=_mean, std=_std)
output_transform = Normalize(mean=_mean, std=_std)

in_window_size = 5
out_window_size = 5
val_ratio = 0.3
batch_size = 32

fill_value = -1

In [5]:
train_ds = GWGridDataset(data_path=interpolated_data_dir,
                         dataset='train', val_ratio=val_ratio,
                         in_window_size=in_window_size,
                         out_window_size=out_window_size,
                         input_transform=input_transform,
                         output_transform=output_transform,
                         fillval=fill_value)

train_dl = DataLoader(train_ds, batch_size=batch_size, 
                      shuffle=False, pin_memory=True)

In [6]:
train_ds._data.mean(), train_ds._data.std()

(np.float64(-0.5474661404044111), np.float64(0.6561926480601338))

In [7]:
train_ds._data.shape

(1336, 40, 40, 40)

In [8]:
val_ds = GWGridDataset(data_path=interpolated_data_dir,
                         dataset='val', val_ratio=val_ratio,
                         in_window_size=in_window_size,
                         out_window_size=out_window_size,
                         input_transform=input_transform,
                         output_transform=output_transform,
                         fillval=fill_value)

val_dl = DataLoader(val_ds, batch_size=batch_size, 
                      shuffle=False, pin_memory=True)

len(val_dl)

18

In [9]:
val_ds._data.shape[0]

573

In [10]:
val_ds._data.shape[0] + train_ds._data.shape[0]

1909

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Found following device: {device}")

Found following device: cpu


In [12]:
# l2loss = LpLoss(d=3, p=2)
# h1loss = H1Loss(d=3)

# train_loss = h1loss
# eval_losses={'h1': h1loss, 'l2': l2loss}

# Model configuration
n_modes = (16, 16, 16)
in_channels = in_window_size
out_channels = out_window_size
hidden_channels = 64
projection_channels = 64
# scheduler_interval = 10

In [13]:
model = FNO(n_modes=n_modes, in_channels=in_channels, 
            out_channels=out_channels,
            hidden_channels=hidden_channels, 
            projection_channels=projection_channels).double()

In [18]:
results_path = '/srv/scratch/z5370003/projects/04_groundwater/variable_density/results/FNO/20250529_184905'
results_path = '/Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Projects/01_PhD/05_groundwater/results/'
model_path = os.path.join(results_path, 'savedmodel_fno')
model.load_state_dict(torch.load(model_path, weights_only=True, map_location=device))

RuntimeError: Error(s) in loading state_dict for FNO:
	Missing key(s) in state_dict: "fno_blocks.fno_skips.0.weight", "fno_blocks.fno_skips.1.weight", "fno_blocks.fno_skips.2.weight", "fno_blocks.fno_skips.3.weight", "fno_blocks.convs.0.bias", "fno_blocks.convs.0.weight.0.tensor", "fno_blocks.convs.1.bias", "fno_blocks.convs.1.weight.0.tensor", "fno_blocks.convs.2.bias", "fno_blocks.convs.2.weight.0.tensor", "fno_blocks.convs.3.bias", "fno_blocks.convs.3.weight.0.tensor". 
	Unexpected key(s) in state_dict: "convs.bias", "convs.weight.0.tensor", "convs.weight.1.tensor", "convs.weight.2.tensor", "convs.weight.3.tensor", "fno_skips.0.weight", "fno_skips.1.weight", "fno_skips.2.weight", "fno_skips.3.weight". 

In [17]:
os.path.exists('/Users/arpitkapoor/Library/CloudStorage/OneDrive-UNSW/Shared/Projects/01_PhD/05_groundwater/results')

True

In [15]:
model_handler = ModelHandler(model=model, device=device)

# Generate predictions
preds = np.array(model_handler.predict(val_dl))
preds = output_transform.inverse_transform(preds)

# Get targets
targets = model_handler.get_targets(val_dl)
targets = output_transform.inverse_transform(targets)

100%|██████████| 18/18 [00:07<00:00,  2.42it/s]


In [16]:
from torchsummary import summary

In [23]:
summary(model.float(), val_ds[0][0].shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1      [-1, 256, 40, 40, 40]           1,536
            Conv3d-2       [-1, 64, 40, 40, 40]          16,448
               MLP-3       [-1, 64, 40, 40, 40]               0
            Conv3d-4       [-1, 64, 40, 40, 40]           4,096
      SpectralConv-5       [-1, 64, 40, 40, 40]             256
            Conv3d-6       [-1, 64, 40, 40, 40]           4,096
      SpectralConv-7       [-1, 64, 40, 40, 40]             256
            Conv3d-8       [-1, 64, 40, 40, 40]           4,096
      SpectralConv-9       [-1, 64, 40, 40, 40]             256
           Conv3d-10       [-1, 64, 40, 40, 40]           4,096
     SpectralConv-11       [-1, 64, 40, 40, 40]             256
           Conv3d-12       [-1, 64, 40, 40, 40]           4,160
           Conv3d-13        [-1, 5, 40, 40, 40]             325
              MLP-14        [-1, 5, 40,

In [29]:
preds[targets == fill_value] = np.nan
targets[targets == fill_value] = np.nan

In [30]:
def plot_2d_projection(x_grid, y_grid, z_grid, values, vmin=None, vmax=None, title=None, cmap='viridis'):

    # Create a figure with 3 subplots for different slices
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

    if title is not None:
        fig.suptitle(title, fontsize=16)

    # If vmin or vmax is not provided, use the min and max of the values
    if vmin is None:
        vmin = np.nanmin(values)
    if vmax is None:
        vmax = np.nanmax(values)

    
    # YZ plane (constant X)
    for ix in range(values.shape[0]):
        im1 = ax1.imshow(values[ix,:,:].T, aspect='auto', 
                        extent=[y_grid[0], y_grid[-1], z_grid[0], z_grid[-1]],
                        origin='lower', cmap=cmap, alpha=0.1, vmin=vmin, vmax=vmax)
    ax1.set_title('YZ plane (constant X)')
    ax1.set_xlabel('Y')
    ax1.set_ylabel('Z')
    plt.colorbar(im1, ax=ax1)
    
    # XZ plane (constant Y)
    for iy in range(values.shape[1]):
        im2 = ax2.imshow(values[:,iy,:].T, aspect='auto',
                        extent=[x_grid[0], x_grid[-1], z_grid[0], z_grid[-1]],
                        origin='lower', cmap=cmap, alpha=0.1, vmin=vmin, vmax=vmax)
    ax2.set_title('XZ plane (constant Y)')
    ax2.set_xlabel('X')
    ax2.set_ylabel('Z')
    plt.colorbar(im2, ax=ax2)
    
    # XY plane (constant Z)
    for iz in range(values.shape[2]):
        im3 = ax3.imshow(values[:,:,iz].T, aspect='auto',
                        extent=[x_grid[0], x_grid[-1], y_grid[0], y_grid[-1]],
                        origin='lower', cmap=cmap, alpha=0.1, vmin=vmin, vmax=vmax)
    ax3.set_title('XY plane (constant Z)')
    ax3.set_xlabel('X')
    ax3.set_ylabel('Y')
    plt.colorbar(im3, ax=ax3)
    
    plt.tight_layout()

    return fig

In [31]:
# preds[np.isclose(preds, -1, atol=5e-3)] = -1
# preds

In [32]:
targets_path = os.path.join(results_path, 'targets')
preds_path = os.path.join(results_path, 'preds')

os.makedirs(targets_path, exist_ok=True)
os.makedirs(preds_path, exist_ok=True)

In [33]:
errors_path = os.path.join(results_path, 'errors')
os.makedirs(errors_path, exist_ok=True)

In [34]:
from tqdm import trange

vmin = np.nanmin(targets)
vmax = np.nanmax(targets)

max_error = np.nanmax(np.abs(preds[:, 0] - targets[:, 0]))


for t in trange(targets.shape[0]):

    # Plot targets
    target_fig = plot_2d_projection(val_ds.x_grid, val_ds.y_grid, val_ds.z_grid, 
                    targets[t, 0], vmin=vmin, vmax=vmax, title='Targets')
    target_fig.savefig(os.path.join(targets_path, f'{str(t).zfill(4)}.png'))
    plt.close(target_fig)


    # Plot predictions
    pred_fig = plot_2d_projection(val_ds.x_grid, val_ds.y_grid, val_ds.z_grid, 
                    preds[t, 0], vmin=vmin, vmax=vmax, title='Predictions')
    
    pred_fig.savefig(os.path.join(preds_path, f'{str(t).zfill(4)}.png'))
    plt.close(pred_fig)


    # Plot errors
    error_fig = plot_2d_projection(val_ds.x_grid, val_ds.y_grid, val_ds.z_grid, 
                                   preds[t, 0] - targets[t, 0], title='Error (Preds - Targets)', 
                                   vmin=-max_error, vmax=max_error, cmap='coolwarm')
    error_fig.savefig(os.path.join(errors_path, f'{str(t).zfill(4)}.png'))
    plt.close(error_fig)

100%|██████████| 563/563 [31:18<00:00,  3.34s/it]


In [35]:
import cv2
from tqdm import tqdm
import os

In [36]:
video_name = os.path.join(results_path, 'hydraulic_head_error.avi')

frames = [f for f in sorted(os.listdir(targets_path)) if not f.startswith('.')]

# Configure frame paths
target_frame_path = os.path.join(targets_path, frames[0])
pred_frame_path = os.path.join(preds_path, frames[0])
error_frame_path = os.path.join(errors_path, frames[0])

# Read frames from file
target_frame = cv2.imread(target_frame_path)
pred_frame = cv2.imread(pred_frame_path)
error_frame = cv2.imread(error_frame_path)

# vertically concatenate images
combined_frame = cv2.vconcat([target_frame, pred_frame, error_frame])

# Configure video writer
height, width, layers = combined_frame.shape
video = cv2.VideoWriter(video_name, 0, 4, (width//2, height//2))


for frame in tqdm(frames):
    
    # Configure frame paths
    target_frame_path = os.path.join(targets_path, frame)
    pred_frame_path = os.path.join(preds_path, frame)
    error_frame_path = os.path.join(errors_path, frame)
    
    # Read frames from file
    target_frame = cv2.imread(target_frame_path)
    pred_frame = cv2.imread(pred_frame_path)
    error_frame = cv2.imread(error_frame_path)
    
    # vertically concatenate images
    combined_frame = cv2.vconcat([target_frame, pred_frame, error_frame])
    combined_frame = cv2.resize(combined_frame, (width//2, height//2))

    # Write to file
    video.write(combined_frame)

# Cleanup
cv2.destroyAllWindows()
video.release()

100%|██████████| 563/563 [00:18<00:00, 29.78it/s]
