In [1]:
import numpy as np
import pickle
import os
import copy
from scipy.stats import chi2_contingency
import time
import pprint

In [2]:
'''
Evaluate dicts
    DONE overall size
    TEST states that do not have enough samples
    swaps that are better
        overall number
        return sample size, p value etc and see
    would be nice to have how often the swaps occur as overall percent but the dict will stop collecting keys
        so not an accurate caught

Choose the p values
    which swap values are better with win %
    then how many by the different narrowing things

Maybe examine some specific states
'''

'\nEvaluate dicts\n    DONE overall size\n    TEST states that do not have enough samples\n    swaps that are better\n        overall number\n        return sample size, p value etc and see\n    would be nice to have how often the swaps occur as overall percent but the dict will stop collecting keys\n        so not an accurate caught\n\nChoose the p values\n    which swap values are better with win %\n    then how many by the different narrowing things\n\nMaybe examine some specific states\n'

### Constants and Functions

In [3]:
def load_pkl_object(pkl_path):
    '''
    Load a pickle object
    '''
    with open(pkl_path, 'rb') as handle:
        return pickle.load(handle)


def save_object_as_pkl(object_to_save, save_tag):
    '''
    Save object a pickle file
    '''
    save_path = f'joined_dict_results\\{save_tag}.pickle'

    with open(save_path, 'wb') as handle:
        print("saving: ", save_path)
        pickle.dump(object_to_save, handle, protocol=pickle.HIGHEST_PROTOCOL)


def get_chi_square_test_from_action_dict(
    action_dict,
    state_key,
    min_total_count=50,
    min_swap_count=25,
    min_attack_count=25,
    swap_key='swap_0', attack_key='attack',
    sum_wins_key='sum_wins', count_key='count',
    is_print_statistics=False):

    attack_action = 0
    swap_party_zero_action = 1
    swap_party_one_action = 2

    is_use_p_value = False
    is_swap_better = False
    p_value = None
    swap_win_rate_better_rate = 0.
    recommended_action = attack_action

    try:
        if state_key in action_dict:

            if swap_key in action_dict[state_key] and attack_key in action_dict[state_key]:
                swap_wins = action_dict[state_key][swap_key][sum_wins_key]
                swap_count = action_dict[state_key][swap_key][count_key]
                attack_wins = action_dict[state_key][attack_key][sum_wins_key]
                attack_count = action_dict[state_key][attack_key][count_key]

                total_count = swap_count + attack_count

                if total_count > min_total_count and swap_count > min_swap_count and attack_count > min_attack_count:

                    swap_win_percent = swap_wins / swap_count
                    attack_win_percent = attack_wins / attack_count
                    
                    if swap_win_percent > attack_win_percent:
                        is_swap_better = True
                        swap_win_rate_better_rate = swap_win_percent - attack_win_percent
                    else:
                        is_swap_better = False
                        swap_win_rate_better_rate = 0.

                    # chi squared table breaks down if any 0 values
                    # really should not have less than 5
                    if attack_wins == attack_count:
                        recommended_action = attack_action
                        # choose attack as attack always wins
                        if is_print_statistics:
                            print("Attack always wins")
                            print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                            print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                    elif swap_wins == swap_count:
                        # choose swap
                        is_use_p_value = True
                        is_swap_better = True
                        p_value = 0.
                        recommended_action = swap_party_zero_action
                        if is_print_statistics:
                            print("swap always wins, choosing swap")
                            print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                            print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                    elif swap_wins == 0:
                        recommended_action = attack_action
                        # swap always loses
                        if is_print_statistics:
                            print("Swap always loses")
                            print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                            print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                    elif attack_wins == 0:
                        # attack always loses and swap won at least once so choose swap
                        is_use_p_value = True
                        is_swap_better = True
                        p_value = 0.
                        recommended_action = swap_party_zero_action
                        if is_print_statistics:
                            print("Attack always loses, choosing swap ")
                            print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                            print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                    else:
                        contingency_table = [[swap_wins, swap_count - swap_wins], [attack_wins, attack_count - attack_wins]]
                        chi2, p_value, dof, expected = chi2_contingency(contingency_table)
                        is_use_p_value = True

                        if is_swap_better:
                            if p_value < 0.25:
                                recommended_action = swap_party_zero_action
                            elif swap_win_rate_better_rate >= .1:
                                recommended_action = swap_party_zero_action
                            elif swap_win_rate_better_rate >= .05 and p_value < .6:
                                recommended_action = swap_party_zero_action

                        if is_print_statistics:
                            #print(f'Swap Win : { win_loss_draw1[0] / sum(win_loss_draw1):.3f}')
                            print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                            print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                            print(f'Chi-square statistic: {chi2:.3f}')
                            print(f'P-value: {p_value:.5f}')

        else:
            is_use_p_value = False
            is_swap_better = False
            p_value = None
            swap_win_rate_better_rate = 0.
            recommended_action = attack_action
    except Exception as e:
        print("Error: in chi square test ", str(e) )
        is_use_p_value = False
        is_swap_better = False
        p_value = None
        swap_win_rate_better_rate = 0.
        recommended_action = attack_action
    
    return recommended_action, swap_win_rate_better_rate, is_use_p_value, is_swap_better, p_value


def get_chi_square_test_from_count_wins(
    swap_wins, swap_count, attack_wins, attack_count,
    min_total_count=50,
    min_swap_count=25,
    min_attack_count=25,
    is_print_statistics=False):

    attack_action = 0
    swap_party_zero_action = 1
    swap_party_one_action = 2

    is_use_p_value = False
    is_swap_better = False
    p_value = None
    swap_win_rate_better_rate = 0.
    recommended_action = attack_action

    try:
        total_count = swap_count + attack_count

        if total_count > min_total_count and swap_count > min_swap_count and attack_count > min_attack_count:

            swap_win_percent = swap_wins / swap_count
            attack_win_percent = attack_wins / attack_count
            
            if swap_win_percent > attack_win_percent:
                is_swap_better = True
                swap_win_rate_better_rate = swap_win_percent - attack_win_percent
            else:
                is_swap_better = False
                swap_win_rate_better_rate = 0.

            # chi squared table breaks down if any 0 values
            # really should not have less than 5
            if attack_wins == attack_count:
                recommended_action = attack_action
                # choose attack as attack always wins
                if is_print_statistics:
                    print("Attack always wins")
                    print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                    print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
            elif swap_wins == swap_count:
                # choose swap
                is_use_p_value = True
                is_swap_better = True
                p_value = 0.
                recommended_action = swap_party_zero_action
                if is_print_statistics:
                    print("swap always wins, choosing swap")
                    print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                    print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
            elif swap_wins == 0:
                recommended_action = attack_action
                # swap always loses
                if is_print_statistics:
                    print("Swap always loses")
                    print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                    print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
            elif attack_wins == 0:
                # attack always loses and swap won at least once so choose swap
                is_use_p_value = True
                is_swap_better = True
                p_value = 0.
                recommended_action = swap_party_zero_action
                if is_print_statistics:
                    print("Attack always loses, choosing swap ")
                    print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                    print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
            else:
                contingency_table = [[swap_wins, swap_count - swap_wins], [attack_wins, attack_count - attack_wins]]
                chi2, p_value, dof, expected = chi2_contingency(contingency_table)
                is_use_p_value = True

                if is_swap_better:
                    if p_value < 0.25:
                        recommended_action = swap_party_zero_action
                    elif swap_win_rate_better_rate >= .1:
                        recommended_action = swap_party_zero_action
                    elif swap_win_rate_better_rate >= .05 and p_value < .6:
                        recommended_action = swap_party_zero_action

                if is_print_statistics:
                    #print(f'Swap Win : { win_loss_draw1[0] / sum(win_loss_draw1):.3f}')
                    print(f"Swap win rate: {swap_wins / swap_count:.3f} | Count {swap_count}")
                    print(f"Attack win rate: {attack_wins / attack_count:.3f} | Count {attack_count}")
                    print(f'Chi-square statistic: {chi2:.3f}')
                    print(f'P-value: {p_value:.5f}')

        else:
            is_use_p_value = False
            is_swap_better = False
            p_value = None
            swap_win_rate_better_rate = 0.
            recommended_action = attack_action

    except Exception as e:
        print("Error: in chi square test ", str(e) )
        is_use_p_value = False
        is_swap_better = False
        p_value = None
        swap_win_rate_better_rate = 0.
        recommended_action = attack_action
    
    return recommended_action, swap_win_rate_better_rate, is_use_p_value, is_swap_better, p_value



### 2v2

In [4]:
two_combined_dict = load_pkl_object('joined_dict_results\\two_vs_two_combined_results.pickle')


In [5]:
print(len(two_combined_dict))

3118229


In [6]:
for k, v in two_combined_dict.items():
    print(k, v)
    break

(2, 2, 2, 1, 3, 3, 1, 1, 0, 0, 1, 0) {'attack': {'sum_wins': 455, 'count': 900}, 'swap_0': {'sum_wins': 306, 'count': 846}}


In [7]:
# getting the counts
# need to filter out the times when TTF is 0 (ie not a true 2v2)



def get_state_counts(check_dict):
    count_key = 'count'
    sum_wins_key = 'sum_wins'
    agent_first_move_attack_key = 'attack'
    agent_first_move_swap_party_0_key = 'swap_0'
    agent_first_move_swap_party_1_key = 'swap_1'
    time_int = int(time.time())

    count_list = []
    swap_2v2_dict = {}
    swap_better_count = 0

    for state_key, state_dict in check_dict.items():
        if agent_first_move_swap_party_0_key in state_dict and agent_first_move_attack_key in state_dict:
            count_list.append(state_dict[agent_first_move_swap_party_0_key].get(count_key, np.NaN))
            count_list.append(state_dict[agent_first_move_attack_key].get(count_key, np.NaN))

            recommended_action, swap_win_rate_better_rate, is_use_p_value, is_swap_better, p_value = get_chi_square_test_from_action_dict(
                check_dict, state_key, is_print_statistics=False)
        
            if is_swap_better:

                swap_better_count += 1

                swap_2v2_dict[state_key] = {
                    'recommended_action': recommended_action,
                    'swap_win_rate_better_rate': swap_win_rate_better_rate,
                    'is_use_p_value': is_use_p_value,
                    'is_swap_better': is_swap_better,
                    'p_value': p_value,
                    'swap_0_count': state_dict[agent_first_move_swap_party_0_key].get(count_key, np.NaN),
                    'attack_count': state_dict[agent_first_move_attack_key].get(count_key, np.NaN),
                    'swap_0_win_count': state_dict[agent_first_move_swap_party_0_key].get(sum_wins_key, np.NaN),
                    'attack_win_count': state_dict[agent_first_move_attack_key].get(sum_wins_key, np.NaN)
                }
            
                

        # for move_key, move_dict in state_dict.items():
        #     count_list.append(move_dict.get(count_key, np.NaN))

    count_array = np.array(count_list)
    print(f"swap better count: {swap_better_count} | len swap dict {len(swap_2v2_dict)}")

    save_object_as_pkl(count_list, f'count_list_2v2_{time_int}')
    save_object_as_pkl(swap_2v2_dict, f'swap_2v2_dict_{time_int}')

    return count_array, swap_better_count, swap_2v2_dict

two_count_array, swap_better_count, swap_2v2_dict = get_state_counts(copy.deepcopy(two_combined_dict))


swap better count: 38194 | len swap dict 38194
saving:  joined_dict_results\count_list_2v2_1720491755.pickle
saving:  joined_dict_results\swap_2v2_dict_1720491755.pickle


In [8]:
print(two_count_array.shape, np.mean(two_count_array), np.median(two_count_array), np.min(two_count_array), np.max(two_count_array))



(2426156,) 89.4686932744638 3.0 1 3426669


In [12]:
for x in [5, 10, 15, 20, 25, 50, 100]:
    less_than_x = np.sum(two_count_array < x)
    print(f'count less than {x}: {less_than_x} | {less_than_x / len(two_count_array):.3f} ')

count less than 5: 1403590 | 0.579 
count less than 10: 1775533 | 0.732 
count less than 15: 1928543 | 0.795 
count less than 20: 2015950 | 0.831 
count less than 25: 2073878 | 0.855 
count less than 50: 2206844 | 0.910 
count less than 100: 2292732 | 0.945 


In [10]:
break_num = 0

for k, v in swap_2v2_dict.items():
    pprint.pprint(k)
    pprint.pprint(v)
    break_num += 1
    if break_num > 3:
        break

(1, 2, 1, 1, 2, 3, 1, 1, 0, 0, 0, 0)
{'attack_count': 38328,
 'attack_win_count': 19162,
 'is_swap_better': True,
 'is_use_p_value': True,
 'p_value': 0.21553126209387186,
 'recommended_action': 1,
 'swap_0_count': 1396,
 'swap_0_win_count': 722,
 'swap_win_rate_better_rate': 0.017244158250396657}
(1, 1, 2, 2, 1, 3, 1, 2, 0, 2, 0, 0)
{'attack_count': 80936,
 'attack_win_count': 27488,
 'is_swap_better': True,
 'is_use_p_value': True,
 'p_value': 1.084388778413714e-09,
 'recommended_action': 1,
 'swap_0_count': 1005,
 'swap_0_win_count': 434,
 'swap_win_rate_better_rate': 0.09221442456591217}
(2, 1, 1, 2, -1, -1, 1, 1, 0, 0, 1, 0)
{'attack_count': 3933,
 'attack_win_count': 329,
 'is_swap_better': True,
 'is_use_p_value': True,
 'p_value': 6.6842593945635835e-96,
 'recommended_action': 1,
 'swap_0_count': 4039,
 'swap_0_win_count': 1050,
 'swap_win_rate_better_rate': 0.17631418107723787}
(1, 1, 1, 2, -1, -1, 1, 1, 0, 0, 0, 0)
{'attack_count': 69193,
 'attack_win_count': 19315,
 'is_swap

In [11]:
break_num = 0
recommend_swap_count = 0
for k, v in swap_2v2_dict.items():
    if v['recommended_action'] == 1:
        recommend_swap_count += 1

    # pprint.pprint(k)
    # pprint.pprint(v)
    # break_num += 1
    # if break_num > 10:
    #     break
print(recommend_swap_count, len(swap_2v2_dict))

29863 38194


### 2v3

In [None]:
'''
so for 2v3 recall the wins and counts are one part and the other part is final state ended up in so the math is something like
	num_wins no 2v2
	num battles no 2v2
	then for each state take the num actions
		find win rate for that state with the better action there
			given some min number. if no min number then don't use
			estimated_num_wins_in_state = win_rate * num_battles_in_state

			num_wins += estiomated_num_wins
			num_battles += num_battles_instates

	once that dict is made, can then do the analysis on it

steps
load the 2v3 dict
load the 2v2 dict

iterate through the steps in two vs three combined dict
	need to have min number of outcomes in the state
'''

In [None]:
two_combined_dict = load_pkl_object('joined_dict_results\\two_vs_two_combined_results.pickle')
two_vs_three_combined_dict = load_pkl_object('joined_dict_results\\two_vs_three_2v3_combined_results.pickle')
two_vs_two_recommended_dict = load_pkl_object('joined_dict_results\\swap_2v2_dict_1720491755.pickle')

In [None]:
# example entries

# (1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0) {'attack': {'sum_wins': 705, 'count': 5542, (1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 5407}, 'swap_0': {'sum_wins': 0, 'count': 4458, (1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 314}}
# (1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0) {'attack': {'sum_wins': 14931, 'count': 121805}}
# (2, 2, 2, 2, 2, 1, 1, 2, 1, 3, -1, -1, 0, 0, 1, 2, 0) {'attack': {'sum_wins': 0, 'count': 1}}
# (2, 2, 1, 0, 0, 0, 2, 0, 3, 0, -1, 0, 0, 0, 0, 2, 0) {'attack': {'sum_wins': 0, 'count': 18}}
# (1, 2, 1, 0, 0, 0, 1, 0, 1, 0, -1, 0, 0, 0, 0, 2, 0) {'attack': {'sum_wins': 0, 'count': 588}}
# (2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0) {'attack': {'sum_wins': 0, 'count': 82}, 'swap_0': {'sum_wins': 0, 'count': 7, (1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 7, (1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0): 1}}

In [None]:
def make_new_state_tuple(state_key, filter_list):
    new_state_list = []
    for i in range(len(state_key)):
        if filter_list[i]:
            new_state_list.append(state_key[i])

    new_state_tuple = tuple(new_state_list)
    if len(new_state_tuple) != 12:
        print("Error: new state tuple not correct length of 12", new_state_tuple)

    return new_state_tuple

def convert_2v3_state_to_2v2_state(state_key):
    # example state
    #(1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0)
    #(1, 2, 1,  1, 1, 1,    1, 1,   1, 1,   1, 1,   0, 0,   0, 0, 0)
    # idx 0,1,2 are active attacking opp pkm
    # idx 3,4,5 are party attacking party pkm
    # idx 6,7 are opp active attacking agent
    # idx 8,9 are party 0 active attacking agent
    # idx 10,11 are party 1 active attacking agent
    # idx 12, 13, are agent hp normalized
    # idx 14,15,16 are opp hp normalized

    is_two_vs_two_key = False
    filtered_state_key = tuple([19])

    if len(state_key) == 17:
        if state_key[0] == 0 or state_key[4] == 0:
            print("Error: active opp is fainted ", state_key)
        else:
            if state_key[1] == 0 and state_key[5] == 0 \
                and state_key[2] != 0 and state_key[6] != 0:
                # opp party 0 is down
                filter_list = [True, False, True,
                               True, False, True,
                               True, True,
                               False, False,
                               True, True,
                               True, True,
                               True, False, True,]  
                
                filtered_state_key = make_new_state_tuple(state_key, filter_list)

            elif state_key[2] == 0 and state_key[6] == 0 \
                and state_key[1] != 0 and state_key[5] != 0:
                # opp party 1 is down
                filter_list = [True, True, False,
                               True, True, False,
                               True, True,
                               True, True,
                               False, False,
                               True, True,
                               True, True, False,]  
                
                filtered_state_key = make_new_state_tuple(state_key, filter_list)
            else:
                print("Error: unable to located the fainted opp ", state_key)
    else:
        print("Error: state key not 17 long ", state_key)

    if not is_two_vs_two_key:
        print("Warning: not a 2v2 key ", state_key, filtered_state_key)

    return filtered_state_key, is_two_vs_two_key


def get_state_counts_2v3(two_vs_three_combined_dict, save_tag,
        two_vs_two_recommended_dict, two_vs_two_combined_dict):
    count_key = 'count'
    sum_wins_key = 'sum_wins'
    agent_first_move_attack_key = 'attack'
    agent_first_move_swap_party_0_key = 'swap_0'
    agent_first_move_swap_party_1_key = 'swap_1'
    time_int = int(time.time())

    count_list = []
    swap_dict = {}
    swap_better_count = 0

    three_vs_three_lookup_dict = {}

    for state_key, state_dict in two_vs_three_combined_dict.items():
        if agent_first_move_swap_party_0_key in state_dict and agent_first_move_attack_key in state_dict:

            # get wins, counts for swap/attack that led to end state
            # and do not go to 2v2
            attack_sub_count = state_dict[agent_first_move_attack_key].get(count_key, 0)
            attack_sub_wins = state_dict[agent_first_move_attack_key].get(sum_wins_key, 0)

            swap_sub_count = state_dict[agent_first_move_swap_party_0_key].get(count_key, 0)
            swap_sub_wins = state_dict[agent_first_move_swap_party_0_key].get(sum_wins_key, 0)

            # get counts, estimated wins for states that went to 2v2
            attack_to_2v2_count = 0
            swap_to_2v2_count = 0
            attack_to_2v2_estimated_wins = 0.
            swap_to_2v2_estimated_wins = 0.

            for inner_move_key, inner_move_value in state_dict[agent_first_move_attack_key].items():
                if inner_move_key != count_key and inner_move_key != sum_wins_key:
                    inner_move_state_key_count = inner_move_value.get(inner_move_key, 0)
                    attack_to_2v2_count += inner_move_state_key_count
                    # turn the 2v3 keys into 2v2 keys
                    filtered_state_key, is_use_two_vs_two_key = convert_2v3_state_to_2v2_state(inner_move_key)

                    # get results for that state using the best action by win %
                    if is_use_two_vs_two_key:
                        if filtered_state_key in two_vs_two_recommended_dict:
                            recommended_action = two_vs_two_recommended_dict[filtered_state_key].get('recommended_action', 0)

                            if recommended_action == 1:
                                lookup_key = agent_first_move_swap_party_0_key
                            else:
                                lookup_key = agent_first_move_attack_key

                            lookup_dict = two_vs_two_combined_dict.get(filtered_state_key, {}).get(lookup_key, {})
                            lookup_2v2_count = lookup_dict.get(count_key, 0)
                            lookup_2v2_sum_wins = lookup_dict.get(sum_wins_key, 0) 
                            if lookup_2v2_count > 0 and sum_wins_key in lookup_dict:
                                lookup_win_percent = lookup_2v2_sum_wins / lookup_2v2_count
                                attack_to_2v2_estimated_wins += lookup_win_percent * inner_move_state_key_count

            for inner_move_key, inner_move_value in state_dict[agent_first_move_swap_party_0_key].items():
                if inner_move_key != count_key and inner_move_key != sum_wins_key:
                    inner_move_state_key_count = inner_move_value.get(inner_move_key, 0)
                    swap_to_2v2_count += inner_move_state_key_count
                    # turn the 2v3 keys into 2v2 keys
                    filtered_state_key, is_use_two_vs_two_key = convert_2v3_state_to_2v2_state(inner_move_key)

                    # get results for that state using the best action by win %
                    if is_use_two_vs_two_key:
                        if filtered_state_key in two_vs_two_recommended_dict:
                            recommended_action = two_vs_two_recommended_dict[filtered_state_key].get('recommended_action', 0)

                            if recommended_action == 1:
                                lookup_key = agent_first_move_swap_party_0_key
                            else:
                                lookup_key = agent_first_move_attack_key

                            lookup_dict = two_vs_two_combined_dict.get(filtered_state_key, {}).get(lookup_key, {})
                            lookup_2v2_count = lookup_dict.get(count_key, 0)
                            lookup_2v2_sum_wins = lookup_dict.get(sum_wins_key, 0) 
                            if lookup_2v2_count > 0 and sum_wins_key in lookup_dict:
                                lookup_win_percent = lookup_2v2_sum_wins / lookup_2v2_count
                                attack_to_2v2_estimated_wins += lookup_win_percent * inner_move_state_key_count

            # get the total wins, total counts
            # make new chi square test look up function for this

            total_attack_counts = attack_sub_count + attack_to_2v2_count
            total_attack_wins = attack_sub_wins + int(np.round(attack_to_2v2_estimated_wins,0))

            total_swap_counts = swap_sub_count + swap_to_2v2_count
            total_swap_wins = swap_sub_wins + int(np.round(swap_to_2v2_estimated_wins,0))

            # for 3v2 need to take the best swap
            recommended_action, swap_win_rate_better_rate, is_use_p_value, is_swap_better, p_value = get_chi_square_test_from_count_wins(
                total_swap_wins, total_swap_counts, total_attack_wins, total_attack_counts,
                min_total_count=50,
                min_swap_count=25,
                min_attack_count=25,
                is_print_statistics=False)
            
            count_list.append(total_swap_counts)
            count_list.append(total_attack_counts)
        
            if is_swap_better:

                swap_better_count += 1

                swap_dict[state_key] = {
                    'recommended_action': recommended_action,
                    'swap_win_rate_better_rate': swap_win_rate_better_rate,
                    'is_use_p_value': is_use_p_value,
                    'is_swap_better': is_swap_better,
                    'p_value': p_value,
                    # to do: fill these in
                    'swap_0_count': state_dict[agent_first_move_swap_party_0_key].get(count_key, np.NaN),
                    'attack_count': state_dict[agent_first_move_attack_key].get(count_key, np.NaN),
                    '2v2_swap_visit_count': total_swap_counts,
                    '2v2_swap_estimated_win_count': total_swap_wins,
                    '2v2_attack_visit_count': total_attack_counts ,
                    '2v2_attack_estimated_win_count': total_attack_wins,
                }

                three_vs_three_lookup_dict[state_key] = {
                    count_key: total_attack_counts,
                    sum_wins_key: total_attack_wins,
                }
            else:
                three_vs_three_lookup_dict[state_key] = {
                    count_key: total_attack_counts,
                    sum_wins_key: total_attack_wins,
                }


    count_array = np.array(count_list)
    print(f"swap better count: {swap_better_count} | len swap dict {len(swap_dict)}")

    save_object_as_pkl(swap_dict, f'swap_{save_tag}_dict_{time_int}')
    save_object_as_pkl(three_vs_three_lookup_dict, f'three_vs_three_lookup_dict_{save_tag}_{time_int}')
    save_object_as_pkl(count_list, f'count_list_{save_tag}_{time_int}')

    return count_array, swap_better_count, swap_dict, three_vs_three_lookup_dict



In [None]:
#two_v_three_count_array, two_v_three_swap_better_count, swap_two_v_three_dict = get_state_counts(
#    two_vs_three_combined_dict, two_vs_three_save_tag, two_vs_two_recommended_dict, two_vs_two_combined_dict
#)