# huMon Generator

In [9]:
# !pip install dash
from dash import Dash, dcc, html, Input, Output
import networkx as nx
import pandas as pd
import random
import numpy as np
import plotly.graph_objs as go
from matplotlib.colors import to_hex, to_rgb

class StatTreeGenerator:
    def __init__(self, branch_pattern):
        self.branch_pattern = branch_pattern
        self.nodes = {}
        self.G = nx.DiGraph()
        self.base_stat = 50  # Initial base stat for root
        self.max_stat_cap = 250  # Max stat cap across generations
        self.stats_names = ["Health", "PAttack", "PDefense", "MAttack", "MDefense", "Speed"]
        
        # Initialize with base stats for root node
        self.nodes["0"] = {
            "generation": 0,
            "stats": {stat: self.base_stat for stat in self.stats_names},
            "color": 'white'
        }
        
        # Color codes for Generation 1 nodes
        self.colors = {
            "0-0": 'magenta',
            "0-1": 'green',
            "0-2": 'yellow',
            "0-3": 'cyan',
            "0-4": 'black',
            "0-5": 'purple'
        }
        self.max_stats_per_generation = {}

    def generate_tree(self):
        node_id = "0"
        # Generation 0: initialize root node with base stats totaling 100
        base_stat = 100 // 6
        self.nodes[node_id] = {
            "generation": 0,
            "stats": {stat: base_stat for stat in ["Health", "PAttack", "PDefense", "MAttack", "MDefense", "Speed"]},
            "color": 'white'
        }
        self.G.add_node(node_id)
        self.create_branches(node_id, 1)
        
        # Calculate max stats for each generation
        self.calculate_max_stats_per_generation()

        return pd.DataFrame([{"node_id": node, **self.nodes[node]} for node in self.nodes])

    def calculate_max_stats_per_generation(self):
        """Calculate the maximum possible stat for each generation based on the doubling point system."""
        fibonacci_weights = [500, 333, 180, 45, 15, 3]
        total_weight = sum(fibonacci_weights)
        
        for generation in range(1, max(self.branch_pattern) + 1):
            min_points = 10 * (2 ** (generation - 1))
            max_points = 64 * (2 ** (generation - 1))
            max_additional_points = max_points
            
            # Calculate the maximum possible distribution for this generation
            stat_increase = [int(max_additional_points * weight / total_weight) for weight in fibonacci_weights]
            max_stat_increase = sum(stat_increase)
            self.max_stats_per_generation[generation] = max_stat_increase + 100  # Starting base stats + max increment

    def create_branches(self, parent_id, generation):
        if generation > len(self.branch_pattern):
            return
        num_children = self.branch_pattern[generation - 1]
        for i in range(num_children):
            child_id = f"{parent_id}-{i}"
            
            if generation == 1:
                # Generation 1 nodes with Fibonacci-based stat distribution
                color = self.colors.get(child_id, 'gray')
                stats = self.generate_ranked_stats(generation)
            else:
                # Subsequent generations with adjusted stats
                color = self.adjust_color(self.nodes[parent_id]["color"], generation)
                stats = self.specialize_ranked_stats(parent_id, generation)

            self.nodes[child_id] = {"generation": generation, "stats": stats, "color": color}
            self.G.add_edge(parent_id, child_id)
            self.create_branches(child_id, generation + 1)

    def generate_ranked_stats(self, generation):
        # Define base points to be distributed and generation-based multiplier
        base_points = 100 + 50 * generation
        stat_profile = random.choice(['balanced', 'extreme_one', 'extreme_two'])  # Select a stat distribution profile

        # Apply different distribution strategies
        if stat_profile == 'balanced':
            # Spread points evenly with minor variation
            base_allocation = base_points // len(self.stats_names)
            stats = {
                stat: min(max(self.base_stat, base_allocation + int(np.random.normal(0, base_allocation * 0.2))), self.max_stat_cap)
                for stat in self.stats_names
            }
        
        elif stat_profile == 'extreme_one':
            # Choose one stat to receive a large portion, others receive less
            dominant_stat = random.choice(self.stats_names)
            stats = {stat: self.base_stat for stat in self.stats_names}
            stats[dominant_stat] = min(base_points, self.max_stat_cap)
            
            # Distribute remaining points among other stats
            remaining_points = base_points - stats[dominant_stat]
            for stat in self.stats_names:
                if stat != dominant_stat:
                    stats[stat] += min(max(10, int(remaining_points / (len(self.stats_names) - 1))), self.max_stat_cap)

        elif stat_profile == 'extreme_two':
            # Choose two stats to receive high values, others get less
            dominant_stats = random.sample(self.stats_names, 2)
            stats = {stat: self.base_stat for stat in self.stats_names}
            for stat in dominant_stats:
                stats[stat] = min(base_points // 2, self.max_stat_cap)
            
            # Distribute remaining points among other stats
            remaining_points = base_points - sum(stats[stat] for stat in dominant_stats)
            for stat in self.stats_names:
                if stat not in dominant_stats:
                    stats[stat] += min(max(10, int(remaining_points / (len(self.stats_names) - 2))), self.max_stat_cap)

        return stats

    def specialize_ranked_stats(self, parent_id, generation):
        parent_stats = self.nodes[parent_id]["stats"]
        base_points = 80 + 40 * generation
        stat_profile = random.choice(['balanced', 'extreme_one', 'extreme_two', 'mixed'])

        # Apply similar strategy as in `generate_ranked_stats`
        if stat_profile == 'balanced':
            stats = {
                stat: min(max(10, parent_stats[stat] + int(np.random.normal(0, base_points * 0.1))), self.max_stat_cap)
                for stat in self.stats_names
            }
        
        elif stat_profile == 'extreme_one':
            dominant_stat = max(parent_stats, key=parent_stats.get)
            stats = {stat: min(max(10, parent_stats[stat] + int(np.random.normal(0, base_points * 0.1))), self.max_stat_cap)
                     for stat in self.stats_names}
            # Boost the dominant stat significantly
            stats[dominant_stat] = min(parent_stats[dominant_stat] + base_points, self.max_stat_cap)

        elif stat_profile == 'extreme_two':
            dominant_stats = sorted(parent_stats, key=parent_stats.get, reverse=True)[:2]
            stats = {stat: min(max(10, parent_stats[stat] + int(np.random.normal(0, base_points * 0.1))), self.max_stat_cap)
                     for stat in self.stats_names}
            for stat in dominant_stats:
                stats[stat] = min(parent_stats[stat] + base_points // 2, self.max_stat_cap)

        elif stat_profile == 'mixed':
            # Apply a mix where one stat is boosted while the others vary randomly
            dominant_stat = random.choice(self.stats_names)
            stats = {stat: min(max(10, parent_stats[stat] + int(np.random.normal(0, base_points * 0.15))), self.max_stat_cap)
                     for stat in self.stats_names}
            stats[dominant_stat] += base_points // 3  # Slightly boost the dominant stat

        return stats

    def adjust_color(self, color, generation):
        rgb = np.array(to_rgb(color))
        factor = 0.8 + (generation * 0.05)
        return to_hex(np.clip(rgb * factor, 0, 1))

    def hierarchy_pos(self, G, root="0", width=1.0, vert_gap=0.4, vert_loc=0.5, xcenter=0.5):
        """Generate hierarchical positions for each node in the tree."""
        pos = {}

        def _hierarchy_pos(node, left, right, vert_loc, parent=None):
            pos[node] = ((left + right) / 2, vert_loc)
            children = list(G.successors(node))
            if children:
                dx = (right - left) / len(children)
                nextx = left
                for child in children:
                    _hierarchy_pos(child, nextx, nextx + dx, vert_loc - vert_gap, node)
                    nextx += dx

        _hierarchy_pos(root, 0, width, vert_loc)
        return pos

    def plot_tree_structure(self):
        pos = self.hierarchy_pos(self.G)
        edge_x, edge_y = [], []
        for edge in self.G.edges():
            x0, y0 = pos[edge[0]]
            x1, y1 = pos[edge[1]]
            edge_x.extend([x0, x1, None])
            edge_y.extend([y0, y1, None])
        edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines')
        
        node_x, node_y, hover_text, marker_colors, node_ids = [], [], [], [], []
        for node in self.G.nodes():
            x, y = pos[node]
            node_x.append(x)
            node_y.append(y)
            stats = self.nodes[node]["stats"]
            color = self.nodes[node]["color"]
            text = f"Node ID: {node}<br>Generation: {self.nodes[node]['generation']}"
            hover_text.append(text)
            marker_colors.append(color)
            node_ids.append(node)  # Add node ID to node_ids list

        # Assign node IDs to customdata
        node_trace = go.Scatter(
            x=node_x, y=node_y, mode='markers', hoverinfo='text', text=hover_text,
            marker=dict(size=20, color=marker_colors, line=dict(color='black', width=1)),
            customdata=node_ids, name="tree"  # Attach node IDs to customdata
        )
        
        return [edge_trace, node_trace]


# Initialize the tree generator
tree_generator = StatTreeGenerator(branch_pattern=[6, 3, 2, 1])
tree_data = tree_generator.generate_tree()

# Helper function to create dropdown options for multi-node comparison
def generate_flat_dropdown_options():
    options = [{"label": f"Gen {gen} - {node}", "value": node} for gen in sorted(tree_data["generation"].unique())
               for node in tree_data[tree_data["generation"] == gen]["node_id"]]
    return options

app = Dash(__name__)

app.layout = html.Div([
    # Top Section: Ancestral Tree Visualization
    html.Div([
        dcc.Graph(
            id='tree-plot', 
            figure={
                'data': tree_generator.plot_tree_structure(),
                'layout': go.Layout(title='Ancestral Tree', showlegend=False, margin=dict(b=20, l=20, r=20, t=40))
            }
        ),
    ], style={'width': '100%', 'display': 'inline-block'}),

    # Middle Section: Selected Node and Immediate Children Radar Charts
    html.Div([
        # Left column: Single node radar chart and stats
        html.Div([
            dcc.Graph(id='single-node-radar-chart'),
            html.Div(id='single-node-total-stats', style={
                'margin-top': '10px',
                'padding': '10px',
                'font-size': '16px',
                'text-align': 'center',
                'background-color': '#f0f0f0',
                'border': '1px solid #ccc',
                'border-radius': '5px'
            }),
        ], style={'width': '48%', 'display': 'inline-block', 'vertical-align': 'top'}),

        # Right column: Immediate children radar chart
        html.Div([
            html.H4("Immediate Children Stats Comparison"),
            dcc.Graph(id='children-radar-chart'),
        ], style={'width': '48%', 'display': 'inline-block', 'vertical-align': 'top', 'margin-left': '4%'}),
    ], style={'width': '100%', 'display': 'inline-block', 'margin-top': '20px'}),

   # Bottom Section: Multi-Node Comparison
    html.Div([
        html.Label("Select up to 6 nodes for comparison:"),
    
        # Container for dropdowns and radar chart in a row layout
        html.Div([
            # Left side: Dropdown selectors in a column
            html.Div([
                html.Div([
                    dcc.Dropdown(
                        id=f'node-slot-{i}',
                        options=generate_flat_dropdown_options(),
                        placeholder=f"Select Node {i + 1}"
                    )
                ], style={'margin-bottom': '10px'}) for i in range(6)
            ], style={'width': '20%', 'display': 'inline-block', 'vertical-align': 'top'}),

        # Right side: Comparison radar chart
        html.Div([
            dcc.Graph(id='comparison-radar-chart'),
            html.Div(id='total-stats', style={
                'margin-top': '20px',
                'padding': '10px',
                'font-size': '14px',
                'text-align': 'center',
                'background-color': '#f0f0f0',
                'border': '1px solid #ccc',
                'border-radius': '5px'
            })
        ], style={'width': '75%', 'display': 'inline-block', 'vertical-align': 'top', 'margin-left': '5%'})
        
    ], style={'width': '100%', 'display': 'flex', 'align-items': 'flex-start', 'margin-top': '20px'}),
], style={'width': '100%', 'display': 'inline-block', 'margin-top': '20px'}),
])

# Callback for single node radar chart and children radar chart
@app.callback(
    [Output('single-node-radar-chart', 'figure'), Output('single-node-total-stats', 'children'),
     Output('children-radar-chart', 'figure')],
    Input('tree-plot', 'clickData')
)
def update_single_node_radar_chart_and_children(clickData):
    if not clickData:
        return go.Figure(), "Select a node to view stats", go.Figure()
    
    node_id = clickData['points'][0].get('customdata')
    if not node_id or node_id not in tree_generator.nodes:
        return go.Figure(), "Node data not found", go.Figure()
    
    # Define a fixed order for the stats
    fixed_order = ["Health", "PAttack", "PDefense", "MAttack", "MDefense", "Speed"]
    stats = tree_generator.nodes[node_id]["stats"]
    values = [stats[stat] for stat in fixed_order]
    total_stat_value = sum(values)

    # Create single-node radar chart
    radar_chart = go.Figure()
    radar_chart.add_trace(go.Scatterpolar(
        r=values,
        theta=fixed_order,
        fill='toself',
        name=f"Node {node_id}",
        marker=dict(color=tree_generator.nodes[node_id]["color"])
    ))
    radar_chart.update_layout(
        polar=dict(radialaxis=dict(visible=True, range=[0, max(values)])),
        showlegend=False,
        title=f"Stats for Node {node_id}"
    )

    # Display total stats for this node
    total_stats_text = html.Div([
        html.B(f"Total Stats for Node {node_id}"), html.Br(),
        html.Span(", ".join([f"{stat}: {value}" for stat, value in zip(fixed_order, values)])),
        html.Br(),
        html.B(f"Overall Total: {total_stat_value}")
    ])

    # Get immediate children of the selected node
    children = list(tree_generator.G.successors(node_id))
    children_radar_chart = go.Figure()

    # Generate colors for each child (up to 6 unique colors for visibility)
    colors = ["blue", "green", "red", "purple", "orange", "cyan"]

    # Plot each child node on the radar chart
    for idx, child_id in enumerate(children):
        child_stats = tree_generator.nodes[child_id]["stats"]
        child_values = [child_stats[stat] for stat in fixed_order]

        children_radar_chart.add_trace(go.Scatterpolar(
            r=child_values,
            theta=fixed_order,
            fill='toself',
            name=f"Child {child_id}",
            marker=dict(color=colors[idx % len(colors)])
        ))

    # Update children radar chart layout
    children_radar_chart.update_layout(
        polar=dict(radialaxis=dict(visible=True, range=[0, max([max(values) for values in children_radar_chart.data]) if children_radar_chart.data else 100])),
        showlegend=True,
        title="Immediate Children Stats Comparison"
    )

    return radar_chart, total_stats_text, children_radar_chart

# Callback for multi-node comparison radar chart
@app.callback(
    [Output('comparison-radar-chart', 'figure'), Output('total-stats', 'children')],
    [Input(f'node-slot-{i}', 'value') for i in range(6)]
)
def update_comparison_radar_chart(*selected_nodes):
    selected_nodes = [node_id for node_id in selected_nodes if node_id is not None]

    fixed_order = ["Health", "PAttack", "PDefense", "MAttack", "MDefense", "Speed"]
    radar_chart = go.Figure()
    colors = ["blue", "green", "red", "purple", "orange", "cyan"]
    total_stats_content = []

    for idx, node_id in enumerate(selected_nodes):
        if node_id not in tree_generator.nodes:
            continue

        stats = tree_generator.nodes[node_id]["stats"]
        values = [stats[stat] for stat in fixed_order]
        total_stat_value = sum(values)
        
        radar_chart.add_trace(go.Scatterpolar(
            r=values,
            theta=fixed_order,
            fill='toself',
            name=f"Node {node_id}",
            marker=dict(color=colors[idx % len(colors)])
        ))
        
        node_total_text = f"Node {node_id} - Total: {total_stat_value}"
        stat_details = ", ".join([f"{stat}: {value}" for stat, value in zip(fixed_order, values)])
        total_stats_content.append(html.Div([
            html.B(node_total_text), html.Br(),
            html.Span(stat_details)
        ], style={'margin-bottom': '10px'}))

    radar_chart.update_layout(
        polar=dict(radialaxis=dict(visible=True, range=[0, max([max(values) for values in radar_chart.data]) if radar_chart.data else 100])),
        showlegend=True,
        title="Comparison of Selected Nodes"
    )

    total_stats_text = html.Div(total_stats_content, style={
        'margin-top': '10px',
        'font-size': '14px',
        'line-height': '1.6'
    })

    return radar_chart, total_stats_text

if __name__ == '__main__':
    app.run_server()