In [1]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import numpy as np
import pandas as pd
import itertools as it
import networkx as nx
import sympy as sp
import matplotlib.pyplot as plt

from finite_groups import *

from activation_funcs import *

In [None]:
# ==============================================================================
# SECTION 1: DEFINE GROUP (G)
# ==============================================================================
def define_group():
    """
    Define the elements of your group G and the multiplication rule.
    """
    # EXAMPLE: S3 (Permutations of {0, 1, 2})
    # Elements (Tuples)
    elements = [
        (0, 1, 2), # e
        (1, 0, 2), # (0 1)
        (2, 1, 0), # (0 2)
        (0, 2, 1), # (1 2)
        (1, 2, 0), # (0 1 2)
        (2, 0, 1), # (0 2 1)
    ]

    def mult_func(p, q):
        return tuple(p[x] for x in q)

    return elements, mult_func

# ==============================================================================
# SECTION 2: DEFINE SUBGROUP (H)
# ==============================================================================
def define_subgroup(group_elements):
    """
    Define the subgroup H (must be a subset of group_elements).
    """
    # EXAMPLE: H = { e, (0 1) }
    H = [
        (0, 1, 2),
        (1, 0, 2)
    ]
    return H

# ==============================================================================
# SECTION 3: DEFINE CONJUGACY CLASSES (MANUAL)
# ==============================================================================
def define_conjugacy_classes():
    """
    Manually define the conjugacy classes of G.
    Returns a list of tuples: (Representative, [List of all Members])
    """
    
    # 1. Identity
    c1 = ((0, 1, 2), [(0, 1, 2)])
    
    # 2. Transpositions
    c2 = ((1, 0, 2), [(1, 0, 2), (2, 1, 0), (0, 2, 1)])
    
    # 3. 3-Cycles
    c3 = ((1, 2, 0), [(1, 2, 0), (2, 0, 1)])
    
    return [c1, c2, c3]

# ==============================================================================
# SECTION 4: CHARACTER TABLE (SYMBOLIC)
# ==============================================================================
def define_character_table():
    """
    Define the Character Table using SymPy expressions if needed.
    """
    
    rep_1 = (0, 1, 2)
    rep_2 = (1, 0, 2)
    rep_3 = (1, 2, 0)
    
    # You can use sp.sqrt(2), sp.I, etc.
    char_map = {
        #             [Triv, Sign, Std]
        rep_1: [1,    1,    2],
        rep_2: [1,   -1,    0],
        rep_3: [1,    1,   -1]
    }
    
    labels = ["Trivial", "Sign", "Standard"]
    return char_map, labels

# ==============================================================================
# MAIN EXECUTION
# ==============================================================================
def run_analysis():
    print(">>> 1. Loading Group Data...")
    elements, mult_func = define_group()
    classes_data = define_conjugacy_classes()
    
    # Initialize Group with MANUAL classes
    G = FiniteGroup(elements, mult_func, classes=classes_data)
    print(f"    Group Order: {G.n}")
    
    print("\n>>> 2. Loading Subgroup Data...")
    H_elements = define_subgroup(elements)
    solver = InducedRepSolver(G)
    solver.set_subgroup(H_elements)
    
    print("\n>>> 3. Loading Character Table...")
    char_map, labels = define_character_table()
    solver.load_character_table(char_map, irrep_labels=labels)
    
    print("\n>>> 4. Computing Exact Projectors (SymPy)...")
    solver.compute_projectors()
    
    print("\n>>> 5. Induced Representation Decomposition:")
    for label, Q in solver.Qblocks.items():
        if Q.shape[1] > 0:
            print(f"    Irrep '{label}' appears with dimension {Q.shape[1]}")
    
    print("\n>>> 6. Building Interaction Graph...")
    
    # Define activation function using SymPy (e.g. Max(0, x) for ReLU)
    def relu_sym(x): 
        return sp.Max(0, x)
    
    graph = solver.build_interaction_graph(activation_fn=relu_sym, verbose=True)
    print(f"    Edges found: {graph.edges()}")
    
    if len(graph.nodes) > 0:
        plt.figure(figsize=(6, 6))
        pos = nx.spring_layout(graph, seed=42)
        nx.draw_networkx(graph, pos, node_color="#E8F0FF", edgecolors="blue", 
                         node_size=1000, font_weight="bold", with_labels=True)
        plt.title(f"Interaction Graph (SymPy Exact)")
        plt.show()

In [None]:
run_analysis()