In [None]:
from collections import Counter
import heapq

from IPython.display import display, HTML

from troops import Troop, ALL_TROOPS
from traits import Trait
from prefix_tree import Trie

# Part 1: Build tries for Troops and Traits
trait_trie = Trie() 
for trait in Trait: 
    trait_trie.insert(trait.name, trait)

troop_trie = Trie()
for troop in ALL_TROOPS:
    for name in troop.names:
        troop_trie.insert(name, troop)

# Part 2: Getting the current composition by input
current_troops: set[Troop] = set()
excluded_troops: set[Troop] = set()
trait_counter = Counter()
trait_dummy = None

for troop_name in input("Current troops: ").split(" "):
    if not troop_name.strip(): continue  # skip whitespace inputs

    # Handling trait dummies
    if troop_name[:3] == "TD:":
        troop_name = troop_name[3:]
        if troop_name.count(":") != 1: raise ValueError("Trait dummy must have exactly two traits, separated by a colon")

        trait_dummy_traits = []
        for trait_name in troop_name.split(":"):
            trait = trait_trie.search_with_fallback(trait_name)
            if not trait: raise ValueError(f"Trait '{trait_name}' not found")
            trait_counter[trait] += 1
            trait_dummy_traits.append(trait)

        trait_dummy = Troop(trait_dummy_traits[0], trait_dummy_traits[1], "traitdummy")
        continue

    # Handling troop exclusions
    exclude = False
    if troop_name and troop_name[0] == "!":
        troop_name = troop_name[1:]
        exclude = True

    # Finding the troop using the Troop Trie
    troop = troop_trie.search_with_fallback(troop_name)
    if troop in current_troops or troop in excluded_troops: raise ValueError(f"Troop '{troop_name}' is already included/excluded")  # prevent duplicates
    if not troop: raise ValueError(f"Troop '{troop_name}' not found")  # troop not found
    if exclude: excluded_troops.add(troop)
    else:
        current_troops.add(troop)
        troop.add_traits(trait_counter)

if len(current_troops) > 6: raise ValueError("Current troops exceed maximum of 6")  # prevent an invalid starting team size

# Part 3: Define the scoring method for compositions
TWO_TRAIT_SCORE = 2
FOUR_TRAIT_SCORE = 5
SIX_TRAIT_SCORE = 10

def composition_score(trait_counter: Counter[Trait, int]) -> int:
    score = 0
    for count in trait_counter.values():
        if count == 2 or count == 3: score += TWO_TRAIT_SCORE
        if count == 4 or count == 5: score += FOUR_TRAIT_SCORE
        if count >= 6: score += SIX_TRAIT_SCORE
    return score

# Part 4: Backtracking to find all compositions
MAX_TROOPS = 6
potential_troops = list(ALL_TROOPS - current_troops - excluded_troops)
current_troops = list(current_troops)
all_compositions: list[tuple[int, Counter[Trait, int], list[Troop]]] = []

def backtrack(start: int):
    if len(current_troops) == MAX_TROOPS:
        all_compositions.append((composition_score(trait_counter), trait_counter.copy(), current_troops.copy()))
        return

    for i in range(start, len(potential_troops)):
        current_troops.append(potential_troops[i])
        potential_troops[i].add_traits(trait_counter)
        backtrack(i + 1)
        current_troops.pop()
        potential_troops[i].subtract_traits(trait_counter)

backtrack(0)
print(f"{len(all_compositions)} compositions found.")

# Part 5: Displaying the best compositions
NUM_BEST = 10
best_compositions = heapq.nlargest(NUM_BEST, all_compositions, key=lambda x: x[0])
html = "<div style='font-family:sans-serif;color:white;background:#1e1e2f;padding:12px;border-radius:12px;width:max-content'>"

for score, trait_counts, troops in best_compositions:
    # Composition container
    html += """
    <div style='margin-bottom:12px;padding:8px;
                border-radius:8px;background:#2a2a40;
                box-shadow:0 0 6px rgba(0,0,0,0.4);'>
    """

    # Troop images row
    html += "<div style='display:flex;gap:4px;justify-content:center;margin-bottom:4px;'>"
    for troop in troops:
        html += f'<img src="{troop.image_path()}" alt="{troop.names[0]}" width="72" style="border-radius:10px;">'

    # Trait dummy with trait icons overlay
    if trait_dummy:
        t1, t2 = trait_dummy.trait_1, trait_dummy.trait_2
        html += f"""
        <div style="position:relative; display:inline-block; width:72px;">
            <img src="{trait_dummy.image_path()}" alt="{trait_dummy.names[0]}" width="72"
                style="border-radius:10px; width:100%; height:100%; display:block; object-fit:cover;">
            <img src="{t1.image_path()}" title="{t1.name}"
                style="position:absolute;bottom:4px;left:4px;width:26px;height:26px;border-radius:6px;">
            <img src="{t2.image_path()}" title="{t2.name}"
                style="position:absolute;bottom:4px;right:4px;width:26px;height:26px;border-radius:6px;">
        </div>
        """
    html += "</div>"

    # Traits + score row
    html += "<div style='display:flex;align-items:center;justify-content:space-between;margin-top:2px;'>"

    # Trait icons container
    html += "<div style='display:flex;flex-wrap:wrap;'>"
    for trait, count in trait_counts.items():
        count = count // 2 * 2  # only even numbers for traits
        if count == 0:
            continue  # only show active traits
        html += f"""
        <span style="display:inline-flex;align-items:center;margin-right:8px;margin-bottom:2px;">
            <img src="{trait.image_path()}" title="{trait}" width="28" style="vertical-align:middle;">
            <span style="font-weight:bold;font-size:14px;margin-left:2px;">{count}</span>
        </span>
        """
    html += "</div>"

    # Score display
    html += f"<div style='color:#ccc;font-size:13px;font-weight:bold;margin-right:4px;'>Score: {score}</div>"
    html += "</div></div>"

html += "</div>"
display(HTML(html))

230230 compositions found.
