In [None]:
from sage.combinat.symmetric_group_algebra import e
from sage.sets.family import Family
from sage.rings.rational_field import QQ
from sage.modules.with_basis.subquotient import SubmoduleWithBasis
from sage.misc.cachefunc import cached_method
from sage.matrix.constructor import matrix
from sage.combinat.permutation import Permutation
from sage.combinat.partition import Partition
from sage.combinat.diagram import Diagram
from sage.categories.modules_with_basis import ModulesWithBasis
from sympy.polys.matrices import DomainMatrix
from sympy.physics.quantum.matrixutils import matrix_tensor_product
from sympy import Matrix, SparseMatrix, eye
import matplotlib.pyplot as plt
from sage.misc.cachefunc import cached_method
from sage.modules.with_basis.subquotient import SubmoduleWithBasis
from sage.structure.list_clone import ClonableArray
from sage.sets.family import Family
from math import comb
from sage.matrix.special import diagonal_matrix


class Tabloid(ClonableArray):
    def __init__(self, parent, tableaux):
        tableaux = map(frozenset, tableaux)
        super().__init__(parent, tableaux)

    def check(self):
        # TODO: check that the correct elements are included
        if [len(row) for row in self] != self.parent().partition:
            raise ValueError

    def symmetric_group_action(self, permutation):
        p = self.parent()
        return p.element_class(p, [frozenset(permutation(val) for val in row) for row in self])


class Tabloids(UniqueRepresentation, Parent):
    @staticmethod
    def __classcall_private__(cls, partition):
        partition = Partition(partition)
        return super().__classcall__(cls, partition)

    def __init__(self, partition):
        self.partition = partition

    def __iter__(self):
        yield from self._recursive_build(list(self.partition), [[]] * len(self.partition), 1)

    def _recursive_build(self, existing_partition, existing_tabloid, k):
        if all(row == 0 for row in existing_partition):
            yield self.from_tableaux(existing_tabloid)

        for row in range(len(existing_partition)):
            if existing_partition[row] != 0:
                new_tabloid = [row.copy() for row in existing_tabloid]
                new_partition = existing_partition.copy()

                new_tabloid[row].append(k)
                new_partition[row] -= 1

                yield from self._recursive_build(new_partition, new_tabloid, k + 1)

    def from_tableaux(self, tableaux):
        ret = [frozenset(r) for r in tableaux]

        return self.element_class(self, ret)

    def list(self):
        return [t for t in self]

    def cardinality(self):
        return multinomial(list(self.partition))

    Element = Tabloid


class TabloidModule(CombinatorialFreeModule):
    @staticmethod
    def __classcall_private__(cls, SGA, partition):
        if SGA.group().rank() != sum(partition) - 1:
            rk = SGA.group().rank() + 1
            raise ValueError(
                f"the domain size (={rk}) does not match the number of boxes (={n}) of the diagram")

        return super().__classcall__(cls, SGA, partition)

    def __init__(self, SGA, partition):
        self.partition = partition
        self.SGA = SGA
        indices = Tabloids(partition)
        return super().__init__(SGA.base_ring(), indices, category=SGA.category())

    class Element(CombinatorialFreeModule.Element):
        def _acted_upon_(self, x, self_on_left):
            ret = super()._acted_upon_(x, self_on_left)

            if ret is not None:
                return ret

            if self_on_left:
                return None

            p = self.parent()

            if x in p.SGA:
                return p.sum(c * (perm * self) for perm, c in x.monomial_coefficients().items())

            if x in p.SGA.indices():
                return p.element_class(p, {tabloid.symmetric_group_action(x): c for tabloid, c in self._monomial_coefficients.items()})


class SpechtModuleOverTableaux(SubmoduleWithBasis):
    def __init__(self, SGA, partition):
        self.partition = partition
        self.SGA = SGA

        tabloidModule = TabloidModule(SGA, partition)
        polytabloid_basis = tabloidModule.basis()
        tabloids = Tabloids(partition)
        support_order = list(polytabloid_basis.keys())

        basis = Family({T: e(T) * tabloidModule(tabloids.from_tableaux(T))
                       for T in partition.standard_tableaux()})
        SubmoduleWithBasis.__init__(self, basis, support_order, ambient=tabloidModule,
                                    unitriangular=False, category=SGA.category())

        COB = matrix([b.lift().to_vector() for b in self.basis()]).T
        P, L, U = COB.LU()

        # This is a slight abuse as the codomain should be a module with a different S_n action, but
        #    we only use it internally, so there isn't any problems
        self._PLinv = tabloidModule.module_morphism(matrix=(P*L).inverse(), codomain=tabloidModule)

        indices = tuple(self._indices)

        # Since U is upper triangular we can construct a one way inverse to it
        Utinv = U.matrix_from_rows(list(range(U.ncols()))).inverse()
        Upinv = Utinv.augment(U.matrix_from_rows(range(U.ncols(), U.nrows())).T)

        self._section = tabloidModule.module_morphism(
            matrix=Upinv, codomain=self, unitriangular='upper')

    def retract(self, v):
        v = self._PLinv(v)
        return self._section(v)

    def representation_matrix(self, elt):
        return matrix(self.base_ring(), [self.retract(self.SGA(elt) * b.lift()).to_vector() for b in self.basis()])

    class Element(SubmoduleWithBasis.Element):
        def _acted_upon_(self, x, self_on_left=False):
            ret = super()._acted_upon_(x, self_on_left)
            if ret is not None:
                return ret

            p = self.parent()

            if x in p.SGA or x in p.SGA.group():
                if self_on_left:
                    return None

                return p.retract(p.SGA._ambient(x) * self.lift())

            return None

In [None]:
sm = SpechtModuleOverTableaux(SymmetricGroupAlgebra(QQ, 3), Partition([2, 1]))
sm.representation_matrix(Permutation([3, 2, 1]))