Here we verify that the ODE sampler and task works correctly.

In [1]:
import math
from math import exp
import numpy as np

import jax
import jax.numpy as jnp
from jax import random

import torch
import torch.nn.functional as F

import matplotlib.pyplot as plt

from time import time
import os, sys
sys.path.append("/pscratch/sd/j/jwl50/icl_odes_dev/src/datagen/main")

In [2]:
class DataSampler:
    def __init__(self, n_dims, name=None, **kwargs):
        self.n_dims = n_dims
        self.name = name

    # Sample = (b l d)
    def sample(self, n_points, batch_size):
        raise NotImplementedError
        
class Task:
    def __init__(self, name=None, **kwargs):
        self.name = name

    # For use in training, includes formatting. Output a dictionary of tensors: prefix, x, and y.
    def evaluate(self, x):
        raise NotImplementedError

    # Can be used separately from training, without formatting.
    def out(self, x):
        raise NotImplementedError

    @staticmethod
    def get_metric():
        raise NotImplementedError

    @staticmethod
    def get_training_metric():
        raise NotImplementedError
        
def squared_error(y_pred, y):
    return (y - y_pred).square()

def mean_squared_error(y_pred, y):
    return (y - y_pred).square().mean()

## NP

In [33]:
from util.odes import sample_gp, sample_params, interpolate_cheb, pseudospectral_solve_batch

In [62]:
# Sample parameters for ODE, initial condition, and forcing function
# Solve ODE on [-1, 1] with pseudospectral method
class ODEOperatorSampler(DataSampler):
    def __init__(
        self,
        n_dims,
        c_sampling="equispaced", # equispaced, cheb, randt
        # u_sampling="randt", # randt, vect
        eqn_class=0, # 0, 1, 2
        operator_scale=1,
        gp_length=1,
        u0_scale=1,
        seed=0,
        device="cuda",
        **kwargs
    ):
        super().__init__(
            n_dims=n_dims,
            name="ode_operator",
            **kwargs
        )
        self.c_sampling = c_sampling
        self.gp_length = gp_length
        self.u0_scale = u0_scale

        # Cheb
        self.n_cheb = 40
        self.to_cheb = lambda x : np.cos(np.pi*x)
        self.t_cheb = self.to_cheb(np.linspace(0, 1, self.n_cheb+1))

        # Equation
        self.eqn_class = eqn_class
        self.operator_scale = operator_scale

        self.seed = seed
        self.device = device
        self.generator = np.random.default_rng(seed=seed)

    def sample(self, n_points, batch_size, n_dims_truncated=None):
        n_samples = n_dims_truncated

        # Define t_eval based on sampling
        time_start = time()
        if self.c_sampling == "equispaced":
            t_eval = np.linspace(-1, 1, n_samples)
        elif self.c_sampling == "cheb":
            t_eval = self.to_cheb(np.linspace(0, 1, n_samples))
        elif self.c_sampling == "randt":
            t_eval = self.generator.uniform(-1, 1, n_samples)
        t = np.concatenate([self.t_cheb, t_eval])
        time_end = time()
        print(f"Sampling t_eval time: {time_end-time_start}")

        # Sample c
        time_start = time()
        c = sample_gp(t, shape=(batch_size, n_points), variance=self.gp_length, rng=self.generator) # (batch_size, n_points, len(t))
        c_cheb = c[:, :, :self.n_cheb+1] # (batch_size, n_points, n_cheb+1)
        c_eval = c[:, :, self.n_cheb+1:] # (batch_size, n_points, n_samples)
        time_end = time()
        print(f"Sampling c time: {time_end-time_start}")

        # Sample u0
        time_start = time()
        u0 = self.generator.uniform(-self.u0_scale, self.u0_scale, (batch_size, n_points)) # (batch_size, n_points)
        time_end = time()
        print(f"Sampling u0 time: {time_end-time_start}")

        # Sample equation parameters
        if self.eqn_class in [0, 1, 2]:
            time_start = time()
            a1, a2, a3, a4 = sample_params(
                eqn_class=self.eqn_class,
                size=(batch_size),
                scale=self.operator_scale,
                rng=self.generator
            )
            time_end = time()
            print(f"Sampling eqn params time: {time_end-time_start}")
        else:
            print(f"eqn_class {self.eqn_class} not implemented in ODEOperator")
            raise NotImplementedError

        # Solve for u pseudospectrally on [-1, 1]
        if self.eqn_class in [0, 1, 2]:
            time_start = time()
            a1_exp = np.tile(a1.reshape(-1, 1), (1, n_points))
            a2_exp = np.tile(a2.reshape(-1, 1), (1, n_points))
            a3_exp = np.tile(a3.reshape(-1, 1), (1, n_points))
            a4_exp = np.tile(a4.reshape(-1, 1), (1, n_points))
            u_cheb = pseudospectral_solve_batch(
                a1_exp,
                a2_exp,
                a3_exp,
                a4_exp,
                u0,
                c_cheb,
                c_cheb,
            ) # (batch_size, n_points, n_cheb+1)
            time_end = time()
            print(f"Solving u time: {time_end-time_start}")
            debug_data = {
                "a1_exp": a1_exp,
                "a2_exp": a2_exp,
                "a3_exp": a3_exp,
                "a4_exp": a4_exp,
                "u0": u0,
                "c_cheb": c_cheb,
                "u_cheb": u_cheb,
            }

        # Raw task data
        time_start = time()
        task_data = {}
        task_data["c_cheb"] = c_cheb # (batch_size, n_points, n_cheb+1)
        task_data["c_eval"] = c_eval # (batch_size, n_points, n_samples)
        task_data["u0"] = u0 # (batch_size, n_points)
        task_data["operator_params"] = (a1, a2, a3, a4)
        task_data["u_cheb"] = u_cheb # (batch_size, n_points, N_cheb+1)
        time_end = time()
        print(f"Constructing task_data time: {time_end-time_start}")

        return task_data, debug_data


class ODEOperatorICL(Task):
    def __init__(
        self,
        noise_variance=0,
        seed=0,
        device="cuda",
        **kwargs
    ):
        super().__init__(name="ode_operator", **kwargs)
        self.device = device
        self.rng = np.random.default_rng(seed=seed)
        self.noise_variance = noise_variance

    def evaluate(self, task_data):
        c_cheb = task_data["c_cheb"] # (batch_size, n_points, n_cheb+1)
        c_eval = task_data["c_eval"] # (batch_size, n_points, n_samples)
        B, L, D = c_eval.shape

        # Cheb interpolate onto sampled points
        time_start = time()
        t_eval = self.rng.uniform(-1, 1, B)
        u_cheb = task_data["u_cheb"]
        u_eval = interpolate_cheb(u_cheb, t_eval) # (batch_size, n_points, n_samples)
        u_eval = u_eval[np.arange(B), :, np.arange(B)] # (batch_size, n_points)
        time_end = time()
        print(f"Interpolating u time: {time_end-time_start}")

        if self.noise_variance != 0:
            u_ood = u_eval + self.noise_variance*self.rng.standard_normal(u_eval.shape)
        else:
            u_ood = u_eval

        # Inputs: [(c, u0, 1), u_ood]
        # Outputs: [u_ood]
        time_start = time()
        in_data = torch.zeros((B, 2*L-1, D+3)).to(device=self.device) # (B, 2L-1, D+3)
        in_data[:, ::2, :-3] = torch.tensor(c_eval)
        in_data[:, ::2, -3] = torch.tensor(task_data["u0"])
        in_data[:, ::2, -2] = 1
        in_data[:, 1::2, -1] = torch.tensor(u_ood)[:, :-1]

        out_data = torch.tensor(u_eval).to(device=self.device).unsqueeze(-1) # (B, L, 1)
        time_end = time()
        print(f"Constructing in/out data time: {time_end-time_start}")

        return {
            "in": in_data, # (B, 2L-1, D+3)
            "out": out_data, # (B, L, 1)
        }

    def get_metric(self):
        def metric(out_pred, out):
            assert out_pred[:, ::2, -1:].shape == out.shape
            return squared_error(out_pred[:, ::2, -1:], out)
        return metric

    def get_training_metric(self):
        def training_metric(out_pred, out):
            assert out_pred[:, ::2, -1:].shape == out.shape
            return mean_squared_error(out_pred[:, ::2, -1:], out)
        return training_metric

    # JL 11/18/24 TODO
    def get_eval_metrics(self):
        def eval_metric(out_pred, out, in_data):
            return {}
        return eval_metric

In [63]:
n_dims = 21
n_points = 25
batch_size = 64
seed = 1

sampler = ODEOperatorSampler(
    n_dims,
    c_sampling="equispaced",
    eqn_class=0,
    operator_scale=1,
    u0_scale=1,
    seed=seed,
    device="cuda",
)
task_data, debug_data = sampler.sample(n_points, batch_size, n_dims_truncated=n_dims)
assert task_data["c_cheb"].shape == (batch_size, n_points, 41)
assert task_data["c_eval"].shape == (batch_size, n_points, n_dims)
assert task_data["u0"].shape == (batch_size, n_points)
assert task_data["u_cheb"].shape == (batch_size, n_points, 41)

print(np.linalg.norm(task_data["c_cheb"]))
print(np.linalg.norm(task_data["c_eval"]))
print(np.linalg.norm(task_data["u0"]))
print(np.linalg.norm(task_data["u_cheb"]))

task = ODEOperatorICL(
    noise_variance=0,
    seed=seed,
    device="cuda",
)
out = task.evaluate(task_data)

assert out["in"].shape == (batch_size, 2*n_points-1, n_dims+3)
assert out["out"].shape == (batch_size, n_points, 1)

print(torch.norm(out["in"]))
print(torch.norm(out["out"]))

Sampling t_eval time: 5.53131103515625e-05
Sampling c time: 0.00458979606628418
Sampling u0 time: 1.8596649169921875e-05
Sampling eqn params time: 1.5735626220703125e-05
Solving u time: 0.06065702438354492
Constructing task_data time: 1.430511474609375e-06
251.02713262556887
179.68895480489982
23.169067690613815
698.0256477530816
Interpolating u time: 0.004431724548339844
Constructing in/out data time: 0.0019643306732177734
tensor(194.4633, device='cuda:0')
tensor(58.5164, device='cuda:0', dtype=torch.float64)


In [35]:
u_cheb_np = pseudospectral_solve_batch(
    debug_data["a1_exp"].cpu().numpy(),
    debug_data["a2_exp"].cpu().numpy(),
    debug_data["a3_exp"].cpu().numpy(),
    debug_data["a4_exp"].cpu().numpy(),
    debug_data["u0"].cpu().numpy(),
    debug_data["c_cheb"].cpu().numpy(),
    debug_data["c_cheb"].cpu().numpy(),
)
print(np.linalg.norm(u_cheb_np))

1805.679921053933


## Torch

In [3]:
from util.odes_torch import sample_gp, sample_params, interpolate_cheb, pseudospectral_solve_batch

In [28]:
class ODEOperatorSampler(DataSampler):
    def __init__(
        self,
        n_dims,
        c_sampling="equispaced", # equispaced, cheb, randt
        eqn_class=0, # 0, 1, 2
        operator_scale=1,
        gp_length=1,
        u0_scale=1,
        seed=0,
        device="cuda",
        dtype=torch.float64,
        **kwargs
    ):
        super().__init__(
            n_dims=n_dims,
            name="ode_operator",
            **kwargs
        )
        self.c_sampling = c_sampling
        self.gp_length = gp_length
        self.u0_scale = u0_scale
        self.device = device
        self.dtype = dtype

        # Chebyshev setup
        self.n_cheb = 41
        t = torch.linspace(0, 1, self.n_cheb, device=device, dtype=dtype)
        self.t_cheb = torch.cos(torch.pi * t)

        # Equation setup
        self.eqn_class = eqn_class
        self.operator_scale = operator_scale
        self.seed = seed

    def sample(self, n_points, batch_size, n_dims_truncated=None):
        n_samples = n_dims_truncated
        device, dtype = self.device, self.dtype

        # Define t_eval based on sampling method
        time_start = time()
        if self.c_sampling == "equispaced":
            t_eval = torch.linspace(-1, 1, n_samples, device=device, dtype=dtype)
        elif self.c_sampling == "cheb":
            t = torch.linspace(0, 1, n_samples, device=device, dtype=dtype)
            t_eval = torch.cos(torch.pi * t)
        elif self.c_sampling == "randt":
            t_eval = torch.empty(n_samples, device=device, dtype=dtype).uniform_(-1, 1)
        t = torch.cat([self.t_cheb, t_eval])
        time_end = time()
        # print(f"Sampling t_eval time: {time_end-time_start}")

        # Sample c using Gaussian process
        time_start = time()
        c = sample_gp(
            t, 
            shape=(batch_size, n_points), 
            variance=self.gp_length,
            seed=self.seed,
        ).to(dtype=dtype, device=device) # (batch_size, n_points, len(t))
        c_cheb = c[..., :self.n_cheb] # (batch_size, n_points, n_cheb)
        c_eval = c[..., self.n_cheb:] # (batch_size, n_points, n_samples)
        time_end = time()
        # print(f"Sampling c time: {time_end-time_start}")

        # Sample initial conditions
        time_start = time()
        u0 = torch.empty(batch_size, n_points, device=device, dtype=dtype).uniform_(
            -self.u0_scale, self.u0_scale
        )
        time_end = time()
        # print(f"Sampling u0 time: {time_end-time_start}")

        # Sample equation parameters
        if self.eqn_class in [0, 1, 2]:
            time_start = time()
            a1, a2, a3, a4 = sample_params(
                eqn_class=self.eqn_class,
                size=(batch_size,),
                scale=self.operator_scale,
                device=device
            )
            time_end = time()
            # print(f"Sampling eqn params time: {time_end-time_start}")
        else:
            raise NotImplementedError(f"eqn_class {self.eqn_class} not implemented")

        # Solve ODE using pseudospectral method
        if self.eqn_class in [0, 1, 2]:
            time_start = time()
            # Expand parameters for each point
            a1_exp = a1.unsqueeze(1).expand(-1, n_points)
            a2_exp = a2.unsqueeze(1).expand(-1, n_points)
            a3_exp = a3.unsqueeze(1).expand(-1, n_points)
            a4_exp = a4.unsqueeze(1).expand(-1, n_points)
            
            u_cheb = pseudospectral_solve_batch(
                a1_exp, a2_exp, a3_exp, a4_exp,
                u0, c_cheb, c_cheb,
                device=device,
                dtype=dtype
            ) # (batch_size, n_points, n_cheb)
            time_end = time()
            # print(f"Solving u time: {time_end-time_start}")
            
            debug_data = {
                "a1_exp": a1_exp,
                "a2_exp": a2_exp,
                "a3_exp": a3_exp,
                "a4_exp": a4_exp,
                "u0": u0,
                "c_cheb": c_cheb,
                "u_cheb": u_cheb,
            }

        # Package results
        time_start = time()
        task_data = {
            "c_cheb": c_cheb,
            "c_eval": c_eval,
            "u0": u0,
            "operator_params": (a1, a2, a3, a4),
            "u_cheb": u_cheb
        }
        time_end = time()
        # print(f"Constructing task_data time: {time_end-time_start}")

        return task_data, debug_data

class ODEOperatorICL(Task):
    def __init__(
        self,
        noise_variance=0,
        seed=0,
        device="cuda",
        dtype=torch.float64,
        **kwargs
    ):
        super().__init__(name="ode_operator", **kwargs)
        self.device = device
        self.dtype = dtype
        self.noise_variance = noise_variance
        self.seed = seed
        if seed is not None:
            torch.manual_seed(seed)

    def evaluate(self, task_data):
        device, dtype = self.device, self.dtype
        c_cheb = task_data["c_cheb"]  # (batch_size, n_points, n_cheb+1)
        c_eval = task_data["c_eval"]  # (batch_size, n_points, n_samples)
        B, L, D = c_eval.shape

        # Interpolate solution onto sampled points
        time_start = time()
        t_eval = torch.empty(B, device=device, dtype=dtype).uniform_(-1, 1)
        u_cheb = task_data["u_cheb"]
        u_eval = interpolate_cheb(u_cheb, t_eval)  # (batch_size, n_points, n_samples)
        # Select diagonal elements
        idx = torch.arange(B, device=device)
        u_eval = u_eval[idx, :, idx]  # (batch_size, n_points)
        time_end = time()
        # print(f"Interpolating u time: {time_end-time_start}")

        # Add noise if specified
        if self.noise_variance != 0:
            noise = torch.randn_like(u_eval) * self.noise_variance
            u_ood = u_eval + noise
        else:
            u_ood = u_eval

        # Construct input and output data
        time_start = time()
        in_data = torch.zeros(B, 2*L-1, D+3, device=device, dtype=dtype)
        in_data[:, ::2, :-3] = c_eval
        in_data[:, ::2, -3] = task_data["u0"]
        in_data[:, ::2, -2] = 1
        in_data[:, 1::2, -1] = u_ood[:, :-1]

        out_data = u_eval.unsqueeze(-1)  # (B, L, 1)
        time_end = time()
        # print(f"Constructing in/out data time: {time_end-time_start}")

        return {
            "in": in_data,  # (B, 2L-1, D+3)
            "out": out_data,  # (B, L, 1)
        }

    def get_metric(self):
        def metric(out_pred, out):
            assert out_pred[:, ::2, -1:].shape == out.shape
            return squared_error(out_pred[:, ::2, -1:], out)
        return metric

    def get_training_metric(self):
        def training_metric(out_pred, out):
            assert out_pred[:, ::2, -1:].shape == out.shape
            return mean_squared_error(out_pred[:, ::2, -1:], out)
        return training_metric

    def get_eval_metrics(self):
        def eval_metric(out_pred, out, in_data):
            return {}
        return eval_metric

In [29]:
n_dims = 21
n_points = 25
batch_size = 64
seed = 3

sampler = ODEOperatorSampler(
    n_dims,
    c_sampling="equispaced",
    eqn_class=0,
    operator_scale=1,
    u0_scale=1,
    seed=seed,
    device="cuda",
    dtype=torch.float64,
)
task_data, debug_data = sampler.sample(n_points, batch_size, n_dims_truncated=n_dims)
assert task_data["c_cheb"].shape == (batch_size, n_points, 41)
assert task_data["c_eval"].shape == (batch_size, n_points, n_dims)
assert task_data["u0"].shape == (batch_size, n_points)
assert task_data["u_cheb"].shape == (batch_size, n_points, 41)

print(torch.norm(task_data["c_cheb"]))
print(torch.norm(task_data["c_eval"]))
print(torch.norm(task_data["u0"]))
print(torch.norm(task_data["u_cheb"]))

task = ODEOperatorICL(
    noise_variance=0,
    seed=seed,
    device="cuda",
    dtype=torch.float64,
)
out = task.evaluate(task_data)

assert out["in"].shape == (batch_size, 2*n_points-1, n_dims+3)
assert out["out"].shape == (batch_size, n_points, 1)

print(torch.norm(out["in"]))
print(torch.norm(out["out"]))

tensor(258.4625, device='cuda:0', dtype=torch.float64)
tensor(184.7523, device='cuda:0', dtype=torch.float64)
tensor(22.8554, device='cuda:0', dtype=torch.float64)
tensor(1805.6799, device='cuda:0', dtype=torch.float64)
tensor(212.0845, device='cuda:0', dtype=torch.float64)
tensor(94.4553, device='cuda:0', dtype=torch.float64)


In [67]:
print(debug_data["a1_exp"].shape)
print(debug_data["a2_exp"].shape)
print(debug_data["a3_exp"].shape)
print(debug_data["a4_exp"].shape)
print(debug_data["u0"].shape)
print(debug_data["c_cheb"].shape)

(64, 25)
(64, 25)
(64, 25)
(64, 25)
(64, 25)
(64, 25, 41)


In [32]:
u_cheb_torch = pseudospectral_solve_batch(
    torch.tensor(debug_data["a1_exp"]),
    torch.tensor(debug_data["a2_exp"]),
    torch.tensor(debug_data["a3_exp"]),
    torch.tensor(debug_data["a4_exp"]),
    torch.tensor(debug_data["u0"]),
    torch.tensor(debug_data["c_cheb"]),
    torch.tensor(debug_data["c_cheb"]),
    device="cpu",
    dtype=torch.float64,
)
print(torch.norm(u_cheb_torch))

tensor(1805.6799, device='cuda:0', dtype=torch.float64)


  torch.tensor(debug_data["a1_exp"]),
  torch.tensor(debug_data["a2_exp"]),
  torch.tensor(debug_data["a3_exp"]),
  torch.tensor(debug_data["a4_exp"]),
  torch.tensor(debug_data["u0"]),
  torch.tensor(debug_data["c_cheb"]),
  torch.tensor(debug_data["c_cheb"]),


In [69]:
u_cheb_torch.shape

torch.Size([64, 25, 41])

In [70]:
debug_data["u_cheb"].shape

(64, 25, 41)

In [71]:
np.max(np.abs(debug_data["u_cheb"] - u_cheb_torch.numpy()))

4.668795760665034e-06

## Logs

```
NP
Sampling t_eval time: 5.9604644775390625e-05
Sampling c time: 0.005355358123779297
Sampling u0 time: 4.00543212890625e-05
Sampling eqn params time: 1.9073486328125e-05
Solving u time: 0.10842704772949219
Constructing task_data time: 1.6689300537109375e-06
256.90702322639004
183.5376963813352
23.04420699960258
545.8161496831693
Interpolating u time: 0.00507354736328125
Constructing in/out data time: 0.9835073947906494
tensor(198.0279, device='cuda:0')
tensor(59.1095, device='cuda:0', dtype=torch.float64)

Torch
Sampling t_eval time: 7.152557373046875e-05
Sampling c time: 0.01267552375793457
Sampling u0 time: 6.461143493652344e-05
Sampling eqn params time: 8.58306884765625e-05
Solving u time: 0.0841209888458252
Constructing task_data time: 7.152557373046875e-07
tensor(253.5806, device='cuda:0')
tensor(181.5987, device='cuda:0')
tensor(23.2177, device='cuda:0')
tensor(2164.1528, device='cuda:0')
Interpolating u time: 0.0015714168548583984
Constructing in/out data time: 0.00010728836059570312
tensor(233.0792, device='cuda:0')
tensor(138.7715, device='cuda:0')
```

In [28]:
debug_data

{'a1_exp': array([[1.35087079, 1.35087079, 1.35087079, ..., 1.35087079, 1.35087079,
         1.35087079],
        [0.76731776, 0.76731776, 0.76731776, ..., 0.76731776, 0.76731776,
         0.76731776],
        [1.05889367, 1.05889367, 1.05889367, ..., 1.05889367, 1.05889367,
         1.05889367],
        ...,
        [1.02783707, 1.02783707, 1.02783707, ..., 1.02783707, 1.02783707,
         1.02783707],
        [1.10884436, 1.10884436, 1.10884436, ..., 1.10884436, 1.10884436,
         1.10884436],
        [0.52515809, 0.52515809, 0.52515809, ..., 0.52515809, 0.52515809,
         0.52515809]]),
 'a2_exp': array([[1.00019229, 1.00019229, 1.00019229, ..., 1.00019229, 1.00019229,
         1.00019229],
        [1.11113001, 1.11113001, 1.11113001, ..., 1.11113001, 1.11113001,
         1.11113001],
        [0.79960001, 0.79960001, 0.79960001, ..., 0.79960001, 0.79960001,
         0.79960001],
        ...,
        [1.4817649 , 1.4817649 , 1.4817649 , ..., 1.4817649 , 1.4817649 ,
         1.481