In [1]:
import pandas as pd
import numpy as np

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_validate

pd.options.display.max_columns = None 
pd.options.display.max_rows = None

import timeit

In [2]:
train = pd.read_csv('../Data/train_clean.csv')

### Preprocessing for classification

In [3]:
def process_data(df):
    #df['succesful_play'] = (np.where((df['play_type'] == 'pass')
    #                                   & (df['yards_gained']>= 6.3), 1,
    #                              np.where((df['play_type'] == 'run') 
    #                                       & (df['yards_gained'] >= 4.4), 1, 0)))
    
    #df['game_setting'] = df['home_team'].str.cat(df['away_team'])
    dummies = pd.get_dummies(df['posteam'],prefix= 'is_posteam', drop_first = True )
    df = df.join(dummies)
    
    dummies_home = pd.get_dummies(df['home_team'],prefix= 'home', drop_first = True )
    df = df.join(dummies_home)
    
    dummies_away = pd.get_dummies(df['away_team'],prefix= 'away', drop_first = True )
    df = df.join(dummies_away)
    
    df['posteam_is_home'] = (np.where(df['posteam_type'] == 'home', 1 , 0))
    
    playtype = pd.get_dummies(df['play_type'], prefix= 'pt', drop_first = True )
    df = df.join(playtype)
    
    df['game_half'] = np.where(df['game_half'] == 'Half1', 1, 2)
    
    df['side_of_field_is_hometeam'] = np.where(df['side_of_field'] == df['home_team'], 1, 0)
    
    object_cols = ['home_team','away_team' ,'posteam', 'posteam_type','defteam',
                   'side_of_field','game_date','time','yrdln', 'desc','play_type']
    df = df.drop(columns = object_cols)
    return df 

In [4]:
train = process_data(train)

In [5]:
train.head()

Unnamed: 0,play_id,game_id,yardline_100,quarter_seconds_remaining,half_seconds_remaining,game_seconds_remaining,game_half,quarter_end,drive,sp,qtr,down,goal_to_go,ydstogo,ydsnet,yards_gained,shotgun,no_huddle,qb_dropback,qb_kneel,qb_spike,qb_scramble,home_timeouts_remaining,away_timeouts_remaining,timeout,posteam_timeouts_remaining,defteam_timeouts_remaining,total_home_score,total_away_score,posteam_score,defteam_score,score_differential,posteam_score_post,defteam_score_post,score_differential_post,no_score_prob,opp_fg_prob,opp_safety_prob,opp_td_prob,fg_prob,safety_prob,td_prob,extra_point_prob,two_point_conversion_prob,ep,epa,total_home_epa,total_away_epa,total_home_rush_epa,total_away_rush_epa,total_home_pass_epa,total_away_pass_epa,comp_air_epa,comp_yac_epa,total_home_comp_air_epa,total_away_comp_air_epa,total_home_comp_yac_epa,total_away_comp_yac_epa,total_home_raw_air_epa,total_away_raw_air_epa,total_home_raw_yac_epa,total_away_raw_yac_epa,wp,def_wp,home_wp,away_wp,wpa,home_wp_post,away_wp_post,total_home_rush_wpa,total_away_rush_wpa,total_home_pass_wpa,total_away_pass_wpa,comp_air_wpa,comp_yac_wpa,total_home_comp_air_wpa,total_away_comp_air_wpa,total_home_comp_yac_wpa,total_away_comp_yac_wpa,total_home_raw_air_wpa,total_away_raw_air_wpa,total_home_raw_yac_wpa,total_away_raw_yac_wpa,punt_blocked,first_down_rush,first_down_pass,first_down_penalty,third_down_converted,third_down_failed,fourth_down_converted,fourth_down_failed,incomplete_pass,interception,punt_inside_twenty,punt_in_endzone,punt_out_of_bounds,punt_downed,punt_fair_catch,kickoff_inside_twenty,kickoff_in_endzone,kickoff_out_of_bounds,kickoff_downed,kickoff_fair_catch,fumble_forced,fumble_not_forced,fumble_out_of_bounds,solo_tackle,safety,penalty,tackled_for_loss,fumble_lost,own_kickoff_recovery,own_kickoff_recovery_td,qb_hit,rush_attempt,pass_attempt,sack,touchdown,pass_touchdown,rush_touchdown,return_touchdown,extra_point_attempt,two_point_attempt,field_goal_attempt,kickoff_attempt,punt_attempt,fumble,complete_pass,assist_tackle,lateral_reception,lateral_rush,lateral_return,lateral_recovery,return_yards,replay_or_challenge,defensive_two_point_attempt,defensive_two_point_conv,defensive_extra_point_attempt,defensive_extra_point_conv,is_posteam_ARI,is_posteam_ATL,is_posteam_BAL,is_posteam_BUF,is_posteam_CAR,is_posteam_CHI,is_posteam_CIN,is_posteam_CLE,is_posteam_DAL,is_posteam_DEN,is_posteam_DET,is_posteam_GB,is_posteam_HOU,is_posteam_IND,is_posteam_JAC,is_posteam_KC,is_posteam_MIA,is_posteam_MIN,is_posteam_NE,is_posteam_NO,is_posteam_NYG,is_posteam_NYJ,is_posteam_OAK,is_posteam_PHI,is_posteam_PIT,is_posteam_SD,is_posteam_SEA,is_posteam_SF,is_posteam_STL,is_posteam_TB,is_posteam_TEN,is_posteam_WAS,home_ATL,home_BAL,home_BUF,home_CAR,home_CHI,home_CIN,home_CLE,home_DAL,home_DEN,home_DET,home_GB,home_HOU,home_IND,home_JAC,home_KC,home_MIA,home_MIN,home_NE,home_NO,home_NYG,home_NYJ,home_OAK,home_PHI,home_PIT,home_SD,home_SEA,home_SF,home_STL,home_TB,home_TEN,home_WAS,away_ATL,away_BAL,away_BUF,away_CAR,away_CHI,away_CIN,away_CLE,away_DAL,away_DEN,away_DET,away_GB,away_HOU,away_IND,away_JAC,away_KC,away_MIA,away_MIN,away_NE,away_NO,away_NYG,away_NYJ,away_OAK,away_PHI,away_PIT,away_SD,away_SEA,away_SF,away_STL,away_TB,away_TEN,away_WAS,posteam_is_home,pt_extra_point,pt_field_goal,pt_kickoff,pt_no_play,pt_pass,pt_punt,pt_qb_kneel,pt_qb_spike,pt_run,side_of_field_is_hometeam
0,46,2009091000,30.0,900.0,1800.0,3600.0,1,0,1,0,1,0.0,0.0,0,0,0.0,0,0,0.0,0,0,0,3,3,0.0,3.0,3.0,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.001506,0.179749,0.006639,0.281138,0.2137,0.003592,0.313676,0.0,0.0,0.323526,2.014474,2.014474,-2.014474,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,39.0,0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,0
1,68,2009091000,58.0,893.0,1793.0,3593.0,1,0,1,0,1,1.0,0.0,10,5,5.0,0,0,1.0,0,0,0,3,3,0.0,3.0,3.0,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.000969,0.108505,0.001061,0.169117,0.2937,0.003638,0.423011,0.0,0.0,2.338,0.077907,2.092381,-2.092381,0.0,0.0,0.077907,-0.077907,-0.938735,1.016643,-0.938735,0.938735,1.016643,-1.016643,-0.938735,0.938735,1.016643,-1.016643,0.546433,0.453567,0.546433,0.453567,0.004655,0.551088,0.448912,0.0,0.0,0.004655,-0.004655,-0.028383,0.033038,-0.028383,0.028383,0.033038,-0.033038,-0.028383,0.028383,0.033038,-0.033038,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,1,0,0,0,0,1
2,92,2009091000,53.0,856.0,1756.0,3556.0,1,0,1,0,1,2.0,0.0,5,2,-3.0,0,0,0.0,0,0,0,3,3,0.0,3.0,3.0,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.001057,0.105106,0.000981,0.162747,0.304805,0.003826,0.421478,0.0,0.0,2.415907,-1.40276,0.689621,-0.689621,-1.40276,1.40276,0.077907,-0.077907,0.0,0.0,-0.938735,0.938735,1.016643,-1.016643,-0.938735,0.938735,1.016643,-1.016643,0.551088,0.448912,0.551088,0.448912,-0.040295,0.510793,0.489207,-0.040295,0.040295,0.004655,-0.004655,0.0,0.0,-0.028383,0.028383,0.033038,-0.033038,-0.028383,0.028383,0.033038,-0.033038,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,1,1
3,113,2009091000,56.0,815.0,1715.0,3515.0,1,0,1,0,1,3.0,0.0,8,2,0.0,1,0,1.0,0,0,0,3,3,0.0,3.0,3.0,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.001434,0.149088,0.001944,0.234801,0.289336,0.004776,0.318621,0.0,0.0,1.013147,-1.712583,-1.022962,1.022962,-1.40276,1.40276,-1.634676,1.634676,0.0,0.0,-0.938735,0.938735,1.016643,-1.016643,2.473837,-2.473837,-4.108513,4.108513,0.510793,0.489207,0.510793,0.489207,-0.049576,0.461217,0.538783,-0.040295,0.040295,-0.044921,0.044921,0.0,0.0,-0.028383,0.028383,0.033038,-0.033038,0.081542,-0.081542,-0.126463,0.126463,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,1,0,0,0,0,1
4,139,2009091000,56.0,807.0,1707.0,3507.0,1,0,1,0,1,4.0,0.0,8,2,0.0,0,0,0.0,0,0,0,3,3,0.0,3.0,3.0,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.001861,0.21348,0.003279,0.322262,0.244603,0.006404,0.208111,0.0,0.0,-0.699436,2.097796,1.074834,-1.074834,-1.40276,1.40276,-1.634676,1.634676,0.0,0.0,-0.938735,0.938735,1.016643,-1.016643,2.473837,-2.473837,-4.108513,4.108513,0.461217,0.538783,0.461217,0.538783,0.097712,0.558929,0.441071,-0.040295,0.040295,-0.044921,0.044921,0.0,0.0,-0.028383,0.028383,0.033038,-0.033038,0.081542,-0.081542,-0.126463,0.126463,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,0,0,0,1


In [6]:
#train = train.drop(columns = 'posteam')
train.shape

(316538, 244)

In [7]:
# columns for classification are the columns for the team in posetion '(y)'
posteams = [col for col in train.columns if col.startswith('is_posteam_')]
len(posteams)

32

In [8]:
# all the other columns are the independent variables X
cols = [col for col in train.columns if col not in posteams ]
len(cols)

212

In [9]:
X = train[cols]
y = train[posteams]

### Classification 

In [10]:
from sklearn.tree import DecisionTreeClassifier 
from sklearn import metrics

In [11]:
# training Decision Tree Classifier
cross_validate(DecisionTreeClassifier(), X, y,
               cv = 5, scoring = 'accuracy')

{'fit_time': array([67.92008018, 75.37106299, 79.55239582, 68.36521602, 70.11745811]),
 'score_time': array([0.21009421, 0.23390079, 0.21365213, 0.21072268, 0.21160293]),
 'test_score': array([0.82063878, 0.78058065, 0.81424149, 0.78544237, 0.81668694])}

The decision tree classifier can predict the team in posation with an accuracy of ca. 80%

#### Testing different classification models 

In [None]:
cross_validate(RandomForestClassifier(n_estimators=100, n_jobs=-1), X, y,
               cv = 5, scoring = 'accuracy')

The Random Forest Classifier is less accurate with ca. 70%