# Day 22: Wizard Simulator 20XX

In [1]:
from __future__ import annotations

import heapq
import re
from copy import deepcopy
from dataclasses import dataclass, field

from tools import loader, parsers

BOSS_HP, BOSS_ATTACK = map(int, re.findall(r'\d+', parsers.string(loader.get(2015, 22))))

Oh my. This one took me 2 days and 3 complete rewrites to get it right. We need to manage a lot of stuff: player state, boss state, active spells, mana spent, spell timers... To simplify things, I decided to make a State class to store all that. This allows us to make a custom lt method for determining priority (we want the states with the least mana spent).

In [2]:
@dataclass
class State:
    mana_used: int = 0
    boss_hp: int = BOSS_HP
    boss_attack: int = BOSS_ATTACK
    player_hp: int = 50
    player_armour: int = 0
    player_mana: int = 500
    boss_turn: bool = False
    active_spells: list[Spell] = field(default_factory=list)

    def __lt__(self, other: State) -> bool:
        return self.mana_used < other.mana_used

Next, we need to handle spells. We have 2 types of spells: ones that activate immediately and ones that last for some time. We are going to call the `apply` method to cast all spells, and then each turn wi will call `proc` to trigger lasting spells. This is probably not a very good implementation, I'd prefer this class to be immutable, but then we'd need to come up with a different way to keep track of spell timers.

In [3]:
@dataclass
class Spell:
    name: str
    cost: int
    damage: int = 0
    heal: int = 0
    armour: int = 0
    mana: int = 0
    timer: int = 0

    def __repr__(self) -> str:
        return self.name

    def __eq__(self, other: Spell) -> bool:
        return self.name == other.name

    def cast(self, state: State) -> bool:
        if (self.timer and self in state.active_spells) or state.player_mana < self.cost:
            return False
        state.player_mana -= self.cost
        state.mana_used += self.cost
        state.player_hp += self.heal
        if self.timer > 0:
            state.active_spells.append(deepcopy(self))
        else:
            self.proc(state)
        return True

    def proc(self, state: State) -> None:
        self.timer -= 1
        state.boss_hp -= self.damage
        state.player_mana += self.mana
        if self.name == 'Shield':
            state.player_armour = self.armour
            if self.timer == 0:
                state.player_armour = 0
        state.active_spells = [i for i in state.active_spells if i.timer > 0]


SPELLS = [
    Spell(name='Magic Missile', cost=53, damage=4),
    Spell(name='Drain', cost=73, damage=2, heal=2),
    Spell(name='Shield', cost=113, armour=7, timer=6),
    Spell(name='Poison', cost=173, damage=3, timer=6),
    Spell(name='Recharge', cost=229, mana=101, timer=5),
]

In the main loop we use a heap to store our game states, which allows us to prioritize the states with the lowest mana usage. This means that as soon as we get a result, it's going to be the best one, so we don't have to explore other branches. Effects do not proc immediately after casting, so the order of operations is important. 

Splitting turns doubles heap operations and slows down the program significantly, but it's the only way I got it to work while maintaining all checks and operations in the correct order, as well as keeping the code readable.

In [4]:
def battle(part2: bool) -> int:
    queue = [State()]
    while queue:
        state = heapq.heappop(queue)
        if state.player_hp <= 0:
            continue
        if part2 and not state.boss_turn:
            state.player_hp -= 1

        for spell in state.active_spells:
            spell.proc(state)
        if state.boss_hp <= 0:
            return state.mana_used

        if state.boss_turn:
            state.player_hp -= max(1, state.boss_attack - state.player_armour)
            state.boss_turn = False
            heapq.heappush(queue, state)
            continue

        for spell in SPELLS:
            new_state = deepcopy(state)
            new_state.boss_turn = True
            if spell.cast(new_state):
                heapq.heappush(queue, new_state)
    raise ValueError('Solution not found')


print(battle(part2=False))
print(battle(part2=True))

900
1216
