### In this notebook, we run experiments for in-context ODEs. We compare the performance of the Transformer, BaseConv, GD, and least squares.

In [1]:
# Imports
import sys, os
sys.path.append("/scratch/precision-ls/")
import numpy as np
import torch
import matplotlib.pyplot as plt

from src.models import get_model_from_run
from src.datagen.main import get_data_sampler, get_task_sampler

device = "cuda" if torch.cuda.is_available() else "cpu"

# Save dir
save_dir = "/scratch/precision-ls/notebooks"

### Load models

In [16]:
# Folder containing model checkpoints
run_dir = "/scratch/precision-ls/models/"

# Transformer ODE model checkpoints
run_ids = {
    4: "task=odesiclfinal/14c81a1c-960a-4b14-818f-092b3646175e",
    8: "task=odesiclfinal/576ade63-0753-4b28-aa1d-f4b485be83fe",
    12: "task=odesiclfinal/13eaab29-6a9b-4c62-8d2c-6ba5833c88b1",
    16: "task=odesiclfinal/95857faa-b95a-4406-98a7-218ed76f1184",
    24: "task=odesiclfinal/1fe0fc66-711a-4556-8110-df1f5a8931d9",
}

run_paths = {key: os.path.join(run_dir, run_ids[key]) for key in run_ids.keys()}

# BaseConv gradient descent model checkpoints
gd_run_ids = {
    "gd": "task=explicitgradient/bbfcc74c-91ea-4879-9eae-0d5c32c63e61",
}

gd_run_paths = {key: os.path.join(run_dir, gd_run_ids[key]) for key in gd_run_ids.keys()}

In [3]:
# Eval model
batch_size = 128

def get_model(key, run_paths=run_paths):
    model, conf = get_model_from_run(run_paths[key])
    model = model.to(device=device)
    print(model)
    
    if "data_kwargs" not in conf.training.keys():
        conf.training.data_kwargs = {}
    if "task_kwargs" not in conf.training.keys():
        conf.training.task_kwargs = {}
    
    n_dims = conf.model.n_dims
    
    conf.training.data_kwargs["eqn_class"] = 1
    
    data_sampler = get_data_sampler(conf.training.data, n_dims=n_dims, batch_size=batch_size, n_points=conf.training.curriculum.points.end, **conf.training.data_kwargs)
    task_sampler = get_task_sampler(
        conf.training.task,
        n_dims=n_dims,
        batch_size=batch_size,
        n_points=conf.training.curriculum.points.end,
        **conf.training.task_kwargs
    )
    task = task_sampler()
    metric = task.get_training_metric()
    eval_metric = task.get_metric()
    
    return model, conf, data_sampler, task, metric, eval_metric

def eval_model(key):
    model, conf, data_sampler, task, metric, eval_metric = get_model(key)
    model = model.to(device=device)
    
    # Sample data
    data_sample = data_sampler.sample(
        batch_size=batch_size,
        n_points=conf.training.curriculum.points.end,
        n_dims_truncated=conf.training.curriculum.dims.end,
    )

    # Write task data
    task_data = task.evaluate(data_sample)
    xs, ys = task_data["in"].to(device=device), task_data["out"].to(device=device)

    # Predict using model
    with torch.no_grad():
        pred = model(xs)
    
    # Eval
    out = metric(pred, ys)
    eval_out = eval_metric(pred, ys)
    
    return out, eval_out

## Eval ODEs

In [None]:
mses = {}
mses_25 = {}
mses_50 = {}
mses_75 = {}

for save_key in run_paths:
    mse, eval_out = eval_model(save_key)
    mse_batch = torch.mean(eval_out, dim=list(range(1, len(eval_out.shape)))).detach().cpu().numpy()
    mse_25 = np.percentile(mse_batch, 25, axis=-1)
    mse_75 = np.percentile(mse_batch, 75, axis=-1)
    mse_50 = np.median(mse_batch, axis=-1)
    
    mses[save_key] = mse
    mses_25[save_key] = mse_25
    mses_50[save_key] = mse_50
    mses_75[save_key] = mse_75
    
nlayers = list(run_paths.keys())

In [None]:
# Plot
plt.figure(figsize=(12,12))
plt.plot(nlayers, list(mses_50[nlayer] for nlayer in nlayers), marker="o", color="black", markersize=12, linewidth=2)
plt.fill_between(nlayers, list(mses_25[nlayer] for nlayer in nlayers), list(mses_75[nlayer] for nlayer in nlayers), color='black', alpha=0.2)
plt.title("In-context ODEs with Transformers", fontsize=32)
plt.xlabel("Number of layers", fontsize=32)
plt.ylabel("MSE", fontsize=32)
plt.yscale("log")
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)
plt.savefig(os.path.join(save_dir, "iclodes_attn_nlayers.png"))
plt.show()

### Code for evaluating out-of-distribution ODEs

In [7]:
def get_model_ood(key, ood_key=None, ood_factor=1):
    model, conf = get_model_from_run(run_paths[key])
    
    if "data_kwargs" not in conf.training.keys():
        conf.training.data_kwargs = {}
    if "task_kwargs" not in conf.training.keys():
        conf.training.task_kwargs = {}
    
    n_dims = conf.model.n_dims
    
    if ood_key:
        conf.training.data_kwargs[ood_key] *= ood_factor
    
    data_sampler = get_data_sampler(conf.training.data, n_dims=n_dims, batch_size=batch_size, n_points=conf.training.curriculum.points.end, **conf.training.data_kwargs)
    task_sampler = get_task_sampler(
        conf.training.task,
        n_dims=n_dims,
        batch_size=batch_size,
        n_points=conf.training.curriculum.points.end,
        **conf.training.task_kwargs
    )
    task = task_sampler()
    metric = task.get_training_metric()
    eval_metric = task.get_metric()
    
    return model, conf, data_sampler, task, metric, eval_metric

In [8]:
def eval_ood(key, ood_key=None, ood_factor=1):
    model, conf, data_sampler, task, metric, eval_metric = get_model_ood(key, ood_key=ood_key, ood_factor=ood_factor)
    model = model.to(device=device)
    
    # Sample data
    data_sample = data_sampler.sample(
        batch_size=batch_size,
        n_points=conf.training.curriculum.points.end,
        n_dims_truncated=conf.training.curriculum.dims.end,
    )

    # Write task data
    task_data = task.evaluate(data_sample)
    xs, ys = task_data["in"].to(device=device), task_data["out"].to(device=device)

    # Predict using model
    with torch.no_grad():
        pred = model(xs)
    
    # Eval
    out = metric(pred, ys)
    eval_out = eval_metric(pred, ys)
    
    return out, eval_out

### Out-of-distribution: forcing functions

In [None]:
mses = {}
mses_25 = {}
mses_50 = {}
mses_75 = {}

save_key = 12
ood_factors = np.array([0.1, 0.2, 0.3, 0.5, 1, 3, 10])

for ood_factor in ood_factors:
    mse, eval_out = eval_ood(save_key, "gp_length", ood_factor)
    mse_batch = torch.mean(eval_out, dim=list(range(1, len(eval_out.shape)))).detach().cpu().numpy()
    mse_25 = np.percentile(mse_batch, 25, axis=-1)
    mse_75 = np.percentile(mse_batch, 75, axis=-1)
    mse_50 = np.median(mse_batch, axis=-1)
    
    mses[ood_factor] = mse
    mses_25[ood_factor] = mse_25
    mses_50[ood_factor] = mse_50
    mses_75[ood_factor] = mse_75

In [None]:
# Plot
plt.figure(figsize=(12,12))
plt.plot(ood_factors, list(mses_50[ood_factor] for ood_factor in ood_factors), marker="o", color="black", markersize=12, linewidth=2)
plt.fill_between(ood_factors, list(mses_25[ood_factor] for ood_factor in ood_factors), list(mses_75[ood_factor] for ood_factor in ood_factors), color='black', alpha=0.2)
plt.title("OOD Forcing Functions", fontsize=32)
plt.xlabel("GP Length Parameter", fontsize=32)
plt.xscale("log")
plt.ylabel("MSE", fontsize=32)
plt.yscale("log")
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)
plt.savefig(os.path.join(save_dir, "iclodes_attn_forcingfcn.png"))
plt.show()

### Out-of-distribution: ODE parameters

In [None]:
mses = {}
mses_25 = {}
mses_50 = {}
mses_75 = {}

save_key = 12
ood_factors = np.array([1/16, 1/4, 1, 4, 16])

for ood_factor in ood_factors:
    mse, eval_out = eval_ood(save_key, "operator_scale", ood_factor)
    mse_batch = torch.mean(eval_out, dim=list(range(1, len(eval_out.shape)))).detach().cpu().numpy()
    mse_25 = np.percentile(mse_batch, 25, axis=-1)
    mse_75 = np.percentile(mse_batch, 75, axis=-1)
    mse_50 = np.median(mse_batch, axis=-1)
    
    mses[ood_factor] = mse
    mses_25[ood_factor] = mse_25
    mses_50[ood_factor] = mse_50
    mses_75[ood_factor] = mse_75

In [None]:
# Plot
plt.figure(figsize=(12,12))
plt.plot(0.1*ood_factors, list(mses_50[ood_factor] for ood_factor in ood_factors), marker="o", color="black", markersize=12, linewidth=2)
plt.fill_between(0.1*ood_factors, list(mses_25[ood_factor] for ood_factor in ood_factors), list(mses_75[ood_factor] for ood_factor in ood_factors), color='black', alpha=0.2)
plt.title("OOD ODE Parameters", fontsize=32)
plt.xlabel("ODE Parameter Scale", fontsize=32)
plt.xscale("log")
plt.ylabel("MSE", fontsize=32)
plt.yscale("log")
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)
plt.savefig(os.path.join(save_dir, "iclodes_attn_odeparams.png"))
plt.show()

### Out-of-distribution: u0 scale

In [None]:
mses = {}
mses_25 = {}
mses_50 = {}
mses_75 = {}

save_key = 12
ood_factors = np.array([1/16, 1/4, 1, 4, 16])

for ood_factor in ood_factors:
    mse, eval_out = eval_ood(save_key, "u0_scale", ood_factor)
    mse_batch = torch.mean(eval_out, dim=list(range(1, len(eval_out.shape)))).detach().cpu().numpy()
    mse_25 = np.percentile(mse_batch, 25, axis=-1)
    mse_75 = np.percentile(mse_batch, 75, axis=-1)
    mse_50 = np.median(mse_batch, axis=-1)
    
    mses[ood_factor] = mse
    mses_25[ood_factor] = mse_25
    mses_50[ood_factor] = mse_50
    mses_75[ood_factor] = mse_75

In [None]:
# Plot
plt.figure(figsize=(12,12))
plt.plot(ood_factors, list(mses_50[ood_factor] for ood_factor in ood_factors), marker="o", color="black", markersize=12, linewidth=2)
plt.fill_between(ood_factors, list(mses_25[ood_factor] for ood_factor in ood_factors), list(mses_75[ood_factor] for ood_factor in ood_factors), color='black', alpha=0.2)
plt.title("OOD Initial Conditions", fontsize=32)
plt.xlabel("Initial Conditions Scale", fontsize=32)
plt.xscale("log")
plt.ylabel("MSE", fontsize=32)
plt.yscale("log")
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)
plt.savefig(os.path.join(save_dir, "iclodes_attn_initial.png"))
plt.show()

### Gradient descent and least squares

In [None]:
from src.datagen.main.odes import ODEOperatorSampler

# Load gradient descent model
gd_model, gd_conf, gd_data_sampler, gd_task, gd_metric, gd_eval_metric = get_model("gd", run_paths=gd_run_paths)

# Load ODE task
_, _, _, ode_task, _, _ = get_model_ood(12, ood_key="gp_length", ood_factor=1)

# Set batch size, number of points, and number of dimensions (hardcoded from default ODE config)
batch_size = 5
L = 26
D = 20

In [21]:
# Compute grad (requires setting L)
def get_ls_grad(A, b, x_init, L=L):
    grad_true = torch.einsum(
        "bld,bl->bd",
        A,
        torch.einsum("bld,bd->bl", A, x_init) - b
    ).to(device=device) / (L-1)
    return grad_true

def get_ls_grad_pred(model, A, b, x_init):
    # Prompting
    gd_xs = torch.zeros((A.shape[0], L, D+3), device=device)
    gd_xs[:, 0, :-1] = x_init
    gd_xs[:, 1:, :-1] = A
    gd_xs[:, 1:, -1] = b

    # Predict gradient
    with torch.no_grad():
        grad_pred = model(gd_xs)[:, -1, :D+2].to(device=device)
        
    return grad_pred

In [22]:
def gradient_descent(A, b, A_q, b_q, num_iters=10000, lr=0.01, noise_std=0, model=None, setting="pred", num_increasing=500):
    
    min_losses = []
    
    for i in range(batch_size):
        
        x_init = torch.zeros(1, D+2).to(device=device)
        loss = float("inf")
        min_loss = float("inf")
        increasing_counter = 0

        for iter_i in range(num_iters):
            
            if setting == "pred":
                assert model is not None
                grad_true = get_ls_grad_pred(model, A[i:i+1], b[i:i+1], x_init)
            else:
                grad_true = get_ls_grad(A[i:i+1], b[i:i+1], x_init)

            grad_true += noise_std * torch.randn_like(grad_true)
            x_init -= lr * grad_true

            loss = (torch.einsum('bd,bd->b', x_init, A_q[i:i+1]) - b_q[i:i+1]).square()
            if loss.mean().item() < min_loss:
                min_loss = loss.mean().item()
            else:
                increasing_counter += 1
            if increasing_counter >= num_increasing:
                break

            if iter_i % 100 == 0:
                print(f"Iter {iter_i}: loss {loss.mean().item()}")
                
        min_losses.append(min_loss)

    min_losses = np.array(min_losses)
    mse_25 = np.percentile(min_losses, 25, axis=-1)
    mse_75 = np.percentile(min_losses, 75, axis=-1)
    mse_50 = np.median(min_losses, axis=-1)

    return [mse_25, mse_50, mse_75]
        

In [23]:
def least_squares_solvers(ode_task_data, gd_num_iters=100000, gd_lr=0.03):
    ode_xs, ode_ys = ode_task_data["in"], ode_task_data["out"]
    
    # Convert into least squares
    A = ode_xs[:, :-1, :-1].to(device=device) # (B, L, D)
    b = ode_xs[:, :-1, -1].to(device=device) # (B, L)
    A_q = ode_xs[:, -1, :-1].to(device=device) # (B, D)
    b_q = ode_ys.flatten().to(device=device) # (B)
    
    # Least squares
    x_ls_pred, _, _, _ = torch.linalg.lstsq(
        A.cpu(), b.cpu(), driver="gelsd",
    )
    x_ls_pred = x_ls_pred.to(device=device).reshape((batch_size, -1))
    mse_ls = (torch.einsum("bd,bd->b", x_ls_pred, A_q) - b_q).square().detach().cpu().numpy() # (B)
    mse_ls_25 = np.percentile(mse_ls, 25, axis=-1)
    mse_ls_75 = np.percentile(mse_ls, 75, axis=-1)
    mse_ls_50 = np.median(mse_ls, axis=-1)

    # Gradient descent
    mses_gdpred = gradient_descent(A, b, A_q, b_q, num_iters=gd_num_iters, lr=gd_lr, model=gd_model, setting="pred")
    mses_gdtrue = gradient_descent(A, b, A_q, b_q, num_iters=gd_num_iters, lr=gd_lr, model=None, setting="true")
    
    out_dict = {
        "ls": [mse_ls_25, mse_ls_50, mse_ls_75],
        "gdpred": mses_gdpred,
        "gdtrue": mses_gdtrue,
    }
    return out_dict

In [24]:
def eval_ood_least_squares(ood_key=None, ood_factor=1, gd_lr=0.05):
    # Define sampler
    ode_data_sampler = ODEOperatorSampler(
        n_dims=D,
        c_sampling="equispaced", # equispaced, cheb, randt
        eqn_class=1,
        operator_scale=0.1*ood_factor if ood_key=="operator_scale" else 0.1,
        gp_length=ood_factor if ood_key=="gp_length" else 1,
        u0_scale=ood_factor if ood_key=="u0_scale" else 1,
        seed=0,
        device="cuda",
    )

    # Sample
    ode_data = ode_data_sampler.sample(
        batch_size=batch_size,
        n_points=L,
        n_dims_truncated=D,
    )
    
    # Write task data
    ode_task_data = ode_task.evaluate(ode_data)
    
    # Eval
    out_dict = least_squares_solvers(ode_task_data, gd_lr=0.01)

    # Transformer
    mse, eval_out = eval_ood(save_key, ood_key, ood_factor)
    mse_batch = torch.mean(eval_out, dim=list(range(1, len(eval_out.shape)))).detach().cpu().numpy()
    mse_25 = np.percentile(mse_batch, 25, axis=-1)
    mse_75 = np.percentile(mse_batch, 75, axis=-1)
    mse_50 = np.median(mse_batch, axis=-1)
    out_dict["attn"] = [mse_25, mse_50, mse_75]

    return out_dict

In [None]:
# Run gradient descent and least squares
ode_dict = eval_ood_least_squares(ood_key=None, ood_factor=1, gd_lr=0.05)

### Plot: in-distribution Transformer, BaseConv, GD, and least squares

In [None]:
key_dict = {
    "attn": "Transformer",
    "ls": "Least squares",
    "gdpred": "BaseConv", 
    "gdtrue": "GD",
}

color_dict = {
    "attn": "orange",
    "ls": "green",
    "gdpred": "blue",
    "gdtrue": "grey",
}

plt.figure(figsize=(12,12))

# Plot each method
x = np.arange(len(ode_dict))
for i, (method, values) in enumerate(ode_dict.items()):
    median = values[1]
    lower = values[0]
    upper = values[2]
    
    plt.errorbar(i, median, yerr=[[median-lower], [upper-median]],
                fmt='o', capsize=12, capthick=5,
                markersize=24,
                color=color_dict[method], 
                label=key_dict[method],
                linewidth=5,
                markeredgewidth=4)

plt.yscale('log')
plt.grid(True, which='both', linestyle='--', linewidth=1.0)
plt.xticks(x, [key_dict[k] for k in ode_dict.keys()], rotation=45, fontsize=24)
plt.yticks(fontsize=24)
plt.ylabel('MSE', fontsize=32)
plt.title('Error Comparison', fontsize=32)
plt.legend(fontsize=24, markerscale=0.5)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "iclodes_all_indistribution.png"))
plt.show()