In [None]:
# whiteout_survival_battle_simulator.py
"""
Whiteout Survival Battle Simulator (No Heroes)
=============================================

*Fully tier‑aware, per‑stat combat sandbox* for theory‑crafting in
**Whiteout Survival**.

Version **v0.5** – Key changes
-----------------------------
1. **Separate Attack vs Lethality damage paths**  
   *Attack* must first chew through **Defense**, while *Lethality* bypasses
   Defense and goes straight to **Health**.  Any overflow from either type
   propagates to the next troop line (Inf → Cav → Arch).
2. **New `receive_damage()` signature**  
   `receive_damage(attack_dmg, lethality_dmg)` returns the unspent portions of
   both, enabling proper spill‑over logic.
3. **Cleaner battle loop** – Each side now computes *separate* outgoing Attack
   and Lethality pools each turn.

Base templates (unchanged)
~~~~~~~~~~~~~~~~~~~~~~~~~~
```
Infantry  attack 0  defense 3  health 5  lethality 0
Cavalry   attack 3  defense 1  health 1  lethality 4
Archer    attack 4  defense 0  health 0  lethality 4
```
Then add the **tier number** to each stat (T10 = +10, etc.).
"""

import logging
from dataclasses import dataclass, field
from typing import Dict, Tuple

# ----------------------------------------------------------------------
# Configurable constants
# ----------------------------------------------------------------------
MAX_TURNS: int = 50  # failsafe against infinite loops

logger = logging.getLogger("whiteout.battle")

# ----------------------------------------------------------------------
# Tier‑based stat generator
# ----------------------------------------------------------------------
TEMPLATE: Dict[str, Tuple[int, int, int, int]] = {
    "Infantry": (0, 3, 5, 0),  # (atk, def, hp, leth)
    "Cavalry":  (3, 1, 1, 4),
    "Archer":   (4, 0, 0, 4),
}


def tier_stats(tier: int) -> Dict[str, "TroopStats"]:
    """Return a dict of TroopStats scaled to *tier* (1-12 etc.)."""
    out: Dict[str, TroopStats] = {}
    for name, (atk_b, def_b, hp_b, leth_b) in TEMPLATE.items():
        out[name] = TroopStats(
            name=name,
            attack=atk_b + tier,
            defense=def_b + tier,
            health=hp_b + tier,
            lethality=leth_b + tier,
        )
    return out


# ----------------------------------------------------------------------
# Core data classes
# ----------------------------------------------------------------------
@dataclass
class TroopStats:
    name: str
    attack: float
    defense: float
    health: float
    lethality: float


@dataclass
class Buffs:
    attack_pct: float = 0.0
    defense_pct: float = 0.0
    health_pct: float = 0.0
    lethality_pct: float = 0.0

    def apply(self, base: "TroopStats") -> "TroopStats":
        f = lambda pct: 1.0 + pct / 100.0
        return TroopStats(
            name=base.name,
            attack=base.attack * f(self.attack_pct),
            defense=base.defense * f(self.defense_pct),
            health=base.health * f(self.health_pct),
            lethality=base.lethality * f(self.lethality_pct),
        )


@dataclass
class TroopGroup:
    base: TroopStats
    count: int
    buffs: Buffs

    eff: TroopStats = field(init=False)
    defense_pool: float = field(init=False)
    health_pool: float = field(init=False)
    injured: int = field(default=0, init=False)

    def __post_init__(self) -> None:
        self.eff = self.buffs.apply(self.base)
        self.defense_pool = self.count * self.eff.defense
        self.health_pool = self.count * self.eff.health

    # ------------------------------------------------------------------
    # Combat resolution
    # ------------------------------------------------------------------
    def receive_damage(self, atk_dmg: float, leth_dmg: float) -> Tuple[float, float]:
        """Apply *atk_dmg* (vs Defense) and *leth_dmg* (bypass) to this troop.

        Returns the (remaining_atk, remaining_leth) that overflows to the next
        troop line.
        """
        if self.active_count == 0 or (atk_dmg <= 0 and leth_dmg <= 0):
            return atk_dmg, leth_dmg

        prev_active = self.active_count

        # 1) Attack first: chew through Defense, leftover (if any) threatens HP
        absorbed_def = min(self.defense_pool, atk_dmg)
        self.defense_pool -= absorbed_def
        atk_after_def = atk_dmg - absorbed_def  # can be > 0

        # 2) Total HP damage this round
        potential_hp_dmg = atk_after_def + leth_dmg
        absorbed_hp = min(self.health_pool, potential_hp_dmg)
        self.health_pool -= absorbed_hp

        # 3) Update injuries (any drop in soldier count)
        self.injured += prev_active - self.active_count

        # 4) Split what was actually consumed between atk and leth so we know leftovers
        atk_used = min(atk_after_def, absorbed_hp)
        leth_used = absorbed_hp - atk_used

        rem_atk = atk_after_def - atk_used
        rem_leth = leth_dmg - leth_used
        return rem_atk, rem_leth

    # ------------------------------------------------------------------
    @property
    def active_count(self) -> int:
        if self.eff.health == 0:
            return 0
        return max(0, int(self.health_pool // self.eff.health))


@dataclass
class Army:
    troops: Dict[str, TroopGroup]  # keys: "Infantry", "Cavalry", "Archer"

    def total_attack(self) -> float:
        return sum(t.active_count * t.eff.attack for t in self.troops.values())

    def total_lethality(self) -> float:
        return sum(t.active_count * t.eff.lethality for t in self.troops.values())

    def has_active(self) -> bool:
        return any(t.active_count > 0 for t in self.troops.values())

    def injured_report(self) -> Dict[str, int]:
        return {k: tg.injured for k, tg in self.troops.items()}


# ----------------------------------------------------------------------
# Battle simulator
# ----------------------------------------------------------------------

def simulate(
    attacker: "Army",
    defender: "Army",
    *,
    max_turns: int = MAX_TURNS,
) -> Tuple[int, Dict[str, int], Dict[str, int]]:
    """Run the fight and return (*turns, attacker_injured, defender_injured*)."""

    order = ("Infantry", "Cavalry", "Archer")  # front → back
    turn = 0

    while attacker.has_active() and defender.has_active() and turn < max_turns:
        turn += 1

        atk_atk = attacker.total_attack()
        atk_leth = attacker.total_lethality()
        def_atk = defender.total_attack()
        def_leth = defender.total_lethality()

        logger.debug(
            "Turn %d | atkA %.0f/%.0f | atkD %.0f/%.0f",
            turn, atk_atk, atk_leth, def_atk, def_leth,
        )

        # ---------------- Attacker -> Defender -----------------
        spill_atk, spill_leth = atk_atk, atk_leth
        for t in order:
            spill_atk, spill_leth = defender.troops[t].receive_damage(spill_atk, spill_leth)
            if spill_atk <= 0 and spill_leth <= 0:
                break

        # ---------------- Defender -> Attacker -----------------
        spill_atk, spill_leth = def_atk, def_leth
        for t in order:
            spill_atk, spill_leth = attacker.troops[t].receive_damage(spill_atk, spill_leth)
            if spill_atk <= 0 and spill_leth <= 0:
                break

    logger.info("Battle ended in %d turns", turn)
    return turn, attacker.injured_report(), defender.injured_report()


# ----------------------------------------------------------------------
# Quick demo when run directly
# ----------------------------------------------------------------------
if __name__ == "__main__":
    logging.basicConfig(format="%(levelname)s | %(message)s", level=logging.DEBUG)

    # --- Stats & armies -------------------------------------------------
    stats_t10 = tier_stats(5)

    atk_counts = {"Infantry": 3000, "Cavalry": 0, "Archer": 1000}
    def_counts = {"Infantry": 6000, "Cavalry": 0, "Archer": 1000}
    no_buffs = {t: Buffs() for t in stats_t10}

    attacker = Army({k: TroopGroup(stats_t10[k], atk_counts[k], no_buffs[k]) for k in stats_t10})
    defender = Army({k: TroopGroup(stats_t10[k], def_counts[k], no_buffs[k]) for k in stats_t10})

    turns, atk_inj, def_inj = simulate(attacker, defender)
    logger.info("Attacker injured -> %s", atk_inj)
    logger.info("Defender injured -> %s", def_inj)


DEBUG | Turn 1 | atkA 24000/24000 | atkD 39000/39000
INFO | Battle ended in 1 turns
INFO | Attacker injured -> {'Infantry': 3000, 'Cavalry': 0, 'Archer': 1000}
INFO | Defender injured -> {'Infantry': 2400, 'Cavalry': 0, 'Archer': 0}
