Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example/Implementation of Neural Controlled Differential Equations #408

Open
johnp-4dvanalytics opened this issue Sep 8, 2020 · 24 comments

Comments

@johnp-4dvanalytics
Copy link

The paper provides a good method for encoding data with an ODE. This would be useful for being the encoder of an encoder/decoder architecture as an alternative to using an RNN encoder.

https://arxiv.org/abs/2005.08926
https://github.com/patrick-kidger/NeuralCDE

I have played around with the example in the code repo, but it is very slow and could probably be significantly faster if written with DiffEqFlux.

@ChrisRackauckas
Copy link
Member

Yeah, that could be a nice model to implement. Let me know if you need any help optimizing it.

@johnp-4dvanalytics
Copy link
Author

johnp-4dvanalytics commented Sep 10, 2020

@ChrisRackauckas I was able to get a version of this working although it is very slow. If there are any optimizations that stand out please let me know. Thanks!

Here is the code for it:

using Statistics
using DataInterpolations
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Zygote
using ProgressBars
using Random

T = Float32

bs = 512
X = [rand(T, 10, 50) for _ in 1:bs*10]

function create_spline(i)
    x = X[i]
    t = x[end, :]
    t = (t .- minimum(t)) ./ (maximum(t) - minimum(t))

    spline = QuadraticInterpolation(x, t)
end

splines = [create_spline(i) for i in tqdm(1:length(X))]

rand_inds = randperm(length(X))

i_sz = size(X[1], 1)
h_sz = 16

use_gpu = true
batches = [[splines[rand_inds[(i-1)*bs+1:i*bs]]] for i in tqdm(1:length(X)÷bs)]

data_ = Iterators.cycle(batches)

function call_and_cat(splines, t)
    vals = Zygote.ignore() do
        vals = hcat([spline(t) for spline in splines]...)
    end
    vals |> (use_gpu ? gpu : cpu)
end

function derivative(A::QuadraticInterpolation, t::Number)
    idx = findfirst(x -> x >= t, A.t) - 1
    idx == 0 ? idx += 1 : nothing
    if idx == length(A.t) - 1
        i₀ = idx - 1; i₁ = idx; i₂ = i₁ + 1;
    else
        i₀ = idx; i₁ = i₀ + 1; i₂ = i₁ + 1;
    end
    dl₀ = (2t - A.t[i₁] - A.t[i₂]) / ((A.t[i₀] - A.t[i₁]) * (A.t[i₀] - A.t[i₂]))
    dl₁ = (2t - A.t[i₀] - A.t[i₂]) / ((A.t[i₁] - A.t[i₀]) * (A.t[i₁] - A.t[i₂]))
    dl₂ = (2t - A.t[i₀] - A.t[i₁]) / ((A.t[i₂] - A.t[i₀]) * (A.t[i₂] - A.t[i₁]))
    A.u[:, i₀] * dl₀ + A.u[:, i₁] * dl₁ + A.u[:, i₂] * dl₂
end

function derivative_call_and_cat(splines, t)
    vals = Zygote.ignore() do
        vals = hcat([derivative(spline, t) for spline in splines]...)
    end
    vals |> (use_gpu ? gpu : cpu)
end

cde = Chain(
    Dense(h_sz, h_sz, relu),
    Dense(h_sz, h_sz*i_sz, tanh),
) |> (use_gpu ? gpu : cpu)

h_to_out = Dense(h_sz, 2) |> (use_gpu ? gpu : cpu)

initial = Dense(i_sz, h_sz) |> (use_gpu ? gpu : cpu)

cde_p, cde_re = Flux.destructure(cde)
initial_p, initial_re = Flux.destructure(initial)
h_to_out_p, h_to_out_re = Flux.destructure(h_to_out)

basic_tgrad(u,p,t) = zero(u)

function predict_func(p, BX)
    By = call_and_cat(BX, 1)

    x0 = call_and_cat(BX, 0)
    i = 1
    j = (i-1)+length(initial_p)

    h0 = initial_re(p[i:j])(x0)

    function dhdt(h,p,t)
        x = derivative_call_and_cat(BX, t)
        bs = size(h, 2)
        a = reshape(cde_re(p)(h), (i_sz, h_sz, bs))
        b = reshape(x, (1, i_sz, bs))

        dh = batched_mul(b,a)[1,:,:]
    end

    i = j+1
    j = (i-1)+length(cde_p)

    tspan = (0.0f0, 0.8f0)

    ff = ODEFunction{false}(dhdt,tgrad=basic_tgrad)
    prob = ODEProblem{false}(ff,h0,tspan,p[i:j])
    sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
    solver = Tsit5()

    h = solve(prob,solver,u0=h0,saveat=tspan[end], save_start=false).u[end]

    i = j+1
    j = (i-1)+length(h_to_out_p)

    y_hat = h_to_out_re(p[i:j])(h)

    y_hat, By[1:2, :]
end

function loss_func(p, BX)
    y_hat, y = predict_func(p, BX)

    mean(sum(sqrt.((y .- y_hat).^2), dims=1))
end

p = vcat(initial_p, cde_p, h_to_out_p)

callback = function (p, l)
  display(l)
  return false
end

using DiffEqFlux

result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.05),
    data_,
    cb = callback,
    maxiters = 10)

@ChrisRackauckas
Copy link
Member

Hey, here's an updated version with comments on what was done and timings. That little training step improved by about 6.6x:

using Statistics
using DataInterpolations
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Zygote
using ProgressBars
using Random

T = Float32

bs = 512
X = [rand(T, 10, 50) for _ in 1:bs*10]

function create_spline(i)
    x = X[i]
    t = x[end, :]
    t = (t .- minimum(t)) ./ (maximum(t) - minimum(t))

    spline = QuadraticInterpolation(x, t)
end

splines = [create_spline(i) for i in tqdm(1:length(X))]

rand_inds = randperm(length(X))

i_sz = size(X[1], 1)
h_sz = 16

use_gpu = true
batches = [[splines[rand_inds[(i-1)*bs+1:i*bs]]] for i in tqdm(1:length(X)÷bs)]

data_ = Iterators.cycle(batches)

function call_and_cat(splines, t)
    vals = Zygote.ignore() do
        vals = reduce(hcat,[spline(t) for spline in splines])
    end
    vals |> (use_gpu ? gpu : cpu)
end

function derivative(A::QuadraticInterpolation, t::Number)
    idx = findfirst(x -> x >= t, A.t) - 1
    idx == 0 ? idx += 1 : nothing
    if idx == length(A.t) - 1
        i₀ = idx - 1; i₁ = idx; i₂ = i₁ + 1;
    else
        i₀ = idx; i₁ = i₀ + 1; i₂ = i₁ + 1;
    end
    dl₀ = (2t - A.t[i₁] - A.t[i₂]) / ((A.t[i₀] - A.t[i₁]) * (A.t[i₀] - A.t[i₂]))
    dl₁ = (2t - A.t[i₀] - A.t[i₂]) / ((A.t[i₁] - A.t[i₀]) * (A.t[i₁] - A.t[i₂]))
    dl₂ = (2t - A.t[i₀] - A.t[i₁]) / ((A.t[i₂] - A.t[i₀]) * (A.t[i₂] - A.t[i₁]))
    @views @. A.u[:, i₀] * dl₀ + A.u[:, i₁] * dl₁ + A.u[:, i₂] * dl₂
end

function derivative_call_and_cat(splines, t)
    vals = Zygote.ignore() do
        reduce(hcat,[derivative(spline, t) for spline in splines]) |> (use_gpu ? gpu : cpu)
    end
end

cde = Chain(
    Dense(h_sz, h_sz, relu),
    Dense(h_sz, h_sz*i_sz, tanh),
) |> (use_gpu ? gpu : cpu)

h_to_out = Dense(h_sz, 2) |> (use_gpu ? gpu : cpu)

initial = Dense(i_sz, h_sz) |> (use_gpu ? gpu : cpu)

cde_p, cde_re = Flux.destructure(cde)
initial_p, initial_re = Flux.destructure(initial)
h_to_out_p, h_to_out_re = Flux.destructure(h_to_out)

basic_tgrad(u,p,t) = zero(u)

function predict_func(p, BX)
    By = call_and_cat(BX, 1)

    x0 = call_and_cat(BX, 0)
    i = 1
    j = (i-1)+length(initial_p)

    h0 = initial_re(p[i:j])(x0)

    function dhdt(h,p,t)
        x = derivative_call_and_cat(BX, t)
        bs = size(h, 2)
        a = reshape(cde_re(p)(h), (i_sz, h_sz, bs))
        b = reshape(x, (1, i_sz, bs))

        dh = batched_mul(b,a)[1,:,:]
    end

    i = j+1
    j = (i-1)+length(cde_p)

    tspan = (0.0f0, 0.8f0)

    ff = ODEFunction{false}(dhdt,tgrad=basic_tgrad)
    prob = ODEProblem{false}(ff,h0,tspan,p[i:j])
    sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
    solver = Tsit5()

    sol = solve(prob,solver,u0=h0,saveat=tspan[end], save_start=false, sensealg=sense)
    #@show sol.destats
    i = j+1
    j = (i-1)+length(h_to_out_p)

    y_hat = h_to_out_re(p[i:j])(sol[end])

    y_hat, By[1:2, :]
end

function loss_func(p, BX)
    y_hat, y = predict_func(p, BX)

    mean(sum(sqrt.((y .- y_hat).^2), dims=1))
end

p = vcat(initial_p, cde_p, h_to_out_p)

callback = function (p, l)
  display(l)
  return false
end

using DiffEqFlux

Zygote.gradient((p)->loss_func(p, first(data_)...),p)
@time Zygote.gradient((p)->loss_func(p, first(data_)...),p)

@time result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.05),
    data_,
    cb = callback,
    maxiters = 10)

# Start
# 13.178288 seconds (40.04 M allocations: 3.362 GiB, 5.01% gc time)

# Reduce(hcat)
# 6.273153 seconds (21.84 M allocations: 2.443 GiB, 7.86% gc time)

# @views @.
# 3.315527 seconds (5.05 M allocations: 495.695 MiB, 3.81% gc time)

# gpu in do
# 2.652512 seconds (5.11 M allocations: 466.430 MiB, 2.64% gc time)

# Training time before:
# 199.442675 seconds (218.19 M allocations: 23.603 GiB, 66.46% gc time)

# Training time after:
# 30.587359 seconds (58.69 M allocations: 5.210 GiB, 3.74% gc time)

The rate limiting step here is that the spline data is on the CPU while your computations are on the GPU, so the most costly portion now is simply moving the spline output to the GPU. That's like 90% of the cost or something ridiculous now, so you'd have to tackle that problem and I was only giving myself 30 minutes to play with this. One way you could do this would be to make your quadratic spline asynchronously pre-cache some of the next time points onto the GPU while other computations are taking place. Or, even cooler, train a neural network to mimic the spline but be all on the GPU, and then use that in place of the spline. But shipping that much data every step is going to dominate the computation so it's gotta be dealt with somehow.

What's the baseline you want to beat here? Do you have that code around to time?

@johnp-4dvanalytics
Copy link
Author

Awesome, thanks for reviewing the code and giving those optimizations! That's quite a speed up.

I will try to figure out how to avoid moving the data from CPU to GPU every step.

Btw the author of the paper / repo above recently released this new repo for this type of model:
https://github.com/patrick-kidger/torchcde

So probably we would want to show it outperforming some examples in that repo.

The baseline I was comparing against was based off of the code in https://github.com/patrick-kidger/NeuralCDE/blob/master/example/example.py

I'll do a speed comparison against that and the optimized code you posted.

@johnp-4dvanalytics
Copy link
Author

Here is the code I am using as the efficiency baseline:

On my system the baseline code take about ~6 seconds to run whereas the optimized version you posted takes about ~18 seconds. I think once we fix moving the data to GPU at each step it should be a lot faster though.

import controldiffeq
import math
import torch

class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)

    def forward(self, z):
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        z = z.tanh()
        z = z.view(*z.shape[:-1], self.hidden_channels, self.input_channels)
        return z


class NeuralCDE(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()
        self.hidden_channels = hidden_channels

        self.func = CDEFunc(input_channels, hidden_channels)
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.readout = torch.nn.Linear(hidden_channels, output_channels)

    def forward(self, times, coeffs):
        spline = controldiffeq.NaturalCubicSpline(times, coeffs)

        z0 = self.initial(spline.evaluate(times[0]))

        z_T = controldiffeq.cdeint(dX_dt=spline.derivative,
                                   z0=z0,
                                   func=self.func,
                                   t=times[[0, int(len(times)*.8)]],
                                   atol=1e-2,
                                   rtol=1e-2)
        z_T = z_T[1]
        pred_y = self.readout(z_T)
        return pred_y

def get_data():
    X = torch.rand(512, 50, 10)
    t = torch.linspace(0., 1, X.shape[1])
    y = X[:, -1, :2]

    return t, X, y

from ipdb import set_trace
def main():
    train_t, train_X, train_y = get_data()

    model = NeuralCDE(input_channels=train_X.shape[-1], hidden_channels=16, output_channels=2)
    optimizer = torch.optim.Adam(model.parameters())

    train_coeffs = controldiffeq.natural_cubic_spline_coeffs(train_t, train_X)
    
    import time
    from tqdm import tqdm
    
    start = time.time()
    for epoch in tqdm(range(10)):
        pred_y = model(train_t, train_coeffs).squeeze(-1)
        loss = (pred_y - train_y).norm(dim=-1).mean()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print('Epoch: {}   Training loss: {}'.format(epoch, loss.item()))
    end = time.time()
    print(end - start)

if __name__ == '__main__':
    main()

@ChrisRackauckas
Copy link
Member

I don't see any shuttling to GPUs there: are the spline coefficients on the GPU in that implementation? If they are, that would make a massive difference.

Also, doing this as 1 spline instead of 5000 splines probably makes a decent difference.

@johnp-4dvanalytics
Copy link
Author

johnp-4dvanalytics commented Sep 10, 2020

This it the file where the spline code is defined:
https://github.com/patrick-kidger/NeuralCDE/blob/master/controldiffeq/interpolate.py#L229

It does seem like the spline operations are being done as a batch on the GPU

def derivative(self, t):
        """Evaluates the derivative of the natural cubic spline at a point t, which should be a scalar tensor."""
        fractional_part, index = self._interpret_t(t)
        inner = self._two_c[..., index, :] + self._three_d[..., index, :] * fractional_part
        deriv = self._b[..., index, :] + inner * fractional_part
        return deriv

@johnp-4dvanalytics
Copy link
Author

I wrapped the PyTorch spline functions using PyCall and CUDA.jl and I was able to get a speed up of ~5.5x over the torch version, it was 92.7 seconds for the 10 batches of torch version and ~17 seconds for the 10 batches of the DiffEqFlux version.

Btw thanks for the awesome library!

Here's what the calls looked like:
spl is the python spline object

x0, tspan = Zygote.ignore() do
        tspan = spl.interval.cpu().numpy()
        x0 = Base.unsafe_wrap(CuArray, CuPtr{Float32}(spl.evaluate(0).permute(1,0).contiguous().data_ptr()), (10, 512))
        x0, tspan
    end

    function dhdt(h,p,t)
        x = Zygote.ignore() do
            x = Base.unsafe_wrap(CuArray, CuPtr{Float32}(spl.derivative(t).permute(1,0).contiguous().data_ptr()), (10, 512))
        end
        ....

@ChrisRackauckas
Copy link
Member

Awesome. So yeah, it would really be nice to get that directly implemented in Julia as a library function for people who want to use this method. It's the rate-limiting step.

@patrick-kidger
Copy link

patrick-kidger commented Sep 14, 2020

Just been pointed at this. Three quick comments:

  • The PyTorch baseline you posted above is entirely on the CPU - it looks like you might be comparing timings between different devices.
  • I'm a bit dubious about the timings you posted. Suitably modifying the baseline you're using, I find:
    PyTorch CPU: 18 seconds; PyTorch GPU: 5 seconds.
    (Including the cost of a CPU-to-GPU copy of the coefficients on every epoch.)
  • I'd recommend against quadratic splines as are used in the Julia examples. At least if these are linear-first-piece quadratic splines then they can in principle resonate, growing unboundedly from bounded data. This is one of the reasons for natural cubic splines.

Anyway, I'm not Julia-proficient but let me know if I can help out over here.

@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Sep 14, 2020

I think his latest version (the version that was timed) is using your spline code IIUC

@johnp-4dvanalytics
Copy link
Author

johnp-4dvanalytics commented Sep 14, 2020

@patrick-kidger

In the Julia code I used for the benchmark I was doing a call out to your spline code so that I could do a consistent comparison.

I did have some errors in the python code I posted above that I fixed in my script that I used for the benchmarking. Here is the updated script:

class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(CDEFunc, self).__init__()
        self.linear = torch.nn.Linear(hidden_channels,
                                      hidden_channels * input_channels)
        self.f1 = torch.nn.Linear(hidden_channels, hidden_channels)
        store_attr()

    def forward(self, t, z):
#         return (self.linear(z).view(len(z), self.hidden_channels, self.input_channels))
        z = F.relu(self.f1(z))
        
        return torch.tanh(self.linear(z).view(len(z), self.hidden_channels, self.input_channels))


class NeuralCDE(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()

        self.func = CDEFunc(input_channels, hidden_channels)
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.readout = torch.nn.Linear(hidden_channels, output_channels)

    def forward(self, coeffs):
        spline = torchcde.NaturalCubicSpline(coeffs)
    
        z0 = self.initial(spline.evaluate(0.))
        
        zt = torchcde.cdeint(X=spline,
                              z0=z0,
                              func=self.func,
#                              t=t,
                              t=spline.interval,
                              method='dopri5',
#                               atol=1e-2,
#                               rtol=1e-2,
                              options=dict(grid_points=spline.grid_points, 
                                           eps=1e-5,
                                           
#                                            , dtype=torch.float32
                                          ))
        z_T = zt[:, -1]
        
        pred_y = self.readout(z_T)
        return pred_y

bs = 512
input_channels = X.size(2)
hidden_channels = 16  # hyperparameter, we can pick whatever we want for this

device="cuda"
torch_core.defaults.device = device

dls = DataLoaders(TfmdDL(train_ds, shuffle=True, bs=bs), 
                  TfmdDL(val_ds, shuffle=True, bs=bs), device=device)

model = NeuralCDE(input_channels=input_channels, hidden_channels=hidden_channels, output_channels=2).type(dtype).to(device)

optimizer = torch.optim.Adam(model.parameters())

import time
from tqdm import tqdm

start = time.time()
total = 0
for i, (BX, By) in enumerate(tqdm(dls[0])):
    pred_y = model(BX).squeeze(-1)
    loss = (pred_y - By).norm(dim=-1).mean()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print('Batch: {}   Training loss: {}'.format(i+1, loss.item()))
    total += 1
    if total == 10:
        break
end = time.time()
end-start

@patrick-kidger
Copy link

patrick-kidger commented Sep 14, 2020

TfmdDL (amongst others) isn't defined so I can't run that.

I wouldn't set the default device in the way that you're doing. In particular this doesn't perform the usual CPU-to-GPU copy you'll usually see when training a model.

You're using grid_points wrong: this should be passed when using linear interpolation, but not with cubic interpolation.

I suggest using https://github.com/patrick-kidger/torchcde/blob/master/example/example.py as a reference point.

@johnp-4dvanalytics
Copy link
Author

Thanks for the clarification about the grid_points, I'll fix that. I am using fastai for creating the dataloaders, it automatically puts the batches on the gpu. I'll post a minimal runnable script shortly that takes out the fastai code.

@johnp-4dvanalytics
Copy link
Author

johnp-4dvanalytics commented Sep 14, 2020

Here are updated Python and Julia scripts based on the example you linked to:

EDIT: the loss for the Julia version isn't decreasing like the Python version, so I may have a bug with the model. I will try to fix that.

Python: Time to run the training part at the bottom: ~58 seconds

import math
import torch
import torchcde


######################
# A CDE model looks like
#
# z_t = z_0 + \int_0^t f_\theta(z_s) dX_s
#
# Where X is your data and f_\theta is a neural network. So the first thing we need to do is define such an f_\theta.
# That's what this CDEFunc class does.
# Here we've built a small single-hidden-layer neural network, whose hidden layer is of width 128.
######################
class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        ######################
        # input_channels is the number of input channels in the data X. (Determined by the data.)
        # hidden_channels is the number of channels for z_t. (Determined by you!)
        ######################
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)

    ######################
    # For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at
    # different times, which would be unusual. But it's there if you need it!
    ######################
    def forward(self, t, z):
        # z has shape (batch, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        ######################
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        ######################
        z = z.tanh()
        ######################
        # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
        # because we need it to represent a linear map from R^input_channels to R^hidden_channels.
        ######################
        z = z.view(z.size(0), self.hidden_channels, self.input_channels)
        return z


######################
# Next, we need to package CDEFunc up into a model that computes the integral.
######################
class NeuralCDE(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()

        self.func = CDEFunc(input_channels, hidden_channels)
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.readout = torch.nn.Linear(hidden_channels, output_channels)

    def forward(self, coeffs):
        X = torchcde.NaturalCubicSpline(coeffs)

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        z0 = self.initial(X.evaluate(0.))

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.func,
                              t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y


######################
# Now we need some data.
# Here we have a simple example which generates some spirals, some going clockwise, some going anticlockwise.
######################
def get_data():
    t = torch.linspace(0., 4 * math.pi, 100)

    start = torch.rand(128) * 2 * math.pi
    x_pos = torch.cos(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos[:64] *= -1
    y_pos = torch.sin(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos += 0.01 * torch.randn_like(x_pos)
    y_pos += 0.01 * torch.randn_like(y_pos)
    ######################
    # Easy to forget gotcha: time should be included as a channel; Neural CDEs need to be explicitly told the
    # rate at which time passes. Here, we have a regularly sampled dataset, so appending time is pretty simple.
    ######################
    X = torch.stack([t.unsqueeze(0).repeat(128, 1), x_pos, y_pos], dim=2)
    y = torch.zeros(128)
    y[:64] = 1

    perm = torch.randperm(128)
    X = X[perm]
    y = y[perm]

    ######################
    # X is a tensor of observations, of shape (batch=128, sequence=100, channels=3)
    # y is a tensor of labels, of shape (batch=128,), either 0 or 1 corresponding to anticlockwise or clockwise respectively.
    ######################
    return X, y


def main(num_epochs=30):
    device = "cuda"
    dtype = torch.float32
    train_X, train_y = get_data()
    train_X, train_y = train_X.type(dtype), train_y.type(dtype)

    ######################
    # input_channels=3 because we have both the horizontal and vertical position of a point in the spiral, and time.
    # hidden_channels=8 is the number of hidden channels for the evolving z_t, which we get to choose.
    # output_channels=1 because we're doing binary classification.
    ######################
    model = NeuralCDE(input_channels=3, hidden_channels=8, output_channels=1).type(dtype).to(device)
    optimizer = torch.optim.Adam(model.parameters())

    ######################
    # Now we turn our dataset into a continuous path. We do this here via natural cubic spline interpolation.
    # The resulting `train_coeffs` is a tensor describing the path.
    # For most problems, it's probably easiest to save this tensor and treat it as the dataset.
    ######################
    train_coeffs = torchcde.natural_cubic_spline_coeffs(train_X)

    train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
    for epoch in range(num_epochs):
        for batch in train_dataloader:
            batch_coeffs, batch_y = batch
            batch_coeffs, batch_y = batch_coeffs.to(device), batch_y.to(device)
            pred_y = model(batch_coeffs).squeeze(-1)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print('Epoch: {}   Training loss: {}'.format(epoch, loss.item()))
        
def return_data():
    device = "cuda"
    dtype = torch.float32
    train_X, train_y = get_data()
    train_X, train_y = train_X.type(dtype), train_y.type(dtype)
    train_coeffs = torchcde.natural_cubic_spline_coeffs(train_X)

    train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
    data = [(torchcde.NaturalCubicSpline(batch[0]), batch[1]) for batch in train_dataloader]
    
    return data


import pickle as p

#get data for spline to use in Julia code
data = return_data()
p.dump(data, open("cde_data.p", "wb"))

import time
start = time.time()
main(10)
end = time.time()
end-start

Julia code using the spline from Python
time to run training: ~27 seconds

using Statistics
using DataInterpolations
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Zygote
using ProgressBars
using Random
using CUDA
using PyCall

py"""
    import torch
    import pickle as p

    torch.cuda.set_device(0)

    spl_targ = p.load(open("cde_data.p", "rb"))
    spl = [x[0].cuda() for x in spl_targ]
    targ = [x[1].cuda() for x in spl_targ]
    bs = targ[0].shape[0]
    """

bs = py"bs"

T = Float32

i_sz = 3
h_sz = 8
o_sz = 1

use_gpu = true

batches = [(py"spl"[i], py"targ"[i]) for i in 1:py"len(spl)"]

data_ = Iterators.cycle(batches)

cde = Chain(
    Dense(h_sz, h_sz, relu),
    Dense(h_sz, h_sz*i_sz, tanh),
) |> (use_gpu ? gpu : cpu)

h_to_out = Dense(h_sz, o_sz) |> (use_gpu ? gpu : cpu)

initial = Dense(i_sz, h_sz) |> (use_gpu ? gpu : cpu)

cde_p, cde_re = Flux.destructure(cde)
initial_p, initial_re = Flux.destructure(initial)
h_to_out_p, h_to_out_re = Flux.destructure(h_to_out)

basic_tgrad(u,p,t) = zero(u)

function predict_func(p, spl)
    x0, tspan = Zygote.ignore() do
        tspan = spl.interval.cpu().numpy()
        x0 = Base.unsafe_wrap(CuArray, CuPtr{Float32}(spl.evaluate(0).permute(1,0).contiguous().data_ptr()), (i_sz, bs))
        x0, tspan
    end
    i = 1
    j = (i-1)+length(initial_p)

    local batch_size = bs

    h0 = initial_re(p[i:j])(x0)

    function dhdt(h,p,t)
        x = Zygote.ignore() do
            x = Base.unsafe_wrap(CuArray, CuPtr{Float32}(spl.derivative(t).permute(1,0).contiguous().data_ptr()), (i_sz, batch_size))
        end
        bs = size(h, 2)
        a = reshape(cde_re(p)(h), (i_sz, h_sz, bs))
        b = reshape(x, (1, i_sz, bs))

        dh = batched_mul(b,a)[1,:,:]
    end

    i = j+1
    j = (i-1)+length(cde_p)

    ff = ODEFunction{false}(dhdt,tgrad=basic_tgrad)
    prob = ODEProblem{false}(ff,h0,tspan,p[i:j])
    sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
    solver = Tsit5()

    sol = solve(prob,solver,u0=h0,saveat=tspan[end], save_start=false, sensealg=sense)
    i = j+1
    j = (i-1)+length(h_to_out_p)

    y_hat = h_to_out_re(p[i:j])(sol[end])

    y_hat
end

function loss_func(p, spl, targ)
    y_hat = predict_func(p, spl)

    y = Zygote.ignore() do
        y = Base.unsafe_wrap(CuArray, CuPtr{Float32}(targ.cuda().unsqueeze(-1).permute(1,0).contiguous().data_ptr()), (o_sz, bs))
    end

    return Flux.Losses.logitbinarycrossentropy(y_hat, y, agg=mean)
end

p = vcat(initial_p, cde_p, h_to_out_p)

callback = function (p, l)
  display(l)
  return false
end

predict_func(p, first(data_)[1])

using DiffEqFlux

Zygote.gradient((p)->loss_func(p, first(data_)...),p)
@time Zygote.gradient((p)->loss_func(p, first(data_)...),p)

@time result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.05),
    data_,
    cb = callback,
    maxiters = 10*length(batches))

I suspect the gap between the Julia and Python code may widen in Julia's favor with larger batch sizes, hidden sizes, etc. but I still think doing something with DiffEqGPU will probably be the way to get a major speed up. I'll be trying some of the ideas Chris mentioned in the other thread to see if I can get a speed up with that.

@ChrisRackauckas
Copy link
Member

I wouldn't get your hopes up there.

think doing something with DiffEqGPU will probably be the way to get a major speed up

That's for a different use case, i.e. ensembles of small ODEs.

I suspect the gap between the Julia and Python code may widen in Julia's favor with larger batch sizes, hidden sizes, etc.

On the contrary, it would probably shrink as the rate limiting step is sooner or later going to be the cost of the GPU kernels, which if both are calling into CuBLAS then it'll be the same. So if they are taking the same number of steps (which they likely aren't due to some stabilizing tricks, but those are like 50% gains), then you'd expect the cost to be the same. It's when the kernels aren't fully saturated that the extra codegen and fusion matters.

@johnp-4dvanalytics
Copy link
Author

Okay, thanks for keeping me from going down that route. I was mainly thinking that if each trajectory has it's own solver they should be required to do many less function calls, whereas (if I'm not mistaken) the batch version has to do a stop whenever any of the trajectories needs to stop, so there would be many more function calls done for the current batch version. Is there a good way to reduce the number of function calls?

@ChrisRackauckas
Copy link
Member

Oh I read you wrong. If you goal is to make use of like 4 GPUs by running 4 trajectories at a time on different GPUs, yeah DiffEqGPU isn't the right tool but EnsembleDistributed where each Julia process has a different GPU will do this. You could also try to pack multiple onto the same GPU, but sooner or later you'll get memory limited.

Is there a good way to reduce the number of function calls?

That's a great research question. In the general context, that's just developing a "better" differential equation solver which can be hard work given how much they've been optimized, but there's still some tricks no one has done and we will have some new methods coming out soonish. But in the context of training a CDE, there's some other tricks one can employ. For example, you don't necessarily need to fit the ODE solves themselves: you can regularize to find solutions that are fast to solve, you drop accuracy and only increase accuracy after a decending a bit, etc.

@patrick-kidger
Copy link

Some discrepancies:

  • At least for the Python version, you're also measuring the overhead of creating data, creating the model, which is wrong.
  • You're using dopri5 with Python and tsit5 with Julia.
  • It looks like the Julia version gets a one-interation warm-up before measuring. Ideally both should be given a warm-up.

@ChrisRackauckas: You mention that the number of solver steps can be reduced via some stabilising tricks - I'm curious what you're referring to specifically?

@johnp-4dvanalytics
Copy link
Author

johnp-4dvanalytics commented Sep 15, 2020

@patrick-kidger @ChrisRackauckas

I created a non-gpu, non-batch version that performs fairly well. I wasn't able to figure out the issue with the GPU version, I believe that there is some non-trivial issue with the gradient calculation.

Below is the script for training on the data from your example. It probably could easily incorporate multiprocessing to speed it up. It has a use_linear variable for choosing whether to use linear interpolation or natural cubic interpolation. The linear interpolation is quite slow since it has the additional tstops, but the cubic interpolation version is quite fast. It takes ~83 seconds to run on 10 epochs of the data of your example. The ODE solves get faster as the loss goes down, so the time taken is dependent on the starting parameters and the order of the examples seen, so there is some variance in the time.

I benchmarked against your example with the changes that you mentioned and found that on average the code took ~60 seconds to run. The script for that is included at the bottom.

EDIT: I was accidentally using h_sz=16 instead of 8 for the Julia version, and I forgot to include the compilation warmup for the sciml_train, updated the script with those values.

using Statistics
using DataInterpolations
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Zygote
using ProgressBars
using Random
using CUDA
using PyCall
using DiffEqFlux
using BenchmarkTools

py"""
    from scipy.interpolate import CubicSpline
    import numpy as np
    import math
    import torch
    from tqdm import tqdm

    def get_data(N):
        t = torch.linspace(0., 4 * math.pi, 100)

        start = torch.rand(N) * 2 * math.pi
        x_pos = torch.cos(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
        x_pos[:64] *= -1
        y_pos = torch.sin(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
        x_pos += 0.01 * torch.randn_like(x_pos)
        y_pos += 0.01 * torch.randn_like(y_pos)
        ######################
        # Easy to forget gotcha: time should be included as a channel; Neural CDEs need to be explicitly told the
        # rate at which time passes. Here, we have a regularly sampled dataset, so appending time is pretty simple.
        ######################
        X = torch.stack([t.unsqueeze(0).repeat(N, 1), x_pos, y_pos], dim=2)
        y = torch.zeros(N)
        y[:64] = 1

        perm = torch.randperm(N)
        X = X[perm]
        y = y[perm]

        ######################
        # X is a tensor of observations, of shape (batch=128, sequence=100, channels=3)
        # y is a tensor of labels, of shape (batch=128,), either 0 or 1 corresponding to anticlockwise or clockwise respectively.
        ######################
        return X, y
    X, y = get_data(N=256)

    y = y.unsqueeze(-1)

    t = np.linspace(0, 1, X.shape[1], dtype=np.float32)

    y = y.permute(1, 0).numpy()

    Xy = [(CubicSpline(x=t,y=X[i].numpy(), axis=0, bc_type="natural"), t, y[:, i]) for i in tqdm(range(len(X)))]
    """

T = Float32
use_linear = false

Xy = py"Xy"
if use_linear
    Xy = [(DataInterpolations.LinearInterpolation(T[permutedims(Xy[i][1](Xy[i][2]), (2,1));], Xy[i][2]), Xy[i][2], Xy[i][3]) for i in tqdm(1:length(Xy))]
end

i_sz = length(Xy[1][1](0))
h_sz = 8
o_sz = length(Xy[1][3])

use_gpu = false
device = (use_gpu ? gpu : cpu)

cde = Chain(
    Dense(h_sz, h_sz, relu),
    Dense(h_sz, h_sz*i_sz, tanh),
) |> device
h_to_out = Dense(h_sz, o_sz) |> device
initial = Dense(i_sz, h_sz) |> device

cde_p, cde_re = Flux.destructure(cde)
initial_p, initial_re = Flux.destructure(initial)
h_to_out_p, h_to_out_re = Flux.destructure(h_to_out)

basic_tgrad(u,p,t) = zero(u)

function derivative(A::DataInterpolations.LinearInterpolation{<:AbstractArray{<:Number}}, t::Number)
    idx = findfirst(x -> x >= t, A.t) - 1
    idx == 0 ? idx += 1 : nothing
    θ = 1 / (A.t[idx+1] - A.t[idx])
    (A.u[:, idx+1] - A.u[:, idx]) / (A.t[idx+1] - A.t[idx])
end

function predict_func(p, spl, t)
    x0, tspan = Zygote.ignore() do
        tspan = (t[1], t[end])

        x0 = T[spl(0);]
        x0, tspan
    end
    i = 1
    j = (i-1)+length(initial_p)

    h0 = initial_re(p[i:j])(x0)

    function dhdt(h,p,t)
        dx = Zygote.ignore() do
            # dx = SMatrix{i_sz, 1}(reshapeT[x_int_i.derivative(1)(t);], (i_sz, 1)))
            if use_linear
                dx = derivative(spl, t)
            else
                dx = T[spl.derivative(1)(t);]
            end
            dx
        end
        dh = reshape(cde_re(p)(h), (h_sz, i_sz))*dx
    end

    i = j+1
    j = (i-1)+length(cde_p)

    ff = ODEFunction{false}(dhdt,tgrad=basic_tgrad)
    prob = ODEProblem{false}(ff,h0,tspan,p[i:j])
    sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())

    if use_linear
        tstops = t
        d_discontinuities = t
    else
        tstops = []
        d_discontinuities = []
    end

    solver = Tsit5()

    sol = solve(prob,solver,u0=h0,saveat=tspan[end],
        save_start=false,
        sensealg=sense,
        tstops=tstops,
        d_discontinuities=tstops,
        atol=1e-6, rtol=1e-4
    )

    out = sol[end]

    i = j+1
    j = (i-1)+length(h_to_out_p)

    y_hat = h_to_out_re(p[i:j])(sol[end])

    y_hat
end

N = length(Xy)

inds = randperm(N)
train_inds = inds[1:trunc(Int, length(inds)*.5)]
val_inds = inds[trunc(Int, length(inds)*.5)+1:end]
@assert length(train_inds) + length(val_inds) == length(inds)

train_dl = Iterators.cycle([Xy[i] for i in train_inds])
val_dl = [Xy[i] for i in val_inds]

function loss_func(p, spl, t, y; train=true)
    y_hat = predict_func(p, spl, t)

    loss = Flux.Losses.logitbinarycrossentropy(y_hat, y, agg=mean)

    if train
        for _ in 1:num_additional
            spl_i, t_i, y_i = Zygote.ignore() do
                train_i = train_inds[rand(1:length(train_inds))]
                spl_i, t_i, y_i = Xy[train_i]
            end

            y_hat_i = predict_func(p, spl_i, t_i)

            loss += Flux.Losses.logitbinarycrossentropy(y_hat, y, agg=mean)
        end
        loss = loss/(num_additional+1)
    end
    return loss
end

p = vcat(initial_p, cde_p, h_to_out_p)

# Zygote.gradient((p)->loss_func(p, first(train_dl)...),p)

callback = function (p, _)
    global display_i
    if display_i % display_every == 0
        l = Zygote.ignore() do
            l = 0
            subset_inds = randperm(length(val_inds))[1:num_val_to_test]
            val_inds_subset = val_inds[subset_inds]

            for i in val_inds_subset
                l += loss_func(p, Xy[i]..., train=false)
            end

            l/length(subset_inds)
        end

        display(l)
    end
    display_i += 1
    return false
end

display_i = 1
num_val_to_test = 16
display_every = Inf
# num_additional = 4 - 1
num_additional = 0

result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.01),
    cb = callback,
    train_dl,
    maxiters = 1)

@time result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.01),
    cb = callback,
    train_dl,
    maxiters = length(train_inds)*10) #time = 82.803470

function final_loss(p)
    l = 0

    for i in val_inds
        l += loss_func(p, Xy[i]..., train=false)
    end

    l/length(val_inds)
end

final_loss(result_neuralode.minimizer) #0.00112
import math
import torch
import torchcde

######################
# A CDE model looks like
#
# z_t = z_0 + \int_0^t f_\theta(z_s) dX_s
#
# Where X is your data and f_\theta is a neural network. So the first thing we need to do is define such an f_\theta.
# That's what this CDEFunc class does.
# Here we've built a small single-hidden-layer neural network, whose hidden layer is of width 128.
######################
class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        ######################
        # input_channels is the number of input channels in the data X. (Determined by the data.)
        # hidden_channels is the number of channels for z_t. (Determined by you!)
        ######################
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)

    ######################
    # For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at
    # different times, which would be unusual. But it's there if you need it!
    ######################
    def forward(self, t, z):
        # z has shape (batch, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        ######################
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        ######################
        z = z.tanh()
        ######################
        # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
        # because we need it to represent a linear map from R^input_channels to R^hidden_channels.
        ######################
        z = z.view(z.size(0), self.hidden_channels, self.input_channels)
        return z


######################
# Next, we need to package CDEFunc up into a model that computes the integral.
######################
class NeuralCDE(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()

        self.func = CDEFunc(input_channels, hidden_channels)
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.readout = torch.nn.Linear(hidden_channels, output_channels)

    def forward(self, coeffs):
        X = torchcde.NaturalCubicSpline(coeffs)

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        z0 = self.initial(X.evaluate(0.))

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.func,
                              t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y


######################
# Now we need some data.
# Here we have a simple example which generates some spirals, some going clockwise, some going anticlockwise.
######################
def get_data():
    t = torch.linspace(0., 4 * math.pi, 100)

    start = torch.rand(128) * 2 * math.pi
    x_pos = torch.cos(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos[:64] *= -1
    y_pos = torch.sin(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos += 0.01 * torch.randn_like(x_pos)
    y_pos += 0.01 * torch.randn_like(y_pos)
    ######################
    # Easy to forget gotcha: time should be included as a channel; Neural CDEs need to be explicitly told the
    # rate at which time passes. Here, we have a regularly sampled dataset, so appending time is pretty simple.
    ######################
    X = torch.stack([t.unsqueeze(0).repeat(128, 1), x_pos, y_pos], dim=2)
    y = torch.zeros(128)
    y[:64] = 1

    perm = torch.randperm(128)
    X = X[perm]
    y = y[perm]

    ######################
    # X is a tensor of observations, of shape (batch=128, sequence=100, channels=3)
    # y is a tensor of labels, of shape (batch=128,), either 0 or 1 corresponding to anticlockwise or clockwise respectively.
    ######################
    return X, y


def main(num_epochs=30):
    device = "cuda"
    dtype = torch.float32
    train_X, train_y = get_data()
    train_X, train_y = train_X.type(dtype), train_y.type(dtype)

    ######################
    # input_channels=3 because we have both the horizontal and vertical position of a point in the spiral, and time.
    # hidden_channels=8 is the number of hidden channels for the evolving z_t, which we get to choose.
    # output_channels=1 because we're doing binary classification.
    ######################
    model = NeuralCDE(input_channels=3, hidden_channels=8, output_channels=1).type(dtype).to(device)
    optimizer = torch.optim.Adam(model.parameters())

    ######################
    # Now we turn our dataset into a continuous path. We do this here via natural cubic spline interpolation.
    # The resulting `train_coeffs` is a tensor describing the path.
    # For most problems, it's probably easiest to save this tensor and treat it as the dataset.
    ######################
    train_coeffs = torchcde.natural_cubic_spline_coeffs(train_X)

    train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
    
    #warm up
    for batch in train_dataloader:
        batch_coeffs, batch_y = batch
        batch_coeffs, batch_y = batch_coeffs.to(device), batch_y.to(device)
        pred_y = model(batch_coeffs).squeeze(-1)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        break
    
    start = time.time()
    for epoch in range(num_epochs):
        for batch in train_dataloader:
            batch_coeffs, batch_y = batch
            batch_coeffs, batch_y = batch_coeffs.to(device), batch_y.to(device)
            pred_y = model(batch_coeffs).squeeze(-1)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print('Epoch: {}   Training loss: {}'.format(epoch, loss.item()))
    end = time.time()
    print("Time taken: {} seconds".format(end-start))
    
    return end-start

import time
from tqdm.notebook import tqdm

times_taken = [main(10) for _ in tqdm(range(64))]

import numpy as np

np.mean(times_taken) #59.79856628552079

@pharringtonp19
Copy link

What's the latest status on this project? Seems useful

@ChrisRackauckas
Copy link
Member

I don't think anyone has picked it up. In terms of differentiable interpolations, DataInterpolations.jl got some nice stable differentiability overloads, so this should be easy pickings but someone needs to package it all up.

@johnp-4dvanalytics
Copy link
Author

johnp-4dvanalytics commented Aug 19, 2021

@ChrisRackauckas I refactored the CPU version and changed it to only use Julia code. I created a simple .md example file for it. Could you point me to a guide for how to open a pull request for it?

@ChrisRackauckas
Copy link
Member

https://www.youtube.com/watch?v=QVmU29rCjaA is a tutorial for all of that kind of stuff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants