In [None]:
import pandas as pd
import numpy as np
import matplotlib as mpl

def map_feature_portfolio(infile, metric='%score', 
                          mean_across_maps=False, filters={'test_pos': 2}, 
                          unstack_level = 2, renaming=None):
    df = pd.read_csv(infile)
    df = df.rename(columns = {'strat':'portfolio'}) # renames the portfolio column

    # display(df.describe())
    
    # human-friendly portfolio names
    df.replace({'WR_LR_RR_HR': 'rush', 'WR_LR_RR_HR_BB_BK': 'rush+suport', 'WR_LR_RR_HR_WD_LD_RD_HD': 'rush+defense',  'WR_LR_RR_HR_WD_LD_RD_HD_BB_BK': 'all'}, inplace=True)
    
    # human-friendly feature selectors
    df.replace({'materialdistancehp': 'Simple', 'quadrantmodel': 'Per-quadrant'}, inplace=True)
    
    #if there is any other rename, do it:
    if renaming is not None:
        df.replace(renaming, inplace=True)
    
    # the columns I want to see
    indexes = ['mapname', 'feature', 'portfolio', 'alpha', '_lambda', 'train_matches', 'dec_int', 'test_pos']
    
    #filters according to user criteria 
    for fname,fvalue in filters.items():
        if isinstance(fvalue, float):
            #print(fname, fvalue, df[fname])
            df = df[np.isclose(df[fname],fvalue)] 
        else:
            df = df[df[fname] == fvalue] 
        # drops the column
        df.drop(fname, axis=1, inplace=True)
        indexes.remove(fname)
        
    # display(df.describe())
    
    # builds a hierarchical (multiindex) data frame (kinda groups by map, strategy & units)
    hier = df.set_index(indexes) 
    #hier.index = natsorted(hier.index)
    #print(np.where(hier.index.duplicated()))  # finds entries with duplicate indexes, thx: https://stackoverflow.com/a/28652153/1251716
    
    # unstack the hierarchical dataframe up to the portfolio (all other parameters are grouped?)
    unstacked = hier.unstack(unstack_level)[metric]
    
    if mean_across_maps: 
        # if i group by the 'feature' level, i can get the mean on all maps played using that feature extractor
        return unstacked.groupby(level=['feature']).mean()
    else:
        return unstacked
        
        
def autolabel(ax):
    """
    Attach a text label above each bar displaying its height 
    (source: https://stackoverflow.com/a/22689127/1251716)
    """
    rects = [rect for rect in ax.get_children() if isinstance(rect, mpl.patches.Rectangle)]
    for rect in rects:
        height = rect.get_height()
        if height != 1: #little hack to avoid a dummy 1 between bars
            ax.text(rect.get_x() + rect.get_width()/2., 1.001*height,
                    '%d' % int(height),
                    ha='center', va='bottom')

In [None]:
# displays the results for all adversaries
opponents = ['A3N', 'SSSmRTS','StrategyTactics','PuppetSearchMCTS','WorkerRush','LightRush'] 

# filters by a specific set of parameters
filters = {'test_pos': 2, 'dec_int': 100, 'train_matches': 100000, '_lambda': 0.5, 'alpha': 0.01,}

for o in opponents:
    print("---{}, no search budget---".format(o))
    display(map_feature_portfolio('../results/parameters/{}_avg.csv'.format(o), '%score', True, filters))
    
    print("---{}, search budget=100ms---".format(o))
    display(map_feature_portfolio('../results/parameters/{}_b100_avg.csv'.format(o), '%score', True, filters))

In [None]:
### ONLY LIGHT & WORKER RUSH WITH b0 and b100:
# filters by a specific set of parameters
filters = {'test_pos': 2, 'dec_int': 100,  '_lambda': 0.5, 'alpha': 0.01, 'portfolio': 'all'}

for o in ['WorkerRush', 'LightRush']:
    for b in [0, 100]:
        print("---{}, search budget={}ms---".format(o, b))
        display(map_feature_portfolio('../results/parameters/{}_b{}_avg.csv'.format(o,b), '%score', False, filters))

In [None]:
### This is to check the feature extractor on regular maps
import matplotlib.pyplot as plt

print('# Comparing feature extractors on "regular" maps')
# filters by a specific set of parameters
#filters = {'test_pos': 2, 'dec_int': 100,  '_lambda': 0.5, 'alpha': 0.01, 'portfolio': 'all',}# 'train_matches': 100000}

#for o in ['WorkerRush', 'LightRush']:
#    print("---{}---".format(o))
#    display(map_feature_portfolio('../results/parameters/{}_b0_avg.csv'.format(o), '%score', False, filters, 1))
    
#print('# Comparing feature extractors, now only at the end of training ')

filters = {'test_pos': 2, 'dec_int': 100,  '_lambda': 0.5, 'alpha': 0.01, 'portfolio': 'all', 'train_matches': 100000}
for o in ['WorkerRush', 'LightRush']:
    #print("---{}---".format(o))
    #display(data)
    data = map_feature_portfolio('../results/parameters/{}_b0_avg.csv'.format(o), '%score', False, filters, 1)
    data.sort_values(by=['mapname'], ascending=False, inplace=True) # so that 8x8 comes before 32x32
    ax = data.plot(title=o, ylim=[0, 105],kind='bar')
    ax.set(xlabel="Map", ylabel="Score (%)")
    ax.set_xticklabels( ax.get_xticklabels(), rotation=0) # rotates the labels without changing them
    autolabel(ax)
    ax.get_figure().savefig(f'/tmp/feature-vs_{o}_regular.pdf')

In [None]:
### This is to check the feature extractor on irregular/split maps

print('# Comparing feature extractors on "irregular/split" maps')

# filters by a specific set of parameters
#filters = {'test_pos': 2, 'dec_int': 100,  '_lambda': 0.5, 'alpha': 0.01, 'portfolio': 'all',}# 'train_matches': 100000}

#for o in ['WorkerRush', 'LightRush']:
#     print("---{}---".format(o))
#     display(map_feature_portfolio('../results/features/{}_b0_avg.csv'.format(o), '%score', False, filters, 1))
    
# print('# Comparing feature extractors, now only at the end of training ')

filters = {'test_pos': 2, 'dec_int': 100,  '_lambda': 0.5, 'alpha': 0.01, 'portfolio': 'all', 'train_matches': 100000}
map_rename = {'FourBasesWorkers8x8': '8x8', 'NoWhereToRun9x8': '9x8', 'TwoBasesBarracks16x16': '16x16'}

for o in ['WorkerRush', 'LightRush']:
    data = map_feature_portfolio('../results/features/{}_b0_avg.csv'.format(o), filters=filters, unstack_level= 1, renaming=map_rename)
    #print("---{}---".format(o))
    #display(data)
    data = data.loc[['8x8', '9x8', '16x16']] # to sort the columns for the barplot (https://stackoverflow.com/a/20555406/1251716)
    ax = data.plot(title=o, ylim=[0, 105],kind='bar')
    ax.set(xlabel="Map", ylabel="Score (%)")
    ax.set_xticklabels( ax.get_xticklabels(), rotation=0) # rotates the labels without changing them
    autolabel(ax)
    ax.get_figure().savefig(f'/tmp/feature-vs_{o}_irregular.pdf')