Massive HT to Dr. Ruckus MTG for inspiring me to make this code:
https://www.youtube.com/watch?v=Xq4T44EvPvo


In [None]:
import cards
import sys
import random
import time
import ipywidgets as widgets
import matplotlib.pyplot as plt
import datetime
import os
import multiprocess as mp

from typing import List


In [None]:
starting_decklist = """
## Epoch 84 baseline decklist
0 Arbor Elf
0 Arboreal Grazer
1 Elvish Mystic
0 Krosan Wayfarer
0 Skyshroud Ranger
4 Sakura-Tribe Elder
4 Tangled Florahedron
0 Wall of Roots
4 Elvish Spirit Guide
1 Simian Spirit Guide
2 Generous Ent
0 Beanstalk Giant
2 Chancellor of the Tangle
0 Panglacial Wurm
1 Sol Ring
4 Goblin Charbelcher
4 Wild Growth
5 Forest
4 Abundant Harvest
3 Ancient Stirrings
4 Caravan Vigil
1 Lay of the Land
4 Reclaim the Wastes
0 Edge of Autumn
2 Explore
4 Land Grant
4 Manamorphose
0 Nissa's Triumph
0 Rampant Growth
0 Beneath the Sands
1 Cultivate
0 Grow from the Ashes
0 Journey of Discovery
0 Nissa's Pilgrimage
1 Recross the Paths
0 Search for Tomorrow
0 Migration Path


"""

# Split the decklist into lines
deckrange = starting_decklist.split('\n')
# Remove empty lines and comments
deckrange = [line for line in deckrange if line and not line.startswith('#')]
# Split the decklist into card names and quantities
deckrange = [line.split(' ', 1) for line in deckrange]
# Convert the quantities to integers and label the values in a dictionary
deckrange = [{'quant':int(quantity), 'name':cardname} for quantity, cardname in deckrange]

# Ensure that we've got a 60-card list
cardcount = sum([card['quant'] for card in deckrange])
print(f'### Total cards: {cardcount}')
if cardcount != 60:
    print("Decklist must contain exactly 60 cards.")
    sys.exit(1)

cardnames = [card['name'] for card in deckrange]

In [None]:

def get_deck_variants(deckrange):
    """Get all possible deck variants"""
    decks_61 = []
    cards_61 = []
    decks_59 = []
    cards_59 = []

    deck_baseline = ""
    for card in deckrange:
        deck_baseline += str(card['quant']) + " " + card['name'] + "\n"
    
    # 61-card decks
    for chosen_card in deckrange:
        deck = ""
        # Get the max quantity of the chosen card
        # Find the card by name
        card_class = cards.get_card_by_name(chosen_card['name'])
        if chosen_card['quant'] < card_class.deck_max_quant:
            for card in deckrange:
                quant = card['quant']
                if card['name'] == chosen_card['name']:
                    quant += 1
                deck += str(quant) + " " + card['name'] + "\n"
            decks_61.append(deck)
            cards_61.append(chosen_card['name'])

    # 59-card decks
    for chosen_card in deckrange:
        deck = ""
        if chosen_card['quant'] > 0: #chosen_card['min_quant']:
            for card in deckrange:
                quant = card['quant']
                if card['name'] == chosen_card['name']:
                    quant -= 1
                deck += str(quant) + " " + card['name'] + "\n"
            decks_59.append(deck)
            cards_59.append(chosen_card['name'])

    return deck_baseline, decks_61, cards_61, decks_59, cards_59


In [None]:

PRUNE_LIMIT = 1000 # Max number of leaf nodes that we support iterating through

def print_tree(state:cards.Player, depth = 0):
    print ("  "*depth, state.short_str())
    for child in state.childstates:
        print_tree(child, depth+1)

def get_all_leaf_nodes(state:cards.Player) -> List[cards.Player]:
    if state.is_pruned:
        return []

    if len(state.childstates) == 0:
        return [state]
    else:
        leaf_nodes = []
        for child in state.childstates:
            leaf_nodes.extend(get_all_leaf_nodes(child))
        return leaf_nodes

def find_fastest_win(state:cards.Player, maxturn = 10):
    did_win = False
    win_state = None
    max_leaf_nodes = 0
    action_count = 0

    state.start_game()
    state.start_turn()
    
    # Track leaf nodes that are unique
    unique_leaf_nodes = {}

    while not did_win:
        action_count += 1
        leaf_nodes = get_all_leaf_nodes(state)
        
        if len(leaf_nodes) > max_leaf_nodes:
            max_leaf_nodes = len(leaf_nodes)

        # Find the minimum turn in the leaf nodes
        min_turn = min([leaf.current_turn for leaf in leaf_nodes])

        # Find any leaf nodes that are at the minimum turn
        min_turn_leaf_nodes = [leaf for leaf in leaf_nodes if leaf.current_turn == min_turn]

        # Find any leaf nodes where check_win() is True
        win_leaf_nodes = [leaf for leaf in min_turn_leaf_nodes if leaf.check_win()]

        if len(win_leaf_nodes) > 0:
            did_win = True
            win_state = win_leaf_nodes[0]
            break
        elif min_turn > maxturn:
            break

        next_min_turn_leaf_nodes = []
        # For each leaf node in the min_turn_leaf_node list, deduplicate states that have the same string representation
        for leaf in min_turn_leaf_nodes:
            string_rep = str(leaf)
            if string_rep not in unique_leaf_nodes:
                unique_leaf_nodes[str(leaf)] = leaf
                next_min_turn_leaf_nodes.append(leaf)
            else:
                leaf.is_pruned = True

        #next_min_turn_leaf_nodes = list(unique_leaf_nodes.values())

        min_turn_leaf_nodes = next_min_turn_leaf_nodes
        
        #print(f'Deduplicated {len(min_turn_leaf_nodes)} leaf nodes to {len(unique_leaf_nodes)} leaf nodes')

        # For each leaf node that is 
        # If we have more than leaf_node_limit leaf nodes, randomly select leaf_node_limit of them
        leaf_node_limit = PRUNE_LIMIT
        excess = len(min_turn_leaf_nodes) - leaf_node_limit
        if excess > 0:
            #print(f'Warning: Exceeding leaf node limit of {leaf_node_limit} at turn {min_turn} with {len(min_turn_leaf_nodes)} leaf nodes')

            """
            # Print off five random leaf nodes
            for i in range(5):
                print(f'*** Random leaf node {i}:')
                random_leaf = random.choice(min_turn_leaf_nodes)
                print_tree(random_leaf)
                print(random_leaf)
                random_leaf.dumplog()
            """
            
            random.seed(state.randseed)
            # Randomly select a subset of leaf nodes to prune
            # Shuffle our list of leaf nodes
            random.shuffle(min_turn_leaf_nodes)
            
            # Select the second half of min_turn_leaf_nodes to be pruned
            prune_nodes = min_turn_leaf_nodes[leaf_node_limit:]
            for leaf in prune_nodes:
                leaf.is_pruned = True
            # The others are not pruned and are kept.
            min_turn_leaf_nodes = min_turn_leaf_nodes[:leaf_node_limit]

        # Step through all min_turn_leaf_nodes
        for leaf in min_turn_leaf_nodes:
            next_states = leaf.step_next_actions()
            for next_state in next_states:
                if next_state.check_win():
                    did_win = True
                    win_state = next_state
                    break
            if did_win:
                break
        
    return win_state, action_count, max_leaf_nodes

In [None]:
USE_PARALLEL = True
PARALLEL_SPARE_CORES = 2 # How many cores do we save for doing other things on the computer?
DETERMINISTIC = False
RECORD_WINNING_LOG_MESSAGES = False
cards.LOGGING_ENABLED = False

fastest_recorded_win_turns = 4
fastest_recorded_win = None

def test_decklist(decklist, num_trials, max_turns, seed_base = 0):
    global fastest_recorded_win_turns
    global fastest_recorded_win
    
    durations = []
    total_turns = 0

    winning_log_messages = {}
    then = time.time()

    # NOTE: Use a deterministic seed for testing performance improvements
    if not DETERMINISTIC:
        seed_base = random.randint(0, 2**31-1)
    players = [cards.Player(decklist, seed_base + i) for i in range(num_trials)]
        
    if USE_PARALLEL:
        # h.t. https://www.machinelearningplus.com/python/parallel-processing-python/ for the multiprocessing code
        pool = mp.Pool(mp.cpu_count()-PARALLEL_SPARE_CORES)
        results = pool.map(find_fastest_win, [player for player in players])
        pool.close()
    else:
        results = [find_fastest_win(player) for player in players]

    for i, result in enumerate(results):
        
        win_state, action_count, max_leaf_nodes = result

        won_turn = max_turns + 2

        if win_state is not None:
            won_turn = win_state.current_turn
            #end_reason = win_state.log[-1].strip()

            if won_turn < fastest_recorded_win_turns:
                filename = f'turn_{won_turn}_win.txt'
                with open(filename, 'w') as f:
                    f.write(f'Won in {won_turn} turns')
                    original_state = players[i]
                    if USE_PARALLEL:
                        original_state.start_turn()
                    f.write(f' Original state: {original_state}\n')
                    f.write('\n'.join(win_state.log))
                    f.write(str(win_state))

            if RECORD_WINNING_LOG_MESSAGES:
                for log_message in win_state.log:
                    log_message = log_message.strip()
                    if log_message not in winning_log_messages:
                        winning_log_messages[log_message] = 0
                    winning_log_messages[log_message] += 1
        else:
            #print (f'  Did not find win.  Max leaf nodes: {max_leaf_nodes}')
            pass

        total_turns += won_turn

        # TODO: Also save the total number of plays / alt-plays / activations that each card had

        #if end_reason not in end_reasons:
        #    end_reasons[end_reason] = 1
        #else:
        #    end_reasons[end_reason] += 1

    duration = time.time() - then
    avg_duration = duration / num_trials

    avg_win_turn = total_turns / num_trials
    print (f'  Average win turn: {avg_win_turn}')
    print (f'  Tested decklist in {duration} ({avg_duration} each)')

    # Return the average winning turn number
    return avg_win_turn
    


In [None]:
def update_plots(baseline_wins, running_wins_61_avgs, running_wins_59_avgs, running_best_win, running_delta):
    # Create a plot of each decklist's average winning turn number
    plt.figure(figsize=(20,10))
    plt.plot(baseline_wins, label='Baseline')
    plt.plot(running_wins_61_avgs, label='61 card decklist')
    plt.plot(running_wins_59_avgs, label='59 card decklist')
    plt.plot(running_best_win, label='Best win')
    plt.plot(running_delta, label='Delta')
    plt.legend()
    plt.show()


In [None]:
from re import L


epoch_num = 0
log_folder = ''

def log_to_file(log_filename, log_message):
    # First, check to see if our logfolder exists
    if not os.path.exists(log_folder):
        os.makedirs(log_folder)

    with open(log_folder + log_filename, 'a') as f:
        f.write(log_message)

def run_epoch(deckrange, num_trials, max_turns, step_size):
    global epoch_num
    deck_baseline, decks_61, cards_61, decks_59, cards_59 = get_deck_variants(deckrange)
    wins_61 = {}
    wins_59 = {}
    baseline_wins = []
    epoch_num += 1
    overall_tsv_filename = f'progress.tsv'
    tsv_filename = f'epoch_{epoch_num}.tsv'
    decklist_filename = f'epoch_{epoch_num}_decklist.txt'
    
    print (f'Running epoch {epoch_num} with {num_trials} trials and {max_turns} max turns.')
    simulations_per_step = (len(decks_61) + len(decks_59)) * step_size
    print (f' Total number of simulated games in this epoch: {num_trials * simulations_per_step}')
    for i in range(len(decks_61)):
        wins_61[i] = []
    for i in range(len(decks_59)):
        wins_59[i] = []
    log_to_file(decklist_filename, f'Epoch {epoch_num} baseline decklist\n{deck_baseline}')
    print(f' Baseline decklist:\n{deck_baseline}')
    print(f' Number of 61-card decks: {len(decks_61)}')
    print(f' Number of 59-card decks: {len(decks_59)}')

    running_baseline_wins = []
    running_wins_61_avgs = []
    running_wins_59_avgs = []
    running_best_win = []
    running_delta = []
    running_durations = []

    for i in range(num_trials):
        print(f'Step {i+1}/{num_trials}:')

        print(f' Current Decklist:')
        print(deck_baseline)

        then = time.time()
        print(f'Testing baseline')
        baseline_wins.append(test_decklist(deck_baseline, step_size, max_turns, seed_base = i * step_size))

        for deck_61_index, deck_61 in enumerate(decks_61):
            print(f' Testing addition of {cards_61[deck_61_index]} ({deck_61_index+1} / {len(decks_61)})')
            wins_61[deck_61_index].append(test_decklist(deck_61, step_size, max_turns, seed_base = i * step_size))
        for deck_59_index, deck_59 in enumerate(decks_59):
            print(f' Testing removal of {cards_59[deck_59_index]} ({deck_59_index+1} / {len(decks_59)})')
            wins_59[deck_59_index].append(test_decklist(deck_59, step_size, max_turns, seed_base = i * step_size))

        duration = time.time() - then
        avg_duration = duration / simulations_per_step

        wins_61_avgs = {}
        wins_59_avgs = {}
        for deck_61_index, deck_61 in enumerate(decks_61):
            wins_61_avgs[cards_61[deck_61_index]] = sum(wins_61[deck_61_index]) / len(wins_61[deck_61_index])
        for deck_59_index, deck_59 in enumerate(decks_59):
            wins_59_avgs[cards_59[deck_59_index]] = sum(wins_59[deck_59_index]) / len(wins_59[deck_59_index])

        # Sort the wins_61_avgs and wins_59_avgs by average winning turn
        wins_61_avgs = {k: v for k, v in sorted(wins_61_avgs.items(), key=lambda item: item[1])}
        wins_59_avgs = {k: v for k, v in sorted(wins_59_avgs.items(), key=lambda item: item[1])}

        baseline_wins_avg = sum(baseline_wins) / len(baseline_wins)

        running_wins_61_avgs.append(wins_61_avgs)
        running_wins_59_avgs.append(wins_59_avgs)
        running_baseline_wins.append(baseline_wins_avg)
        running_durations.append(avg_duration)

        # Print out the sorted list of cards and their average winning turn
        print(f' Baseline wins: {baseline_wins_avg}')
        print(f' Average duration: {avg_duration}')
        print(f'  Best cards to add:')
        for card, avg_win in wins_61_avgs.items():
            delta = avg_win - baseline_wins_avg
            if delta > 0:
                print(f'   {card}: +{delta}')
            else:
                print(f'   {card}: {delta}')
        print(f'  Best cards to remove:')
        for card, avg_win in wins_59_avgs.items():
            delta = avg_win - baseline_wins_avg
            if delta > 0:
                print(f'   {card}: +{delta}')
            else:
                print(f'   {card}: {delta}')

        # Get the best card to add and the best card to remove
        best_card_to_add = list(wins_61_avgs.keys())[0]
        best_card_to_remove = list(wins_59_avgs.keys())[0]

        # Average the win rate of the best 61-card deck and the best 59-card deck
        best_61_win = wins_61_avgs[best_card_to_add]
        best_59_win = wins_59_avgs[best_card_to_remove]
        best_win = (best_61_win + best_59_win) / 2

        print (f' Best card to add: {best_card_to_add} ({best_61_win - best_win})')
        print (f' Best card to remove: {best_card_to_remove} ({best_59_win - best_win})')
        delta = best_win - baseline_wins_avg
        print (f' Best win: {best_win} vs. prev {baseline_wins_avg} (change: {delta})')

        running_best_win.append(best_win)
        running_delta.append(delta)

        #update_plots(baseline_wins, running_wins_61_avgs, running_wins_59_avgs, running_best_win, running_delta)

        # Output to TSV
        # First, check to see if our logfolder exists
        if not os.path.exists(log_folder):
            os.makedirs(log_folder)

        # Then check to see if we need to write headers
        if not os.path.exists(log_folder+tsv_filename):
            headers = 'Trials\tAvg. Time Per Test\tBaseline\tBest Win\tDelta\tBest Card to Add\tBest Card to Remove\t'
            for cardname in cardnames:
                headers += f"'#{cardname}\t'+{cardname}\t'-{cardname}\t"
            headers += '\n'
            log_to_file(tsv_filename, headers)
            if not os.path.exists(log_folder+overall_tsv_filename):
                log_to_file(overall_tsv_filename, f'Epoch\t{headers}')

        log_line = f'{(i+1)*step_size}\t{avg_duration:.4f}\t{baseline_wins_avg:.3f}\t{best_win:.3f}\t{delta:.3f}\t{best_card_to_add}\t{best_card_to_remove}\t'
        for cardname in cardnames:
            # Output the number of cards in the current decklist
            card_quant = 0
            for card in deckrange:
                if card['name'] == cardname:
                    card_quant = card['quant']
            
            log_line += f'{card_quant}\t'

            # Output the delta for adding this card
            if cardname in wins_61_avgs:
                log_line += f'{wins_61_avgs[cardname]:.3f}\t'
            else:
                log_line += f'\t'

            # Output the delta for removing this card
            if cardname in wins_59_avgs:
                log_line += f'{wins_59_avgs[cardname]:.3f}\t'
            else:
                log_line += f'\t'

        log_line += '\n'
        log_to_file(tsv_filename, log_line)

        # If we're on the last iteration, output this log of data to TSV also
        if (i == num_trials-1):
            log_to_file(overall_tsv_filename, f'{epoch_num}\t{log_line}')


    return baseline_wins, best_win, best_card_to_add, best_card_to_remove


In [None]:

num_epochs = 100 # 1000
max_turns = 10 # 20
num_trials = 1 # 20 # 10000 # How many times to run a step of simulations within each epoch
step_size = 1000 #250 #150 # How many times to run each deck in each step.
# Total number of simulations per epoch per deck will be: step_size * num_trials

# Log folder is named with the year, month, day, hour, minute, and second
log_folder = f'logs/output_prune{PRUNE_LIMIT}_turns{max_turns}_{datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")}/'

for i in range(num_epochs):
    print(f'Epoch {i+1} of {num_epochs}')
    baseline_wins, best_win, best_card_to_add, best_card_to_remove = run_epoch(deckrange, num_trials, max_turns, step_size)
    
    # Find the card in deckrange that has this name and increase its quant
    for card in deckrange:
        if card['name'] == best_card_to_add:
            card['quant'] += 1
        if card['name'] == best_card_to_remove:
            card['quant'] -= 1

# Print the final decklist
print('Final decklist:')
final_decklist = ""
for card in deckrange:
    final_decklist += str(card['quant']) + " " + card['name'] + "\n"

print(final_decklist)


