In [18]:
import abc
from beartype import beartype
from beartype.typing import List, Tuple
from __future__ import annotations


@beartype
class Scalar(abc.ABC):
    def __init__(self):
        pass

    def __add__(self, other: Scalar):
        return Sum(self, other)

    @abc.abstractmethod
    def _repr_latex_(self):
        pass


@beartype
class ScalarSymbol(Scalar):
    def __init__(self, name: str):
        self.name = name

    def _repr_latex_(self):
        return f"{self.name:s}"


@beartype
class ScalarConstant(Scalar):
    def __init__(self, value: float):
        self.value = value

    def _repr_latex_(self):
        return f"{self.value:g}"


@beartype
class Neg(abc.ABC):
    def __init__(self, right: Scalar):
        pass


@beartype
class Sum(abc.ABC):
    def __init__(self, left: Scalar, right: Scalar):
        self.left = left
        self.right = right

    def _repr_latex_(self):
        return f"{self.left._repr_latex_():s} + {self.right._repr_latex_():s}"


@beartype
class Matrix2D(abc.ABC):
    def __init__(self, n_rows: int, n_cols: int):
        self.n_rows = n_rows
        self.n_cols = n_cols

    @abc.abstractmethod
    def __getitem__(self, index) -> Scalar:
        pass

    @abc.abstractmethod
    def __setitem__(self, index, value) -> None:
        pass

    @abc.abstractmethod
    def inv(self, index, value) -> Matrix2D:
        pass

    def _repr_latex_(self) -> str:
        s = "\\begin{bmatrix}"
        for i in range(self.n_rows):
            # s += "\n"
            for j in range(self.n_cols):
                v = self.__getitem__((i, j))
                s += v._repr_latex_()
                if j < self.n_cols - 1:
                    s += "&"
                else:
                    s += "\\\\"
        s += "\\end{bmatrix}"
        return s


import numpy as np


@beartype
class Matrix2DNumpy(Matrix2D):
    def __init__(self, n_rows: int, n_cols: int, data=None):
        super().__init__(n_rows, n_cols)
        if data is None:
            self._data = np.zeros((n_rows, n_cols))
        else:
            self._data = data

    def __getitem__(self, index) -> Scalar:
        v = self._data.__getitem__((index[0], index[1]))
        return ScalarConstant(v)

    def __setitem__(self, index, value) -> None:
        self._data.__setitem__((index[0], index[1]), value)

    def inv(self) -> Matrix2D:
        assert self.n_rows == self.n_cols
        return Matrix2DNumpy(self.n_rows, self.n_cols, np.linalg.inv(self._data))


import casadi as ca


@beartype
class Matrix2DCasadiSX(Matrix2D):
    def __init__(self, n_rows: int, n_cols: int, data=None):
        super().__init__(n_rows, n_cols)
        if data is None:
            self._data = ca.SX.zeros((n_rows, n_cols))
        else:
            self._data = data

    def __getitem__(self, index: Tuple[int, int]) -> Scalar:
        v = self._data.__getitem__((index[0], index[1]))
        if v.op() == 44:
            return ScalarConstant(float(v))
        elif v.op() == 47:
            return ScalarSymbol(v.name())
        elif v.op() == 5:
            # TODO
            return ScalarConstant(0.0)
        else:
            raise NotImplementedError("unhandled casadi op code")

    def __setitem__(self, index, value) -> None:
        self._data.__setitem__((index[0], index[1]), value)

    def inv(self):
        assert self.n_rows == self.n_cols
        return Matrix2DCasadiSX(self.n_rows, self.n_cols, ca.inv(self._data))


@beartype
class Neg(abc.ABC):
    def __init__(self, right: Scalar):
        return Scalar


import sympy


@beartype
class Matrix2DSympy(Matrix2D):
    def __init__(self, n_rows: int, n_cols: int, data=None):
        super().__init__(n_rows, n_cols)
        if data is None:
            self._data = sympy.zeros(n_rows, n_cols)
        else:
            self._data = data

    def __getitem__(self, index) -> Scalar:
        v = self._data.__getitem__((index[0], index[1]))
        if isinstance(v, sympy.Symbol):
            return ScalarSymbol(v.name)
        elif isinstance(v, sympy.core.numbers.Zero):
            return ScalarConstant(0.0)
        else:
            raise NotImplementedError("unhandled type: {:s}".format(str(type(v))))

    def __setitem__(self, index, value) -> None:
        self._data.__setitem__((index[0], index[1]), value)

    def inv(self) -> Matrix2D:
        assert self.n_rows == self.n_cols
        return Matrix2DSympy(self.n_rows, self.n_cols, sympy.Matrix.inv(self._data))

In [19]:
A = Matrix2DNumpy(3, 3)
A[0, 0] = 0.5
A[1, 1] = 0.5
A[2, 2] = 0.5
A.inv()

<__main__.Matrix2DNumpy at 0x7fd9b16cec50>

In [26]:
A = Matrix2DCasadiSX(3, 3)
A[0, 0] = 0.5
A[1, 1] = 0.5
A[2, 2] = 0.5
A[0, 2] = ca.SX.sym("y")
A

<__main__.Matrix2DCasadiSX at 0x7fd9b16ce8f0>

In [28]:
ca.inv(A._data)

SX(@1=2, 
[[@1, 00, (-((@1*y)/0.5))], 
 [00, @1, 00], 
 [00, 00, @1]])

In [21]:
A.inv()

<__main__.Matrix2DCasadiSX at 0x7fd9b16cc700>

In [22]:
A = Matrix2DSympy(3, 3)
A[0, 0] = 0.5
A[1, 1] = 0.5
A[2, 2] = 0.5
A[0, 2] = sympy.symbols("y")
A

<__main__.Matrix2DSympy at 0x7fd9b16cf070>

In [23]:
A.inv()

<__main__.Matrix2DSympy at 0x7fd9b16cef20>

In [24]:
s = ScalarSymbol("x") + ScalarConstant(1.0)
s

<__main__.Sum at 0x7fd9b16cf8b0>

In [25]:
s

<__main__.Sum at 0x7fd9b16cf8b0>