In [3]:
import os

import pytensor
import pytensor.tensor as pt
import pytensor.tensor.type as ptt
import torch


#os.environ["OMP_PREFIX"] = "/opt/anaconda3/envs/pytensor-dev"

In [4]:
from copy import copy
from operator import itemgetter
from pytensor.compile.builders import construct_nominal_fgraph
from pytensor.graph.basic import Apply
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify

In [20]:
class AutogradOp(Op, HasInnerGraph):
    def __init__(self, inputs, outputs, wrt=None):
        
        self.wrt = []
        self.input_types = [i.type for i in inputs]
        # todo: It's less that we need
        # the outputs, and more that we
        # want an fgraph passed in
        # todo: Shared variables?
        self.fgraph, _, _, _ = construct_nominal_fgraph(
            inputs, outputs
        )

        if wrt is None:
            # take all the vars
            self.wrt = list(enumerate(inputs))
        else:
            if not isinstance(wrt, list | tuple):
                # take only one var
                wrt = [wrt]
            if not set(wrt).issubset(set(inputs)):
                raise RuntimeError(
                    f"You can differentiate unknown inputs: {wrt}, {inputs}"
                )
            self.wrt = [(i, x) for i, x in enumerate(inputs) if x in wrt] 

    def make_node(self, *inputs):
        apply_inputs = [i_t.filter_variable(i) for i, i_t in zip(inputs, self.input_types, strict=True)]
        return Apply(self, apply_inputs, [t[1].type() for t in self.wrt])

    def perform(self, *args, **kwargs):
        raise RuntimeError("Should not go to c runtime")
    
    def clone(self):
        res = copy(self)
        res.fgraph = res.fgraph.clone()
        return res

    @property
    def fn(self):
        """Lazily compile the inner function graph."""
        if getattr(self, "_fn", None) is not None:
            return self._fn

        self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
        self._fn.trust_input = True

        return self._fn

    @property
    def inner_inputs(self):
        return self.fgraph.inputs

    @property
    def inner_outputs(self):
        return self.fgraph.outputs

@pytorch_funcify.register(AutogradOp)
def autograd(op, node, **kwargs):
    inner_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
    indicies = list(map(itemgetter(0), op.wrt))

    def fn(*inputs):
        inputs = [i.requires_grad_(True) for i in inputs]
        out = inner_fn(*inputs)
        out.backward()
        grads = [inputs[i].grad for i in indicies]
        if len(grads) == 1:
            return grads[0]
        else:
            return grads

    return torch.compiler.disable(fn)

In [21]:
# x^2 + 2xy + y^2
# for this example, pow is broken
from pytensor.scalar.basic import Pow
@pytorch_funcify.register(Pow)
def pow(op, node, **kwargs):
    return torch.pow

x = ptt.scalar("x")
y = ptt.scalar("y")
res = pt.pow(x, 2) + (2 * x * y) + pt.pow(y, 3)
autograder = AutogradOp([x, y], [res], [x, y])
f = pytensor.function([x, y], autograder(x, y), mode="PYTORCH")

In [22]:
f(3, 2)

[array(10.), array(18.)]

In [41]:
from collections.abc import Sequence

import pymc as pm
from pymc import Model
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.initial_point import PointType, make_initial_point_fn
from pymc.model.core import modelcontext, join_nonshared_inputs
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.vartypes import continuous_types, discrete_types, typefilter
from pytensor.graph.basic import Constant, Variable, graph_inputs

from pymc.util import (
    UNSET,
    VarName,
    WithMemoization,
    _add_future_warning_tag,
    _UnsetType,
    get_transformed_name,
    get_value_vars_from_user_vars,
    get_var_name,
    treedict,
    treelist,
)

class AutogradModel(Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def logp_dlogp_function(self,
        grad_vars=None,
        tempered=False,
        initial_point: PointType | None = None,
        ravel_inputs: bool | None = None,
        **kwargs,
    ): 
        # this all comes from the base class
        if grad_vars is None:
            grad_vars = self.continuous_value_vars
        else:
            grad_vars = get_value_vars_from_user_vars(grad_vars, self)
            for i, var in enumerate(grad_vars):
                if var.dtype not in continuous_types:
                    raise ValueError(f"Can only compute the gradient of continuous types: {var}")

        if tempered:
            costs = [self.varlogp, self.datalogp]
        else:
            costs = [self.logp()]

        input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
        if initial_point is None:
            initial_point = self.initial_point(0)
        extra_vars_and_values = {
            var: initial_point[var.name]
            for var in self.value_vars
            if var in input_vars and var not in grad_vars
        }
        
        # new logic
        class FnWrapper:
            def __init__(
                self,
                costs,
                grad_vars,
                extra_vars_and_values=None,
                *,
                dtype=None,
                casting="no",
                compute_grads=True,
                model=None,
                initial_point: PointType | None = None,
                ravel_inputs: bool | None = None,
                **kwargs,
            ):
                if extra_vars_and_values is None:
                    extra_vars_and_values = {}

                names = [arg.name for arg in grad_vars + list(extra_vars_and_values.keys())]
                if any(name is None for name in names):
                    raise ValueError("Arguments must be named.")
                if len(set(names)) != len(names):
                    raise ValueError("Names of the arguments are not unique.")

                self._grad_vars = grad_vars
                self._extra_vars = list(extra_vars_and_values.keys())
                self._extra_var_names = {var.name for var in extra_vars_and_values.keys()}

                if dtype is None:
                    dtype = pytensor.config.floatX
                self.dtype = dtype

                self._n_costs = len(costs)
                if self._n_costs == 0:
                    raise ValueError("At least one cost is required.")

                cost = costs[0]

                self._extra_are_set = False
                givens = []
                self._extra_vars_shared = {}
                for var, value in extra_vars_and_values.items():
                    shared = pytensor.shared(value, var.name + "_shared__", shape=value.shape)
                    self._extra_vars_shared[var.name] = shared
                    givens.append((var, shared))

                if compute_grads:
                    grads = AutogradOp(grad_vars, [cost], grad_vars)
                    outputs = [cost, [grads]]
                else:
                    outputs = [cost]

                if ravel_inputs:
                    if initial_point is None:
                        initial_point = modelcontext(model).initial_point()
                    outputs, raveled_grad_vars = join_nonshared_inputs(
                        point=initial_point, inputs=grad_vars, outputs=outputs, make_inputs_shared=False
                    )
                    inputs = [raveled_grad_vars]
                else:
                    inputs = grad_vars

                self._pytensor_function = compile(inputs, outputs, givens=givens, **kwargs)
                self._raveled_inputs = ravel_inputs

            def __call__(self, grad_vars, *, extra_vars=None):
                if extra_vars is not None:
                    self.set_extra_values(extra_vars)
                elif not self._extra_are_set:
                    raise ValueError("Extra values are not set.")

                if isinstance(grad_vars, RaveledVars):
                    if self._raveled_inputs:
                        grad_vars = (grad_vars.data,)
                    else:
                        grad_vars = DictToArrayBijection.rmap(grad_vars).values()
                elif self._raveled_inputs and not isinstance(grad_vars, Sequence):
                    grad_vars = (grad_vars,)

                return self._pytensor_function(*grad_vars)
        return FnWrapper(
            costs,
            grad_vars,
            extra_vars_and_values,
            model=self,
            initial_point=initial_point,
            ravel_inputs=ravel_inputs,
            **kwargs
        )

In [42]:
with AutogradModel() as m:
    res = pm.Normal("blah", mu=0, sigma=1)
    pm.sample(10)

Only 10 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md

Initializing NUTS using jitter+adapt_diag...


TypeError: Outputs must be pytensor Variable or Out instances. Received AutogradOp of type <class '__main__.AutogradOp'>