In [12]:
from collections import Counter
import heapq

from IPython.display import display, HTML

from troops import Troop, ALL_TROOPS
from traits import Trait
from prefix_tree import Trie

# -----------------------------------------
# Part 1: Build tries for Troops and Traits
# -----------------------------------------
trait_trie = Trie() 
for trait in Trait: 
    trait_trie.insert(trait.name, trait)

troop_trie = Trie()
for troop in ALL_TROOPS:
    for name in troop.names:
        troop_trie.insert(name, troop)

# ------------------------------
# Part 2: Getting the user input
# ------------------------------
current_troops: set[Troop] = set()
excluded_troops: set[Troop] = set()
trait_counter = Counter()
trait_dummy: Troop = None
wanted_traits: dict[Trait, int] = {}
excluded_traits: dict[Trait, int] = {}
sort_cost_asc: bool = None
team_size = 6

user_input = input("Current troops: ")
print(f"Input: {user_input}")

for param in user_input.split():
    param = param.strip()
    if not param: continue  # skip whitespace inputs

    # Handling team size
    if param.isdecimal():
        team_size = int(param)
        continue

    # Handling sorting by cost
    if param[0] == "-":
        if "-ascending".startswith(param): sort_cost_asc = True
        elif "-descending".startswith(param): sort_cost_asc = False
        continue
    
    # Handling trait dummies
    if ":" in param:
        trait_dummy = Troop(0, trait_trie.search_with_fallback(param.split(":")[0]), 
                            trait_trie.search_with_fallback(param.split(":")[1]), 
                            "traitdummy")
        
        trait_dummy.add_traits(trait_counter)
        continue
    
    # Handling exclusions
    exclude = param[0] == "!"
    if exclude: param = param[1:]
    
    # Handling traits
    if param.isupper():
        i = 0
        while i < len(param) and param[i].isdecimal():
            i += 1
        
        if exclude: excluded_traits[trait_trie.search_with_fallback(param[i:])] = int(param[:i] or 2)
        else: wanted_traits[trait_trie.search_with_fallback(param[i:])] = int(param[:i] or 2)
        continue
        
    # Handling troops
    troop: Troop = troop_trie.search_with_fallback(param)
    if exclude: 
        excluded_troops.add(troop)
    else:
        current_troops.add(troop)
        troop.add_traits(trait_counter)

# -------------------------------------------
# Part 3: Define the scoring method for teams
# -------------------------------------------
TWO_TRAIT_SCORE = 2
FOUR_TRAIT_SCORE = 5
SIX_TRAIT_SCORE = 10

def team_score(trait_counter: Counter[Trait, int]) -> int:
    score = 0
    for count in trait_counter.values():
        if count == 2 or count == 3: score += TWO_TRAIT_SCORE
        if count == 4 or count == 5: score += FOUR_TRAIT_SCORE
        if count >= 6: score += SIX_TRAIT_SCORE
    return score

# --------------------------------------
# Part 4: Backtracking to find all teams
# --------------------------------------
potential_troops = list(ALL_TROOPS - current_troops - excluded_troops)
current_troops = list(current_troops)
all_teams: list[tuple[int, Counter[Trait, int], list[Troop]]] = []

def backtrack(start: int):
    # Stop if not enough troops left to reach full team
    if len(potential_troops) - start < team_size - len(current_troops):
        return

    # Stop if we have unwanted traits
    if any(trait_counter[trait] >= count for trait, count in excluded_traits.items()): 
        return  

    # Base case: full team formed
    if len(current_troops) == team_size:
        # Ensure we have all wanted traits
        if all(trait_counter[trait] >= count for trait, count in wanted_traits.items()):  
            all_teams.append((team_score(trait_counter), trait_counter.copy(), current_troops.copy()))
        return

    for i in range(start, len(potential_troops)):
        current_troops.append(potential_troops[i])
        potential_troops[i].add_traits(trait_counter)
        backtrack(i + 1)
        current_troops.pop()
        potential_troops[i].subtract_traits(trait_counter)

backtrack(0)
print(f"{len(all_teams)} teams found.")

# --------------------------------------
# Part 5: Sorting to find the best teams
# --------------------------------------
NUM_BEST = 10

if sort_cost_asc is None:
    sort_key = lambda x: x[0]
elif sort_cost_asc is True:
    sort_key = lambda x: (x[0], -sum(troop.cost for troop in x[2]))
elif sort_cost_asc is False:
    sort_key = lambda x: (x[0], sum(troop.cost for troop in x[2]))

best_teams = heapq.nlargest(NUM_BEST, all_teams, key=sort_key)

# ---------------------------------
# Part 6: Displaying the best teams
# ---------------------------------
html = "<div style='font-family:sans-serif;color:white;background:#1e1e2f;padding:12px;border-radius:12px;width:max-content'>"

for score, trait_counts, troops in best_teams:
    # team container
    html += """
    <div style='margin-bottom:12px;padding:8px;
                border-radius:8px;background:#2a2a40;
                box-shadow:0 0 6px rgba(0,0,0,0.4);'>
    """

    # Troop images row
    html += "<div style='display:flex;gap:4px;justify-content:center;margin-bottom:4px;'>"
    for troop in troops:
        html += f'<img src="{troop.image_path()}" alt="{troop.names[0]}" width="72" style="border-radius:10px;">'

    # Trait dummy with trait icons overlay
    if trait_dummy:
        html += f"""
        <div style="position:relative; display:inline-block; width:72px;">
            <img src="{trait_dummy.image_path()}" alt="{trait_dummy.names[0]}" width="72"
                style="border-radius:10px; width:100%; height:100%; display:block; object-fit:cover;">
            <img src="{trait_dummy.trait1.image_path()}" title="{trait_dummy.trait1}"
                style="position:absolute;bottom:4px;left:4px;width:26px;height:26px;border-radius:6px;">
            <img src="{trait_dummy.trait2.image_path()}" title="{trait_dummy.trait2}"
                style="position:absolute;bottom:4px;right:4px;width:26px;height:26px;border-radius:6px;">
        </div>
        """
    html += "</div>"

    # Traits + score row
    html += "<div style='display:flex;align-items:center;justify-content:space-between;margin-top:2px;'>"

    # Trait icons container
    html += "<div style='display:flex;flex-wrap:wrap;'>"
    for trait, count in trait_counts.items():
        count = count // 2 * 2  # only even numbers for traits
        if count == 0:
            continue  # only show active traits
        html += f"""
        <span style="display:inline-flex;align-items:center;margin-right:8px;margin-bottom:2px;">
            <img src="{trait.image_path()}" title="{trait}" width="28" style="vertical-align:middle;">
            <span style="font-weight:bold;font-size:14px;margin-left:2px;">{count}</span>
        </span>
        """
    html += "</div>"

    # Score display
    html += f"<div style='color:#ccc;font-size:13px;font-weight:bold;margin-right:4px;'>Score: {score}</div>"
    html += "</div></div>"

html += "</div>"
display(HTML(html))

Input: 4CLAN 4BR 7
1 teams found.
