In [3]:
from collections import defaultdict
import heapq
from itertools import combinations
import time

from IPython.display import display, HTML

from troops import Troop
from traits import Trait
from teams import Team
from prefix_tree import Trie
from search_algorithms import generate_connected_subgraphs

# -----------------------------------------
# Part 1: Build tries for Troops and Traits
# -----------------------------------------
trait_trie = Trie() 
for trait in Trait: 
    trait_trie.insert(trait.name, trait)

troop_trie = Trie()
for troop in Troop.all_troops:
    for name in troop.all_names:
        troop_trie.insert(name, troop)

# ------------------------------
# Part 2: Getting the user input
# ------------------------------
current_troops: set[Troop] = set()
excluded_troops: set[Troop] = set()
trait_dummy: Troop = None
wanted_traits: dict[Trait, int] = {}
excluded_traits: dict[Trait, int] = {}
sort_cost_asc: bool = None
brute_force = False
team_size = 6
display_count = 10

user_input = input("Current troops: ")
print(f"Input: {user_input}")

for param in user_input.split():
    param = param.strip()
    if not param: continue  # skip whitespace inputs

    # Handling team size
    if param.isdecimal():
        team_size = int(param)
        continue

    # Handling sorting by cost
    if param[0] == "-":
        if "-ascending".startswith(param): sort_cost_asc = True
        elif "-descending".startswith(param): sort_cost_asc = False
        elif "-bruteforce".startswith(param): brute_force = True
        elif param[1:].isdecimal(): display_count = int(param[1:])
        continue
    
    # Handling trait dummies
    if ":" in param:
        trait_dummy = Troop(0, trait_trie.search_with_fallback(param.split(":")[0]), 
                            trait_trie.search_with_fallback(param.split(":")[1]), 
                            "traitdummy")
        continue
    
    # Handling exclusions
    exclude = param[0] == "!"
    if exclude: param = param[1:]
    
    # Handling traits
    if param.isupper():
        i = 0
        while i < len(param) and param[i].isdecimal():
            i += 1
        
        if exclude: excluded_traits[trait_trie.search_with_fallback(param[i:])] = int(param[:i] or 2)
        else: wanted_traits[trait_trie.search_with_fallback(param[i:])] = int(param[:i] or 2)
        continue
        
    # Handling troops
    troop: Troop = troop_trie.search_with_fallback(param)
    if exclude: 
        excluded_troops.add(troop)
    else:
        current_troops.add(troop)

# ------------------------------------------------
# Part 3: Use a search algorithm to find all teams
# ------------------------------------------------
potential_troops: tuple[Troop] = tuple(Troop.all_troops - current_troops - excluded_troops)
current_troops: tuple[Troop] = tuple(current_troops)
start = time.time()

# Option 1: Brute Force - Loop through all combinations of troops 
if brute_force:
    all_teams = [Team(current_troops + troop_selection, trait_dummy) 
                for troop_selection in combinations(potential_troops, team_size - len(current_troops))]

# Option 2: Greedy - Build an adjacency list where troops are connected if they share a trait, then find all connected subgraphs
else:
    troop_graph = defaultdict(set)
    for t1 in potential_troops:
        for t2 in potential_troops:
            if t1 == t2: continue
            if {t1.trait1, t1.trait2} & {t2.trait1, t2.trait2}:
                troop_graph[t1].add(t2)
                troop_graph[t2].add(t1)

    all_teams = [Team(current_troops + troop_selection, trait_dummy) 
                for troop_selection in generate_connected_subgraphs(troop_graph, team_size - len(current_troops))]

print(f"Time taken: {time.time() - start:.6f} seconds")

# Pruning the team list to the user's required team constraints
all_teams = [team for team in all_teams
             if all(team.traits[trait] >= wanted_traits[trait] for trait in wanted_traits)
             and all(team.traits[trait] < excluded_traits[trait] for trait in excluded_traits)]

print(f"Teams found: {len(all_teams)}")

# --------------------------------------
# Part 4: Sorting to find the best teams
# --------------------------------------
if sort_cost_asc is None:
    sort_key = lambda team: (team.score, -team.meta_ranking)
elif sort_cost_asc is True:
    sort_key = lambda team: (team.score, -team.cost, -team.meta_ranking)
elif sort_cost_asc is False:
    sort_key = lambda team: (team.score, team.cost, -team.meta_ranking)

best_teams = heapq.nlargest(display_count, all_teams, key=sort_key)

# ---------------------------------
# Part 5: Displaying the best teams
# ---------------------------------
html = "<div style='font-family:sans-serif;color:white;background:#1e1e2f;padding:12px;border-radius:12px;width:max-content'>"

for team in best_teams:
    # team container
    html += """
    <div style='margin-bottom:12px;padding:8px;
                border-radius:8px;background:#2a2a40;
                box-shadow:0 0 6px rgba(0,0,0,0.4);'>
    """

    # Troop images row
    html += "<div style='display:flex;gap:4px;justify-content:center;margin-bottom:4px;'>"
    for troop in sorted(team.troops, key=lambda troop: troop.cost):
        html += f'<img src="{troop.image}" alt="{troop.name}" width="72" style="border-radius:10px;">'

    # Trait dummy with trait icons overlay
    if team.trait_dummy:
        html += f"""
        <div style="position:relative; display:inline-block; width:72px;">
            <img src="{team.trait_dummy.image}" alt="{team.trait_dummy.name}" width="72"
                style="border-radius:10px; width:100%; height:100%; display:block; object-fit:cover;">
            <img src="{team.trait_dummy.trait1.image}" title="{team.trait_dummy.trait1}"
                style="position:absolute;bottom:4px;left:4px;width:26px;height:26px;border-radius:6px;">
            <img src="{team.trait_dummy.trait2.image}" title="{team.trait_dummy.trait2}"
                style="position:absolute;bottom:4px;right:4px;width:26px;height:26px;border-radius:6px;">
        </div>
        """
    html += "</div>"

    # Traits + score + elixir cost row
    html += "<div style='display:flex;align-items:center;justify-content:space-between;margin-top:2px;'>"

    # Trait icons container
    html += "<div style='display:flex;flex-wrap:wrap;'>"
    for idx, count in enumerate(team.traits):
        count = count // 2 * 2  # only even numbers for traits
        if count == 0: continue  # only show active traits
        html += f"""
        <span style="display:inline-flex;align-items:center;margin-right:8px;margin-bottom:2px;">
            <img src="{Trait(idx).image}" title="{Trait(idx).name}" width="28" style="vertical-align:middle;">
            <span style="font-weight:bold;font-size:14px;margin-left:2px;">{count}</span>
        </span>
        """
    html += "</div>"

    # Cost + score container
    html += """
    <div style='display:flex;align-items:center;gap:12px;font-size:14px;font-weight:bold;margin-right:5px;'>
    """

    # Elixir cost
    html += f"""
        <span style='display:flex;align-items:center;gap:4px;'>
            <img src="images/elixir.webp" alt="Elixir" width="20" height="22" style="vertical-align:middle;">
            <span>{team.cost}</span>
        </span>
    """

    # Score display
    html += f"""
        <span style='display:flex;align-items:center;gap:2px;'>
            <img src="images/starsteel.png" alt="Elixir" width="20" height="20" style="vertical-align:middle;">
            <span>{team.score}</span>
        </span>
    """
    html += "</div></div></div>"

html += "</div>"
display(HTML(html))

Input: barb !goldenknight FIRE:BRAWL 4NOBLE
Time taken: 0.011166 seconds
Teams found: 11
