In [None]:
from sympy import eye, Matrix, Rational
from typing import Literal
from dataclasses import dataclass

OpType = Literal["n->kn", "n<->m", "n->n+km"]


@dataclass(frozen=True)
class ScaleOp:
    n: int
    k: Rational


@dataclass(frozen=True)
class SwapOp:
    n: int
    m: int


@dataclass(frozen=True)
class AddOp:
    n: int
    m: int
    k: Rational


Op = SwapOp | ScaleOp | AddOp


class RowReduction:
    def __init__(self, A: Matrix):
        self.A = A
        self.m = A.rows
        self.n = A.cols
        self.ops: list[Op] = []

    def op(
        self,
        i: int | None = None,
        j: int | None = None,
        c: int | Rational | None = None,
    ):
        assert i is not None
        if j is not None and c is None:
            op = SwapOp(i, j)
        elif j is None and c is not None:
            op = ScaleOp(i, c)
        elif j is not None and c is not None:
            op = AddOp(i, j, c)
        else:
            raise ValueError()
        self.ops.append(op)
        display(self.result)

    @property
    def result(self) -> Matrix:
        B = self.A.copy()
        for op in self.ops:
            match op:
                case ScaleOp(i, c):
                    B = B.elementary_row_op("n->kn", row=i, k=c)  # type: ignore
                case SwapOp(i, j):
                    B = B.elementary_row_op("n<->m", row1=i, row2=j)  # type: ignore
                case AddOp(i, j, c):
                    B = B.elementary_row_op("n->n+km", row1=i, row2=j, k=c)  # type: ignore
        return B  # type: ignore

    @property
    def elementary_matrices(self):
        mats: list[Matrix] = []
        for op in self.ops:
            I = eye(self.A.rows)  # type: ignore
            match op:
                case ScaleOp(i, c):
                    mat = I.elementary_row_op("n->kn", row=i, k=c)  # type: ignore
                case SwapOp(i, j):
                    mat = I.elementary_row_op("n<->m", row1=i, row2=j)  # type: ignore
                case AddOp(i, j, c):
                    mat = I.elementary_row_op("n->n+km", row1=i, row2=j, k=c)  # type: ignore
            mats.append(mat)  # type: ignore
        return mats

    @property
    def reducing_matrix(self) -> Matrix:
        mat = eye(self.A.rows)  # type: ignore
        for E in self.elementary_matrices:
            mat = E @ mat  # type: ignore
        return mat  # type: ignore

    def is_rref(self) -> bool:
        rref, _ = self.A.rref()  # type: ignore
        return rref == self.result()  # type: ignore

In [None]:
A = Matrix([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
A

In [None]:
R = RowReduction(A)

R.op(i=0, j=1)

In [None]:
R.op(i=0, c=Rational(1, 3))

In [None]:
R.op(i=2, j=0, c=-6)

In [None]:
R.op(i=0, j=1, c=Rational(-4, 3))
R.op(i=2, j=1, c=1)

In [None]:
U = R.reducing_matrix
U

In [None]:
U @ A == R.result