In [1]:
from sympy import ImmutableDenseMatrix, Rational, symbols, srepr, pretty
import sympy as sy

In [None]:
import sys
sys.path.append('../../src')

from pyhilbert.spatials import AffineSpace, Lattice, cartes, Momentum, PointGroupBasis, AbelianGroup
from pyhilbert.hilbert import brillouin_zone, MomentumSpace, StateSpace
from pyhilbert.utils import FrozenDict
# from pyhilbert.hilbert import Mode

from itertools import product
from typing import Tuple
from collections import OrderedDict
from functools import reduce

In [None]:
from pyhilbert.spatials import AbelianGroup, PointGroupBasis

In [4]:
basis = ImmutableDenseMatrix([
    [sy.sqrt(3)/2, -sy.Rational(1,2)], 
    [0, 1]
])
triangular_lattice = Lattice(basis=basis, shape=(4,4))
basis = ImmutableDenseMatrix([
    [sy.sqrt(3)/2, -sy.Rational(1,2)], 
    [0, 1]
])
another_lattice = Lattice(basis=basis, shape=(5,5))

In [5]:
triangular_lattice.dual

ReciprocalLattice(basis=Matrix([
[4*sqrt(3)*pi/3,    0],
[2*sqrt(3)*pi/3, 2*pi]]), shape=(4, 4))

In [6]:
basis.eigenvects()[0][2][0]

Matrix([
[-2 - sqrt(3)],
[           1]])

In [7]:
from itertools import product
from typing import Tuple
from collections import OrderedDict
from functools import reduce


def euclidean_full_indices(axes: str, order: int):
    axes_symbols = sy.symbols(axes)
    return tuple(product(*((axes_symbols,) * order)))


def euclidean_commute_indices(indices: Tuple[Tuple[sy.Symbol, ...], ...]):
    commute_table = OrderedDict()
    for idx in indices:
        commute_table[idx] = set(idx)
    return commute_table


def point_group_contract_select(indices: Tuple[Tuple[sy.Symbol, ...], ...]):
    commute_index_table = OrderedDict()
    contract_indices = []
    select_indices = []
    order_indices = set()
    order_idx = 0
    for n, idx in enumerate(indices):
        key = frozenset(idx)
        m = commute_index_table.setdefault(key, order_idx)

        contract_indices.append((n, m))
        if m not in order_indices:
            select_indices.append((n, m))
            order_indices.add(m)
            order_idx += 1

    return contract_indices, select_indices


def full_point_group_rep(irrep: sy.ImmutableDenseMatrix, order: int):
    return reduce(sy.kronecker_product, (irrep,) * order)


def point_group_rep(irrep: sy.ImmutableDenseMatrix, axes: str, order: int):
    indices = euclidean_full_indices(axes, order)
    commute_indices = euclidean_commute_indices(indices)
    contract_indices, select_indices = point_group_contract_select(commute_indices)

    full_rep = full_point_group_rep(irrep, order)

    contract_matrix = sy.zeros(len(indices), len(select_indices))
    for (i, j) in contract_indices:
        contract_matrix[i, j] = 1

    select_matrix = sy.zeros(len(indices), len(select_indices))
    for (i, j) in select_indices:
        select_matrix[i, j] = 1

    return select_matrix.T @ full_rep @ contract_matrix

In [None]:
from functools import lru_cache
from dataclasses import dataclass
from multipledispatch import dispatch

from pyhilbert.abstracts import Operable
from pyhilbert.spatials import Spatial


@dataclass(frozen=True)
class PointGroupBasis(Spatial):
    expr: sy.Expr
    axes: Tuple[sy.Symbol, ...]
    order: int
    rep: sy.ImmutableDenseMatrix
    
    @property
    def dim(self):
        return len(self.axes)
    
    def __str__(self):
        return f'PointGroupBasis({str(self.expr)})'
    
    def __repr__(self):
        return f'PointGroupBasis({repr(self.expr)})'


@dataclass(frozen=True)
class PointGroupOrder:
    irrep: sy.ImmutableDenseMatrix
    axes: Tuple[sy.Symbol, ...]
    basis_function_order: int

    @lru_cache
    def __full_indices(self):
        return tuple(product(*((self.axes,) * self.basis_function_order)))

    @lru_cache
    def __commute_indices(self):
        indices = self.__full_indices()
        _, select_rules = PointGroupOrder.__get_contract_select_rules(indices)
        sorted_rules = sorted(select_rules, key=lambda x: x[1])
        return tuple(indices[n] for n, _ in sorted_rules)
    
    @property
    @lru_cache
    def euclidean_basis(self) -> sy.ImmutableDenseMatrix:
        indices = self.__commute_indices()
        return sy.ImmutableDenseMatrix([sy.prod(idx) for idx in indices]).T

    @staticmethod
    @lru_cache
    def __get_contract_select_rules(indices: Tuple[Tuple[sy.Symbol, ...], ...]):
        commute_index_table = OrderedDict()
        contract_indices = []
        select_indices = []
        order_indices = set()
        order_idx = 0
        for n, idx in enumerate(indices):
            key = tuple(sorted(idx, key=lambda s: s.name))
            m = commute_index_table.setdefault(key, order_idx)

            contract_indices.append((n, m))
            if m not in order_indices:
                select_indices.append((n, m))
                order_indices.add(m)
                order_idx += 1

        return contract_indices, select_indices

    @property
    @lru_cache
    def full_rep(self):
        return reduce(sy.kronecker_product, (self.irrep,) * self.basis_function_order)
    
    @property
    @lru_cache
    def rep(self):
        indices = self.__full_indices()
        contract_indices, select_indices = self.__get_contract_select_rules(indices)

        contract_matrix = sy.zeros(len(indices), len(select_indices))
        for (i, j) in contract_indices:
            contract_matrix[i, j] = 1

        select_matrix = sy.zeros(len(indices), len(select_indices))
        for (i, j) in select_indices:
            select_matrix[i, j] = 1

        return select_matrix.T @ self.full_rep @ contract_matrix
    
    @property
    @lru_cache
    def basis(self) -> FrozenDict:
        transform = self.rep
        eig = transform.eigenvects()
        
        tbl = {}
        for v, _, vec_group in eig:
            vec = vec_group[0]
            # principle term is the first non-zero term
            principle_term = next(x for x in vec if x != 0)

            rep = vec / principle_term
            expr = sy.simplify(rep.dot(self.euclidean_basis))
            tbl[v] = PointGroupBasis(expr=expr, axes=self.axes, order=self.basis_function_order, rep=rep)

        return FrozenDict(tbl)
    

@dataclass(frozen=True)
class PointGroup(Operable):
    irrep: sy.ImmutableDenseMatrix
    axes: Tuple[sy.Symbol, ...]
    order: int

    @lru_cache
    def group_order(self, order: int):
        return PointGroupOrder(self.irrep, self.axes, order)

    @property
    @lru_cache
    def basis(self):
        tbl = {}
        for o in range(1, self.order):
            group_order = self.group_order(o)
            for k, v in group_order.basis.items():
                tbl.setdefault(k, v)

            if len(tbl) == self.order:
                break
            
        return FrozenDict(tbl)
    

@dispatch(PointGroup, PointGroupBasis)
def operator_mul(g: PointGroup, basis: PointGroupBasis) -> Tuple[sy.Expr, PointGroupBasis]:
    if set(g.axes) != set(basis.axes):
        raise ValueError(f"Axes of PointGroup and PointGroupBasis must match: {g.axes} != {basis.axes}")
    
    g_irrep = g.group_order(basis.order).rep
    basis_rep = basis.rep
    transformed_rep = g_irrep @ basis_rep

    phases = set()
    for n in range(transformed_rep.rows):
        if basis_rep[n] != 0:
            phases.add(sy.simplify(transformed_rep[n] / basis_rep[n]))
        else:
            if sy.simplify(transformed_rep[n]) != 0:
                raise ValueError(f'{basis} is not a basis function!')
        
    if not phases:
        raise ValueError(f'{basis} is a trivial basis function: zero')
    
    if len(phases) > 1:
        raise ValueError(f'{basis} is not a basis function!')
    
    return phases.pop(), basis

In [9]:
c3_irrep = ImmutableDenseMatrix([
    [sy.cos(2*sy.pi/3), -sy.sin(2*sy.pi/3)], 
    [sy.sin(2*sy.pi/3), sy.cos(2*sy.pi/3)]
])
c3 = AbelianGroup(c3_irrep, sy.symbols('x y'), 3)

In [10]:
c3*c3.basis[-Rational(1, 2) + sy.sqrt(3)*sy.I/2]

(-1/2 + sqrt(3)*I/2, PointGroupBasis(x - I*y))

In [11]:
c3.basis

((1, PointGroupBasis(x**2 + y**2)), (-1/2 + sqrt(3)*I/2, PointGroupBasis(x - I*y)), (-1/2 - sqrt(3)*I/2, PointGroupBasis(x + I*y)))

In [17]:
c6_irrep = ImmutableDenseMatrix([
    [sy.cos(sy.pi/3), -sy.sin(sy.pi/3)], 
    [sy.sin(sy.pi/3), sy.cos(sy.pi/3)]
])
c6 = PointGroup(c6_irrep, sy.symbols('x y'), 6)

In [18]:
mirror_irrep = ImmutableDenseMatrix([
    [1, 0], 
    [0, -1]
])
mirror = PointGroup(mirror_irrep, sy.symbols('x y'), 2)

In [19]:
PointGroupOrder(mirror_irrep, sy.symbols('x y'), 1).basis

Key              Value
 -1 PointGroupBasis(y)
  1 PointGroupBasis(x)

In [20]:
c3 * c3.basis[-Rational(1, 2) - sy.sqrt(3)*sy.I/2]

(-1/2 - sqrt(3)*I/2, PointGroupBasis(x + I*y))

In [21]:
c6.basis

               Key                                Value
-1/2 - sqrt(3)*I/2 PointGroupBasis(x**2 + I*x*y - y**2)
-1/2 + sqrt(3)*I/2 PointGroupBasis(x**2 - I*x*y - y**2)
 1/2 + sqrt(3)*I/2             PointGroupBasis(x - I*y)
 1/2 - sqrt(3)*I/2             PointGroupBasis(x + I*y)
                -1     PointGroupBasis(x*(x**2 - y**2))
                 1         PointGroupBasis(x**2 + y**2)

In [68]:
c3.basis.keys()

(1, -1/2 + sqrt(3)*I/2, -1/2 - sqrt(3)*I/2)

In [15]:
-Rational(1, 2) - sy.sqrt(3)*sy.I/2

-1/2 - sqrt(3)*I/2

In [106]:
PointGroup.commute_indices(euclidean_full_indices('x y', 2))

((x, x), (x, y), (y, y))

In [107]:
x, y = sy.symbols('x y')
sy.prod((x, y, y))

x*y**2

In [4]:
print(triangular_lattice.affine)

AffineSpace(basis=[['sqrt(3)/2', '-1/2'], ['0', '1']])


In [5]:
print(f'{srepr(basis)}')

ImmutableDenseMatrix([[Mul(Rational(1, 2), Pow(Integer(3), Rational(1, 2))), Rational(-1, 2)], [Integer(0), Integer(1)]])


In [6]:
a = brillouin_zone(dual(triangular_lattice))
b = brillouin_zone(dual(another_lattice))

In [9]:
StateSpace.flat_permutation_order(a, a)

(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)

In [4]:
brillouin_zone(dual(triangular_lattice))

MomentumSpace(16):
	0: Offset(['0', '0'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']])
	1: Offset(['0', '1/4'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']])
	2: Offset(['0', '1/2'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']])
	3: Offset(['0', '3/4'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']])
	4: Offset(['1/4', '0'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']])
	5: Offset(['1/4', '1/4'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']])
	6: Offset(['1/4', '1/2'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']])
	7: Offset(['1/4', '3/4'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']])
	8: Offset(['1/2', '0'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']])
	9: Offset(['1/2', '1/4'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']])
	10: Offset(['1/2', '1/2'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']])
	11: Offset(['1/2', '3/4'] ∈ [['4*sqrt(3)*pi/3', '0'], ['2*sqrt(3)*pi/3', '2*pi']

In [None]:
[1,2,3][(2,3,1)]

  [1,2,3][*(2,3,1)]


TypeError: list indices must be integers or slices, not tuple

In [None]:
def explain_hash(cls):
    print("MRO:", " -> ".join(c.__name__ for c in cls.__mro__))
    for c in cls.__mro__:
        h = c.__dict__.get("__sub__", "<inherited>")
        print(f"{c.__name__}: __sub__ =", h)

explain_hash(MomentumSpace)

In [None]:
hash(slice(0, 1)), hash(slice(0, 1))

In [None]:
m = Mode(count=1, attr=FrozenDict({'a': 1, 'b': 2, 'c': 3}))
FrozenDict({**m.attr})

In [None]:
triangular = ImmutableDenseMatrix([
    [sy.sqrt(3)/2, -sy.Rational(1,2)], 
    [0, 1]
])
triangular = Lattice(triangular, )