In [1]:
import numpy as np
import pickle
import os
import copy
from scipy.stats import chi2_contingency
import time
import pprint

In [2]:
'''
Evaluate dicts
    DONE overall size
    TEST states that do not have enough samples
    swaps that are better
        overall number
        return sample size, p value etc and see
    would be nice to have how often the swaps occur as overall percent but the dict will stop collecting keys
        so not an accurate caught

Choose the p values
    which swap values are better with win %
    then how many by the different narrowing things

Maybe examine some specific states
'''

'\nEvaluate dicts\n    DONE overall size\n    TEST states that do not have enough samples\n    swaps that are better\n        overall number\n        return sample size, p value etc and see\n    would be nice to have how often the swaps occur as overall percent but the dict will stop collecting keys\n        so not an accurate caught\n\nChoose the p values\n    which swap values are better with win %\n    then how many by the different narrowing things\n\nMaybe examine some specific states\n'

### Constants and Functions

In [30]:
def load_pkl_object(pkl_path):
    '''
    Load a pickle object
    '''
    with open(pkl_path, 'rb') as handle:
        return pickle.load(handle)


def save_object_as_pkl(object_to_save, save_tag):
    '''
    Save object a pickle file
    '''
    save_path = f'joined_dict_results\\{save_tag}.pickle'

    with open(save_path, 'wb') as handle:
        print("saving: ", save_path)
        pickle.dump(object_to_save, handle, protocol=pickle.HIGHEST_PROTOCOL)


def get_chi_square_test_from_action_dict(
    action_dict,
    state_key,
    min_total_count=100,
    min_swap_count=50,
    min_attack_count=50,
    swap_key='swap_0', attack_key='attack',
    sum_wins_key='sum_wins', count_key='count',
    is_print_statistics=False):

    attack_action = 0
    swap_party_zero_action = 1
    swap_party_one_action = 2

    is_use_p_value = False
    is_swap_better = False
    p_value = None
    swap_win_rate_better_rate = 0.
    recommended_action = attack_action

    try:
        if state_key in action_dict:

            if swap_key in action_dict[state_key] and attack_key in action_dict[state_key]:
                swap_wins = action_dict[state_key][swap_key][sum_wins_key]
                swap_count = action_dict[state_key][swap_key][count_key]
                attack_wins = action_dict[state_key][attack_key][sum_wins_key]
                attack_count = action_dict[state_key][attack_key][count_key]

                total_count = swap_count + attack_count

                if total_count > min_total_count and swap_count > min_swap_count and attack_count > min_attack_count:

                    swap_win_percent = swap_wins / swap_count
                    attack_win_percent = attack_wins / attack_count
                    
                    if swap_win_percent > attack_win_percent:
                        is_swap_better = True
                        swap_win_rate_better_rate = swap_win_percent - attack_win_percent
                    else:
                        is_swap_better = False
                        swap_win_rate_better_rate = 0.

                    # chi squared table breaks down if any 0 values
                    # really should not have less than 5
                    if attack_wins == attack_count:
                        recommended_action = attack_action
                        # choose attack as attack always wins
                        if is_print_statistics:
                            print("Attack always wins")
                            print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                            print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                    elif swap_wins == swap_count:
                        # choose swap
                        is_use_p_value = True
                        is_swap_better = True
                        p_value = 0.
                        recommended_action = swap_party_zero_action
                        if is_print_statistics:
                            print("swap always wins, choosing swap")
                            print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                            print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                    elif swap_wins == 0:
                        recommended_action = attack_action
                        # swap always loses
                        if is_print_statistics:
                            print("Swap always loses")
                            print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                            print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                    elif attack_wins == 0:
                        # attack always loses and swap won at least once so choose swap
                        is_use_p_value = True
                        is_swap_better = True
                        p_value = 0.
                        recommended_action = swap_party_zero_action
                        if is_print_statistics:
                            print("Attack always loses, choosing swap ")
                            print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                            print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                    else:
                        contingency_table = [[swap_wins, swap_count - swap_wins], [attack_wins, attack_count - attack_wins]]
                        chi2, p_value, dof, expected = chi2_contingency(contingency_table)
                        is_use_p_value = True

                        if is_swap_better:
                            if p_value < 0.25:
                                recommended_action = swap_party_zero_action
                            elif swap_win_rate_better_rate >= .1:
                                recommended_action = swap_party_zero_action
                            elif swap_win_rate_better_rate >= .05 and p_value < .6:
                                recommended_action = swap_party_zero_action

                        if is_print_statistics:
                            #print(f'Swap Win : { win_loss_draw1[0] / sum(win_loss_draw1):.3f}')
                            print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                            print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                            print(f'Chi-square statistic: {chi2:.3f}')
                            print(f'P-value: {p_value:.5f}')

        else:
            is_use_p_value = False
            is_swap_better = False
            p_value = None
            swap_win_rate_better_rate = 0.
            recommended_action = attack_action
    except Exception as e:
        print("Error: in chi square test ", str(e) )
        is_use_p_value = False
        is_swap_better = False
        p_value = None
        swap_win_rate_better_rate = 0.
        recommended_action = attack_action
    
    return recommended_action, swap_win_rate_better_rate, is_use_p_value, is_swap_better, p_value


def get_chi_square_test_from_count_wins(
    swap_wins, swap_count, attack_wins, attack_count,
    min_total_count=50,
    min_swap_count=25,
    min_attack_count=25,
    is_print_statistics=False):

    attack_action = 0
    swap_party_zero_action = 1
    swap_party_one_action = 2

    is_use_p_value = False
    is_swap_better = False
    p_value = None
    swap_win_rate_better_rate = 0.
    recommended_action = attack_action

    try:
        total_count = swap_count + attack_count

        if total_count > min_total_count and swap_count > min_swap_count and attack_count > min_attack_count:

            swap_win_percent = swap_wins / swap_count
            attack_win_percent = attack_wins / attack_count
            
            if swap_win_percent > attack_win_percent:
                is_swap_better = True
                swap_win_rate_better_rate = swap_win_percent - attack_win_percent
            else:
                is_swap_better = False
                swap_win_rate_better_rate = 0.

            # chi squared table breaks down if any 0 values
            # really should not have less than 5
            if attack_wins == attack_count:
                recommended_action = attack_action
                # choose attack as attack always wins
                if is_print_statistics:
                    print("Attack always wins")
                    print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                    print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
            elif swap_wins == swap_count:
                # choose swap
                is_use_p_value = True
                is_swap_better = True
                p_value = 0.
                recommended_action = swap_party_zero_action
                if is_print_statistics:
                    print("swap always wins, choosing swap")
                    print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                    print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
            elif swap_wins == 0:
                recommended_action = attack_action
                # swap always loses
                if is_print_statistics:
                    print("Swap always loses")
                    print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                    print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
            elif attack_wins == 0:
                # attack always loses and swap won at least once so choose swap
                is_use_p_value = True
                is_swap_better = True
                p_value = 0.
                recommended_action = swap_party_zero_action
                if is_print_statistics:
                    print("Attack always loses, choosing swap ")
                    print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                    print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
            else:
                contingency_table = [[swap_wins, swap_count - swap_wins], [attack_wins, attack_count - attack_wins]]
                chi2, p_value, dof, expected = chi2_contingency(contingency_table)
                is_use_p_value = True

                if is_swap_better:
                    if p_value < 0.1:
                        recommended_action = swap_party_zero_action
                    # elif swap_win_rate_better_rate >= .1:
                    #     recommended_action = swap_party_zero_action
                    elif swap_win_rate_better_rate >= .05 and p_value < .25:
                        recommended_action = swap_party_zero_action
                    # if p_value < 0.25:
                    #     recommended_action = swap_party_zero_action
                    # elif swap_win_rate_better_rate >= .1:
                    #     recommended_action = swap_party_zero_action
                    # elif swap_win_rate_better_rate >= .05 and p_value < .6:
                    #     recommended_action = swap_party_zero_action

                if is_print_statistics:
                    #print(f'Swap Win : { win_loss_draw1[0] / sum(win_loss_draw1):.3f}')
                    print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                    print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                    print(f'Chi-square statistic: {chi2:.3f}')
                    print(f'P-value: {p_value:.5f}')

        else:
            is_use_p_value = False
            is_swap_better = False
            p_value = None
            swap_win_rate_better_rate = 0.
            recommended_action = attack_action

    except Exception as e:
        print("Error: in chi square test ", str(e) )
        is_use_p_value = False
        is_swap_better = False
        p_value = None
        swap_win_rate_better_rate = 0.
        recommended_action = attack_action
    
    return recommended_action, swap_win_rate_better_rate, is_use_p_value, is_swap_better, p_value



### 2v2

In [None]:
two_combined_dict = load_pkl_object('joined_dict_results\\two_vs_two_combined_results.pickle')


In [None]:
print(len(two_combined_dict))

In [None]:
for k, v in two_combined_dict.items():
    print(k, v)
    break

In [None]:
# getting the counts
# need to filter out the times when TTF is 0 (ie not a true 2v2)



def get_state_counts(check_dict):
    count_key = 'count'
    sum_wins_key = 'sum_wins'
    agent_first_move_attack_key = 'attack'
    agent_first_move_swap_party_0_key = 'swap_0'
    agent_first_move_swap_party_1_key = 'swap_1'
    time_int = int(time.time())

    count_list = []
    swap_2v2_dict = {}
    swap_better_count = 0

    for state_key, state_dict in check_dict.items():
        if agent_first_move_swap_party_0_key in state_dict and agent_first_move_attack_key in state_dict:
            count_list.append(state_dict[agent_first_move_swap_party_0_key].get(count_key, np.NaN))
            count_list.append(state_dict[agent_first_move_attack_key].get(count_key, np.NaN))

            recommended_action, swap_win_rate_better_rate, is_use_p_value, is_swap_better, p_value = get_chi_square_test_from_action_dict(
                check_dict, state_key, is_print_statistics=False)
        
            if is_swap_better:

                swap_better_count += 1

                swap_2v2_dict[state_key] = {
                    'recommended_action': recommended_action,
                    'swap_win_rate_better_rate': swap_win_rate_better_rate,
                    'is_use_p_value': is_use_p_value,
                    'is_swap_better': is_swap_better,
                    'p_value': p_value,
                    'swap_0_count': state_dict[agent_first_move_swap_party_0_key].get(count_key, np.NaN),
                    'attack_count': state_dict[agent_first_move_attack_key].get(count_key, np.NaN),
                    'swap_0_win_count': state_dict[agent_first_move_swap_party_0_key].get(sum_wins_key, np.NaN),
                    'attack_win_count': state_dict[agent_first_move_attack_key].get(sum_wins_key, np.NaN)
                }
            
                

        # for move_key, move_dict in state_dict.items():
        #     count_list.append(move_dict.get(count_key, np.NaN))

    count_array = np.array(count_list)
    print(f"swap better count: {swap_better_count} | len swap dict {len(swap_2v2_dict)}")

    save_object_as_pkl(count_list, f'count_list_2v2_{time_int}')
    save_object_as_pkl(swap_2v2_dict, f'swap_2v2_dict_{time_int}')

    return count_array, swap_better_count, swap_2v2_dict

two_count_array, swap_better_count, swap_2v2_dict = get_state_counts(copy.deepcopy(two_combined_dict))


In [None]:
print(two_count_array.shape, np.mean(two_count_array), np.median(two_count_array), np.min(two_count_array), np.max(two_count_array))
# more allowable version
# swap better count: 38194 | len swap dict 38194
# saving:  joined_dict_results\count_list_2v2_1720491755.pickle
# saving:  joined_dict_results\swap_2v2_dict_1720491755.pickle
#(2426156,) 89.4686932744638 3.0 1 3426669

# count less than 5: 1403590 | 0.579 
# count less than 10: 1775533 | 0.732 
# count less than 15: 1928543 | 0.795 
# count less than 20: 2015950 | 0.831 
# count less than 25: 2073878 | 0.855 
# count less than 50: 2206844 | 0.910 
# count less than 100: 2292732 | 0.945

# print(recommend_swap_count, len(swap_2v2_dict))
# 29863 38194



In [None]:
for x in [5, 10, 15, 20, 25, 50, 100]:
    less_than_x = np.sum(two_count_array < x)
    print(f'count less than {x}: {less_than_x} | {less_than_x / len(two_count_array):.3f} ')

In [None]:
break_num = 0

for k, v in swap_2v2_dict.items():
    pprint.pprint(k)
    pprint.pprint(v)
    break_num += 1
    if break_num > 3:
        break

In [None]:
break_num = 0
recommend_swap_count = 0
for k, v in swap_2v2_dict.items():
    if v['recommended_action'] == 1:
        recommend_swap_count += 1

    # pprint.pprint(k)
    # pprint.pprint(v)
    # break_num += 1
    # if break_num > 10:
    #     break
print(recommend_swap_count, len(swap_2v2_dict))

### 2v3

In [None]:
'''
so for 2v3 recall the wins and counts are one part and the other part is final state ended up in so the math is something like
	num_wins no 2v2
	num battles no 2v2
	then for each state take the num actions
		find win rate for that state with the better action there
			given some min number. if no min number then don't use
			estimated_num_wins_in_state = win_rate * num_battles_in_state

			num_wins += estiomated_num_wins
			num_battles += num_battles_instates

	once that dict is made, can then do the analysis on it

steps
load the 2v3 dict
load the 2v2 dict

iterate through the steps in two vs three combined dict
	need to have min number of outcomes in the state
'''

In [4]:
time_int = int(time.time())
two_vs_three_save_tag = f"two_vs_three_eval_{time_int}"

In [5]:
#two_vs_three_combined_dict = load_pkl_object('joined_dict_results\\two_vs_three_2v3_combined_results.pickle')
two_vs_three_combined_dict = load_pkl_object('G:\\3v3_vgc_saves_71124\\3v3_results\\2v3_2v3final_1000_1720922738_action_state_results_dict_checkpoint_2.pickle')

In [6]:
two_vs_two_recommended_dict = load_pkl_object('joined_dict_results\\swap_2v2_dict_1720783493.pickle')

In [7]:
two_vs_two_combined_dict = load_pkl_object('joined_dict_results\\two_vs_two_combined_results.pickle')



In [8]:
len(two_vs_three_combined_dict)

682425

In [9]:
# example entries

# (1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0) {'attack': {'sum_wins': 705, 'count': 5542, (1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 5407}, 'swap_0': {'sum_wins': 0, 'count': 4458, (1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 314}}
# (1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0) {'attack': {'sum_wins': 14931, 'count': 121805}}
# (2, 2, 2, 2, 2, 1, 1, 2, 1, 3, -1, -1, 0, 0, 1, 2, 0) {'attack': {'sum_wins': 0, 'count': 1}}
# (2, 2, 1, 0, 0, 0, 2, 0, 3, 0, -1, 0, 0, 0, 0, 2, 0) {'attack': {'sum_wins': 0, 'count': 18}}
# (1, 2, 1, 0, 0, 0, 1, 0, 1, 0, -1, 0, 0, 0, 0, 2, 0) {'attack': {'sum_wins': 0, 'count': 588}}
# (2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0) {'attack': {'sum_wins': 0, 'count': 82}, 'swap_0': {'sum_wins': 0, 'count': 7, (1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 7, (1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 1}}

In [10]:
def make_new_state_tuple(state_key, filter_list):
    new_state_list = []
    for i in range(len(state_key)):
        if filter_list[i]:
            new_state_list.append(state_key[i])

    new_state_tuple = tuple(new_state_list)
    if len(new_state_tuple) != 12:
        print("Error: new state tuple not correct length of 12", new_state_tuple)

    return new_state_tuple


def convert_2v3_state_to_2v2_state(state_key):
    '''
    to do: can be betetr about filtering ubt weird stuff can happen
    '''
    # example state
    #(1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0)
    #(1, 2, 1,  1, 1, 1,    1, 1,   1, 1,   1, 1,   0, 0,   0, 0, 0)
    # idx 0,1,2 are active attacking opp pkm
    # idx 3,4,5 are party attacking party pkm
    # idx 6,7 are opp active attacking agent
    # idx 8,9 are party 0 active attacking agent
    # idx 10,11 are party 1 active attacking agent
    # idx 12, 13, are agent hp normalized
    # idx 14,15,16 are opp hp normalized

    is_two_vs_two_key = False
    filtered_state_key = tuple([19])

    if len(state_key) == 17:
        if state_key[0] == 0 or state_key[3] == 0:
            print("Error: active opp is fainted ", state_key)
        else:
            if state_key[1] == 0 and state_key[4] == 0: #\
                #and state_key[2] != 0 and state_key[5] != 0:
                # opp party 0 is down
                filter_list = [True, False, True,
                               True, False, True,
                               True, True,
                               False, False,
                               True, True,
                               True, True,
                               True, False, True,]  
                
                filtered_state_key = make_new_state_tuple(state_key, filter_list)
                is_two_vs_two_key = True

            elif state_key[2] == 0 and state_key[5] == 0: #\
                #and state_key[1] != 0 and state_key[4] != 0:
                # opp party 1 is down
                filter_list = [True, True, False,
                               True, True, False,
                               True, True,
                               True, True,
                               False, False,
                               True, True,
                               True, True, False,]  
                
                filtered_state_key = make_new_state_tuple(state_key, filter_list)
                is_two_vs_two_key = True
                
            else:
                print("Error: unable to located the fainted opp ", state_key)
    else:
        print("Error: state key not 17 long ", state_key)

    if not is_two_vs_two_key:
        print("Warning: not a 2v2 key ", state_key, filtered_state_key)

    return filtered_state_key, is_two_vs_two_key


def get_state_counts_2v3(two_vs_three_combined_dict, save_tag,
        two_vs_two_recommended_dict, two_vs_two_combined_dict):
    count_key = 'count'
    sum_wins_key = 'sum_wins'
    agent_first_move_attack_key = 'attack'
    agent_first_move_swap_party_0_key = 'swap_0'
    agent_first_move_swap_party_1_key = 'swap_1'
    time_int = int(time.time())

    count_list = []
    swap_dict = {}
    swap_better_count = 0

    three_vs_three_lookup_dict = {}

    try:
        for state_key, state_dict in two_vs_three_combined_dict.items():
            if agent_first_move_swap_party_0_key in state_dict and agent_first_move_attack_key in state_dict:

                # get wins, counts for swap/attack that led to end state
                # and do not go to 2v2
                attack_sub_count = state_dict[agent_first_move_attack_key].get(count_key, 0)
                attack_sub_wins = state_dict[agent_first_move_attack_key].get(sum_wins_key, 0)

                swap_sub_count = state_dict[agent_first_move_swap_party_0_key].get(count_key, 0)
                swap_sub_wins = state_dict[agent_first_move_swap_party_0_key].get(sum_wins_key, 0)

                # get counts, estimated wins for states that went to 2v2
                attack_to_2v2_count = 0
                swap_to_2v2_count = 0
                attack_to_2v2_estimated_wins = 0.
                swap_to_2v2_estimated_wins = 0.

                # test_state = tuple([1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, 0, 0, 0, -1, -1])
                # check_key = tuple([1, 0, -1, 1, 0, -1, 1, 1, 0, 0, -1, -1, 0, 0, 0, -1, -1])
                # print(check_key in state_dict[agent_first_move_attack_key])
                # if test_state == state_key:
                #     import pdb; pdb.set_trace()

                for inner_move_key, inner_move_value in state_dict[agent_first_move_attack_key].items():
                    if inner_move_key != count_key and inner_move_key != sum_wins_key:
                        # turn the 2v3 keys into 2v2 keys
                        filtered_state_key, is_use_two_vs_two_key = convert_2v3_state_to_2v2_state(inner_move_key)

                        # get results for that state using the best action by win %
                        if is_use_two_vs_two_key:
                            if filtered_state_key in two_vs_two_recommended_dict:
                                recommended_action = two_vs_two_recommended_dict[filtered_state_key].get('recommended_action', 0)

                                if recommended_action == 1:
                                    lookup_key = agent_first_move_swap_party_0_key
                                else:
                                    lookup_key = agent_first_move_attack_key
                            else:
                                lookup_key = agent_first_move_attack_key

                            lookup_dict = two_vs_two_combined_dict.get(filtered_state_key, {}).get(lookup_key, {})
                            lookup_2v2_count = lookup_dict.get(count_key, 0)
                            lookup_2v2_sum_wins = lookup_dict.get(sum_wins_key, 0) 
                            if lookup_2v2_count > 0 and sum_wins_key in lookup_dict:
                                lookup_win_percent = lookup_2v2_sum_wins / lookup_2v2_count
                                attack_to_2v2_estimated_wins += lookup_win_percent * inner_move_value
                                attack_to_2v2_count += inner_move_value
                        #(1, 0, -1, 1, 0, -1, 1, 1, 0, 0, -1, -1, 0, 0, 0, -1, -1)
                        # check_key = tuple([1, 0, -1, 1, 0, -1, 1, 1, 0, 0, -1, -1, 0, 0, 0, -1, -1])
                        # if state_key == test_state and check_key == inner_move_key:
                        #     import pdb; pdb.set_trace()

                for inner_move_key, inner_move_value in state_dict[agent_first_move_swap_party_0_key].items():
                    if inner_move_key != count_key and inner_move_key != sum_wins_key:
                        # turn the 2v3 keys into 2v2 keys
                        filtered_state_key, is_use_two_vs_two_key = convert_2v3_state_to_2v2_state(inner_move_key)

                        # get results for that state using the best action by win %
                        if is_use_two_vs_two_key:
                            if filtered_state_key in two_vs_two_recommended_dict:
                                recommended_action = two_vs_two_recommended_dict[filtered_state_key].get('recommended_action', 0)

                                if recommended_action == 1:
                                    lookup_key = agent_first_move_swap_party_0_key
                                else:
                                    lookup_key = agent_first_move_attack_key
                            else:
                                lookup_key = agent_first_move_attack_key

                            lookup_dict = two_vs_two_combined_dict.get(filtered_state_key, {}).get(lookup_key, {})
                            lookup_2v2_count = lookup_dict.get(count_key, 0)
                            lookup_2v2_sum_wins = lookup_dict.get(sum_wins_key, 0) 
                            if lookup_2v2_count > 0 and sum_wins_key in lookup_dict:
                                lookup_win_percent = lookup_2v2_sum_wins / lookup_2v2_count
                                swap_to_2v2_estimated_wins += lookup_win_percent * inner_move_value
                                swap_to_2v2_count += inner_move_value

                # get the total wins, total counts
                # make new chi square test look up function for this

                total_attack_counts = attack_sub_count + attack_to_2v2_count
                total_attack_wins = attack_sub_wins + int(np.round(attack_to_2v2_estimated_wins,0))

                total_swap_counts = swap_sub_count + swap_to_2v2_count
                total_swap_wins = swap_sub_wins + int(np.round(swap_to_2v2_estimated_wins,0))

                # for 3v2 need to take the best swap
                recommended_action, swap_win_rate_better_rate, is_use_p_value, is_swap_better, p_value = get_chi_square_test_from_count_wins(
                    total_swap_wins, total_swap_counts, total_attack_wins, total_attack_counts,
                    min_total_count=50,
                    min_swap_count=25,
                    min_attack_count=25,
                    is_print_statistics=False)
                
                count_list.append(total_swap_counts)
                count_list.append(total_attack_counts)
            
                if is_swap_better:
                    swap_better_count += 1
                    if recommended_action != 0:
                        swap_dict[state_key] = {
                            'recommended_action': recommended_action,
                            'swap_win_rate_better_rate': swap_win_rate_better_rate,
                            'is_use_p_value': is_use_p_value,
                            'is_swap_better': is_swap_better,
                            'p_value': p_value,
                            # to do: fill these in
                            'swap_0_count': swap_sub_count,
                            'attack_count': attack_sub_count,
                            'swap_0_wins': swap_sub_wins,
                            'attack_wins': attack_sub_wins,
                            '2v2_swap_visit_count': swap_to_2v2_count,
                            '2v2_swap_estimated_win_count': swap_to_2v2_estimated_wins,
                            '2v2_attack_visit_count': attack_to_2v2_count,
                            '2v2_attack_estimated_win_count': attack_to_2v2_estimated_wins,
                            'total_swap_visit_count': total_swap_counts,
                            'total_swap_estimated_win_count': total_swap_wins,
                            'total_attack_visit_count': total_attack_counts ,
                            'total_attack_estimated_win_count': total_attack_wins,
                        }

                        three_vs_three_lookup_dict[state_key] = {
                            count_key: total_swap_counts,
                            sum_wins_key: total_swap_wins,
                        }
                    else:
                        three_vs_three_lookup_dict[state_key] = {
                            count_key: total_attack_counts,
                            sum_wins_key: total_attack_wins,
                        }

                else:
                    three_vs_three_lookup_dict[state_key] = {
                        count_key: total_attack_counts,
                        sum_wins_key: total_attack_wins,
                    }
    except Exception as e:
        print("Error in iterating through dict: ", str(e), state_key)

    count_array = np.array(count_list)
    print(f"swap better count: {swap_better_count} | len swap dict {len(swap_dict)}")

    save_object_as_pkl(swap_dict, f'swap_{save_tag}_dict_{time_int}')
    save_object_as_pkl(three_vs_three_lookup_dict, f'three_vs_three_lookup_dict_{save_tag}_{time_int}')
    save_object_as_pkl(count_list, f'count_list_{save_tag}_{time_int}')

    return count_array, swap_better_count, swap_dict, three_vs_three_lookup_dict



In [11]:
two_v_three_count_array, two_v_three_swap_better_count, swap_two_v_three_dict, three_vs_three_from2v3_lookup_dict = get_state_counts_2v3(
   two_vs_three_combined_dict, two_vs_three_save_tag, two_vs_two_recommended_dict, two_vs_two_combined_dict,
)

swap better count: 925 | len swap dict 416
saving:  joined_dict_results\swap_two_vs_three_eval_1720955501_dict_1720955523.pickle
saving:  joined_dict_results\three_vs_three_lookup_dict_two_vs_three_eval_1720955501_1720955523.pickle
saving:  joined_dict_results\count_list_two_vs_three_eval_1720955501_1720955523.pickle


In [13]:
# num_iters = 0
# for k, v in two_vs_three_combined_dict.items():
#     pprint.pprint(len(k))
#     num_iters += 1
#     if num_iters > 10:
#         break
    

#### debugging 3v3 lookup dict

In [14]:
len(three_vs_three_from2v3_lookup_dict)
for k, v in three_vs_three_from2v3_lookup_dict.items():
    print(k, v)
    break

(2, -1, -1, 2, -1, -1, 2, 3, -1, -1, -1, -1, 1, 1, 0, -1, -1) {'count': 274, 'sum_wins': 72}


#### debugging 2v3 swap dict

In [15]:
def get_swap_info(swap_dict):
    print(len(swap_dict))

    break_num = 0
    for k, v in swap_dict.items():
        pprint.pprint(k)
        pprint.pprint(v)
        break_num += 1
        if break_num >= 1:
            break

get_swap_info(swap_two_v_three_dict)

416
(1, -1, -1, 1, -1, -1, 1, 2, -1, -1, -1, -1, 0, 1, 0, -1, -1)
{'2v2_attack_estimated_win_count': 12669.170391281394,
 '2v2_attack_visit_count': 27398,
 '2v2_swap_estimated_win_count': 2394.146754379163,
 '2v2_swap_visit_count': 5579,
 'attack_count': 27950,
 'attack_wins': 4056,
 'is_swap_better': True,
 'is_use_p_value': True,
 'p_value': 1.615733867201681e-67,
 'recommended_action': 1,
 'swap_0_count': 316,
 'swap_0_wins': 38,
 'swap_win_rate_better_rate': 0.1103740705045041,
 'total_attack_estimated_win_count': 16725,
 'total_attack_visit_count': 55348,
 'total_swap_estimated_win_count': 2432,
 'total_swap_visit_count': 5895}


In [None]:
test_state = (2, -1, -1, 1, -1, -1, 1, 2, -1, -1, -1, -1, 0, 1, 0, -1, -1)
#tuple([1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, 0, 0, 0, -1, -1])
rs_state = copy.deepcopy(swap_two_v_three_dict.get(test_state))

print("Non 2v2 stats")
print("attack ", rs_state['attack_count'], rs_state['attack_wins']/ rs_state['attack_count'])
print("swap ", rs_state['swap_0_count'], rs_state['swap_0_wins']/ rs_state['swap_0_count'])

print("2v2 stats")
if rs_state['2v2_attack_visit_count'] > 0:
    print("attack ", rs_state['2v2_attack_visit_count'], rs_state['2v2_attack_estimated_win_count'] / rs_state['2v2_attack_visit_count'])
if rs_state['2v2_swap_visit_count'] >0:
    print("swap ", rs_state['2v2_swap_visit_count'], rs_state['2v2_swap_estimated_win_count'] / rs_state['2v2_swap_visit_count'])

print("Total stats")
if rs_state['total_attack_visit_count'] > 0:
    print("attack ", rs_state['total_attack_visit_count'], rs_state['total_attack_estimated_win_count']/rs_state['total_attack_visit_count'])
if rs_state['total_swap_visit_count'] > 0:
    print("swap ", rs_state['total_swap_visit_count'], rs_state['total_swap_estimated_win_count']/rs_state['total_swap_visit_count'])




In [None]:
#1, 0, -1, 1, 0, -1, 1, 1, 0, 0, -1, -1, 0, 0, 0, -1, -1
test_state = (2, -1, -1, 1, -1, -1, 1, 2, -1, -1, -1, -1, 0, 1, 0, -1, -1)
two_vs_three_combined_dict[test_state]

In [None]:
#(1, 0, -1, 1, 0, -1, 1, 1, 0, 0, -1, -1, 0, 0, 0, -1, -1)
test_2v2_tuple = (1, -1, 1,-1,  1, 1, -1, -1,   0, 0, 0, -1)
test_2v2_entry = copy.deepcopy(two_vs_two_combined_dict.get(test_2v2_tuple))
print(test_2v2_entry)

In [None]:
# # why is win % so low here:
# #(1, 0, -1, 1, 0, -1, 1, 1, 0, 0, -1, -1, 0, 0, 0, -1, -1)
# test_2v2_tuple = tuple([1,-1, 1, -1,     1, 1,-1, -1,    0, 0, 0, -1])
# test_2v2_entry = copy.deepcopy(two_vs_two_combined_dict.get(test_2v2_tuple))
# print(test_2v2_entry)

In [None]:
183 * .5

In [None]:
print(test_2v2_entry)

### 3v2

In [16]:
time_int = int(time.time())
three_vs_two_save_tag = f"three_vs_two_eval_{time_int}"

In [17]:


three_vs_two_combined_dict = load_pkl_object('G:\\3v3_vgc_saves_71124\\3v3_results\\3v2_3v2final_1001_1720922746_action_state_results_dict_checkpoint_2.pickle')

In [18]:
two_vs_two_combined_dict = load_pkl_object('joined_dict_results\\two_vs_two_combined_results.pickle')
two_vs_two_recommended_dict = load_pkl_object('joined_dict_results\\swap_2v2_dict_1720783493.pickle')

In [19]:
# example entries
# (1, 1, 2, 2, 1, 2, 1, 1, 1, -1, -1, -1, 0, 1, 0, 0, 1) {'attack': {'sum_wins': 31, 'count': 31, (1, 1, 2, 2, 1, 2, 1, 1, 1, -1, -1, -1, 0, 1, 0, 0, 1): 37}, 'swap_1': {(1, 1, 2, 2, 1, 2, 1, 1, 1, -1, -1, -1, 0, 1, 0, 0, 1): 50, 'count': 1, 'sum_wins': 1, (1, 2, 2, 2, 1, 1, 1, 1, 1, -1, -1, -1, 0, 1, 0, 0, 1): 1, (2, 2, 1, 1, 1, 2, 1, 1, 1, -1, -1, -1, 0, 0, 0, 0, 1): 4}, 'swap_0': {(1, 1, 2, 2, 1, 2, 1, 1, 1, -1, -1, -1, 0, 1, 0, 0, 1): 48, 'count': 2, 'sum_wins': 2, (1, 2, 2, 2, 1, 1, 1, 1, 1, -1, -1, -1, 0, 1, 0, 0, 1): 1, (2, 2, 1, 1, 1, 2, 1, 1, 1, -1, -1, -1, 1, 0, 0, 0, 1): 1}}
# (1, 0, 2, 0, 2, 0, 2, 2, 1, 0, 0, 0, 0, 1, 0, 1, 0) {'attack': {'sum_wins': 1105, 'count': 1105}}
# (1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 0, 1, 0, 0, 0) {'swap_0': {(1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 0, 1, 0, 0, 0): 5981, 'count': 176, 'sum_wins': 145, (2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 438, (2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 0, 0, 0, 0): 24, (2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 0, 0, 0, 0, 0): 50, (2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0): 3}, 'swap_1': {(2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 456, 'count': 180, 'sum_wins': 151, (1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 0, 1, 0, 0, 0): 6297, (2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 0, 0, 0, 0): 33, (2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0): 10, (2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 0, 0, 0, 0, 0): 68}, 'attack': {'sum_wins': 4366, 'count': 5112, (1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 0, 1, 0, 0, 0): 4994}}
# (1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 0, 0, 0, 0, 0) {'attack': {'sum_wins': 8772, 'count': 10011, (1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 0, 0, 0, 0, 0): 10059}, 'swap_0': {(1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 0, 0, 0, 0, 0): 10580, 'count': 688, 'sum_wins': 576, (1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 341, (1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 0, 0, 0, 0, 0): 55}, 'swap_1': {(1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 0, 0, 0, 0, 0): 10761, 'count': 668, 'sum_wins': 539, (1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 334, (1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 0, 0, 0, 0, 0): 46}}
# (1, 0, 2, 0, 1, 0, 1, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0) {'attack': {'sum_wins': 42151, 'count': 48134}}
# (1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0) {'swap_1': {(1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 37109, 'count': 1198, 'sum_wins': 1060, (2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 3251}, 'swap_0': {(1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 35867, 'count': 1199, 'sum_wins': 1035, (2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 3044}, 'attack': {'sum_wins': 32925, 'count': 37694, (1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 37277}}

In [20]:
def make_new_state_tuple(state_key, filter_list):
    new_state_list = []
    for i in range(len(state_key)):
        if filter_list[i]:
            new_state_list.append(state_key[i])

    new_state_tuple = tuple(new_state_list)
    if len(new_state_tuple) != 12:
        print("Error: new state tuple not correct length of 12", new_state_tuple)

    return new_state_tuple

def convert_3v2_state_to_2v2_state(state_key):
    # example state
    #(1, 1, 2, 2, 1, 2, 1, 1, 1, -1, -1, -1, 0, 1, 0, 0, 1)
    #(1, 1,     2, 2,    1, 2,      1, 1, 1,    -1, -1, -1,     0, 1, 0,    0, 1)

    # index 0,1 are agent active attacking opp
    # index 2,3 are party 0 active attacking opp
    # index 4,5 are party 1 active attacking opp
    # index 6,7,8 are opp active attacking agent
    # index 9,10,11 are opp party active attacking agent
    # index 12, 13, 14 are agent hp normalized
    # index 15, 16 are opp hp normalized

    is_two_vs_two_key = False
    filtered_state_key = tuple([19])

    if len(state_key) == 17:
        if state_key[0] == 0 and state_key[2] == 0 and state_key[4] == 0:
            print("Error: active opp is fainted ", state_key)
        else:
            if state_key[2] == 0 and state_key[3] == 0: #\
                #and state_key[4] != 0 and state_key[5] != 0:
                # agent party 0 is down
                filter_list = [True, True,
                               False, False,
                               True, True,
                               True, False, True,
                               True, False, True,
                               True, False, True,
                               True, True]  
                
                filtered_state_key = make_new_state_tuple(state_key, filter_list)
                is_two_vs_two_key = True

            elif state_key[4] == 0 and state_key[5] == 0:# \
                #and state_key[2] != 0 and state_key[3] != 0:
                # agent party 1 is down
                filter_list = [True, True,
                               True, True,
                               False, False,
                               True, True, False,
                               True, True, False,
                               True, True, False,
                               True, True]    
                
                filtered_state_key = make_new_state_tuple(state_key, filter_list)
                is_two_vs_two_key = True
                
            else:
                print("Error: unable to located the fainted agent ", state_key)
    else:
        print("Error: state key not 17 long ", state_key)

    if not is_two_vs_two_key:
        print("Warning: not a 2v2 key ", state_key, filtered_state_key)

    return filtered_state_key, is_two_vs_two_key


def get_state_counts_3v2(three_vs_two_combined_dict, save_tag,
        two_vs_two_recommended_dict, two_vs_two_combined_dict):
    count_key = 'count'
    sum_wins_key = 'sum_wins'
    agent_first_move_attack_key = 'attack'
    agent_first_move_swap_party_0_key = 'swap_0'
    agent_first_move_swap_party_1_key = 'swap_1'
    time_int = int(time.time())

    count_list = []
    swap_dict = {}
    swap_better_count = 0

    three_vs_three_lookup_dict = {}

    try:
        for state_key, state_dict in three_vs_two_combined_dict.items():
            if ( agent_first_move_swap_party_0_key in state_dict or agent_first_move_swap_party_1_key in state_dict ) \
                and agent_first_move_attack_key in state_dict:
                # either swap move is in the dict and attack key is in the dict

                if agent_first_move_swap_party_0_key in state_dict:
                    is_swap_0_in_dict = True
                else:
                    is_swap_0_in_dict = False

                if agent_first_move_swap_party_1_key in state_dict:
                    is_swap_1_in_dict = True
                else:
                    is_swap_1_in_dict = False

                # get wins, counts for swap/attack that led to end state
                # and do not go to 2v2
                attack_sub_count = state_dict[agent_first_move_attack_key].get(count_key, 0)
                attack_sub_wins = state_dict[agent_first_move_attack_key].get(sum_wins_key, 0)
                attack_to_2v2_count = 0
                attack_to_2v2_estimated_wins = 0.

                swap_0_sub_count = state_dict.get(agent_first_move_swap_party_0_key, {}).get(count_key, 0)
                swap_0_sub_wins = state_dict.get(agent_first_move_swap_party_0_key, {}).get(sum_wins_key, 0)
                swap_0_to_2v2_count = 0
                swap_0_to_2v2_estimated_wins = 0.

                swap_1_sub_count = state_dict.get(agent_first_move_swap_party_1_key, {}).get(count_key, 0)
                swap_1_sub_wins = state_dict.get(agent_first_move_swap_party_1_key, {}).get(sum_wins_key, 0)
                swap_1_to_2v2_count = 0
                swap_1_to_2v2_estimated_wins = 0.

                # get counts, estimated wins for states that went to 2v2
                for inner_move_key, inner_move_value in state_dict[agent_first_move_attack_key].items():
                    if inner_move_key != count_key and inner_move_key != sum_wins_key:

                        # turn the 3v2 keys into 2v2 keys
                        filtered_state_key, is_use_two_vs_two_key = convert_3v2_state_to_2v2_state(inner_move_key)

                        # get results for that state using the best action by win %
                        if is_use_two_vs_two_key:
                            if filtered_state_key in two_vs_two_recommended_dict:
                                recommended_action = two_vs_two_recommended_dict[filtered_state_key].get('recommended_action', 0)

                                if recommended_action == 1:
                                    lookup_key = agent_first_move_swap_party_0_key
                                else:
                                    lookup_key = agent_first_move_attack_key
                            else:
                                lookup_key = agent_first_move_attack_key

                            lookup_dict = two_vs_two_combined_dict.get(filtered_state_key, {}).get(lookup_key, {})
                            lookup_2v2_count = lookup_dict.get(count_key, 0)
                            lookup_2v2_sum_wins = lookup_dict.get(sum_wins_key, 0) 
                            if lookup_2v2_count > 0 and sum_wins_key in lookup_dict:
                                lookup_win_percent = lookup_2v2_sum_wins / lookup_2v2_count
                                attack_to_2v2_estimated_wins += lookup_win_percent * inner_move_value
                                attack_to_2v2_count += inner_move_value

                if is_swap_0_in_dict:
                    for inner_move_key, inner_move_value in state_dict[agent_first_move_swap_party_0_key].items():
                        if inner_move_key != count_key and inner_move_key != sum_wins_key:

                            # turn the 3v2 keys into 2v2 keys
                            filtered_state_key, is_use_two_vs_two_key = convert_3v2_state_to_2v2_state(inner_move_key)

                            # get results for that state using the best action by win %
                            if is_use_two_vs_two_key:
                                if filtered_state_key in two_vs_two_recommended_dict:
                                    recommended_action = two_vs_two_recommended_dict[filtered_state_key].get('recommended_action', 0)

                                    if recommended_action == 1:
                                        lookup_key = agent_first_move_swap_party_0_key
                                    else:
                                        lookup_key = agent_first_move_attack_key
                                else:
                                    lookup_key = agent_first_move_attack_key

                                lookup_dict = two_vs_two_combined_dict.get(filtered_state_key, {}).get(lookup_key, {})
                                lookup_2v2_count = lookup_dict.get(count_key, 0)
                                lookup_2v2_sum_wins = lookup_dict.get(sum_wins_key, 0) 
                                if lookup_2v2_count > 0 and sum_wins_key in lookup_dict:
                                    lookup_win_percent = lookup_2v2_sum_wins / lookup_2v2_count
                                    swap_0_to_2v2_estimated_wins += lookup_win_percent * inner_move_value
                                    swap_0_to_2v2_count += inner_move_value

                if is_swap_1_in_dict:
                    for inner_move_key, inner_move_value in state_dict[agent_first_move_swap_party_1_key].items():
                        if inner_move_key != count_key and inner_move_key != sum_wins_key:

                            # turn the 3v2 keys into 2v2 keys
                            filtered_state_key, is_use_two_vs_two_key = convert_3v2_state_to_2v2_state(inner_move_key)

                            # get results for that state using the best action by win %
                            if is_use_two_vs_two_key:
                                if filtered_state_key in two_vs_two_recommended_dict:
                                    recommended_action = two_vs_two_recommended_dict[filtered_state_key].get('recommended_action', 0)

                                    if recommended_action == 1:
                                        lookup_key = agent_first_move_swap_party_0_key
                                    else:
                                        lookup_key = agent_first_move_attack_key
                                else:
                                    lookup_key = agent_first_move_attack_key

                                lookup_dict = two_vs_two_combined_dict.get(filtered_state_key, {}).get(lookup_key, {})
                                lookup_2v2_count = lookup_dict.get(count_key, 0)
                                lookup_2v2_sum_wins = lookup_dict.get(sum_wins_key, 0) 
                                if lookup_2v2_count > 0 and sum_wins_key in lookup_dict:
                                    lookup_win_percent = lookup_2v2_sum_wins / lookup_2v2_count
                                    swap_1_to_2v2_estimated_wins += lookup_win_percent * inner_move_value
                                    swap_1_to_2v2_count += inner_move_value

                # get the total wins, total counts
                # make new chi square test look up function for this

                total_attack_counts = attack_sub_count + attack_to_2v2_count
                total_attack_wins = attack_sub_wins + int(np.round(attack_to_2v2_estimated_wins,0))

                total_swap_0_counts = swap_0_sub_count + swap_0_to_2v2_count
                total_swap_0_wins = swap_0_sub_wins + int(np.round(swap_0_to_2v2_estimated_wins,0))

                total_swap_1_counts = swap_1_sub_count + swap_1_to_2v2_count
                total_swap_1_wins = swap_1_sub_wins + int(np.round(swap_1_to_2v2_estimated_wins,0))

                # find the best swap action and compare to attack
                if swap_0_sub_count < 25 and swap_1_sub_count >= 25:
                    total_swap_counts = total_swap_1_counts
                    total_swap_wins = total_swap_1_wins
                    potential_recommended_swap_action = 2
                elif swap_1_sub_count < 25 and swap_0_sub_count >= 25:
                    total_swap_counts = total_swap_0_counts
                    total_swap_wins = total_swap_0_wins
                    potential_recommended_swap_action = 1
                elif swap_0_sub_count >= 25 and swap_1_sub_count >= 25:
                    swap_0_win_rate = total_swap_0_wins / (total_swap_0_counts)
                    swap_1_win_rate = total_swap_1_wins / (total_swap_1_counts)

                    if swap_0_win_rate >= swap_1_win_rate:
                        total_swap_counts = total_swap_0_counts
                        total_swap_wins = total_swap_0_wins
                        potential_recommended_swap_action = 1
                    else:
                        total_swap_counts = total_swap_1_counts
                        total_swap_wins = total_swap_1_wins
                        potential_recommended_swap_action = 2
                else:
                    total_swap_counts = 0
                    total_swap_wins = 0
                    potential_recommended_swap_action = 0

                # for 3v2 need to take the best swap
                recommended_action, swap_win_rate_better_rate, is_use_p_value, is_swap_better, p_value = get_chi_square_test_from_count_wins(
                    total_swap_wins, total_swap_counts, total_attack_wins, total_attack_counts,
                    min_total_count=50,
                    min_swap_count=25,
                    min_attack_count=25,
                    is_print_statistics=False)
                
                count_list.append(total_swap_counts)
                count_list.append(total_attack_counts)
            
                if is_swap_better:

                    swap_better_count += 1

                    if recommended_action != 0:
                        actual_recommended_action = potential_recommended_swap_action
                    else:
                        actual_recommended_action = 0

                    if actual_recommended_action != 0:
                        swap_dict[state_key] = {
                            'recommended_action': actual_recommended_action,
                            'swap_win_rate_better_rate': swap_win_rate_better_rate,
                            'is_use_p_value': is_use_p_value,
                            'is_swap_better': is_swap_better,
                            'p_value': p_value,
                            'best_swap_visit_count': total_swap_counts,
                            'best_swap_estimated_win_count': total_swap_wins,
                            'attack_visit_count': total_attack_counts ,
                            'attack_estimated_win_count': total_attack_wins,
                        }

                        three_vs_three_lookup_dict[state_key] = {
                            count_key: total_swap_counts,
                            sum_wins_key: total_swap_wins,
                        }
                    else:
                        three_vs_three_lookup_dict[state_key] = {
                            count_key: total_attack_counts,
                            sum_wins_key: total_attack_wins,
                        }
                else:
                    three_vs_three_lookup_dict[state_key] = {
                        count_key: total_attack_counts,
                        sum_wins_key: total_attack_wins,
                    }
    except Exception as e:
        print("Error in iterating through dict: ", str(e), state_key)

    count_array = np.array(count_list)
    print(f"swap better count: {swap_better_count} | len swap dict {len(swap_dict)}")

    save_object_as_pkl(swap_dict, f'swap_{save_tag}_dict_{time_int}')
    save_object_as_pkl(three_vs_three_lookup_dict, f'three_vs_three_lookup_dict_{save_tag}_{time_int}')
    save_object_as_pkl(count_list, f'count_list_{save_tag}_{time_int}')

    return count_array, swap_better_count, swap_dict, three_vs_three_lookup_dict



In [21]:
three_v_two_count_array, three_v_two_swap_better_count, swap_three_v_two_dict, three_vs_three_from3v2_lookup_dict = get_state_counts_3v2(
   three_vs_two_combined_dict, three_vs_two_save_tag, two_vs_two_recommended_dict, two_vs_two_combined_dict
)

swap better count: 124 | len swap dict 26
saving:  joined_dict_results\swap_three_vs_two_eval_1720955581_dict_1720955593.pickle
saving:  joined_dict_results\three_vs_three_lookup_dict_three_vs_two_eval_1720955581_1720955593.pickle
saving:  joined_dict_results\count_list_three_vs_two_eval_1720955581_1720955593.pickle


#### Debut lookup dict for 3v3

In [22]:
len(three_vs_three_from3v2_lookup_dict)
for k, v in three_vs_three_from3v2_lookup_dict.items():
    print(k, v)
    break

(2, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 1, -1) {'count': 3129, 'sum_wins': 1834}


### 3v3

In [23]:
time_int = int(time.time())
three_vs_three_save_tag = f"three_vs_three_eval_{time_int}"

In [24]:
three_vs_three_combined_dict = load_pkl_object('G:\\3v3_vgc_saves_71124\\3v3_results\\3v3_3v3final_1003_1720922757_action_state_results_dict_checkpoint_2.pickle')

In [25]:

two_vs_three_lookup_dict = load_pkl_object('C:\\Users\\james\\github_repos\\pokemon-vgc-engine\\s3stuff\\joined_dict_results\\three_vs_three_lookup_dict_two_vs_three_eval_1720955501_1720955523.pickle')
three_vs_two_lookup_dict = load_pkl_object('C:\\Users\\james\\github_repos\\pokemon-vgc-engine\\s3stuff\\joined_dict_results\\three_vs_three_lookup_dict_three_vs_two_eval_1720955581_1720955593.pickle')

In [26]:
# example states

In [31]:
# eval and save function
# don't need look up

def make_new_state_tuple_flexible(state_key, filter_list, final_tuple_length):
    new_state_list = []
    for i in range(len(state_key)):
        if filter_list[i]:
            new_state_list.append(state_key[i])

    new_state_tuple = tuple(new_state_list)
    if len(new_state_tuple) != final_tuple_length:
        print(f"Error: new state tuple not correct length of {final_tuple_length} | {len(new_state_tuple)} | {new_state_tuple}")

    return new_state_tuple

def convert_3v3_state_to_2v3_or_3v2_state(state_key):
    # example state

    #(1,1,1,    2,2,2,  3,3,3,  4,4,4,  5,5,5,  6,6,6,  1,2,3,  4,5,6)

    # index 0,1,2 are agent active attacking opp
    # index 3,4,5 are party 0 active attacking opp
    # index 6,7,8 are party 1 active attacking opp
    # index 9,10,11 are opp active attacking agent
    # index 12,13,14 are opp party 0 attacking agent
    # index 15,16,17 are opp party 1 attacking agent
    # index 18,19,20 are agent hp normalized
    # index 21,22,23 are opp hp normalized

    is_3v2_or_2v3_key = False
    filtered_state_key = tuple([19])

    if len(state_key) == 24:

        if state_key[3] == 0 and state_key[4] == 0 or state_key[5] == 0:
            # agent party 0 is fainted
            filter_list = [
                True, True, True,
                False, False, False,
                True, True, True,
                True, False, True,
                True, False, True,
                True, False, True,

                True, False, True,
                True, True, True,
            ]
            filtered_state_key = make_new_state_tuple_flexible(state_key, filter_list, 17)
            is_3v2_or_2v3_key = True
        elif state_key[6] == 0 and state_key[7] == 0 or state_key[8] == 0:
            # agent party 1 is fainted
            filter_list = [
                True, True, True,
                True, True, True,
                False, False, False,
                True, True, False,
                True, True, False,
                True, True, False,

                True, True, False,
                True, True, True,
            ]
            filtered_state_key = make_new_state_tuple_flexible(state_key, filter_list, 17)
            is_3v2_or_2v3_key = True
        elif state_key[12] == 0 and state_key[13] == 0 or state_key[14] == 0:
            # opp party 0 is fainted
            filter_list = [
                True, False, True,
                True, False, True,
                True, False, True,
                True, True, True,
                False, False, False,
                True, True, True,

                True, True, True,
                True, False, True,
            ]
            filtered_state_key = make_new_state_tuple_flexible(state_key, filter_list, 17)
            is_3v2_or_2v3_key = True
        elif state_key[15] == 0 and state_key[16] == 0 or state_key[17] == 0:
            # opp party 1 is fainted
            filter_list = [
                True, True, False,
                True, True, False,
                True, True, False,
                True, True, True,
                True, True, True,
                False, False, False,

                True, True, True,
                True, True, False,
            ]
            filtered_state_key = make_new_state_tuple_flexible(state_key, filter_list, 17)
            is_3v2_or_2v3_key = True
        else:
            print("Error: unable to located the fainted agent ", state_key)
        
    else:
        print("Error: state key not 24 long ", state_key)

    if not is_3v2_or_2v3_key:
        print("Warning: not a 2v2 key ", state_key, filtered_state_key)

    return filtered_state_key, is_3v2_or_2v3_key


def get_state_counts_3v3(three_vs_three_combined_dict, save_tag,
        two_vs_three_lookup_dict, three_vs_two_lookup_dict):
    count_key = 'count'
    sum_wins_key = 'sum_wins'
    agent_first_move_attack_key = 'attack'
    agent_first_move_swap_party_0_key = 'swap_0'
    agent_first_move_swap_party_1_key = 'swap_1'
    time_int = int(time.time())

    count_list = []
    swap_dict = {}
    swap_better_count = 0

    try:
        for state_key, state_dict in three_vs_three_combined_dict.items():
            if ( agent_first_move_swap_party_0_key in state_dict or agent_first_move_swap_party_1_key in state_dict ) \
                and agent_first_move_attack_key in state_dict:
                # either swap move is in the dict and attack key is in the dict

                if agent_first_move_swap_party_0_key in state_dict:
                    is_swap_0_in_dict = True
                else:
                    is_swap_0_in_dict = False

                if agent_first_move_swap_party_1_key in state_dict:
                    is_swap_1_in_dict = True
                else:
                    is_swap_1_in_dict = False

                # get wins, counts for swap/attack that led to 2v3 or 3v2 state
                attack_count = 0
                attack_wins = 0
                swap_0_count = 0
                swap_0_wins = 0
                swap_1_count = 0
                swap_1_wins = 0
                
                # to do DRY this up
                # get counts, estimated wins for states that went to 2v2
                for inner_move_key, inner_move_value in state_dict[agent_first_move_attack_key].items():
                    if inner_move_key != count_key and inner_move_key != sum_wins_key:

                        # turn the 3v3 keys into 3v2 or 3v2 keys
                        filtered_state_key, is_use_key = convert_3v3_state_to_2v3_or_3v2_state(inner_move_key)

                        if is_use_key:
                            if filtered_state_key in two_vs_three_lookup_dict:
                                look_up_dict = two_vs_three_lookup_dict[filtered_state_key]
                            elif filtered_state_key in three_vs_two_lookup_dict:
                                look_up_dict = three_vs_two_lookup_dict[filtered_state_key]
                            else:
                                look_up_dict = {}
                            
                            lookup_count = look_up_dict.get(count_key, 0)
                            lookup_sum_wins = look_up_dict.get(sum_wins_key, 0)

                            if lookup_count > 0 and sum_wins_key in look_up_dict:
                                lookup_win_percent = lookup_sum_wins / lookup_count
                                attack_count += inner_move_value
                                attack_wins += lookup_win_percent * inner_move_value


                if is_swap_0_in_dict:
                    for inner_move_key, inner_move_value in state_dict[agent_first_move_swap_party_0_key].items():
                        if inner_move_key != count_key and inner_move_key != sum_wins_key:

                            # turn the 3v3 keys into 3v2 or 3v2 keys
                            filtered_state_key, is_use_key = convert_3v3_state_to_2v3_or_3v2_state(inner_move_key)

                            if is_use_key:
                                if filtered_state_key in two_vs_three_lookup_dict:
                                    look_up_dict = two_vs_three_lookup_dict[filtered_state_key]
                                elif filtered_state_key in three_vs_two_lookup_dict:
                                    look_up_dict = three_vs_two_lookup_dict[filtered_state_key]
                                else:
                                    look_up_dict = {}

                                if lookup_count > 0 and sum_wins_key in look_up_dict:
                                    lookup_win_percent = lookup_sum_wins / lookup_count
                                    swap_0_count += inner_move_value
                                    swap_0_wins += lookup_win_percent * inner_move_value


                if is_swap_1_in_dict:
                    for inner_move_key, inner_move_value in state_dict[agent_first_move_swap_party_1_key].items():
                        if inner_move_key != count_key and inner_move_key != sum_wins_key:

                            # turn the 3v3 keys into 3v2 or 3v2 keys
                            filtered_state_key, is_use_key = convert_3v3_state_to_2v3_or_3v2_state(inner_move_key)

                            if is_use_key:
                                if filtered_state_key in two_vs_three_lookup_dict:
                                    look_up_dict = two_vs_three_lookup_dict[filtered_state_key]
                                elif filtered_state_key in three_vs_two_lookup_dict:
                                    look_up_dict = three_vs_two_lookup_dict[filtered_state_key]
                                else:
                                    look_up_dict = {}

                                if lookup_count > 0 and sum_wins_key in look_up_dict:
                                    lookup_win_percent = lookup_sum_wins / lookup_count
                                    swap_1_count += inner_move_value
                                    swap_1_wins += lookup_win_percent * inner_move_value

                if swap_0_count < 50 and swap_1_count >= 50:
                    best_swap_count = swap_1_count
                    best_swap_wins = swap_1_wins
                    potential_recommended_swap_action = 2
                elif swap_1_count < 50 and swap_0_count >= 50:
                    best_swap_count = swap_0_count
                    best_swap_wins = swap_0_wins
                    potential_recommended_swap_action = 1
                elif swap_1_count >= 50 and swap_0_count >= 50:
                    swap_0_win_rate = swap_0_wins / (swap_0_count)
                    swap_1_win_rate = swap_1_wins / (swap_1_count)

                    if swap_0_win_rate >= swap_1_win_rate:
                        best_swap_count = swap_0_count
                        best_swap_wins = swap_0_wins
                        potential_recommended_swap_action = 1
                    else:
                        best_swap_count = swap_1_count
                        best_swap_wins = swap_1_wins
                        potential_recommended_swap_action = 2
                else:
                    best_swap_count = 0
                    best_swap_wins = 0
                    potential_recommended_swap_action = 0

                recommended_action, swap_win_rate_better_rate, is_use_p_value, is_swap_better, p_value = get_chi_square_test_from_count_wins(
                    best_swap_wins, best_swap_count, attack_wins, attack_count,
                    min_total_count=100,
                    min_swap_count=50,
                    min_attack_count=50,
                    is_print_statistics=False)
                
                count_list.append(best_swap_count)
                count_list.append(attack_count)
            
                if is_swap_better:

                    swap_better_count += 1

                    if recommended_action != 0:
                        # above looked to see if the best swap is better than attack
                        # here assign the actual swap action to the lookup dict
                        actual_recommended_action = potential_recommended_swap_action
                    else:
                        actual_recommended_action = 0

                    if actual_recommended_action != 0:
                        swap_dict[state_key] = {
                            'recommended_action': actual_recommended_action,
                            'swap_win_rate_better_rate': swap_win_rate_better_rate,
                            'is_use_p_value': is_use_p_value,
                            'is_swap_better': is_swap_better,
                            'p_value': p_value,
                            'best_swap_count': best_swap_count,
                            'attack_count': attack_count,
                            'best_swap_wins': best_swap_wins,
                            'attack_wins': attack_wins,
                        }


    except Exception as e:
        print("Error in iterating through dict: ", str(e), state_key)

    count_array = np.array(count_list)
    print(f"swap better count: {swap_better_count} | len swap dict {len(swap_dict)}")

    save_object_as_pkl(swap_dict, f'swap_{save_tag}_dict_{time_int}')
    save_object_as_pkl(count_list, f'count_list_{save_tag}_{time_int}')

    return count_array, swap_better_count, swap_dict



In [32]:
# call eval and save function
three_vs_three_count_array, three_vs_three_swap_better_count, three_vs_three_swap_dict = get_state_counts_3v3(
    three_vs_three_combined_dict,three_vs_three_save_tag, two_vs_three_lookup_dict, three_vs_two_lookup_dict)

# # debugging using the prior made dicts
# three_vs_three_count_array, three_vs_three_swap_better_count, three_vs_three_swap_dict = get_state_counts_3v3(
#     three_vs_three_combined_dict,three_vs_three_save_tag, three_vs_three_from2v3_lookup_dict, three_vs_three_from3v2_lookup_dict)

swap better count: 292 | len swap dict 160
saving:  joined_dict_results\swap_three_vs_three_eval_1720955610_dict_1720956546.pickle
saving:  joined_dict_results\count_list_three_vs_three_eval_1720955610_1720956546.pickle


#### debug

In [29]:
len(three_vs_three_swap_dict)

269