# imports


In [14]:
from __future__ import annotations

import os
import sys
from typing import Any, Callable, List, Literal, Optional, Dict

import torch
import lightning as L
import loguru
import numpy as np
from torch import Tensor, nn

sys.path.insert(0, "../")
from src.utils.logger import create_logger

# data

prepare the dataset first


In [15]:
from loguru import logger

In [20]:
class Group:
    """
    Generic finite group defined by an explicit carrier set and a binary operation.

    This class validates the group axioms exhaustively:
        - Closure
        - Associativity
        - Identity
        - Inverses

    It also detects whether the group is Abelian (commutative).

    Notes
    -----
    - This implementation is intended for *finite* groups only.
    - Associativity checking is O(n^3) and should only be used for small sets.
    - Equality is handled for scalars, NumPy arrays, and Torch tensors.
    """

    def __init__(
        self,
        elements: List[int | float] | np.ndarray | Tensor,
        op: Callable,
        logger: Optional[loguru.Logger] = None,
        *args,
        **kwargs,
    ) -> None:
        """
        Parameters
        ----------
        elements : list | np.ndarray | torch.Tensor
            Explicit enumeration of group elements.
        op : Callable
            Binary operation defining the group law.
        logger : loguru.Logger, optional
            Logger instance for diagnostics.
        """
        self.elements = list(elements)
        self._op = op

        if logger is None:
            logger = create_logger()
        self.logger = logger

        self.identity: Any = None
        self._right_inverses: Dict[Any, Any] = {}
        self._left_inverses: Dict[Any, Any] = {}
        self.is_abelian: bool = False

        self._validate_group()

    # ------------------------------------------------------------------
    # internal utilities
    # ------------------------------------------------------------------

    def _equal(self, a, b) -> bool:
        """
        Robust equality comparison across supported element types.
        """
        if isinstance(a, Tensor) and isinstance(b, Tensor):
            return torch.equal(a, b)
        if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
            return np.array_equal(a, b)
        return a == b

    def _in_elements(self, x) -> bool:
        """
        Check whether an element belongs to the carrier set.
        """
        return any(self._equal(x, e) for e in self.elements)

    # ------------------------------------------------------------------
    # validation logic
    # ------------------------------------------------------------------

    def _validate_group(self) -> None:
        """
        Validate all group axioms and compute derived properties.
        """
        self.logger.info("Validating group axioms")
        try:
            self._check_closure()
            self._check_associativity()
            self.identity = self._find_op_identity()
            self._find_op_inverses()
        except Exception as e:
            logger.error(
                f"Group valildation failed due to the following reason: {e}",
                exc_info=True,
            )
            raise e
        self.is_abelian = self._check_abelian()

        self.logger.success("Group validation successful")

    def _check_closure(self) -> None:
        """
        Verify closure of the operation over the carrier set.
        """
        for a in self.elements:
            for b in self.elements:
                result = self.op(a, b)
                if not self._in_elements(result):
                    raise ValueError(
                        f"Closure violated: {a} * {b} = {result} not in elements"
                    )

    def _check_associativity(self) -> None:
        """
        Verify associativity of the binary operation.
        """
        for a in self.elements:
            for b in self.elements:
                for c in self.elements:
                    left = self.op(self.op(a, b), c)
                    right = self.op(a, self.op(b, c))
                    if not self._equal(left, right):
                        raise ValueError(
                            f"Associativity violated: ({a} * {b}) * {c} != {a} * ({b} * {c})"
                        )

    def _find_op_identity(self):
        """
        Determine the identity element of the group.

        Returns
        -------
        Any
            The identity element.

        Raises
        ------
        ValueError
            If no identity element exists.
        """
        for e in self.elements:
            if all(
                self._equal(self.op(e, x), x) and self._equal(self.op(x, e), x)
                for x in self.elements
            ):
                self.logger.info(f"Identity found: {e}")
                return e
        raise ValueError("No identity element found")

    def _find_op_inverses(self) -> None:
        """
        Compute left and right inverses for all elements.
        """
        if self.identity is None:
            raise RuntimeError("Identity must be computed before inverses")

        for a in self.elements:
            left_inv = None
            right_inv = None
            for b in self.elements:
                if self._equal(self.op(b, a), self.identity):
                    left_inv = b
                if self._equal(self.op(a, b), self.identity):
                    right_inv = b

            if left_inv is None or right_inv is None:
                raise ValueError(f"No inverse found for element {a}")

            self._left_inverses[a] = left_inv
            self._right_inverses[a] = right_inv

        self.logger.info("All inverses found")

    def _check_abelian(self) -> bool:
        """
        Check whether the group is Abelian (commutative).

        Returns
        -------
        bool
            True if the group is Abelian, False otherwise.
        """
        for a in self.elements:
            for b in self.elements:
                if not self._equal(self.op(a, b), self.op(b, a)):
                    return False
        return True

    # ------------------------------------------------------------------
    # public API
    # ------------------------------------------------------------------

    def op(
        self,
        a: int | float | Tensor | np.ndarray,
        b: int | float | Tensor | np.ndarray,
    ):
        """
        Apply the group operation.

        Parameters
        ----------
        a, b : group elements

        Returns
        -------
        group element
        """
        return self._op(a, b)

    def inverse(self, element, *, side: str = "right"):
        """
        Return the inverse of an element.

        Parameters
        ----------
        element : group element
            Element whose inverse is requested.
        side : {"right", "left"}, optional
            Whether to return the right or left inverse.
            Defaults to "right".

        Returns
        -------
        group element
            The requested inverse.

        Raises
        ------
        ValueError
            If the element is not in the group or side is invalid.
        """
        if not self._in_elements(element):
            raise ValueError(f"Element {element} is not in the group")

        if side == "right":
            return self._right_inverses[element]
        elif side == "left":
            return self._left_inverses[element]
        else:
            raise ValueError("side must be either 'right' or 'left'")

In [23]:
n = 5
Z = Group(elements=np.arange(n - 1) + 1, op=lambda x, y: (x * y) % n)

[37m[2m11:16 AM[0m[37m[0m | [37m[2m__main__[0m[37m[0m | [34mINFO[0m | Validating group axioms
[37m[2m11:16 AM[0m[37m[0m | [37m[2m__main__[0m[37m[0m | [34mINFO[0m | Identity found: 1
[37m[2m11:16 AM[0m[37m[0m | [37m[2m__main__[0m[37m[0m | [34mINFO[0m | All inverses found
[37m[2m11:16 AM[0m[37m[0m | [37m[2m__main__[0m[37m[0m | [32m[1mSUCCESS[0m | Group validation successful


In [28]:
Z.inverse(4)

np.int64(4)

In [29]:
class PolyF2:
    """
    Polynomial over F₂ modulo x^2 + 1.
    Represented as a0 + a1 x.
    """

    MOD = 2

    def __init__(self, a0: int, a1: int):
        self.coeffs = (a0 % 2, a1 % 2)

    def __eq__(self, other):
        return isinstance(other, PolyF2) and self.coeffs == other.coeffs

    def __hash__(self):
        return hash(self.coeffs)

    def __add__(self, other):
        # addition in F₂[x] / (x^2 + 1)
        a0 = (self.coeffs[0] + other.coeffs[0]) % 2
        a1 = (self.coeffs[1] + other.coeffs[1]) % 2
        return PolyF2(a0, a1)

    def __repr__(self):
        a0, a1 = self.coeffs
        if a1:
            return f"{a0} + x"
        return str(a0)

In [30]:
elements = [
    PolyF2(0, 0),
    PolyF2(1, 0),
    PolyF2(0, 1),
    PolyF2(1, 1),
]

In [35]:
G = Group(
    elements=elements,
    op=lambda a, b: a + b,
)

[37m[2m11:18 AM[0m[37m[0m | [37m[2m__main__[0m[37m[0m | [34mINFO[0m | Validating group axioms
[37m[2m11:18 AM[0m[37m[0m | [37m[2m__main__[0m[37m[0m | [34mINFO[0m | Identity found: 0
[37m[2m11:18 AM[0m[37m[0m | [37m[2m__main__[0m[37m[0m | [34mINFO[0m | All inverses found
[37m[2m11:18 AM[0m[37m[0m | [37m[2m__main__[0m[37m[0m | [32m[1mSUCCESS[0m | Group validation successful
