# Battle Simulator - Interactive Visualizer

This notebook provides interactive visualization for debugging and verifying combat mechanics.

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.colors import LinearSegmentedColormap
from IPython.display import display, clear_output
import ipywidgets as widgets

from src.simulator import BattleSimulator, GameDataLoader, Position
from src.simulator.battle import BattleResult

In [None]:
# Load game data
sim = BattleSimulator('../data')
loader = sim.data_loader

print(f"Loaded {len(loader.units)} units")
print(f"Loaded {len(loader.abilities)} abilities")
print(f"Loaded {len(loader.encounters)} encounters")

In [None]:
class BattleGridVisualizer:
    """Matplotlib-based battle grid visualizer."""
    
    def __init__(self, figsize=(14, 10)):
        self.figsize = figsize
        self.grid_width = 5
        self.grid_height = 3
        
    def plot_battle(self, battle, highlighted_targets=None, aoe_pattern=None, 
                    selected_unit=None, title=None):
        """Plot the current battle state."""
        fig, axes = plt.subplots(2, 1, figsize=self.figsize, 
                                  gridspec_kw={'height_ratios': [1, 1], 'hspace': 0.3})
        
        highlighted_targets = highlighted_targets or set()
        aoe_pattern = aoe_pattern or {}
        
        # Plot enemy grid (top)
        self._plot_grid(axes[0], battle.enemy_units, "Enemy Side", 
                       highlighted_targets, aoe_pattern, is_enemy=True)
        
        # Plot player grid (bottom)
        self._plot_grid(axes[1], battle.player_units, "Player Side",
                       set(), {}, is_enemy=False, selected_unit=selected_unit)
        
        if title:
            fig.suptitle(title, fontsize=14, fontweight='bold')
        else:
            turn_str = "Player Turn" if battle.is_player_turn else "Enemy Turn"
            fig.suptitle(f"Turn {battle.turn_number} - {turn_str}", fontsize=14)
        
        plt.tight_layout()
        return fig, axes
    
    def _plot_grid(self, ax, units, title, highlighted, aoe, is_enemy=False, selected_unit=None):
        """Plot a single side's grid."""
        ax.set_xlim(-0.5, self.grid_width - 0.5)
        ax.set_ylim(-0.5, self.grid_height - 0.5)
        ax.set_aspect('equal')
        ax.set_title(title, fontsize=12, color='red' if is_enemy else 'green')
        
        # Grid lines
        for x in range(self.grid_width + 1):
            ax.axvline(x - 0.5, color='gray', linewidth=0.5)
        for y in range(self.grid_height + 1):
            ax.axhline(y - 0.5, color='gray', linewidth=0.5)
        
        # Row labels
        row_labels = ['Front', 'Mid', 'Back'] if is_enemy else ['Back', 'Mid', 'Front']
        for y in range(self.grid_height):
            display_y = y if is_enemy else (self.grid_height - 1 - y)
            ax.text(-0.8, y, row_labels[y], ha='right', va='center', fontsize=9)
        
        # Column labels
        for x in range(self.grid_width):
            ax.text(x, self.grid_height - 0.3, f'Col {x}', ha='center', fontsize=8)
        
        # Highlight targets
        for (x, y) in highlighted:
            rect = patches.Rectangle((x - 0.5, y - 0.5), 1, 1,
                                     linewidth=2, edgecolor='cyan', 
                                     facecolor='cyan', alpha=0.3)
            ax.add_patch(rect)
        
        # AOE pattern
        for (x, y), dmg_pct in aoe.items():
            if 0 <= x < self.grid_width and 0 <= y < self.grid_height:
                alpha = min(0.8, dmg_pct / 100)
                color = 'red' if dmg_pct >= 80 else 'orange' if dmg_pct >= 40 else 'yellow'
                rect = patches.Rectangle((x - 0.5, y - 0.5), 1, 1,
                                         facecolor=color, alpha=alpha)
                ax.add_patch(rect)
                ax.text(x, y - 0.3, f'{dmg_pct}%', ha='center', fontsize=7)
        
        # Plot units
        unit_lookup = {(u.position.x, u.position.y): u for u in units if u.is_alive}
        
        for y in range(self.grid_height):
            for x in range(self.grid_width):
                unit = unit_lookup.get((x, y))
                if unit:
                    hp_pct = unit.current_hp / unit.template.stats.hp
                    
                    # Unit circle
                    color = 'green' if hp_pct > 0.7 else 'yellow' if hp_pct > 0.3 else 'red'
                    circle = patches.Circle((x, y), 0.35, facecolor=color, 
                                           edgecolor='black', linewidth=2)
                    
                    # Highlight selected unit
                    if selected_unit is not None and units.index(unit) == selected_unit:
                        circle.set_edgecolor('blue')
                        circle.set_linewidth(4)
                    
                    ax.add_patch(circle)
                    
                    # Unit label
                    class_name = unit.template.class_type.name[:4]
                    ax.text(x, y + 0.05, class_name, ha='center', va='center', 
                           fontsize=8, fontweight='bold')
                    ax.text(x, y - 0.2, f'{int(hp_pct*100)}%', ha='center', 
                           va='center', fontsize=7)
        
        ax.set_xticks([])
        ax.set_yticks([])
        ax.invert_yaxis() if is_enemy else None
    
    def plot_ability_pattern(self, ability, target_x=2, target_y=1):
        """Visualize an ability's targeting and damage patterns."""
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        stats = ability.stats
        
        # Target area pattern
        ax = axes[0]
        ax.set_title(f"Target Area Pattern\n(Type: {stats.target_area.target_type if stats.target_area else 'None'})")
        self._plot_pattern(ax, stats.target_area.data if stats.target_area else [], 
                          target_x, target_y, 'Target')
        
        # Damage area pattern (splash)
        ax = axes[1]
        ax.set_title("Damage Area Pattern (Splash)")
        self._plot_pattern(ax, stats.damage_area, target_x, target_y, 'Splash')
        
        fig.suptitle(f"Ability: {ability.name}\nRange: {stats.min_range}-{stats.max_range}, Targets: {stats.targets}", 
                    fontsize=11)
        plt.tight_layout()
        return fig
    
    def _plot_pattern(self, ax, pattern_data, center_x, center_y, pattern_type):
        """Plot a pattern on a grid."""
        # Show larger area for patterns
        extent = 6
        ax.set_xlim(-extent, extent)
        ax.set_ylim(-extent, extent)
        ax.set_aspect('equal')
        ax.grid(True, alpha=0.3)
        
        # Mark center
        ax.plot(0, 0, 'ko', markersize=10, label='Target')
        
        if not pattern_data:
            ax.text(0, 0.5, 'No pattern data', ha='center')
            return
        
        for entry in pattern_data:
            x = entry.pos.x
            y = entry.pos.y
            dmg_pct = getattr(entry, 'damage_percent', 100)
            weight = getattr(entry, 'weight', None)
            
            # Color based on damage
            alpha = min(0.9, dmg_pct / 100 + 0.1)
            color = plt.cm.Reds(dmg_pct / 100)
            
            rect = patches.Rectangle((x - 0.5, y - 0.5), 1, 1,
                                     facecolor=color, alpha=alpha,
                                     edgecolor='black', linewidth=1)
            ax.add_patch(rect)
            
            label = f'{dmg_pct}%' if weight is None else f'w:{weight}'
            ax.text(x, y, label, ha='center', va='center', fontsize=7)
        
        ax.set_xlabel('X offset from target')
        ax.set_ylabel('Y offset from target')


viz = BattleGridVisualizer()

## Create a Test Battle

In [None]:
# Find units with weapons for testing
units_with_weapons = [
    uid for uid, unit in loader.units.items()
    if unit.weapons
][:8]

print(f"Using units: {units_with_weapons}")

# Create battle
battle = sim.create_custom_battle(
    layout_id=2,
    player_unit_ids=units_with_weapons[:4],
    player_positions=[0, 1, 5, 6],  # Front and mid row
    enemy_unit_ids=units_with_weapons[4:8] if len(units_with_weapons) > 4 else units_with_weapons[:4],
    enemy_positions=[0, 1, 5, 6]
)

print(f"Battle created with {len(battle.player_units)} player units and {len(battle.enemy_units)} enemy units")

In [None]:
# Show initial battle state
viz.plot_battle(battle, title="Initial Battle State")
plt.show()

## Explore Unit Info

In [None]:
# Show info about player unit 0
unit = battle.player_units[0]
print(f"=== Unit Info ===")
print(f"Name: {unit.template.name}")
print(f"Class: {unit.template.class_type.name}")
print(f"Tags: {unit.template.tags}")
print(f"HP: {unit.current_hp}/{unit.template.stats.hp}")
print(f"Position: ({unit.position.x}, {unit.position.y})")
print(f"")
print(f"=== Weapons ===")
for wid, weapon in unit.template.weapons.items():
    print(f"  [{wid}] {weapon.name}")
    print(f"      Damage: {weapon.stats.base_damage_min}-{weapon.stats.base_damage_max}")
    for aid in weapon.abilities:
        ability = loader.get_ability(aid)
        if ability:
            print(f"      â†’ Ability: {ability.name}")
            print(f"        Range: {ability.stats.min_range}-{ability.stats.max_range}")
            print(f"        Targets tags: {ability.stats.targets}")

## Visualize Valid Targets

In [None]:
# Show valid targets for player unit 0, weapon 1
unit_idx = 0
weapon_id = list(battle.player_units[unit_idx].template.weapons.keys())[0]

targets = battle.get_valid_targets(battle.player_units[unit_idx], weapon_id)
highlighted = {(t.x, t.y) for t in targets}

print(f"Valid targets for Unit {unit_idx}, Weapon {weapon_id}: {[(t.x, t.y) for t in targets]}")

viz.plot_battle(battle, highlighted_targets=highlighted, selected_unit=unit_idx,
               title=f"Valid Targets for Unit {unit_idx}")
plt.show()

## Visualize Ability Patterns

In [None]:
# Show patterns for different abilities
# Find some interesting AOE abilities
aoe_abilities = []
for aid, ability in loader.abilities.items():
    if len(ability.stats.damage_area) > 3 or (ability.stats.target_area and len(ability.stats.target_area.data) > 3):
        aoe_abilities.append((aid, ability))
        if len(aoe_abilities) >= 5:
            break

print("AOE Abilities found:")
for aid, ability in aoe_abilities:
    print(f"  [{aid}] {ability.name}")

In [None]:
# Visualize first AOE ability pattern
if aoe_abilities:
    aid, ability = aoe_abilities[0]
    viz.plot_ability_pattern(ability)
    plt.show()

## Interactive Battle Controls

In [None]:
# Interactive battle widget
import random

class InteractiveBattle:
    def __init__(self, battle, viz):
        self.battle = battle
        self.viz = viz
        self.output = widgets.Output()
        
    def display(self):
        # Controls
        unit_dropdown = widgets.Dropdown(
            options=[(f"Unit {i}: {u.template.class_type.name}", i) 
                    for i, u in enumerate(self.battle.player_units) if u.is_alive],
            description='Unit:'
        )
        
        weapon_dropdown = widgets.Dropdown(
            options=[],
            description='Weapon:'
        )
        
        show_targets_btn = widgets.Button(description='Show Targets')
        random_action_btn = widgets.Button(description='Random Action')
        next_turn_btn = widgets.Button(description='Next Turn')
        
        def update_weapons(change):
            unit = self.battle.player_units[unit_dropdown.value]
            weapon_dropdown.options = [
                (f"{wid}: {w.name[:20]}", wid) 
                for wid, w in unit.template.weapons.items()
            ]
        
        unit_dropdown.observe(update_weapons, names='value')
        update_weapons(None)
        
        def show_targets(b):
            with self.output:
                clear_output(wait=True)
                targets = self.battle.get_valid_targets(
                    self.battle.player_units[unit_dropdown.value],
                    weapon_dropdown.value
                )
                highlighted = {(t.x, t.y) for t in targets}
                self.viz.plot_battle(self.battle, highlighted_targets=highlighted,
                                    selected_unit=unit_dropdown.value)
                plt.show()
                print(f"Valid targets: {[(t.x, t.y) for t in targets]}")
        
        def random_action(b):
            with self.output:
                clear_output(wait=True)
                actions = self.battle.get_legal_actions()
                if actions:
                    action = random.choice(actions)
                    result = self.battle.execute_action(action)
                    self.battle.end_turn()
                    print(f"Executed: Unit {action.unit_index}, Weapon {action.weapon_id}, Target ({action.target_position.x}, {action.target_position.y})")
                    print(f"Damage: {result.damage_dealt}, Kills: {result.kills}")
                    
                    # Enemy turn (random)
                    if not self.battle.is_player_turn and self.battle.result == BattleResult.IN_PROGRESS:
                        enemy_actions = self.battle.get_legal_actions()
                        if enemy_actions:
                            enemy_action = random.choice(enemy_actions)
                            self.battle.execute_action(enemy_action)
                        self.battle.end_turn()
                
                self.viz.plot_battle(self.battle)
                plt.show()
                print(f"Battle result: {self.battle.result.name}")
        
        def next_turn(b):
            with self.output:
                clear_output(wait=True)
                self.battle.end_turn()
                self.viz.plot_battle(self.battle)
                plt.show()
        
        show_targets_btn.on_click(show_targets)
        random_action_btn.on_click(random_action)
        next_turn_btn.on_click(next_turn)
        
        controls = widgets.HBox([unit_dropdown, weapon_dropdown, show_targets_btn, 
                                random_action_btn, next_turn_btn])
        
        display(controls)
        display(self.output)
        
        # Initial display
        with self.output:
            self.viz.plot_battle(self.battle)
            plt.show()

# Create interactive battle
interactive = InteractiveBattle(battle, viz)
interactive.display()

## Explore Tag Hierarchy

In [None]:
import json

config = json.load(open('../data/Assets/Config/battle/battle_config.json'))
tag_hierarchy = config.get('tag_hierarchy', {})

print("=== Tag Hierarchy ===")
print("Parent Tag -> Children")
for parent, children in tag_hierarchy.items():
    print(f"  {parent} -> {children}")

# Most common target tags
print("\n=== Most Used Target Tags ===")
print("Tag 24: Ground units (land-based)")
print("Tag 39: All targetable units")
print("Tag 15: Air units")
print("Tag 6: Buildings")

## Explore Class Damage Modifiers

In [None]:
class_types = config.get('classes', {}).get('class_types', {})

# Build damage modifier matrix
classes = sorted(class_types.keys(), key=int)
class_names = {k: class_types[k]['display_name'] for k in classes}

print("=== Class Damage Modifiers ===")
print("Values > 1.0 = bonus damage, < 1.0 = reduced damage")
print("")

# Create matrix
n = len(classes)
matrix = np.ones((n, n))

for i, attacker in enumerate(classes):
    mods = class_types[attacker].get('damage_mods', {})
    for defender, mult in mods.items():
        if defender in classes:
            j = classes.index(defender)
            matrix[i, j] = mult

# Plot heatmap
fig, ax = plt.subplots(figsize=(12, 10))
im = ax.imshow(matrix, cmap='RdYlGn', vmin=0.5, vmax=1.5)

# Labels
labels = [class_names[c][:10] for c in classes]
ax.set_xticks(range(n))
ax.set_yticks(range(n))
ax.set_xticklabels(labels, rotation=45, ha='right')
ax.set_yticklabels(labels)
ax.set_xlabel('Defender Class')
ax.set_ylabel('Attacker Class')
ax.set_title('Class Damage Modifiers\n(Green = bonus, Red = penalty)')

# Add values
for i in range(n):
    for j in range(n):
        if matrix[i, j] != 1.0:
            ax.text(j, i, f'{matrix[i,j]:.2f}', ha='center', va='center', fontsize=7)

plt.colorbar(im)
plt.tight_layout()
plt.show()