Taking Derivatives of Candidate Expression Inside Custom Loss Function #299
-
Hi Miles, I'm working on a problem in which I would like to define a custom loss function that doesn't just use the predicted values of the candidate expression given the training data ( y_pred | X ), but instead evaluates the candidate expression and its derivatives at a number of new (necessarily unknown at the beginning) points while evaluating the loss function. To draw an analogy with neural networks, assume the NN loss function has its own copy of the entire network at each step of the optimization that it uses to make predictions and take derivatives at a number of inputs previously unseen in the training set. Do you have any thoughts on how easy or tricky it would be to make something like this work? If this sounds too vague/confusing, I'd be happy to connect with you over email to provide more details about the problem. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Hi @gopal-iyer, Great question. This is possible by defining the For example: model = PySRRegressor(
...
full_objective="""function my_loss(tree, dataset::Dataset{T,L}, options)::L where {T,L}
prediction, grad, complete = eval_grad_tree_array(tree, dataset.X, options; variable=true)
if !complete
return L(Inf)
end
loss = sum((prediction .- dataset.y) .^ 2) / dataset.n
# The "grad" is a Julia array which contains the gradient with shape (num_features, num_rows)
# e.g., loss += sum(grad)
return loss
end""",
) The specific Julia call that is getting the derivatives is documented here: https://astroautomata.com/SymbolicRegression.jl/stable/api/#Derivatives. Let me know if this helps. |
Beta Was this translation helpful? Give feedback.
-
Hi Miles, |
Beta Was this translation helpful? Give feedback.
I think this would be easiest done on the SymPy side. You can get the expressions with, e.g.,
model.sympy(i)
, wherei
is the index of the expression you want to get. Then, you can compute derivatives withsympy.diff
: https://docs.sympy.org/latest/tutorials/intro-tutorial/calculus.html#derivatives