In [1]:
import sympy as sp
from typing import List
from dataclasses import dataclass

In this notebook, we explore how to model contiguity and strides in a
universe where we support dynamic shapes.  We don't care about dynamic
strides/contiguity per se (we'd be OK with specializing on the input
being contiguous, channels-last, etc), but strides and contiguity
are *derived* from shapes, so if you have dynamic shapes, you
also end up with dynamic strides and contiguity.

Let's take a concrete look at this phenomenon in the simplest possible
context: a contiguous tensor.  Here is the C++ code which implements
computation of contiguous strides for a tensor:

In [2]:
"""
// From c10/util/strides.h
// Computes the contiguous strides of a tensor, given its sizes.
static inline std::vector<typename IntArrayRef::value_type> contiguous_strides(
    const IntArrayRef sizes) {
  using Int = IntArrayRef::value_type;
  const Int dims = static_cast<Int>(sizes.size());

  std::vector<Int> strides;

  if (dims > 0) {
    strides.assign(dims, 0);
    // Start by populating the last dimension: its strides is always 1.
    strides[dims - 1] = 1;
    for (auto i = dims - 2; i >= 0; --i) {
      // Strides can't be 0 even if sizes are 0.
      strides[i] = strides[i + 1] * std::max(sizes[i + 1], Int{1});
    }
  }

  return strides;
}
"""

"\n// From c10/util/strides.h\n// Computes the contiguous strides of a tensor, given its sizes.\nstatic inline std::vector<typename IntArrayRef::value_type> contiguous_strides(\n    const IntArrayRef sizes) {\n  using Int = IntArrayRef::value_type;\n  const Int dims = static_cast<Int>(sizes.size());\n\n  std::vector<Int> strides;\n\n  if (dims > 0) {\n    strides.assign(dims, 0);\n    // Start by populating the last dimension: its strides is always 1.\n    strides[dims - 1] = 1;\n    for (auto i = dims - 2; i >= 0; --i) {\n      // Strides can't be 0 even if sizes are 0.\n      strides[i] = strides[i + 1] * std::max(sizes[i + 1], Int{1});\n    }\n  }\n\n  return strides;\n}\n"

And a port to Python:

In [3]:
def contiguous_strides(sizes: List[int]):
    dims = len(sizes)
    strides = []
    if dims > 0:
        strides = [0] * dims
        strides[dims - 1] = 1
        for i in range(dims - 2, -1, -1):
            strides[i] = strides[i + 1] * sp.Max(sizes[i + 1], 1)
    return strides

In [4]:
print(contiguous_strides([2, 3, 5]))

[15, 5, 1]


Let's look at the symbolic output of this function.  When only the batch
dimension is dynamic, things are pretty simple:

In [5]:
x = sp.symbols('x')
print(contiguous_strides([x, 3, 5]))

[15, 5, 1]


However, if an inner dimension is dynamic, the dynamic shape variable
shows up in the stride calculation

In [6]:
print(contiguous_strides([2, x, 5]))

[5*Max(1, x), 5, 1]


The set of strides returned by contiguous is guaranteed to be
contiguous, but the inverse is not true: there are some degrees of
freedom in the definition of strides when sizes are one or zero.
Here is our definition of "when something is contiguous" (not accounting
for overflow):

In [7]:
"""
// In c10/core/TensorImpl.h
inline bool is_empty() const {
  return numel() == 0;
}

// In c10/core/TensorImpl.cpp
bool TensorImpl::compute_contiguous() const {
  bool is_contiguous = true;
  if (is_empty())
    return is_contiguous;
  int64_t z = 1;
  for (int64_t d = dim() - 1; d >= 0; d--) {
    const auto size_d = sizes_and_strides_.size_at_unchecked(d);
    if (size_d != 1) {
      if (sizes_and_strides_.stride_at_unchecked(d) == z) {
        z *= size_d;
      } else {
        is_contiguous = false;
        break;
      }
    }
  }
  return is_contiguous;
}
"""

'\n// In c10/core/TensorImpl.h\ninline bool is_empty() const {\n  return numel() == 0;\n}\n\n// In c10/core/TensorImpl.cpp\nbool TensorImpl::compute_contiguous() const {\n  bool is_contiguous = true;\n  if (is_empty())\n    return is_contiguous;\n  int64_t z = 1;\n  for (int64_t d = dim() - 1; d >= 0; d--) {\n    const auto size_d = sizes_and_strides_.size_at_unchecked(d);\n    if (size_d != 1) {\n      if (sizes_and_strides_.stride_at_unchecked(d) == z) {\n        z *= size_d;\n      } else {\n        is_contiguous = false;\n        break;\n      }\n    }\n  }\n  return is_contiguous;\n}\n'

In Python (note that we will use the suffix branchy to refer
to code which branches on the concrete value of sizes/strides):

In [8]:
def compute_numel(sizes: List[int]):
    numel = 1
    for s in sizes:
        numel *= s
    return numel

In [9]:
def compute_contiguous_branchy(sizes: List[int], strides: List[int]):
    is_contiguous = True
    if compute_numel(sizes) == 0:
        return is_contiguous
    z = 1
    for d in range(len(sizes) - 1, -1, -1):
        if sizes[d] != 1:
            if strides[d] == z:
                z *= sizes[d]
            else:
                is_contiguous = False
                break
    return is_contiguous

When a dimension has size 1, we are indifferent to the stride at that
dimension:

In [10]:
print(contiguous_strides([3, 1, 5]))

[5, 5, 1]


In [11]:
print(compute_contiguous_branchy([3, 1, 5], [5, 5, 1]))
print(compute_contiguous_branchy([3, 1, 5], [5, 999999, 1]))

True
True


When a tensor contains zero elements, we are indifferent to all the
strides

In [12]:
print(contiguous_strides([3, 0, 5]))

[5, 5, 1]


In [13]:
print(compute_contiguous_branchy([3, 0, 5], [5, 5, 1]))
print(compute_contiguous_branchy([3, 0, 5], [123456, 999999, 424242]))

True
True


Can we compute_contiguous symbolically?  Unfortunately, the "branchy"
implementation, as written above cannot be run directly on SymPy
integers, as in several points in the code we condition on the
concrete values of various comparisons on integers.  Fortunately,
we can introduce a SymInt/SymBool abstraction (as done in previous
notebooks) to provide concrete values and record guards expressing
what is required to be true for the computation to be correct.

In [14]:

GUARDS = []

def is_constant(e):
    if hasattr(e, 'is_constant'):
        return e.is_constant()
    elif e is sp.true or e is sp.false:
        return True
    else:
        return False

class SymObject:
    def __post_init__(self):
        if self.expr is None:
            self.expr = sp.sympify(self.val)
        elif not isinstance(self.expr, sp.Expr):
            self.expr = sp.sympify(self.expr)

@dataclass
class SymBool(SymObject):
    val: bool
    expr: sp.Expr = None
    guarded: bool = False

    def __bool__(self):
        if not self.guarded:
            self.guarded = True
            if not is_constant(self.expr):
                if self.val:
                    GUARDS.append(self.expr)
                else:
                    GUARDS.append(sp.Not(self.expr))
        return self.val

def logical_and(self: bool, other: bool):
    if isinstance(self, SymBool) and isinstance(other, SymBool):
        return SymBool(self.val and other.val, sp.And(self.expr, other.expr))
    return sp.And(self, other)

def logical_or(self: bool, other: bool):
    if isinstance(self, SymBool) and isinstance(other, SymBool):
        return SymBool(self.val or other.val, sp.Or(self.expr, other.expr))
    return sp.Or(self, other)

@dataclass
class SymInt(SymObject):
    val: int
    expr: sp.Expr = None
    guarded: bool = False

    def __int__(self):
        if not self.guarded:
            self.guarded = True
            if not is_constant(self.expr):
                GUARDS.append(self.Eq(self.expr, self.val).simplify())
        return self.val

    def __eq__(self, other):
        if not isinstance(other, SymInt):
            other = SymInt(other)
        return SymBool(self.val == other.val, sp.Eq(self.expr, other.expr))

    def __ne__(self, other):
        if not isinstance(other, SymInt):
            other = SymInt(other)
        return SymBool(self.val != other.val, sp.Ne(self.expr, other.expr))

    def __mul__(self, other):
        if not isinstance(other, SymInt):
            other = SymInt(other)
        return SymInt(self.val * other.val, sp.Mul(self.expr, other.expr))

    def __rmul__(self, other):
        if not isinstance(other, SymInt):
            other = SymInt(other)
        return SymInt(self.val * other.val, sp.Mul(self.expr, other.expr))

def I(val, expr=None):
    return SymInt(val, expr)


Let's run our example.  Under the guards model, we must provide
concrete values for every symbolic integer, so we can resolve
conditionals.

In [15]:
x1, x2, x3, y1, y2, y3 = sp.symbols("x1 x2 x3 y1 y2 y3")

In [16]:
GUARDS.clear()
print(compute_contiguous_branchy(
    [I(3, x1), I(1, x2), I(5, x3)],
    [I(5, y1), I(99999, y2), I(1, y3)]
))

True


We see that this tensor is contiguous...

In [17]:
print(GUARDS)

[Ne(x1*x2*x3, 0), Ne(x3, 1), Eq(y3, 1), Eq(x2, 1), Ne(x1, 1), Eq(y1, x3)]


...subject to these conditions.  These conditions say which particular
path through the loop we took: we require the sizes to be nonzero,
there are number of size one equalities/disequalities, and the
equality requirement between y1 and x3 is the "true" contiguity
requirement.

If we are willing to rewrite the definition of compute contiguous, we
can eliminate the branches, giving a symbolic expression with no
guards.

In [18]:
def compute_contiguous(sizes, strides):
    is_contiguous = True
    z = 1
    for d in range(len(sizes) - 1, -1, -1):
        is_contiguous = logical_and(
            is_contiguous,
            logical_or(
                sp.Eq(sizes[d], 1),
                sp.Eq(strides[d], z)
            )
        )
        z *= sizes[d]
    return logical_or(sp.Eq(compute_numel(sizes), 0), is_contiguous)

TODO: prove these two implementations are equivalent, somehow

We can see that no matter the choice of the stride for a size one
dimension, the result is always contiguous:

In [19]:
print(compute_contiguous([3, 1, 5], [5, x, 1]))

True


And we can see the unflattened contiguity requirement for a completely
general size/stride tensor.

In [20]:
print(compute_contiguous([x1, x2, x3], [y1, y2, y3]))

Eq(x1*x2*x3, 0) | ((Eq(x2, 1) | Eq(y2, x3)) & (Eq(x3, 1) | Eq(y3, 1)) & (Eq(x1, 1) | Eq(y1, x2*x3)))


There's other stuff too:

  - We are not "just" compute_contiguous; we also have have variations
    of this for every memory layout we support.  So the same exercise
    needs to apply everywhere.

  - We also have non_overlapping_and_dense which which involves a sort
    which is very annoying.

In conclusion:

  - We have an explicit choice whether or not to branch inside
    implementations of code that may be traced.  More trace friendly
    code is not as good for eager execution (because you can't do
    things like short circuit).

  - If we store SymInt inside TensorImpl, we need to make a call about
    how we represent the contiguity bits inside Tensor.  These bits
    are literally a single bit, so we cannot store a symbolic boolean
    in them.  It seems the easiest fix is to ensure the
    is_contiguous() is virtualized (it is), and then internally run
    (and cache) the symbolic formula done here.