In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import numpy as np
import pandas as pd
import itertools as it
import networkx as nx
import matplotlib.pyplot as plt


from finite_groups import *

from activation_funcs import *

In [None]:
# ==============================================================================
# SECTION 1: DEFINE GROUP
# ==============================================================================
def define_group():
    """
    Define the elements of your group G and the multiplication rule.
    
    Returns:
        elements (list): A list of all group elements (hashable: tuples, strings, ints).
        mult_func (callable): A function f(a, b) -> a * b.
    """
    # --------------------------------------------------------------------------
    # EXAMPLE: S3 (Permutations of {0, 1, 2})
    # REPLACE this block with your own group logic (e.g., read from CSV)
    # --------------------------------------------------------------------------
    
    # Elements represented as tuples, e.g., (1, 0, 2) means 0->1, 1->0, 2->2
    elements = [
        (0, 1, 2), # Identity (e)
        (1, 0, 2), # (0 1)
        (2, 1, 0), # (0 2)
        (0, 2, 1), # (1 2)
        (1, 2, 0), # (0 1 2)
        (2, 0, 1), # (0 2 1)
    ]

    def mult_func(p, q):
        # Composition rule: (p * q)[i] = p[q[i]]
        # This matches the convention used in the A5 example
        return tuple(p[x] for x in q)

    return elements, mult_func

# ==============================================================================
# SECTION 2: DEFINE SUBGROUP
# ==============================================================================
def define_subgroup(group_elements):
    """
    Define the subgroup H. 
    You must return a list of elements that are present in the 'group_elements' list.
    """
    # --------------------------------------------------------------------------
    # EXAMPLE: Subgroup H = { e, (0 1) } 
    # This is the stabilizer of point 2, isomorphic to S2.
    # --------------------------------------------------------------------------
    
    # You can list them manually:
    H = [
        (0, 1, 2),
        (1, 0, 2)
    ]
    
    # OR filter programmatically:
    # H = [g for g in group_elements if g[2] == 2] 
    
    return H

# ==============================================================================
# SECTION 3: CHARACTER TABLE
# ==============================================================================
def define_character_table(discovered_classes):
    """
    Map the discovered conjugacy classes to their character vectors.
    
    The solver will automatically discover classes and assign them a 'Representative'.
    You must map these representatives to the rows of your character table.
    
    Args:
        discovered_classes (list): List of ConjugacyClass objects found by the solver.
                                   Each object has .representative (the element).
    
    Returns:
        char_map (dict): { representative_element : [chi_1, chi_2, ...] }
        labels (list): Names of the irreps ["A", "B", "E", ...]
    """
    # --------------------------------------------------------------------------
    # EXAMPLE: Character Table for S3
    # Classes: 
    #   1. Identity e      (Size 1)
    #   2. Transpositions  (Size 3) -> (0 1)
    #   3. 3-Cycles        (Size 2) -> (0 1 2)
    #
    # Irreps:
    #   1 (Trivial):  [1,  1,  1]
    #   1- (Sign):    [1, -1,  1]
    #   2 (Standard): [2,  0, -1]
    # --------------------------------------------------------------------------
    
    char_map = {}
    
    # We iterate through the classes the solver found to ensure we cover them all
    for cls in discovered_classes:
        rep = cls.representative
        
        # LOGIC TO MATCH REP TO YOUR TABLE ROWS
        # You might need to check the order of the element, trace, or cycle structure.
        
        # Check if Identity
        if rep == (0, 1, 2):
            # Class 1A
            char_map[rep] = [1, 1, 2]
            
        # Check if Transposition (Order 2)
        # In S3, if it's not identity and applying it twice gives identity, it's a transposition
        elif rep != (0, 1, 2) and rep[rep[0]] == 0 and rep[rep[1]] == 1: 
             # Wait, generic tuple check for order 2: p(p(x)) = x
             # Let's just use the mult_func logic implicitly or manual checks
             is_order_2 = all(rep[rep[i]] == i for i in range(3))
             if is_order_2:
                 # Class 2A
                 char_map[rep] = [1, -1, 0]
             else:
                 # Must be Class 3A (Order 3)
                 char_map[rep] = [1, 1, -1]

    labels = ["Trivial", "Sign", "Standard"]
    
    return char_map, labels

# ==============================================================================
# MAIN EXECUTION (No need to edit below usually)
# ==============================================================================
def run_analysis():
    print(">>> 1. Initializing Group...")
    elements, mult_func = define_group()
    G = FiniteGroup(elements, mult_func)
    print(f"    Group Order: {G.n}")
    
    print("\n>>> 2. Conjugacy Classes Discovered:")
    for c in G.classes:
        print(f"    Class {c.index}: Rep {c.representative} (Size {c.size})")
        
    print("\n>>> 3. Setting Subgroup...")
    H_elements = define_subgroup(elements)
    solver = InducedRepSolver(G)
    solver.set_subgroup(H_elements)
    
    print("\n>>> 4. Loading Character Table...")
    # Passing the discovered classes to the user function so they can map them
    char_map, labels = define_character_table(G.classes)
    solver.load_character_table(char_map, irrep_labels=labels)
    
    print("\n>>> 5. Computing Projectors...")
    solver.compute_projectors()
    
    print("\n>>> 6. Induced Representation Decomposition:")
    for label, Q in solver.Qblocks.items():
        if Q.size > 0:
            print(f"    Irrep '{label}' appears with dimension {Q.shape[1]}")
        else:
            print(f"    Irrep '{label}' does not appear.")

    print("\n>>> 7. Building Interaction Graph...")
    # Define activation function (e.g., ReLU)
    def relu(x): return np.maximum(0, x)
    
    graph = solver.build_interaction_graph(activation_fn=relu)
    
    print(f"    Edges found: {graph.edges()}")
    
    if len(graph.nodes) > 0:
        plt.figure(figsize=(6, 6))
        pos = nx.spring_layout(graph, seed=42)
        nx.draw_networkx(graph, pos, node_color="#E8F0FF", edgecolors="blue", 
                         node_size=1000, font_weight="bold", with_labels=True)
        plt.title(f"Interaction Graph (G order {G.n})")
        plt.show()
    else:
        print("    Graph is empty.")

if __name__ == "__main__":
    run_analysis()