<a href="https://colab.research.google.com/github/pymc-devs/pytensor-workshop/blob/main/notebooks/exercises/implementing_a_type.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**💡 To better engage gray mass we suggest you turn off Colab AI autocompletion in `Tools > Settings > AI Assistance`**

In [1]:
%%capture
try:
    import pytensor_workshop
except ModuleNotFoundError:
    !pip install git+https://github.com/pymc-devs/pytensor-workshop.git

In [2]:
# These exercises become tricker if we allow default inplace operations
# If this itches your curiosity, check:
# https://pytensor.readthedocs.io/en/latest/extending/inplace.html
%set_env PYTENSOR_FLAGS="optimizer_excluding=inplace"

env: PYTENSOR_FLAGS="optimizer_excluding=inplace"


In [3]:
from typing import Any, Sequence
from copy import deepcopy
import numpy as np

In [4]:
import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Variable, rewrite_graph
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.rewriting.basic import out2in
from pytensor.graph.type import Type
from pytensor.tensor.type import TensorType

In [5]:
from pytensor_workshop import test

## Implementing a tuple type

In the walkthough, we saw how implementing new variable types in PyTensor may look like. In this exercise we will implement a real novel pytensor type: Tuples!

Types in PyTensor require one method: `filter`. This method is responsible for accepting or rejecting concrete data as being compatible with the specified type.

It can also be given permission to convert the data into the appropriate type if `strict=False`.

To be able to fully work with Tuples in PyTensor we need to know something about the types contained inside (this will be clear next). For this reason we parametrize each `TupleType` with a sequence of `Type`s.

In [6]:
class TupleType(Type[tuple]):
    def __init__(self, entry_types: Sequence[Type]):
        self.entry_types = entry_types

    def filter(
        self,
        data: "Any",
        strict: bool = False,
        allow_downcast: bool | None = None
    ) -> tuple:
        """Return data or an appropriately wrapped/converted data.

        Subclass implementations should raise a TypeError exception if
        the data is not of an acceptable type.

        Parameters
        ----------
        data:
            The data to be filtered/converted.
        strict: bool (optional)
            If ``True``, the data returned must be the same as the
            data passed as an argument.
        allow_downcast: bool (optional)
            If `strict` is ``False``, and `allow_downcast` is ``True``, the
            data may be cast to an appropriate type. If `allow_downcast` is
            ``False``, it may only be up-cast and not lose precision. If
            `allow_downcast` is ``None`` (default), the behaviour can be
            type-dependent, but for now it means only Python floats can be
            down-casted, and only to floatX scalars.
        """
        if not isinstance(data, tuple):
            if strict:
                raise TypeError("data should be a tuple")
            elif isinstance(data, list):
                data = tuple(data)
            else:
                raise TypeError("cannot convert data to tuple")

        if len(data) != len(self.entry_types):
            raise TypeError(f"data should be a tuple of length {len(self.entry_types)}")

        data = tuple(
            entry_type.filter(entry, strict=strict, allow_downcast=allow_downcast)
            for entry, entry_type in zip(data, self.entry_types)
        )

        return data

    def __str__(self):
        return f"{tuple(str(t) for t in self.entry_types)}"

Let's implement a tuple type that contains one vector and one matrix (or arbitrary size)

In [7]:
vector = TensorType(shape=(None,), dtype="float64", name="vector")
matrix = TensorType(shape=(None, None), dtype="float64", name="matrix")
vec_mat_tuple_type = TupleType([vector, matrix])

And now define a variable of this type, that we can hopefully evaluate

In [8]:
xy = vec_mat_tuple_type("(x,y)")
xy

(x,y)

In [9]:
xy.eval({xy: ([1, 1], [[1, 2], [3, 4]])})

(array([1., 1.]),
 array([[1., 2.],
        [3., 4.]]))

The filter method will trigger when we try to provide a value that is not compatible with the type

In [10]:
try:
    xy.eval({xy: ([[1, 2], [3, 4]], [1, 1])})
except TypeError as exc:
    print(exc)

Bad input argument to pytensor function with name "<ipython-input-9-45228ad6e6c0>:1" at index 0 (0-based).  
Backtrace when that variable is created:

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
    result = self._run_cell(
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
    return runner(coro)
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
    coro.send(None)
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_no

Having new types is not very interesting if we can't do anything symbolically with them. So let's implement two basic operations to create and select from tuples:

In [11]:
class PackTuple(Op):
    """Pack arbitrary PyTensor variables into a tuple"""

    def make_node(self, *inputs):
        # Define a new tuple type and variable
        output_tuple_type = TupleType([inp.type for inp in inputs])
        output = output_tuple_type()
        return Apply(self, inputs, [output])

    def perform(self, node, inputs, output_storage):
        output_storage[0][0] = tuple(inputs)

pack_tuple = PackTuple()

In [12]:
x = pt.arange(3)
y = pt.zeros(7)
xy = pack_tuple(x, y)
xxy = pack_tuple(x, xy)
xxy.dprint()

PackTuple [id A]
 ├─ ARange{dtype='int64'} [id B]
 │  ├─ 0 [id C]
 │  ├─ 3 [id D]
 │  └─ 1 [id E]
 └─ PackTuple [id F]
    ├─ ARange{dtype='int64'} [id B]
    │  └─ ···
    └─ Alloc [id G]
       ├─ 0.0 [id H]
       └─ 7 [id I]


<ipykernel.iostream.OutStream at 0x7d6f2d345030>

In [13]:
xxy.eval()

(array([0, 1, 2]), (array([0, 1, 2]), array([0., 0., 0., 0., 0., 0., 0.])))

## Exercise 1: Implement an Op that selects an entry from a Tuple

Implement the `make_node` and `perform` for the `SelectTuple` `Op` below.
Given a static index it should select the corresponding entry from a symbolic tuple.

In [14]:
class SelectTuple(Op):
    """Select an entry from a PyTensor tuple"""
    __props__ = ("idx",)

    def __init__(self, idx: int):
        self.idx = idx

    def make_node(self, tpl):
        ...

    def perform(self, node, inputs, output_storage):
        ...

@test
def test_select_tuple(selecet_op_class):
    x = pt.arange(3)
    y = pt.zeros(7)
    xy = pack_tuple(x, y)
    xxy = pack_tuple(x, xy)

    y_again = selecet_op_class(1)(selecet_op_class(1)(xxy))
    assert y_again.type == y.type
    y_again.dprint()

    np.testing.assert_allclose(y_again.eval(), np.zeros((7,)))
    np.testing.assert_allclose(y_again.eval({y: np.arange(7)}), np.arange(7))

# test_select_tuple(SelectTuple)  # uncomment me

## Exercise 2: Use the new tuple type

Compile a pytensor function that takes a tuple with a vector and matrix as inputs, squares the first and cubes the second, and packs them back as a tuple in reversed order

In [15]:
inp = ...
out = ...
# fn = pytensor.function(inputs=[inp], outputs=[out])

@test
def test_tuple_function(fn):
    from pytensor.compile.function.types import Function
    assert isinstance(fn, Function)
    assert isinstance(fn.maker.fgraph.inputs[0].type, TupleType)
    assert isinstance(fn.maker.fgraph.outputs[0].type, TupleType)

    x_size = np.random.poisson(7)
    x_test = np.random.normal(size=x_size)
    y_size = np.random.poisson(7, size=(2,))
    y_test = np.abs(np.random.normal(size=y_size))

    [(out_1, out_2)] = fn((x_test, y_test))
    np.testing.assert_allclose(out_2, x_test**2)
    np.testing.assert_allclose(out_1, y_test**3)

# test_tuple_function(fn)  # uncomment me

## Exercise 3: Define a rewrite that undoes a SelectTuple(PackTuple)

PyTensor existing machinery has little ability to reason about graphs that contain our newly defite tuple type.

One thing it can do, is constant fold the operations we defined on tuples.



In [16]:
# Uncomment me when SelectTuple is implemented

# x = pt.arange(3)
# y = pt.zeros(7)
# xy = pack_tuple(x, y)
# xxy = pack_tuple(x, xy)
# y_again = SelectTuple(1)(SelectTuple(1)(xxy))

# with pytensor.config.change_flags(optimizer_verbose=True):
#     rewrite_graph(y_again).dprint()

But if we have something less static, it can't do much. For example the next graph should be equivalent to zero, but PyTensor cannot figure that out because it can't see through the sequence of SelectTuple(PackTuple).

It cannot constant_fold, because `x` is not a constant.

In [17]:
# Uncomment me when SelectTuple is implemented

# x = pt.scalar("x")
# x1 = pack_tuple(x, pt.ones(()))
# x_again = SelectTuple(0)(x1)
# out = x - x_again
# out.dprint()

# print()

# rewrite_graph(out).dprint()

You should fill in the `select_entry_of_packed_tuple` node rewriter, so that it undoes a SelectTuple(PackTuple).

The function should return None when the node in question does not match the pattern we care about, and should return a list containing the variable(s) that we want to replace the output node with, when it does match the pattern.

In [18]:
from pytensor.graph.rewriting.basic import out2in, node_rewriter

@node_rewriter(tracks=None)
def select_entry_of_packed_tuple(fgraph, node) -> list[Variable] | None:
    """Rewrite SelectTuple(idx)(PackTuple()(tpl)) -> tpl[idx]"""
    ...

@test
def test_useless_pack_tuple(rewrite_fn):
    from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
    from pytensor.tensor.rewriting.math import local_add_canonizer
    from pytensor.tensor.rewriting.basic import constant_folding
    from pytensor.graph.rewriting.utils import equal_computations

    # Define a graph rewriter that applies the following 3 rewrites until an equilibrium is achieved
    rewrite = EquilibriumGraphRewriter([select_entry_of_packed_tuple, local_add_canonizer, constant_folding], max_use_ratio=3)

    # Test that using those 3 rewrites, pytensor concludes the output equals zero
    x = pt.scalar("x")
    x1 = pack_tuple(x, pt.ones(()))
    x_again = SelectTuple(0)(x1)
    out = x - x_again

    fg = FunctionGraph(outputs=[out], clone=False)
    with pytensor.config.change_flags(optimizer_verbose=True):
        rewrite.apply(fg)
    [rewritten_out] = fg.outputs
    rewritten_out.dprint()

    assert equal_computations([rewritten_out], [pt.constant(np.array(0.0))])

    # Test that rewrite does not cause any failure in a case where it can't be applied
    scalar = pt.TensorType("float64", shape=())
    xx = TupleType([scalar, scalar])()
    out = SelectTuple(0)(xx) * 2

    fg = FunctionGraph(outputs=[out], clone=False)
    rewrite.apply(fg)
    [rewritten_out] = fg.outputs

    assert equal_computations([out], [rewritten_out])

# test_useless_pack_tuple(select_entry_of_packed_tuple)  # uncomment me

## Exercise 4: Allow symbolic indexing on tuple

Define a new SelectTuple Op where the indexing position is symbolic (shows up as a second input to the apply node).

What constraints are needed on the input tuple type for this to be a well defined operation at compile type?

A TypeError should be raised when such constraints aren't met.

In [19]:
class SymbolicSelectTuple(Op):
    """Select an entry from a PyTensor tuple"""
    __props__ = ()

    def make_node(self, tpl, idx):
        ...

    def perform(self, node, inputs, output_storage):
        ...

@test
def test_symbolic_select_tuple(select_op_class):
    select_op = select_op_class()

    vector_type = pt.TensorType(shape=(None,), dtype="float64", name="vector")
    vec_vec_tuple_type = TupleType([vector_type, vector_type])

    # Check Op semantics are correctly implemented
    idx = pt.scalar("idx", dtype="int64")
    xx = vec_vec_tuple_type("(x,x)")
    y = select_op(xx, idx)
    assert y.type == vector_type  # Only possible result

    # Try to evaluate it
    xx_test = ([0, 1], [2, 3, 4])
    np.testing.assert_allclose(
        y.eval({xx: xx_test, idx:0}),
        [0, 1],
    )

    # Test case that cannot possibly be disambiguated at compile time
    matrix_type = pt.TensorType(shape=(None, None), dtype="float64", name="matrix")
    vec_mat_tuple_type = TupleType([vector_type, matrix_type])
    try:
        xy = vec_mat_tuple_type("(x,y)")
        select_op(xy, idx)
    except TypeError as exc:
        pass
    else:
        assert 0, "Should have raised a TypeError"

    # Trickier case that should be valid
    nested_vec_vec_tuple_type = TupleType([vec_vec_tuple_type, vec_vec_tuple_type])
    nested_xx = nested_vec_vec_tuple_type("((x,x),(x,x))")

    xx = select_op(nested_xx, idx)
    assert xx.type == vec_vec_tuple_type

    x = select_op(xx, 1-idx)
    assert x.type == vector_type

    nested_xx_test_value = (
        ([0, 1], [2, 3, 4]),
        ([4, 5, 6, 7], [8, 9, 10, 11, 12]),
    )
    np.testing.assert_allclose(
        x.eval({nested_xx: nested_xx_test_value, idx: 0}),
        [2, 3, 4],
    )


# test_symbolic_select_tuple(SymbolicSelectTuple)  # uncomment me

## Open-ended challenge: implement your new type and a basic operation on it.

If you got a feel for how to work with non-array types in PyTensor, we challenge you to try to implement yet another type that speaks to you.

Here are some suggestions:
* Strings and find operation
* [Xarray DataArray](https://docs.xarray.dev/en/latest/getting-started-guide/quick-overview.html) with a dim-based broadcasting or indexing operation
* [Numpy polynomials](https://numpy.org/doc/stable/reference/routines.polynomials.html) with addition operation
* [Sparse COO matrices](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_array.html#scipy.sparse.coo_array) and expm1 operation
* Anything else you fancy
