In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal
import torch
import torch.nn.functional as F
import matplotlib.ticker as plticker
import time
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tqdm import tqdm
from collections import defaultdict
from typing import Optional
from scipy.integrate import solve_ivp
from scipy.fftpack import diff as psdiff

In [5]:
def generate_params() -> (int, np.ndarray, np.ndarray, np.ndarray):
    """
    Returns parameters for initial conditions.
    Args:
        None
    Returns:
        int: number of Fourier series terms
        np.ndarray: amplitude of different sine waves
        np.ndarray: phase shift of different sine waves
        np.ndarray: frequency of different sine waves
    """
    N = 10
    lmin, lmax = 1, 3
    A = (np.random.rand(1, N) - 0.5)
    phi = 2.0*np.pi*np.random.rand(1, N)
    l = np.random.randint(lmin, lmax, (1, N))
    return (N, A, phi, l)

def initial_conditions(x: np.ndarray, L: int, params: Optional[list]=None) -> np.ndarray:
    """
    Return initial conditions based on initial parameters.
    Args:
        x (np.ndarray): input array of spatial grid
        L (float): length of the spatial domain
        params (Optinal[list]): input parameters for generating initial conditions
    Returns:
        np.ndarray: initial condition
    """
    if params is None:
        params = generate_params()
    N, A, phi, l = params   
    u = np.sum(A * np.sin((2 * np.pi * l * x[:, None] / L ) + phi), -1)
    return u

In [6]:
def ks_pseudospectral_reconstruction(t: float, u: np.ndarray, L: float) -> np.ndarray:
        """
        Pseudospectral reconstruction of the spatial derivatives of the KS equation, discretized in x.
        Args:
            t (float): time point
            u (np.ndarray): 1D input field
            L (float): length of the spatial domain
        Returns:
            np.ndarray: reconstructed pseudospectral time derivative
        """
        # Compute the x derivatives using the pseudo-spectral method.
        ux = psdiff(u, period=L)
        uxx = psdiff(u, period=L, order=2)
        uxxxx = psdiff(u, period=L, order=4)
        # Compute du/dt.
        dudt = - u*ux - uxx - uxxxx
        return dudt

def to_coords(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """
    Transforms the coordinates to a tensor X of shape [time, space, 2].
    Args:
        x: spatial coordinates
        t: temporal coordinates
    Returns:
        torch.Tensor: X[..., 0] is the space coordinate (in 2D)
                      X[..., 1] is the time coordinate (in 2D)
    """
    x_, t_ = torch.meshgrid(x, t)
    x_, t_ = x_.T, t_.T
    return torch.stack((x_, t_), -1)

L = 25 #64 #128
N = 25 #64 #2**7
x = np.linspace(0, (1-1.0/N)*L, N)
print(f"On interval [{x.min()}, {x.max()}] ")
# Set the tolerance of the solver
tol = 1e-6

# Set the initial conditions.
u0 = initial_conditions(x, L)

# Set the time sample grid.
T = 25 # 100.
t = np.linspace(0, T, T)
X = to_coords(torch.tensor(x), torch.tensor(t))

# Compute the solution using kdv_pseudospectral as spatial solver
sol_example = solve_ivp(fun=ks_pseudospectral_reconstruction, 
                               t_span=[t[0], t[-1]], 
                               y0=u0, 
                               method='Radau', 
                               t_eval=t, 
                               args=(L,), 
                               atol=tol, 
                               rtol=tol)



On interval [0.0, 24.0] 


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [7]:
from torchvision import datasets
from tqdm import tqdm

L = 25 #64 #128
N = 25 #64 #2**7
x = np.linspace(0, (1-1.0/N)*L, N)
print(f"On interval [{x.min()}, {x.max()}] ")
# Set the tolerance of the solver
tol = 1e-6

# Set the time sample grid.
T = 25 # 100.
t = np.linspace(0, T, T)
X = to_coords(torch.tensor(x), torch.tensor(t))


class RandKS(datasets.VisionDataset): 
    def __init__(self, *args,  dataseed=0, N=1000, size=(25, 25), max_x_shift=0.5, max_velocity=1., train=True, **kwargs):

        super().__init__(*args, **kwargs)
        torch.manual_seed(dataseed) 
        
        self.data = torch.zeros((N, 1, size[0], size[1]))
        self.data_aug = torch.zeros((N, 1, size[0], size[1]))
        #self.params = torch.zeros((N, 3))

        with torch.no_grad():  
            for idx in tqdm(range(N)): 
            # TODO: ensure that fixed seed is used
                A, phi, l = params() 
          #      print(A.shape, phi.shape, l.shape)

                u0 = initial_conditions(A, phi, l, L)(x[:, None])

                sol_example = solve_ivp(fun=ks_pseudospectral_reconstruction, 
                                   t_span=[t[0], t[-1]], 
                                   y0=u0, 
                                   method='Radau', 
                                   t_eval=t, 
                                   args=(L,), 
                                   atol=tol, 
                                   rtol=tol)
                
                #self.data[idx] = torch.tensor(sol_example.y.T[::-1].copy())
                self.data[idx] = torch.tensor(sol_example.y.T.copy())
                
                sample = (torch.tensor(sol_example.y.T), X)

                sol = SpaceTranslate(max_x_shift=max_x_shift)(sample=sample, shift='fourier')

                soln, Xn = sol[0], sol[1]

                self.data_aug[idx] = soln # [::-1]
         #       self.params[idx] = torch.tensor([A, phi, l])


    def __getitem__(self, idx): 
        return (self.data[idx], self.data_aug[idx]), 1. # TODO: needs param for problem 
    
    def __len__(self): 
        return len(self.data)     

On interval [0.0, 24.0] 


In [8]:
train_loader = torch.load( "ks_space_train5000.pt")
test_loader = torch.load( "ks_space_test1000.pt")


In [9]:
next(iter(train_loader))

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/ai/miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/ai/miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/ai/miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_69857/416868605.py", line 58, in __getitem__
    return (self.data[idx], self.data_aug[idx]), 1. # TODO: needs param for problem
RuntimeError: CUDA error: initialization error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

