Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High code complexity #917

Open
linusheck opened this issue Oct 17, 2023 · 5 comments
Open

High code complexity #917

linusheck opened this issue Oct 17, 2023 · 5 comments

Comments

@linusheck
Copy link
Contributor

One last complaint from me :D The code for this library is quite complex - it features a lot of different implementations of the same function that are switched inside the function. There are loads of expressions like this:

if dy !== nothing
    if W === nothing
        if inplace_sensitivity(S)
            f(dy, y, p, t)
        else
            recursive_copyto!(dy, vec(f(y, p, t)))
        end
    else
        if inplace_sensitivity(S)
            f(dy, y, p, t, W)
        else
            recursive_copyto!(dy, vec(f(y, p, t, W)))
        end
    end
end

I believe that the library would be much easier to work with if you would put these different implementations into different functions. Have a function that handles the computation if inplace_sensitivity(S) is true, etc.

@ChrisRackauckas
Copy link
Member

The complaint is fine. It is complicated. I'm not sure it's an unnecessary complexity though. If these different calls were in a separate function, then there would be a lot more duplicated code since those dispatches are exactly the same but with the W on the end of them. A cleaner strategy might be some macro or something that's like W !== nothing ? f(dy, y, p, t, W) else f(dy, y, p, t) that is @Wcall f(dy, y, p, t, W) or something.

@linusheck
Copy link
Contributor Author

linusheck commented Oct 30, 2023

Another example:

 if W === nothing
  if DiffEqBase.has_paramjac(f)
      # Calculate the parameter Jacobian into pJ
      f.paramjac(pJ, y, p, t)
  else
      pf.t = t
      pf.u = y
      if inplace_sensitivity(S)
          jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config)
      else
          temp = jacobian(pf, p, sensealg)
          pJ .= temp
      end
  end
else
  if DiffEqBase.has_paramjac(f)
      # Calculate the parameter Jacobian into pJ
      f.paramjac(pJ, y, p, t, W)
  else
      pf.t = t
      pf.u = y
      pf.W = W
      if inplace_sensitivity(S)
          jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config)
      else
          temp = jacobian(pf, p, sensealg)
          pJ .= temp
      end
  end
end

I don't really know how to fix this but a lot of this code is massively branching, always executing the same branch though.

If these different calls were in a separate function, then there would be a lot more duplicated code since those dispatches are exactly the same but with the W on the end of them.

I think it would be better if functions like _vecjacobian would only call abstract functions without any branching, and the functions themselves would figure out the details. In my opinion, there is a lot of code duplication already here, that could be reduced with such a strategy.

Wouldn't using multiple dispatch be enough for this? Maybe this is such a big change that it could only work in a rewrite: encode properties like inplace_sensitivity and W === nothing on the type level. Then just define a single in-place jacobian function that does what f.paramjac, jacobian! or jacobian does based on the types of e.g. S.

The variables like paramjac_config don't have to be globally flying around in the top-level functions. If behavior switches are encoded on the type level, they can be a property of the type.

@linusheck
Copy link
Contributor Author

IDK, you know much more about what this library is actually doing, and such an architecture may be worse, or impossible to implement. Feel free to close the issue, it's just some ideas.

@ChrisRackauckas
Copy link
Member

We can probably do something via dispatch where we make all of them wrapped in a form where it's always f(dy, y, p, t, W), but then f(dy, y, p, t, W::Nothing) = f.f(dy, y, p, t). Then the code can just use the W everywhere and cut down on the number of branches. @frankschae does that sound good to you?

@frankschae
Copy link
Member

yeah sounds like a good idea. I think the adjoint_common.jl file is probably the worst in that regard -- and I am not sure when AbstractDifferentiation might be ready to do that part.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants