In [1]:
from dataclasses import dataclass, replace
from typing import Callable, Iterable

In [2]:
@dataclass
class Character:
    hit_points: int
    damage: int = 0
    armor: int = 0
    mana: int = 0

In [3]:
@dataclass(unsafe_hash=True)
class Spell:
    name: str
    mana: int
    duration: int
    effect: Callable
    cast_by: Character | None = None
    cast_on: Character | None = None

    def cast(self, by, on):
        return replace(self, cast_by=by, cast_on=on)

    def __call__(self):
        self.effect(self)
        return replace(self, duration=self.duration-1)
    
def magic_missile(spell: Spell):
    assert spell.cast_on
    if spell.duration > 0:
        spell.cast_on.hit_points -= 4

def drain(spell: Spell):
    assert spell.cast_by and spell.cast_on
    if spell.duration > 0:
        spell.cast_on.hit_points -= 2
        spell.cast_by.hit_points += 2

def shield(spell: Spell):
    assert spell.cast_by
    if spell.duration == 6:
        spell.cast_by.armor += 7
    elif spell.duration == 0:
        spell.cast_by.armor -= 7

def poison(spell: Spell):
    assert spell.cast_on
    if spell.duration > 0:
        spell.cast_on.hit_points -= 3

def recharge(spell: Spell):
    assert spell.cast_by
    if spell.duration > 0:
        spell.cast_by.mana += 101

spells = (
    Spell("Magic Missile", mana=53,  duration=1, effect=magic_missile),
    Spell("Drain",         mana=73,  duration=1, effect=drain),
    Spell("Shield",        mana=113, duration=6, effect=shield),
    Spell("Poison",        mana=173, duration=6, effect=poison),
    Spell("Recharge",      mana=229, duration=5, effect=recharge)
)

In [4]:
def simulate(spells_prepared: Iterable[Spell], player_stats: tuple[int, int], boss_stats: tuple[int, int], hard_mode=False):
    player = Character(hit_points=player_stats[0], mana=player_stats[1])
    boss   = Character(hit_points=boss_stats[0], damage=boss_stats[1])
    effects = []

    for spell in spells_prepared:
        # "-- Player turn --\n"
        # f"- Player has {player.hit_points} hit points, {player.armor} armor, {player.mana} mana\n"
        # f"- Boss has {boss.hit_points} hit points\n"

        if hard_mode:
            player.hit_points -= 1
            if player.hit_points <= 0:
                return "died"


        effects = [effect() for effect in effects if effect.duration >= 0]
        # ", ".join(f"{effect.name} ({effect.duration})" for effect in effects if effect.duration >= 0) + "\n"

        if boss.hit_points <= 0:
            return "won"
        if spell.mana > player.mana:
            return "out of mana"
        if spell.name in [effect.name for effect in effects if effect.duration > 0]:
            return "not possible"
        
        effects += [spell.cast(by=player, on=boss)]
        player.mana -= spell.mana
        # f"Player casts {effects[-1].name}\n\n"

        # "-- Boss turn --\n"
        # f"- Player has {player.hit_points} hit points, {player.armor} armor, {player.mana} mana\n"
        # f"- Boss has {boss.hit_points} hit points\n"
        
        effects = [effect() for effect in effects if effect.duration >= 0]
        # ", ".join(f"{effect.name} ({effect.duration})" for effect in effects if effect.duration >= 0) + "\n"
        
        if boss.hit_points <= 0:
            return "won"
        
        attack = max(1, boss.damage - player.armor)
        player.hit_points -= attack
        # f"Boss attacks for {attack} damage\n\n"

        if player.hit_points <= 0:
            return "died"
    
    return "exhausted"

In [5]:
def find_least_mana(hard_mode=False):
    player_stats = 50, 500
    boss_stats = 71, 10
    viable_spell_lists = {(spell,) for spell in spells}
    min_mana = 9999

    while viable_spell_lists:
        next_viable_spell_lists = set()
        for viable_spell_list in viable_spell_lists:
            viable_spells = list(viable_spell_list)
            for spell in spells:
                spells_prepared = viable_spells + [spell]
                mana = sum(spell.mana for spell in spells_prepared)
                if mana >= min_mana:
                    continue
                result = simulate(spells_prepared, player_stats, boss_stats, hard_mode)
                if result == "won":
                    min_mana = mana
                elif result == "exhausted":
                    next_viable_spell_lists.add(tuple(spells_prepared))
        viable_spell_lists = next_viable_spell_lists
    
    return min_mana

In [6]:
print("Part 1:", find_least_mana())
print("Part 2:", find_least_mana(hard_mode=True))

Part 1: 1824
Part 2: 1937
