In [1]:
import pandas as pd
pd.set_option('display.max_columns', 500)

# --------------------------------------------------------------------------------------------------------------------

def LLoc(loc):
    # return opposite of WLoc for L team
    key = {'H': 'A', 'A': 'H', 'N': 'N'}
    return key[loc]
    
def split_rows(df, w_map, l_map):
    df1 = df.copy()
    df2 = df.copy()
    
    to_keep = ['Season', 'NumOT']
    df1.drop(columns=[x for x in l_map.keys() if x not in to_keep], inplace=True)
    df2.drop(columns=[x for x in w_map.keys() if x not in to_keep], inplace=True)
    
    df1.rename(columns=w_map, inplace=True)
    df2.rename(columns=l_map, inplace=True)
    
    # win-loss count
    df1["W"], df1["L"] = 1, 0
    df2["W"], df2["L"] = 0, 1
    
    return pd.concat([df1, df2]).sort_index()

def add_column(df, after, name, values):
    df        = df.copy()
    col_index = df.columns.get_loc(after) + 1
    df.insert(col_index, name, values)
    return df

def add_seeds(df):
    seed_df = pd.read_csv('data/MNCAATourneySeeds.csv')
    seed_df['Seed'] = seed_df['Seed'].apply(lambda x: int(x[1:3]))
    return seed_df[["Season", "TeamID", "Seed"]].merge(df)

def add_conferences(df):
    conf = pd.read_csv("data/MTeamConferences.csv")
    team = pd.read_csv("data/MTeams.csv")
    team_and_conf = conf.merge(team)
    df = team_and_conf[["Season", "TeamID", "TeamName", "ConfAbbrev"]].merge(df)
    # conf avg rank
    by_conf = df.groupby(['Season', 'ConfAbbrev'], as_index=False).mean()
    df = add_column(
        df=df, after="ConfAbbrev", name="ConfAverageOrdinalRank", 
        values=df.apply(
            lambda x: by_conf.loc[(by_conf['Season'] == x.Season) & (by_conf['ConfAbbrev']==x.ConfAbbrev)].AverageOrdinalRank.values[0],
            axis=1,
        )
    )
    return df

def add_percentages(df):
    df = add_column(df=df, after="FGA", name="FG%", values=100*(df["FGM"] / df["FGA"]))
    df = add_column(df=df, after="FGA3", name="FG3%", values=100*(df["FGM3"] / df["FGA3"]))
    df = add_column(df=df, after="FTA", name="FT%", values=100*(df["FTM"] / df["FTA"]))
    return df
    
def add_ratings(df):
    ratings = pd.read_csv("data/MMasseyOrdinals.csv")
    # some weeks have different numbers of rankings so get average by each ranking day
    grouped = ratings.groupby(['Season', 'TeamID', 'RankingDayNum'], as_index=False).mean()
    grouped.drop(columns=['RankingDayNum'], inplace=True)
    # then group the group of rankings with day to get season
    average = grouped.groupby(['Season', 'TeamID'], as_index=False).mean().rename(columns={"OrdinalRank": "AverageOrdinalRank"})
    minimum = grouped.groupby(['Season', 'TeamID'], as_index=False).min().rename(columns={"OrdinalRank": "MinOrdinalRank"})
    maximum = grouped.groupby(['Season', 'TeamID'], as_index=False).max().rename(columns={"OrdinalRank": "MaxOrdinalRank"})
    latest  = grouped.groupby(['Season', 'TeamID'], as_index=False).tail(1).rename(columns={"OrdinalRank": "LatestOrdinalRank"})
    ratings_summary = average.merge(minimum).merge(maximum).merge(latest)
    return ratings_summary.merge(df)
 
def get_head_to_head_record(season, team1_id, team2_id):
    results  = pd.read_csv("data/MRegularSeasonCompactResults.csv")
    matchups = results.loc[
        (results['Season'] == season) & 
        (results['WTeamID'].isin([team1_id, team2_id])) & 
        (results['LTeamID'].isin([team1_id, team2_id]))
    ]
    if not matchups.empty: 
        counts = matchups[['WTeamID','LTeamID']].apply(pd.value_counts).fillna(0)
    else:
        counts = {'WTeamID': {team1_id : 0, team2_id: 0}, 'LTeamID': {team1_id : 0, team2_id: 0}}
    
    return {
        team1_id: {"W": counts['WTeamID'][team1_id], "L": counts['LTeamID'][team1_id]}, 
        team2_id: {"W": counts['WTeamID'][team2_id], "L": counts['LTeamID'][team2_id]},
    }
    
def get_common_opponent_record(season, team1_id, team2_id):
    results  = pd.read_csv("data/MRegularSeasonCompactResults.csv")
    team1_matchups = results.loc[
        (results['Season'] == season) & 
        ((results['WTeamID'] == team1_id) | (results['LTeamID'] == team1_id)) 
    ]
    team2_matchups = results.loc[
        (results['Season'] == season) & 
        ((results['WTeamID'] == team2_id) | (results['LTeamID'] == team2_id)) 
    ]
    
    team1_opponents = set(list(team1_matchups['WTeamID'].unique()) + list(team1_matchups['LTeamID'].unique()))
    team2_opponents = set(list(team2_matchups['WTeamID'].unique()) + list(team2_matchups['LTeamID'].unique()))
    
    # remove self
    team1_opponents.remove(team1_id)
    team2_opponents.remove(team2_id)

    common_opponents = list(team1_opponents.intersection(team2_opponents))
    # keep only common matchups
    team1_matchups = team1_matchups.loc[
        (results['WTeamID'].isin(common_opponents)) | (results['LTeamID'].isin(common_opponents))
    ]
    team2_matchups = team2_matchups.loc[
        (results['WTeamID'].isin(common_opponents)) | (results['LTeamID'].isin(common_opponents))
    ]
    
    if not team1_matchups.empty: # team2 is also not empty
        team1_counts = team1_matchups[['WTeamID','LTeamID']].apply(pd.value_counts).fillna(0)
        team2_counts = team2_matchups[['WTeamID','LTeamID']].apply(pd.value_counts).fillna(0)
    else:
        team1_counts = {'WTeamID': {team1_id : 0}, 'LTeamID': {team1_id : 0}}
        team2_counts = {'WTeamID': {team2_id: 0}, 'LTeamID': {team2_id: 0}}
        

    return {
        team1_id: {"W": team1_counts['WTeamID'][team1_id], "L": team1_counts['LTeamID'][team1_id]}, 
        team2_id: {"W": team2_counts['WTeamID'][team2_id], "L": team2_counts['LTeamID'][team2_id]},
    }

def get_last_k_games_record(season, team_id):
    k = 10
    
    results  = pd.read_csv("data/MRegularSeasonCompactResults.csv")
    matchups = results.loc[
        (results['Season'] == season) & 
        ((results['WTeamID'] == team_id) | (results['LTeamID'] == team_id))
    ].tail(k)
    
    counts = matchups[['WTeamID','LTeamID']].apply(pd.value_counts).fillna(0)
    return {"W": counts['WTeamID'][team_id], "L": counts['LTeamID'][team_id]}
    
def add_empty_record_cols(df):
    print('---')
    df = add_column(df=df, after="L", name="WH2H", values=0)
    df = add_column(df=df, after="WH2H", name="LH2H", values=0)
    df = add_column(df=df, after="LH2H", name="WCommonOpp", values=0)
    df = add_column(df=df, after="WCommonOpp", name="LCommonOpp", values=0)
    df = add_column(df=df, after="LCommonOpp", name="WLastK", values=0)
    df = add_column(df=df, after="WLastK", name="LLastK", values=0)
    return df

def populate_record_columns(df):
    df.loc[:, ['AWH2H', 'ALH2H']] = df.apply(
        lambda row: pd.Series(get_head_to_head_record(row.Season, row.ATeamID, row.BTeamID)[row.ATeamID].values(),index=['AWH2H', 'ALH2H']), 
        axis=1,
    )
    print('-')
    df.loc[:, ['AWCommonOpp', 'ALCommonOpp']] = df.apply(
        lambda row: pd.Series(get_common_opponent_record(row.Season, row.ATeamID, row.BTeamID)[row.ATeamID].values(),index=['AWCommonOpp', 'ALCommonOpp']), 
        axis=1,
    )
    print('--')
    df.loc[:, ['AWLastK', 'ALLastK']] = df.apply(
        lambda row: pd.Series(get_last_k_games_record(row.Season, row.ATeamID).values(),index=['AWLastK', 'ALLastK']), 
        axis=1,
    )
    print('---')

    df.loc[:, ['BWH2H', 'BLH2H']] = df.apply(
        lambda row: pd.Series(get_head_to_head_record(row.Season, row.ATeamID, row.BTeamID)[row.BTeamID].values(),index=['BWH2H', 'BLH2H']), 
        axis=1,
    )
    print('----')
    df.loc[:, ['BWCommonOpp', 'BLCommonOpp']] = df.apply(
        lambda row: pd.Series(get_common_opponent_record(row.Season, row.ATeamID, row.BTeamID)[row.BTeamID].values(),index=['BWCommonOpp', 'BLCommonOpp']), 
        axis=1,
    )
    print('-----')
    df.loc[:, ['BWLastK', 'BLLastK']] = df.apply(
        lambda row: pd.Series(get_last_k_games_record(row.Season, row.BTeamID).values(),index=['BWLastK', 'BLLastK']), 
        axis=1,
    )
    print('------')
    return df


#### Box Scores

In [2]:
box_scores = pd.read_csv("data/MRegularSeasonDetailedResults.csv")

box_scores = add_column(df=box_scores, after="WLoc", name="LLoc", values=box_scores["WLoc"].map(LLoc))
box_scores = add_column(df=box_scores, after="WDR", name="WTR", values=box_scores["WOR"]+box_scores["WDR"])
box_scores = add_column(df=box_scores, after="LDR", name="LTR", values=box_scores["LOR"]+box_scores["LDR"])

# map from old columns to new_columns
w_map = {
    'Season': 'Season', 'DayNum': 'DayNum', 'WTeamID': 'TeamID', 'WScore': 'Score', 'WLoc': 'Loc', 'NumOT': 'NumOT', 
    'WFGM': 'FGM', 'WFGA': 'FGA', 'WFGM3': 'FGM3', 'WFGA3': 'FGA3', 'WFTM': 'FTM', 'WFTA': 'FTA', 'WOR': 'OR', 
    'WDR': 'DR', 'WTR': 'TR', 'WAst': 'Ast', 'WTO': 'TO', 'WStl': 'Stl', 'WBlk': 'Blk', 'WPF': 'PF',
}
l_map = {
    'Season': 'Season', 'DayNum': 'DayNum', 'LTeamID': 'TeamID', 'LScore': 'Score', 'LLoc': 'Loc', 'NumOT': 'NumOT',
    'LFGM': 'FGM', 'LFGA': 'FGA', 'LFGM3': 'FGM3', 'LFGA3': 'FGA3', 'LFTM': 'FTM', 'LFTA': 'FTA', 'LOR': 'OR', 
    'LDR': 'DR', 'LTR': 'TR', 'LAst': 'Ast', 'LTO': 'TO', 'LStl': 'Stl', 'LBlk': 'Blk', 'LPF': 'PF',
}


detailed_results_GAME = split_rows(box_scores.copy(), w_map, l_map)
detailed_results_SZN_AVG = detailed_results_GAME.groupby(['Season', 'TeamID'], as_index=False).mean()

In [3]:
detailed_results_GAME

Unnamed: 0,Season,TeamID,Score,Loc,NumOT,FGM,FGA,FGM3,FGA3,FTM,FTA,OR,DR,TR,Ast,TO,Stl,Blk,PF,W,L
0,2003,1104,68,N,0,27,58,3,14,11,18,14,24,38,13,23,7,1,22,1,0
0,2003,1328,62,N,0,22,53,2,10,16,22,10,22,32,8,18,9,2,20,0,1
1,2003,1393,63,N,0,24,67,6,24,9,20,20,25,45,7,12,8,6,16,0,1
1,2003,1272,70,N,0,26,62,8,20,10,19,15,28,43,16,13,4,4,18,1,0
2,2003,1266,73,N,0,24,58,8,18,17,29,17,26,43,15,10,5,2,25,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
96684,2021,1222,91,N,0,37,66,11,23,6,7,10,26,36,24,6,8,4,16,1,0
96685,2021,1326,88,N,0,29,65,8,25,22,30,8,27,35,11,8,7,4,26,0,1
96685,2021,1228,91,N,0,29,64,9,21,24,32,13,27,40,15,10,5,2,23,1,0
96686,2021,1382,74,N,0,25,62,8,19,16,20,11,26,37,12,9,3,4,15,1,0


In [4]:
detailed_results_SZN_AVG

Unnamed: 0,Season,TeamID,Score,NumOT,FGM,FGA,FGM3,FGA3,FTM,FTA,OR,DR,TR,Ast,TO,Stl,Blk,PF,W,L
0,2003,1102,57.250000,0.000000,19.142857,39.785714,7.821429,20.821429,11.142857,17.107143,4.178571,16.821429,21.000000,13.000000,11.428571,5.964286,1.785714,18.750000,0.428571,0.571429
1,2003,1103,78.777778,0.296296,27.148148,55.851852,5.444444,16.074074,19.037037,25.851852,9.777778,19.925926,29.703704,15.222222,12.629630,7.259259,2.333333,19.851852,0.481481,0.518519
2,2003,1104,69.285714,0.035714,24.035714,57.178571,6.357143,19.857143,14.857143,20.928571,13.571429,23.928571,37.500000,12.107143,13.285714,6.607143,3.785714,18.035714,0.607143,0.392857
3,2003,1105,71.769231,0.153846,24.384615,61.615385,7.576923,20.769231,15.423077,21.846154,13.500000,23.115385,36.615385,14.538462,18.653846,9.307692,2.076923,20.230769,0.269231,0.730769
4,2003,1106,63.607143,0.035714,23.428571,55.285714,6.107143,17.642857,10.642857,16.464286,12.285714,23.857143,36.142857,11.678571,17.035714,8.357143,3.142857,18.178571,0.464286,0.535714
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6529,2021,1467,66.277778,0.055556,24.111111,53.722222,7.000000,22.111111,11.055556,16.666667,5.777778,22.111111,27.888889,12.000000,13.000000,8.000000,3.777778,7.666667,0.500000,0.500000
6530,2021,1468,72.555556,0.000000,27.277778,54.500000,6.333333,17.055556,11.666667,15.500000,6.722222,19.833333,26.555556,15.222222,10.833333,6.611111,2.055556,2.666667,0.611111,0.388889
6531,2021,1469,67.631579,0.000000,23.368421,57.842105,5.842105,19.684211,15.052632,20.473684,7.684211,23.157895,30.842105,14.421053,15.263158,6.736842,2.052632,4.315789,0.315789,0.684211
6532,2021,1470,63.866667,0.000000,22.000000,50.866667,5.466667,14.600000,14.400000,19.400000,4.400000,20.133333,24.533333,11.333333,10.933333,6.733333,2.266667,6.733333,0.333333,0.666667


In [5]:
df= add_empty_record_cols(
    add_conferences( # do after ratings added
    add_ratings(
    add_percentages(
        detailed_results_SZN_AVG
))))
df

---


Unnamed: 0,Season,TeamID,TeamName,ConfAbbrev,ConfAverageOrdinalRank,AverageOrdinalRank,MinOrdinalRank,MaxOrdinalRank,LatestOrdinalRank,Score,NumOT,FGM,FGA,FG%,FGM3,FGA3,FG3%,FTM,FTA,FT%,OR,DR,TR,Ast,TO,Stl,Blk,PF,W,L,WH2H,LH2H,WCommonOpp,LCommonOpp,WLastK,LLastK
0,2003,1102,Air Force,mwc,92.915339,145.444516,83.000000,191.000000,156.031250,57.250000,0.000000,19.142857,39.785714,48.114901,7.821429,20.821429,37.564322,11.142857,17.107143,65.135699,4.178571,16.821429,21.000000,13.000000,11.428571,5.964286,1.785714,18.750000,0.428571,0.571429,0,0,0,0,0,0
1,2004,1102,Air Force,mwc,90.656348,74.131669,17.000000,198.000000,45.212121,60.178571,0.000000,20.285714,42.035714,48.258284,8.464286,22.214286,38.102894,11.142857,15.714286,70.909091,6.142857,15.357143,21.500000,13.250000,10.785714,7.785714,2.357143,16.642857,0.785714,0.214286,0,0,0,0,0,0
2,2005,1102,Air Force,mwc,112.314462,84.200194,28.500000,147.500000,95.055556,61.241379,0.068966,21.413793,47.724138,44.869942,8.896552,24.413793,36.440678,9.517241,13.068966,72.823219,7.620690,15.413793,23.034483,13.655172,10.068966,8.793103,1.724138,16.482759,0.586207,0.413793,0,0,0,0,0,0
3,2006,1102,Air Force,mwc,119.374315,49.085155,26.333333,87.500000,52.514286,63.500000,0.000000,21.928571,45.571429,48.119122,8.607143,21.678571,39.703460,11.035714,14.785714,74.637681,6.785714,17.785714,24.571429,14.178571,10.928571,8.000000,2.107143,15.392857,0.785714,0.214286,0,0,0,0,0,0
4,2007,1102,Air Force,mwc,91.564489,18.264046,5.500000,41.250000,41.250000,68.500000,0.000000,22.800000,46.933333,48.579545,9.066667,22.466667,40.356083,13.833333,18.200000,76.007326,6.466667,20.866667,27.333333,15.033333,10.166667,6.266667,1.466667,15.166667,0.733333,0.266667,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6520,2021,1467,Merrimack,nec,247.213668,246.231401,206.255319,297.974359,239.530612,66.277778,0.055556,24.111111,53.722222,44.881075,7.000000,22.111111,31.658291,11.055556,16.666667,66.333333,5.777778,22.111111,27.888889,12.000000,13.000000,8.000000,3.777778,7.666667,0.500000,0.500000,0,0,0,0,0,0
6521,2021,1468,Bellarmine,a_sun,244.005046,225.029268,155.375000,302.111111,180.729167,72.555556,0.000000,27.277778,54.500000,50.050968,6.333333,17.055556,37.133550,11.666667,15.500000,75.268817,6.722222,19.833333,26.555556,15.222222,10.833333,6.611111,2.055556,2.666667,0.611111,0.388889,0,0,0,0,0,0
6522,2021,1469,Dixie St,wac,245.284882,307.822061,276.800000,333.666667,314.893617,67.631579,0.000000,23.368421,57.842105,40.400364,5.842105,19.684211,29.679144,15.052632,20.473684,73.521851,7.684211,23.157895,30.842105,14.421053,15.263158,6.736842,2.052632,4.315789,0.315789,0.684211,0,0,0,0,0,0
6523,2021,1470,Tarleton St,wac,245.284882,296.842426,254.250000,314.210526,254.361702,63.866667,0.000000,22.000000,50.866667,43.250328,5.466667,14.600000,37.442922,14.400000,19.400000,74.226804,4.400000,20.133333,24.533333,11.333333,10.933333,6.733333,2.266667,6.733333,0.333333,0.666667,0,0,0,0,0,0


#### Training Data (?)

In [6]:
a_map = {
    'Season': 'Season', 'TeamID': 'ATeamID', 'Seed': 'ASeed', 'TeamName': 'ATeamName', 'ConfAbbrev': 'AConfAbbrev', 'ConfAverageOrdinalRank': 'AConfAverageOrdinalRank',
    'AverageOrdinalRank': 'AAverageOrdinalRank', 'MinOrdinalRank': 'AMinOrdinalRank', 'MaxOrdinalRank': 'AMaxOrdinalRank', 
    'LatestOrdinalRank': 'ALatestOrdinalRank', 'Score': 'AScore', 'NumOT': 'ANumOT', 'FGM': 'AFGM', 'FGA': 'AFGA', 
    'FG%': 'AFG%', 'FGM3': 'FGM3', 'FGA3': 'AFGA3', 'FG3%': 'AFG3%', 'FTM': 'AFTM', 'FTA': 'AFTA', 'FT%': 'AFT%', 
    'OR': 'AOR', 'DR': 'ADR', 'TR': 'ATR', 'Ast': 'AAst', 'TO': 'ATO', 'Stl': 'AStl', 'Blk': 'ABlk', 'PF': 'APF',
    'W': 'AW', 'L': 'AL', 'WH2H': 'AWH2H', 'LH2H': 'ALH2H', 'WCommonOpp': 'AWCommonOpp', 'LCommonOpp': 'ALCommonOpp', 
    'WLastK': 'AWLastK', 'LLastK': 'ALLastK',
}
b_map = {
    'Season': 'Season', 'TeamID': 'BTeamID', 'Seed': 'BSeed', 'TeamName': 'BTeamName', 'ConfAbbrev': 'BConfAbbrev', 'ConfAverageOrdinalRank': 'BConfAverageOrdinalRank',
    'AverageOrdinalRank': 'BAverageOrdinalRank', 'MinOrdinalRank': 'BMinOrdinalRank', 'MaxOrdinalRank': 'BMaxOrdinalRank', 
    'LatestOrdinalRank': 'BLatestOrdinalRank', 'Score': 'BScore', 'NumOT': 'BNumOT', 'FGM': 'BFGM', 'FGA': 'BFGA', 
    'FG%': 'BFG%', 'FGM3': 'BFGM3', 'FGA3': 'BFGA3', 'FG3%': 'BFG3%', 'FTM': 'BFTM', 'FTA': 'BFTA', 'FT%': 'BFT%', 
    'OR': 'BOR', 'DR': 'BDR', 'TR': 'BTR', 'Ast': 'BAst', 'TO': 'BTO', 'Stl': 'BStl', 'Blk': 'BBlk', 'PF': 'BPF',
    'W': 'BW', 'L': 'BL', 'WH2H': 'BWH2H', 'LH2H': 'BLH2H', 'WCommonOpp': 'BWCommonOpp', 'LCommonOpp': 'BLCommonOpp', 
    'WLastK': 'BWLastK', 'LLastK': 'BLLastK',
}

df_copy = add_seeds(df.copy()) # do last b/c merge removes non-tourney teams
full_df = df

df_a = df_copy.rename(columns=a_map)
df_b =df_copy.rename(columns=b_map)



In [7]:
df_a

Unnamed: 0,Season,ATeamID,ASeed,ATeamName,AConfAbbrev,AConfAverageOrdinalRank,AAverageOrdinalRank,AMinOrdinalRank,AMaxOrdinalRank,ALatestOrdinalRank,AScore,ANumOT,AFGM,AFGA,AFG%,FGM3,AFGA3,AFG3%,AFTM,AFTA,AFT%,AOR,ADR,ATR,AAst,ATO,AStl,ABlk,APF,AW,AL,AWH2H,ALH2H,AWCommonOpp,ALCommonOpp,AWLastK,ALLastK
0,2003,1328,1,Oklahoma,big_twelve,59.997570,16.628562,2.00000,48.750000,6.117647,71.166667,0.100000,25.266667,56.533333,44.693396,7.466667,18.966667,39.367311,13.166667,18.600000,70.788530,12.133333,24.966667,37.100000,14.166667,11.800000,6.933333,3.766667,18.600000,0.800000,0.200000,0,0,0,0,0,0
1,2003,1448,2,Wake Forest,acc,47.512493,11.048289,2.00000,23.000000,11.205882,78.413793,0.068966,26.137931,57.241379,45.662651,6.103448,17.724138,34.435798,20.034483,26.620690,75.259067,14.758621,26.931034,41.689655,14.586207,15.103448,6.413793,4.379310,18.482759,0.827586,0.172414,0,0,0,0,0,0
2,2003,1393,3,Syracuse,big_east,74.207905,30.175754,9.00000,120.000000,13.088235,80.103448,0.034483,29.241379,62.206897,47.006652,5.241379,15.862069,33.043478,16.379310,23.620690,69.343066,14.310345,26.896552,41.206897,14.965517,13.620690,8.310345,7.275862,16.586207,0.827586,0.172414,0,0,0,0,0,0
3,2003,1257,4,Louisville,cusa,98.071257,10.386676,1.00000,51.000000,10.029412,81.833333,0.066667,27.966667,60.833333,45.972603,8.433333,23.400000,36.039886,17.466667,25.100000,69.588313,13.200000,25.100000,38.300000,16.600000,13.366667,7.200000,4.733333,22.666667,0.800000,0.200000,0,0,0,0,0,0
4,2003,1280,5,Mississippi St,sec,49.649729,27.019024,7.00000,133.000000,16.647059,70.166667,0.066667,26.333333,55.800000,47.192354,5.566667,16.200000,34.362140,11.933333,17.633333,67.674858,12.666667,24.600000,37.266667,14.500000,15.633333,8.733333,3.733333,15.900000,0.700000,0.300000,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1195,2021,1457,12,Winthrop,big_south,257.627337,81.198636,66.36000,140.055556,66.360000,79.541667,0.000000,28.083333,60.958333,46.069720,8.458333,23.958333,35.304348,14.916667,21.750000,68.582375,11.208333,26.458333,37.666667,15.041667,13.666667,7.750000,2.375000,3.500000,0.958333,0.041667,0,0,0,0,0,0
1196,2021,1317,13,North Texas,cusa,166.621732,103.962250,74.72000,138.388889,74.720000,68.200000,0.000000,25.080000,53.280000,47.072072,7.400000,19.840000,37.298387,10.640000,14.480000,73.480663,7.280000,23.320000,30.600000,13.040000,12.560000,6.840000,3.000000,5.520000,0.640000,0.360000,0,0,0,0,0,0
1197,2021,1159,14,Colgate,patriot,204.460734,87.990402,38.84000,155.300000,38.840000,86.333333,0.066667,31.466667,63.000000,49.947090,9.066667,22.533333,40.236686,14.333333,19.600000,73.129252,8.533333,28.600000,37.133333,17.133333,9.733333,7.200000,3.400000,4.533333,0.933333,0.066667,0,0,0,0,0,0
1198,2021,1331,15,Oral Roberts,summit,237.369461,180.950166,151.34375,200.222222,159.920000,79.391304,0.000000,27.260870,61.000000,44.689950,10.869565,28.000000,38.819876,14.000000,16.956522,82.564103,6.565217,23.652174,30.217391,11.608696,11.173913,6.217391,4.086957,8.000000,0.565217,0.434783,0,0,0,0,0,0


In [8]:
df_b

Unnamed: 0,Season,BTeamID,BSeed,BTeamName,BConfAbbrev,BConfAverageOrdinalRank,BAverageOrdinalRank,BMinOrdinalRank,BMaxOrdinalRank,BLatestOrdinalRank,BScore,BNumOT,BFGM,BFGA,BFG%,BFGM3,BFGA3,BFG3%,BFTM,BFTA,BFT%,BOR,BDR,BTR,BAst,BTO,BStl,BBlk,BPF,BW,BL,BWH2H,BLH2H,BWCommonOpp,BLCommonOpp,BWLastK,BLLastK
0,2003,1328,1,Oklahoma,big_twelve,59.997570,16.628562,2.00000,48.750000,6.117647,71.166667,0.100000,25.266667,56.533333,44.693396,7.466667,18.966667,39.367311,13.166667,18.600000,70.788530,12.133333,24.966667,37.100000,14.166667,11.800000,6.933333,3.766667,18.600000,0.800000,0.200000,0,0,0,0,0,0
1,2003,1448,2,Wake Forest,acc,47.512493,11.048289,2.00000,23.000000,11.205882,78.413793,0.068966,26.137931,57.241379,45.662651,6.103448,17.724138,34.435798,20.034483,26.620690,75.259067,14.758621,26.931034,41.689655,14.586207,15.103448,6.413793,4.379310,18.482759,0.827586,0.172414,0,0,0,0,0,0
2,2003,1393,3,Syracuse,big_east,74.207905,30.175754,9.00000,120.000000,13.088235,80.103448,0.034483,29.241379,62.206897,47.006652,5.241379,15.862069,33.043478,16.379310,23.620690,69.343066,14.310345,26.896552,41.206897,14.965517,13.620690,8.310345,7.275862,16.586207,0.827586,0.172414,0,0,0,0,0,0
3,2003,1257,4,Louisville,cusa,98.071257,10.386676,1.00000,51.000000,10.029412,81.833333,0.066667,27.966667,60.833333,45.972603,8.433333,23.400000,36.039886,17.466667,25.100000,69.588313,13.200000,25.100000,38.300000,16.600000,13.366667,7.200000,4.733333,22.666667,0.800000,0.200000,0,0,0,0,0,0
4,2003,1280,5,Mississippi St,sec,49.649729,27.019024,7.00000,133.000000,16.647059,70.166667,0.066667,26.333333,55.800000,47.192354,5.566667,16.200000,34.362140,11.933333,17.633333,67.674858,12.666667,24.600000,37.266667,14.500000,15.633333,8.733333,3.733333,15.900000,0.700000,0.300000,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1195,2021,1457,12,Winthrop,big_south,257.627337,81.198636,66.36000,140.055556,66.360000,79.541667,0.000000,28.083333,60.958333,46.069720,8.458333,23.958333,35.304348,14.916667,21.750000,68.582375,11.208333,26.458333,37.666667,15.041667,13.666667,7.750000,2.375000,3.500000,0.958333,0.041667,0,0,0,0,0,0
1196,2021,1317,13,North Texas,cusa,166.621732,103.962250,74.72000,138.388889,74.720000,68.200000,0.000000,25.080000,53.280000,47.072072,7.400000,19.840000,37.298387,10.640000,14.480000,73.480663,7.280000,23.320000,30.600000,13.040000,12.560000,6.840000,3.000000,5.520000,0.640000,0.360000,0,0,0,0,0,0
1197,2021,1159,14,Colgate,patriot,204.460734,87.990402,38.84000,155.300000,38.840000,86.333333,0.066667,31.466667,63.000000,49.947090,9.066667,22.533333,40.236686,14.333333,19.600000,73.129252,8.533333,28.600000,37.133333,17.133333,9.733333,7.200000,3.400000,4.533333,0.933333,0.066667,0,0,0,0,0,0
1198,2021,1331,15,Oral Roberts,summit,237.369461,180.950166,151.34375,200.222222,159.920000,79.391304,0.000000,27.260870,61.000000,44.689950,10.869565,28.000000,38.819876,14.000000,16.956522,82.564103,6.565217,23.652174,30.217391,11.608696,11.173913,6.217391,4.086957,8.000000,0.565217,0.434783,0,0,0,0,0,0


```
tourney = pd.read_csv("data/MNCAATourneyCompactResults.csv")
tourney = tourney.loc[tourney['Season'] >= 2003] # seasons w/ ranking data
tourney = tourney[['Season','WTeamID','LTeamID']]

generalize = {'WTeamID': 'ATeamID', 'LTeamID': 'BTeamID'}
tourney.rename(columns=generalize, inplace=True)
tourney['AWin?'] = 1

mirror_tourney = tourney.copy() # copy and flip teams for more data
mirror_tourney[['BTeamID','ATeamID']] = mirror_tourney[['ATeamID','BTeamID']]
mirror_tourney['AWin?'] = 0


tourney_matchups = pd.concat([tourney, mirror_tourney]).sort_index()
tourney_matchups

tourney_matchups = tourney_matchups.merge(df_a, on=['Season', 'ATeamID']).merge(df_b, on=['Season', 'BTeamID'])

# SUPERRRRR SLOWWWWWWWWWWWWW (~20+ mins i think, so save df after this to csv to avoid re-runnning)
# Now already exists in 'data/FinalDF.csv', no need to run this cell


tourney_matchups = populate_record_columns(tourney_matchups)
tourney_matchups.to_csv('data/TournamentFullDF.csv', index=False)
```


#### Starting models

In [36]:
import numpy as np

from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB, BernoulliNB
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

from sklearn.metrics import accuracy_score

# ------------------------------------------------------------------------------------------------------------------
def train_test_split_by_year(df, year):
    train = df[df.Season != year]
    test  = df[df.Season == year]
    
    train = train.drop(columns=['Season', 'ATeamID', 'BTeamID', 'ATeamName', 'BTeamName', 'AConfAbbrev', 'BConfAbbrev'])
    test  = test.drop(columns=['Season', 'ATeamID', 'BTeamID', 'ATeamName', 'BTeamName', 'AConfAbbrev', 'BConfAbbrev'])
   
    # in form (X_train, X_test, y_train, y_test)
    X_train, X_test = train.loc[:, train.columns != 'AWin?'], test.loc[:, test.columns != 'AWin?']
    y_train, y_test = train.loc[:, train.columns == 'AWin?'], test.loc[:, test.columns == 'AWin?']
    return [X_train, X_test, y_train, y_test]

def tourney_teams_by_year(year):
    seeds = pd.read_csv('data/MNCAATourneySeeds.csv')
    return (seeds.loc[(seeds['Season'] == YEAR)]['TeamID'].unique())

def get_matchups_and_test_data(YEAR, df_col_names, df_a, df_b):
    teams = tourney_teams_by_year(YEAR)
    rows  = []
    matchups = []
    for t1 in teams:
        for t2 in teams:
            if (t1 != t2) and (t1 < t2):
                # set team_a id to be smaller of two teams
                team_a_id = t1
                team_b_id = t2 

                a = df_a.loc[(df_a['Season'] == YEAR) & (df_a['ATeamID'] == team_a_id)].reset_index(drop=True)
                b = df_b.loc[(df_b['Season'] == YEAR) & (df_b['BTeamID'] == team_b_id)].reset_index(drop=True)
                
                merge_df = pd.DataFrame([[YEAR, team_a_id, team_b_id]], columns=['Season', 'ATeamID', 'BTeamID'])
                merge_df = merge_df.merge(a, on=['Season', 'ATeamID']).merge(b, on=['Season', 'BTeamID'])

                rows.append(merge_df.values[0].tolist())
                matchups.append([YEAR, team_a_id, team_b_id])

    test_df     = pd.DataFrame(rows, columns=[x for x in df_col_names if x != 'AWin?'])
    matchups_df = pd.DataFrame(matchups, columns=['Season', 'ATeamID', 'BTeamID'])
    return matchups_df, test_df

def chalk_prediction(df):
    # results_as_bool = df['ASeed'] < df['BSeed']
    # return [1 if x else 0 for x in results_as_bool]
    
    # calculate win probability based on seed. pr(W_a) = (a_seed) / ((a_seed) + (b_seed))
    # example: a=8, b=9; pr(W_a) = 9/17 (53%)
    #          a=1, b=16; pr(W_a) = 16/17 (94%)
    #          a=2, b=10; pr(W_a) = 10/12 (83%)
    return (df['BSeed'] / (df['ASeed'] + df['BSeed']))

def write_submission(df, predictions, filename, year):
    df['ID']   = df.apply(lambda x: f'{int(x.Season)}_{(x.ATeamID)}_{x.BTeamID}', 1)
    df['Pred'] = predictions
    df.to_csv(f'submissions/{year}/{filename}.csv', columns=['ID', 'Pred'], index=False)
    return df
# ------------------------------------------------------------------------------------------------------------------


In [11]:
YEAR = range(2003, 2019+1)
data = pd.read_csv('data/TournamentFullDF.csv')

for year in YEAR: 
    X_train, X_test, y_train, y_test = train_test_split_by_year(data, year)
    
    gnb = GaussianNB() # (DONE)
    bnb = BernoulliNB() # (DONE)
    mlp = MLPClassifier(alpha=1, max_iter=1000) # (DONE)
    qda = QuadraticDiscriminantAnalysis() # (DONE)
    neigh = KNeighborsClassifier(n_neighbors=100, weights='distance') # (DONE)
    gauss = GaussianProcessClassifier(1.0 * RBF(1.0)) # (DONE)
    dtree = DecisionTreeClassifier(max_depth=5) # (DONE)
    adaboost = AdaBoostClassifier() # (DONE)
    rforest = RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1)
    svc_linear = SVC(kernel="linear", C=0.025)
    svc_gamma  = SVC(gamma=2, C=1)

    types  = ['GNB', 'BNB', 'MLP', 'QDA', 'KNN', 'GAUSS', 'DTREE', 'ADABOOST', 'RF', 'SVC_LIN', 'SVC_GAMMA']
    models = [gnb, bnb, mlp, qda, neigh, gauss, dtree, adaboost, rforest, svc_linear, svc_gamma]
    
    chalk_score, chalk_count = accuracy_score(chalk_prediction(X_test), y_test), accuracy_score(chalk_prediction(X_test), y_test, normalize=False)
    scores = []
    counts = []
    
    print(year)
    print(f'CHALK: {round(chalk_score, 2)} ({chalk_count} / {len(y_test)})')
    
    for i,m in enumerate(models):
        
        m.fit(X_train, np.ravel(y_train));
        score, count = accuracy_score(m.predict(X_test), y_test), accuracy_score(m.predict(X_test), y_test, normalize=False)
        print(f'{types[i]}: {round(score, 2)} ({count} / {len(y_test)}) | ({count - chalk_count})')
    
        scores.append(score)
        counts.append(count)
    
    index  = counts.index(max(counts))
    print(f'BEST: {types[index]}')
    print()

2003
CHALK: 0.66 (85 / 128)
GNB: 0.73 (94 / 128) | (9)
BNB: 0.58 (74 / 128) | (-11)
MLP: 0.69 (88 / 128) | (3)
QDA: 0.62 (79 / 128) | (-6)
KNN: 0.66 (84 / 128) | (-1)




GAUSS: 0.57 (73 / 128) | (-12)
DTREE: 0.64 (82 / 128) | (-3)
ADABOOST: 0.66 (84 / 128) | (-1)
RF: 0.69 (88 / 128) | (3)
SVC_LIN: 0.7 (90 / 128) | (5)
SVC_GAMMA: 0.5 (64 / 128) | (-21)
BEST: GNB

2004
CHALK: 0.74 (95 / 128)
GNB: 0.72 (92 / 128) | (-3)
BNB: 0.58 (74 / 128) | (-21)
MLP: 0.7 (89 / 128) | (-6)
QDA: 0.7 (89 / 128) | (-6)
KNN: 0.72 (92 / 128) | (-3)




GAUSS: 0.57 (73 / 128) | (-22)
DTREE: 0.66 (85 / 128) | (-10)
ADABOOST: 0.73 (93 / 128) | (-2)
RF: 0.7 (89 / 128) | (-6)
SVC_LIN: 0.7 (90 / 128) | (-5)
SVC_GAMMA: 0.5 (64 / 128) | (-31)
BEST: ADABOOST

2005
CHALK: 0.69 (88 / 128)
GNB: 0.73 (94 / 128) | (6)
BNB: 0.62 (80 / 128) | (-8)
MLP: 0.57 (73 / 128) | (-15)
QDA: 0.55 (71 / 128) | (-17)
KNN: 0.67 (86 / 128) | (-2)




GAUSS: 0.52 (67 / 128) | (-21)
DTREE: 0.68 (87 / 128) | (-1)
ADABOOST: 0.68 (87 / 128) | (-1)
RF: 0.7 (89 / 128) | (1)
SVC_LIN: 0.69 (88 / 128) | (0)
SVC_GAMMA: 0.5 (64 / 128) | (-24)
BEST: GNB

2006
CHALK: 0.66 (85 / 128)
GNB: 0.72 (92 / 128) | (7)
BNB: 0.59 (76 / 128) | (-9)
MLP: 0.69 (88 / 128) | (3)
QDA: 0.7 (89 / 128) | (4)
KNN: 0.69 (88 / 128) | (3)




GAUSS: 0.57 (73 / 128) | (-12)
DTREE: 0.64 (82 / 128) | (-3)
ADABOOST: 0.63 (81 / 128) | (-4)
RF: 0.69 (88 / 128) | (3)
SVC_LIN: 0.73 (94 / 128) | (9)
SVC_GAMMA: 0.5 (64 / 128) | (-21)
BEST: SVC_LIN

2007
CHALK: 0.8 (102 / 128)
GNB: 0.8 (102 / 128) | (0)
BNB: 0.68 (87 / 128) | (-15)
MLP: 0.79 (101 / 128) | (-1)
QDA: 0.71 (91 / 128) | (-11)
KNN: 0.72 (92 / 128) | (-10)




GAUSS: 0.66 (84 / 128) | (-18)
DTREE: 0.68 (87 / 128) | (-15)
ADABOOST: 0.75 (96 / 128) | (-6)
RF: 0.73 (94 / 128) | (-8)
SVC_LIN: 0.8 (102 / 128) | (0)
SVC_GAMMA: 0.5 (64 / 128) | (-38)
BEST: GNB

2008
CHALK: 0.77 (98 / 128)
GNB: 0.77 (98 / 128) | (0)
BNB: 0.66 (85 / 128) | (-13)
MLP: 0.7 (90 / 128) | (-8)
QDA: 0.63 (81 / 128) | (-17)
KNN: 0.75 (96 / 128) | (-2)




GAUSS: 0.62 (80 / 128) | (-18)
DTREE: 0.76 (97 / 128) | (-1)
ADABOOST: 0.77 (99 / 128) | (1)
RF: 0.69 (88 / 128) | (-10)
SVC_LIN: 0.72 (92 / 128) | (-6)
SVC_GAMMA: 0.5 (64 / 128) | (-34)
BEST: ADABOOST

2009
CHALK: 0.74 (95 / 128)
GNB: 0.77 (98 / 128) | (3)
BNB: 0.58 (74 / 128) | (-21)
MLP: 0.73 (94 / 128) | (-1)
QDA: 0.66 (85 / 128) | (-10)
KNN: 0.81 (104 / 128) | (9)




GAUSS: 0.57 (73 / 128) | (-22)
DTREE: 0.74 (95 / 128) | (0)
ADABOOST: 0.75 (96 / 128) | (1)
RF: 0.7 (90 / 128) | (-5)
SVC_LIN: 0.75 (96 / 128) | (1)
SVC_GAMMA: 0.49 (63 / 128) | (-32)
BEST: KNN

2010
CHALK: 0.67 (86 / 128)
GNB: 0.7 (90 / 128) | (4)
BNB: 0.62 (79 / 128) | (-7)
MLP: 0.65 (83 / 128) | (-3)
QDA: 0.58 (74 / 128) | (-12)
KNN: 0.72 (92 / 128) | (6)




GAUSS: 0.53 (68 / 128) | (-18)
DTREE: 0.64 (82 / 128) | (-4)
ADABOOST: 0.66 (84 / 128) | (-2)
RF: 0.66 (85 / 128) | (-1)
SVC_LIN: 0.64 (82 / 128) | (-4)
SVC_GAMMA: 0.49 (63 / 128) | (-23)
BEST: KNN

2011
CHALK: 0.67 (90 / 134)
GNB: 0.67 (90 / 134) | (0)
BNB: 0.52 (70 / 134) | (-20)
MLP: 0.64 (86 / 134) | (-4)
QDA: 0.5 (67 / 134) | (-23)
KNN: 0.61 (82 / 134) | (-8)




GAUSS: 0.53 (71 / 134) | (-19)
DTREE: 0.61 (82 / 134) | (-8)
ADABOOST: 0.65 (87 / 134) | (-3)
RF: 0.64 (86 / 134) | (-4)
SVC_LIN: 0.67 (90 / 134) | (0)
SVC_GAMMA: 0.5 (67 / 134) | (-23)
BEST: GNB

2012
CHALK: 0.71 (95 / 134)
GNB: 0.67 (90 / 134) | (-5)
BNB: 0.67 (90 / 134) | (-5)
MLP: 0.69 (93 / 134) | (-2)
QDA: 0.67 (90 / 134) | (-5)
KNN: 0.69 (92 / 134) | (-3)




GAUSS: 0.63 (85 / 134) | (-10)
DTREE: 0.69 (92 / 134) | (-3)
ADABOOST: 0.71 (95 / 134) | (0)
RF: 0.71 (95 / 134) | (0)
SVC_LIN: 0.64 (86 / 134) | (-9)
SVC_GAMMA: 0.5 (67 / 134) | (-28)
BEST: ADABOOST

2013
CHALK: 0.66 (89 / 134)
GNB: 0.67 (90 / 134) | (1)
BNB: 0.57 (77 / 134) | (-12)
MLP: 0.69 (93 / 134) | (4)
QDA: 0.57 (76 / 134) | (-13)
KNN: 0.7 (94 / 134) | (5)




GAUSS: 0.55 (74 / 134) | (-15)
DTREE: 0.66 (89 / 134) | (0)
ADABOOST: 0.72 (96 / 134) | (7)
RF: 0.64 (86 / 134) | (-3)
SVC_LIN: 0.64 (86 / 134) | (-3)
SVC_GAMMA: 0.49 (66 / 134) | (-23)
BEST: ADABOOST

2014
CHALK: 0.64 (86 / 134)
GNB: 0.69 (92 / 134) | (6)
BNB: 0.6 (81 / 134) | (-5)
MLP: 0.64 (86 / 134) | (0)
QDA: 0.59 (79 / 134) | (-7)
KNN: 0.66 (88 / 134) | (2)




GAUSS: 0.46 (62 / 134) | (-24)
DTREE: 0.6 (81 / 134) | (-5)
ADABOOST: 0.65 (87 / 134) | (1)
RF: 0.61 (82 / 134) | (-4)
SVC_LIN: 0.73 (98 / 134) | (12)
SVC_GAMMA: 0.5 (67 / 134) | (-19)
BEST: SVC_LIN

2015
CHALK: 0.78 (104 / 134)
GNB: 0.69 (92 / 134) | (-12)
BNB: 0.56 (75 / 134) | (-29)
MLP: 0.71 (95 / 134) | (-9)
QDA: 0.61 (82 / 134) | (-22)
KNN: 0.76 (102 / 134) | (-2)




GAUSS: 0.57 (77 / 134) | (-27)
DTREE: 0.69 (92 / 134) | (-12)
ADABOOST: 0.74 (99 / 134) | (-5)
RF: 0.75 (100 / 134) | (-4)
SVC_LIN: 0.72 (96 / 134) | (-8)
SVC_GAMMA: 0.5 (67 / 134) | (-37)
BEST: KNN

2016
CHALK: 0.66 (89 / 134)
GNB: 0.67 (90 / 134) | (1)
BNB: 0.57 (76 / 134) | (-13)
MLP: 0.74 (99 / 134) | (10)
QDA: 0.63 (84 / 134) | (-5)
KNN: 0.66 (88 / 134) | (-1)




GAUSS: 0.55 (74 / 134) | (-15)
DTREE: 0.66 (88 / 134) | (-1)
ADABOOST: 0.72 (96 / 134) | (7)
RF: 0.72 (97 / 134) | (8)
SVC_LIN: 0.69 (92 / 134) | (3)
SVC_GAMMA: 0.49 (66 / 134) | (-23)
BEST: MLP

2017
CHALK: 0.75 (101 / 134)
GNB: 0.69 (92 / 134) | (-9)
BNB: 0.57 (76 / 134) | (-25)
MLP: 0.72 (97 / 134) | (-4)
QDA: 0.68 (91 / 134) | (-10)
KNN: 0.75 (100 / 134) | (-1)




GAUSS: 0.5 (67 / 134) | (-34)
DTREE: 0.68 (91 / 134) | (-10)
ADABOOST: 0.67 (90 / 134) | (-11)
RF: 0.68 (91 / 134) | (-10)
SVC_LIN: 0.66 (88 / 134) | (-13)
SVC_GAMMA: 0.5 (67 / 134) | (-34)
BEST: KNN

2018
CHALK: 0.66 (89 / 134)
GNB: 0.69 (92 / 134) | (3)
BNB: 0.54 (72 / 134) | (-17)
MLP: 0.69 (92 / 134) | (3)
QDA: 0.64 (86 / 134) | (-3)
KNN: 0.67 (90 / 134) | (1)




GAUSS: 0.6 (80 / 134) | (-9)
DTREE: 0.66 (89 / 134) | (0)
ADABOOST: 0.66 (89 / 134) | (0)
RF: 0.67 (90 / 134) | (1)
SVC_LIN: 0.73 (98 / 134) | (9)
SVC_GAMMA: 0.5 (67 / 134) | (-22)
BEST: SVC_LIN

2019
CHALK: 0.67 (90 / 134)
GNB: 0.69 (92 / 134) | (2)
BNB: 0.49 (65 / 134) | (-25)
MLP: 0.72 (96 / 134) | (6)
QDA: 0.63 (84 / 134) | (-6)
KNN: 0.72 (96 / 134) | (6)




GAUSS: 0.56 (75 / 134) | (-15)
DTREE: 0.7 (94 / 134) | (4)
ADABOOST: 0.72 (97 / 134) | (7)
RF: 0.72 (96 / 134) | (6)
SVC_LIN: 0.73 (98 / 134) | (8)
SVC_GAMMA: 0.5 (67 / 134) | (-23)
BEST: SVC_LIN



In [12]:
YEAR  = 2019
data  = pd.read_csv('data/TournamentFullDF.csv')
teams = tourney_teams_by_year(YEAR)
X_train, _, y_train, _ = train_test_split_by_year(data, YEAR)

gnb = GaussianNB()
neigh = KNeighborsClassifier(n_neighbors=100, weights='distance')

neigh.fit(X_train, np.ravel(y_train));
gnb.fit(X_train, np.ravel(y_train));

# DONE ALREADY for 2019, 2021
# uncomment, rerun when changing year if files alread DNE
# matchups_df, test_df = get_matchups_and_test_data(YEAR, data.columns, df_a, df_b):
# test_df = populate_record_columns(test_df)
# test_df.to_csv(f'data/TournamentDF_{YEAR}.csv', index=False)
# matchups_df.to_csv(f'data/TournamentMatchupsDF_{YEAR}.csv')
test_df = pd.read_csv(f'data/TournamentDF_{YEAR}.csv')
matchups_df = pd.read_csv(f'data/TournamentMatchupsDF_{YEAR}.csv')


In [13]:
matchups_df

Unnamed: 0,Season,ATeamID,BTeamID
0,2019,1181,1277
1,2019,1181,1261
2,2019,1181,1439
3,2019,1181,1280
4,2019,1181,1268
...,...,...,...
2273,2019,1205,1234
2274,2019,1205,1388
2275,2019,1205,1332
2276,2019,1205,1414


In [14]:
test_df

Unnamed: 0,Season,ATeamID,BTeamID,ASeed,ATeamName,AConfAbbrev,AConfAverageOrdinalRank,AAverageOrdinalRank,AMinOrdinalRank,AMaxOrdinalRank,ALatestOrdinalRank,AScore,ANumOT,AFGM,AFGA,AFG%,FGM3,AFGA3,AFG3%,AFTM,AFTA,AFT%,AOR,ADR,ATR,AAst,ATO,AStl,ABlk,APF,AW,AL,AWH2H,ALH2H,AWCommonOpp,ALCommonOpp,AWLastK,ALLastK,BSeed,BTeamName,BConfAbbrev,BConfAverageOrdinalRank,BAverageOrdinalRank,BMinOrdinalRank,BMaxOrdinalRank,BLatestOrdinalRank,BScore,BNumOT,BFGM,BFGA,BFG%,BFGM3,BFGA3,BFG3%,BFTM,BFTA,BFT%,BOR,BDR,BTR,BAst,BTO,BStl,BBlk,BPF,BW,BL,BWH2H,BLH2H,BWCommonOpp,BLCommonOpp,BWLastK,BLLastK
0,2019,1181,1277,1,Duke,acc,53.825731,2.094753,1.117647,3.600000,2.00000,83.500000,0.0,30.500000,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.974700,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,2.0,0.0,7.0,3.0,2,Michigan St,big_ten,43.378573,7.003422,4.000000,12.138889,5.149254,78.823529,0.000000,27.764706,57.176471,48.559671,8.294118,21.617647,38.367347,15.000000,20.000000,75.000000,10.823529,30.088235,40.911765,18.941176,12.852941,5.235294,5.470588,16.911765,0.823529,0.176471,0.0,0.0,0.0,3.0,9.0,1.0
1,2019,1181,1261,1,Duke,acc,53.825731,2.094753,1.117647,3.600000,2.00000,83.500000,0.0,30.500000,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.974700,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,4.0,0.0,7.0,3.0,3,LSU,sec,58.269956,31.006135,14.238806,55.555556,16.089552,81.375000,0.000000,28.500000,61.718750,46.177215,6.843750,21.187500,32.300885,17.531250,23.250000,75.403226,13.500000,25.281250,38.781250,13.125000,13.218750,9.093750,4.312500,18.687500,0.812500,0.187500,0.0,0.0,2.0,1.0,8.0,2.0
2,2019,1181,1439,1,Duke,acc,53.825731,2.094753,1.117647,3.600000,2.00000,83.500000,0.0,30.500000,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.974700,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,1.0,17.0,3.0,7.0,3.0,4,Virginia Tech,acc,53.825731,14.353912,10.037037,24.850000,15.925373,74.000000,0.031250,25.750000,54.156250,47.547605,9.593750,24.343750,39.409499,12.906250,17.031250,75.779817,8.906250,23.687500,32.593750,15.343750,11.375000,6.656250,2.312500,15.281250,0.750000,0.250000,1.0,0.0,12.0,7.0,6.0,4.0
3,2019,1181,1280,1,Duke,acc,53.825731,2.094753,1.117647,3.600000,2.00000,83.500000,0.0,30.500000,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.974700,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,4.0,0.0,7.0,3.0,5,Mississippi St,sec,58.269956,25.131510,17.400000,33.666667,22.500000,77.333333,0.000000,27.606061,58.484848,47.202073,8.606061,22.787879,37.765957,13.515152,18.909091,71.474359,11.666667,24.121212,35.787879,14.242424,13.303030,8.090909,4.969697,17.090909,0.696970,0.303030,0.0,0.0,3.0,3.0,7.0,3.0
4,2019,1181,1268,1,Duke,acc,53.825731,2.094753,1.117647,3.600000,2.00000,83.500000,0.0,30.500000,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.974700,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,3.0,0.0,7.0,3.0,6,Maryland,big_ten,43.378573,27.773058,13.868852,41.684211,26.140625,71.343750,0.000000,25.218750,55.531250,45.413618,7.250000,20.531250,35.312024,13.656250,18.250000,74.828767,10.968750,28.187500,39.156250,13.218750,13.187500,4.343750,4.750000,15.468750,0.687500,0.312500,0.0,0.0,1.0,1.0,5.0,5.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2273,2019,1205,1234,16,Gardner Webb,big_south,234.636352,186.270108,150.068182,253.111111,163.65625,75.774194,0.0,26.290323,55.032258,47.772567,7.741935,20.548387,37.676609,15.451613,21.580645,71.599402,7.129032,25.741935,32.870968,13.612903,11.774194,6.225806,2.483871,16.193548,0.645161,0.354839,0.0,0.0,2.0,0.0,8.0,2.0,10,Iowa,big_ten,43.378573,35.218203,23.639344,72.222222,39.812500,78.303030,0.000000,26.121212,57.393939,45.512144,8.090909,22.393939,36.129905,17.969697,24.272727,74.032459,10.333333,25.393939,35.727273,15.727273,12.151515,6.181818,3.272727,16.030303,0.666667,0.333333,0.0,0.0,1.0,0.0,4.0,6.0
2274,2019,1205,1388,16,Gardner Webb,big_south,234.636352,186.270108,150.068182,253.111111,163.65625,75.774194,0.0,26.290323,55.032258,47.772567,7.741935,20.548387,37.676609,15.451613,21.580645,71.599402,7.129032,25.741935,32.870968,13.612903,11.774194,6.225806,2.483871,16.193548,0.645161,0.354839,0.0,0.0,0.0,0.0,8.0,2.0,11,St Mary's CA,wcc,131.835860,57.371453,34.166667,83.424242,40.984375,72.909091,0.000000,26.545455,56.060606,47.351351,7.424242,19.636364,37.808642,12.393939,16.636364,74.499089,9.818182,24.909091,34.727273,10.060606,10.575758,6.000000,2.545455,16.969697,0.666667,0.333333,0.0,0.0,0.0,0.0,8.0,2.0
2275,2019,1205,1332,16,Gardner Webb,big_south,234.636352,186.270108,150.068182,253.111111,163.65625,75.774194,0.0,26.290323,55.032258,47.772567,7.741935,20.548387,37.676609,15.451613,21.580645,71.599402,7.129032,25.741935,32.870968,13.612903,11.774194,6.225806,2.483871,16.193548,0.645161,0.354839,0.0,0.0,0.0,0.0,8.0,2.0,12,Oregon,pac_twelve,98.884306,58.080394,25.600000,78.682540,45.765625,70.485714,0.028571,25.314286,56.314286,44.951801,7.600000,22.142857,34.322581,12.257143,17.057143,71.859296,9.714286,24.714286,34.428571,13.342857,11.800000,7.771429,4.228571,17.771429,0.657143,0.342857,0.0,0.0,0.0,0.0,8.0,2.0
2276,2019,1205,1414,16,Gardner Webb,big_south,234.636352,186.270108,150.068182,253.111111,163.65625,75.774194,0.0,26.290323,55.032258,47.772567,7.741935,20.548387,37.676609,15.451613,21.580645,71.599402,7.129032,25.741935,32.870968,13.612903,11.774194,6.225806,2.483871,16.193548,0.645161,0.354839,0.0,0.0,0.0,0.0,8.0,2.0,13,UC Irvine,big_west,220.228381,94.347876,66.984375,109.848485,66.984375,72.470588,0.000000,26.941176,58.970588,45.685786,6.794118,18.882353,35.981308,11.794118,16.852941,69.982548,11.647059,28.382353,40.029412,13.235294,11.529412,5.500000,4.264706,18.500000,0.852941,0.147059,0.0,0.0,0.0,0.0,10.0,0.0


In [15]:
# sample of predicitons of probability of [Loss, Win] for 2019
# Looking at (1) Duke v: 
# (2) Michigan St (L)
# (3) LSU (-)
# (4) Villanova (-)
# (5) Mississippi St (-)
# (6) Maryland (-)
# (7) Lousiville (-)
# (8) VCU (-)
# (9) UCF (W)
# (10) Minnessota (-)
# (11) Temple (-)
# (12) Liberty (-)
# (13) St Louis (-)
# (14) Yale (-)
# (16) ND St (W)
# (16) NC Central (-)
_, X_test,_,_ = train_test_split_by_year(test_df, YEAR)


print('KNN')
print(neigh.predict_proba(X_test.iloc[0:15]))
print('\nGNB')
print(gnb.predict_proba(X_test.iloc[0:15]).round(3))

X_test.iloc[0:15]

KNN
[[0.47403709 0.52596291]
 [0.2560427  0.7439573 ]
 [0.37125414 0.62874586]
 [0.31310818 0.68689182]
 [0.25313034 0.74686966]
 [0.25504508 0.74495492]
 [0.28765225 0.71234775]
 [0.28205396 0.71794604]
 [0.2882469  0.7117531 ]
 [0.29811025 0.70188975]
 [0.16517808 0.83482192]
 [0.18640707 0.81359293]
 [0.1319736  0.8680264 ]
 [0.02778746 0.97221254]
 [0.02372182 0.97627818]]

GNB
[[0.045 0.955]
 [0.008 0.992]
 [1.    0.   ]
 [0.001 0.999]
 [0.    1.   ]
 [0.    1.   ]
 [0.    1.   ]
 [0.    1.   ]
 [0.    1.   ]
 [0.    1.   ]
 [0.    1.   ]
 [0.    1.   ]
 [0.    1.   ]
 [0.    1.   ]
 [0.    1.   ]]


Unnamed: 0,ASeed,AConfAverageOrdinalRank,AAverageOrdinalRank,AMinOrdinalRank,AMaxOrdinalRank,ALatestOrdinalRank,AScore,ANumOT,AFGM,AFGA,AFG%,FGM3,AFGA3,AFG3%,AFTM,AFTA,AFT%,AOR,ADR,ATR,AAst,ATO,AStl,ABlk,APF,AW,AL,AWH2H,ALH2H,AWCommonOpp,ALCommonOpp,AWLastK,ALLastK,BSeed,BConfAverageOrdinalRank,BAverageOrdinalRank,BMinOrdinalRank,BMaxOrdinalRank,BLatestOrdinalRank,BScore,BNumOT,BFGM,BFGA,BFG%,BFGM3,BFGA3,BFG3%,BFTM,BFTA,BFT%,BOR,BDR,BTR,BAst,BTO,BStl,BBlk,BPF,BW,BL,BWH2H,BLH2H,BWCommonOpp,BLCommonOpp,BWLastK,BLLastK
0,1,53.825731,2.094753,1.117647,3.6,2.0,83.5,0.0,30.5,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.9747,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,2.0,0.0,7.0,3.0,2,43.378573,7.003422,4.0,12.138889,5.149254,78.823529,0.0,27.764706,57.176471,48.559671,8.294118,21.617647,38.367347,15.0,20.0,75.0,10.823529,30.088235,40.911765,18.941176,12.852941,5.235294,5.470588,16.911765,0.823529,0.176471,0.0,0.0,0.0,3.0,9.0,1.0
1,1,53.825731,2.094753,1.117647,3.6,2.0,83.5,0.0,30.5,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.9747,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,4.0,0.0,7.0,3.0,3,58.269956,31.006135,14.238806,55.555556,16.089552,81.375,0.0,28.5,61.71875,46.177215,6.84375,21.1875,32.300885,17.53125,23.25,75.403226,13.5,25.28125,38.78125,13.125,13.21875,9.09375,4.3125,18.6875,0.8125,0.1875,0.0,0.0,2.0,1.0,8.0,2.0
2,1,53.825731,2.094753,1.117647,3.6,2.0,83.5,0.0,30.5,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.9747,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,1.0,17.0,3.0,7.0,3.0,4,53.825731,14.353912,10.037037,24.85,15.925373,74.0,0.03125,25.75,54.15625,47.547605,9.59375,24.34375,39.409499,12.90625,17.03125,75.779817,8.90625,23.6875,32.59375,15.34375,11.375,6.65625,2.3125,15.28125,0.75,0.25,1.0,0.0,12.0,7.0,6.0,4.0
3,1,53.825731,2.094753,1.117647,3.6,2.0,83.5,0.0,30.5,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.9747,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,4.0,0.0,7.0,3.0,5,58.269956,25.13151,17.4,33.666667,22.5,77.333333,0.0,27.606061,58.484848,47.202073,8.606061,22.787879,37.765957,13.515152,18.909091,71.474359,11.666667,24.121212,35.787879,14.242424,13.30303,8.090909,4.969697,17.090909,0.69697,0.30303,0.0,0.0,3.0,3.0,7.0,3.0
4,1,53.825731,2.094753,1.117647,3.6,2.0,83.5,0.0,30.5,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.9747,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,3.0,0.0,7.0,3.0,6,43.378573,27.773058,13.868852,41.684211,26.140625,71.34375,0.0,25.21875,55.53125,45.413618,7.25,20.53125,35.312024,13.65625,18.25,74.828767,10.96875,28.1875,39.15625,13.21875,13.1875,4.34375,4.75,15.46875,0.6875,0.3125,0.0,0.0,1.0,1.0,5.0,5.0
5,1,53.825731,2.094753,1.117647,3.6,2.0,83.5,0.0,30.5,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.9747,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,1.0,0.0,18.0,4.0,7.0,3.0,7,53.825731,29.106922,14.846154,61.888889,24.65625,74.545455,0.0,25.090909,57.818182,43.396226,8.636364,25.272727,34.172662,15.727273,20.30303,77.462687,10.242424,28.0,38.242424,13.393939,12.393939,4.454545,2.969697,17.272727,0.606061,0.393939,0.0,1.0,11.0,10.0,3.0,7.0
6,1,53.825731,2.094753,1.117647,3.6,2.0,83.5,0.0,30.5,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.9747,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,3.0,0.0,7.0,3.0,8,155.835393,60.742105,32.0625,116.722222,36.828125,71.4375,0.0,25.03125,56.59375,44.229707,7.15625,23.28125,30.738255,14.21875,20.375,69.785276,10.84375,25.9375,36.78125,13.875,14.0,7.96875,4.53125,19.8125,0.78125,0.21875,0.0,0.0,0.0,2.0,9.0,1.0
7,1,53.825731,2.094753,1.117647,3.6,2.0,83.5,0.0,30.5,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.9747,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,1.0,0.0,7.0,3.0,9,109.834451,45.860284,31.769231,97.611111,37.9375,72.129032,0.0,24.677419,53.290323,46.307506,6.806452,19.225806,35.402685,15.967742,24.741935,64.537158,9.806452,26.935484,36.741935,13.258065,11.903226,5.741935,4.483871,16.903226,0.741935,0.258065,0.0,0.0,1.0,0.0,7.0,3.0
8,1,53.825731,2.094753,1.117647,3.6,2.0,83.5,0.0,30.5,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.9747,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,2.0,0.0,7.0,3.0,10,43.378573,53.257292,42.020833,75.888889,44.9375,70.794118,0.029412,24.705882,56.529412,43.704475,5.235294,16.294118,32.129964,16.147059,23.764706,67.945545,11.441176,25.588235,37.029412,14.882353,12.0,4.794118,4.0,16.176471,0.617647,0.382353,0.0,0.0,1.0,1.0,5.0,5.0
9,1,53.825731,2.094753,1.117647,3.6,2.0,83.5,0.0,30.5,63.911765,47.722043,7.264706,24.058824,30.195599,15.235294,22.088235,68.9747,13.382353,28.382353,41.764706,15.911765,13.117647,9.470588,6.823529,15.852941,0.852941,0.147059,0.0,0.0,0.0,0.0,7.0,3.0,11,109.834451,65.9643,53.061538,93.5,58.84375,74.84375,0.0,26.40625,60.21875,43.850545,7.5,22.6875,33.057851,14.53125,19.875,73.113208,9.71875,24.5625,34.28125,14.46875,11.15625,8.65625,2.25,17.46875,0.71875,0.28125,0.0,0.0,0.0,0.0,7.0,3.0


In [16]:
# DONE ALREADY... 
# write_submission(matchups_df.copy(), chalk_prediction(X_test), 'first_chalk_pred'):
# write_submission(matchups_df.copy(), [p_win for [p_loss,p_win] in neigh.predict_proba(X_test)], 'first_knn_pred'):
# write_submission(matchups_df.copy(), [p_win for [p_loss,p_win] in gnb.predict_proba(X_test)], 'first_gnb_pred'):


In [17]:
# write submission to gsheet as bracket
from gsheets import gsheets_bracket
# --------------------------------------------------------------------------------------------------------------------
# DONT CHANGE
BRACKET_KEY_2019 = '1zlAuBQPCesbe3Monu6GBUm8M6Sgy4h45_AawJaHeFWI'
BRACKET_KEY_2021 = '1NXHBYyCgwrbseRgOpfK7aOCSxU65-nYXsjL_rbXMtLs'
# --------------------------------------------------------------------------------------------------------------------

# be sure to run submissions that DNE already on sheet or delete sheet to re-run
# DONE ALREADY...
# print(gsheets_bracket.bracket_from_submission('first_chalk_pred', BRACKET_KEY_2019, 2019))
# print(gsheets_bracket.bracket_from_submission('first_knn_pred', bracket_key_2019, 2019))
# print(gsheets_bracket.bracket_from_submission('first_gnb_pred', bracket_key_2019, 2019))


##### Trying to make a better model

In [18]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from IPython.display import display
from ipywidgets import interact, interactive_output, fixed, Layout, Dropdown, HBox, ToggleButtons

# --------------------------- DISTRIBTUIONS ------------------------------------
def distribution_helper(x, y, data):
    '''
    Plots the distribution of training data against the (continuous) target 
    variable.
    @param x, y: column vector of data
    '''
    fig  = plt.figure(figsize=(20,15))

    # quantitative scatterplot
    sns.scatterplot(data=data, x=x, y=y, hue='AWin?')
                
    plt.xlabel(x, {'fontsize': 16})
    plt.ylabel(y, {'fontsize': 16})
    
    plt.show()
    
def interactive_distributions(data):
    '''
    Plots the distribution of training data against the (continuous) target 
    variable.
    @param data: dataframe
    '''
    f1 = Dropdown(options=data.columns, value='AAverageOrdinalRank', description='Feature-1')
    f2 = Dropdown(options=data.columns, value='BAverageOrdinalRank', description='Feature-2')
    
    ui = HBox([f1, f2])
    out = interactive_output(distribution_helper, {'x': f1, 'y': f2, 'data': fixed(data)})
    return display(ui, out)


In [19]:
train = pd.concat([X_train, y_train],axis=1)
interactive_distributions(train)
# not super helpful....

HBox(children=(Dropdown(description='Feature-1', index=2, options=('ASeed', 'AConfAverageOrdinalRank', 'AAvera…

Output()

In [158]:
# Can use this for spamming models...

# from gsheets import gsheets_bracket
# --------------------------------------------------------------------------------------------------------------------
# DONT CHANGE
BRACKET_KEY = {
    2019: '1zlAuBQPCesbe3Monu6GBUm8M6Sgy4h45_AawJaHeFWI',
    2021: '1NXHBYyCgwrbseRgOpfK7aOCSxU65-nYXsjL_rbXMtLs',
}
# --------------------------------------------------------------------------------------------------------------------

YEAR  = 2021 
data  = pd.read_csv('data/TournamentFullDF.csv')
teams = tourney_teams_by_year(YEAR)
X_train, _, y_train, _ = train_test_split_by_year(data, YEAR)

# DONE ALREADY for 2019, 2021
# uncomment, rerun when changing year if files alread DNE
# matchups_df, test_df = get_matchups_and_test_data(YEAR, data.columns, df_a, df_b):
# test_df = populate_record_columns(test_df)
# test_df.to_csv(f'data/TournamentDF_{YEAR}.csv', index=False)
# matchups_df.to_csv(f'data/TournamentMatchupsDF_{YEAR}.csv', index=False)

test_df = pd.read_csv(f'data/TournamentDF_{YEAR}.csv')
matchups_df = pd.read_csv(f'data/TournamentMatchupsDF_{YEAR}.csv')
_, X_test,_,_ = train_test_split_by_year(test_df, YEAR)


to_keep = [
    'ASeed', 'AConfAverageOrdinalRank', 'AAverageOrdinalRank', 'ALatestOrdinalRank', 
    'AFG%', 'AFG3%', 'AFT%', 'AW', 'AL', 'AWCommonOpp', 'ALCommonOpp', 
    'BSeed', 'BConfAverageOrdinalRank', 'BAverageOrdinalRank', 'BLatestOrdinalRank', 
    'BFG%', 'BFG3%', 'BFT%', 'AW', 'AL', 'BWCommonOpp', 'BLCommonOpp', 
]
X_train = X_train.drop(columns=[col for col in X_train if col not in to_keep])
X_test  = X_test.drop(columns=[col for col in X_test if col not in to_keep])



gnb = GaussianNB() # (DONE)
bnb = BernoulliNB() # (DONE)
mlp = MLPClassifier(alpha=1, max_iter=1000) # (DONE)
qda = QuadraticDiscriminantAnalysis() # (DONE)
neigh = KNeighborsClassifier(n_neighbors=5, weights='distance') # (DONE)
gauss = GaussianProcessClassifier(1.0 * RBF(1.0)) # (DONE)
dtree = DecisionTreeClassifier(max_depth=5) # (DONE)
adaboost = AdaBoostClassifier() # (DONE)
rforest = RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1) # (DONE)
svc_linear = SVC(kernel="linear", C=0.025, probability=True) # (DONE)
svc_gamma  = SVC(gamma=2, C=1, probability=True) # (DONE)

models = [gnb, bnb, mlp, qda, gauss, dtree, adaboost, svc_linear, svc_gamma]
name   = ['gnb', 'bnb', 'mlp', 'qda', 'gauss', 'dtree', 'adaboost', 'svc_linear', 'svc_gamma']

# for i, model in enumerate(models):
model = neigh
model.fit(X_train, np.ravel(y_train));

# be sure name is unique
FILENAME = f'neigh_less_cols2_{YEAR}'
write_submission(matchups_df.copy(), [p_win for [p_loss,p_win] in model.predict_proba(X_test)], FILENAME)
print(gsheets_bracket.bracket_from_submission(FILENAME, BRACKET_KEY[YEAR], YEAR))


https://docs.google.com/spreadsheets/d/1NXHBYyCgwrbseRgOpfK7aOCSxU65-nYXsjL_rbXMtLs/edit#gid=4020804


In [None]:
# THOUGHT 4 NEXT YEAR
# maaybe instead of picking winner based on 50% prob threshold, 
# raise/lower threshold based on mathcup. for example 5 v 12, maybe 
# threshold for 12 to win is like 30-40% probability. 