We are going to take all of the games in a season - about 1200. We need to collect the shot data of each team... this will come from two places

1. play by play -- the key information here is the shot_type. eventmsgactiontype gives a shot type in the form of a number. 
there are several types of shots these can be categorized further but this may be future work 

2. shot chart detail -- this lets us gather details about the shots for a specific team for a specific season

In order to successfully merge these two repositories, we need to gather key information first.

1. We want to gather all of the team ids. 30 teams.
   • this gives us a way to categorize the rest of the information.
   • we can run into issues requesting so much data...

2. After gathering a list of teams, we will get the shot chart data for every game in a season for each team. We can gather the unique game_ids and game events to merge with our other shot data.

    • What is the necessary data for training?

       1. shot distance

       2. shot location? 

3. Using the unique games_ids, we want to get the play by play data. We can pin point the data for a specific team using the game event id.

    • What is the data we need from here? Primarily interested in shot type.

       There are over 30 shot types possibly 40. These can likely be split further.

# to put our minds at ease about the moving parts, lets import our important libraries

In [1]:
from nba_api.stats.endpoints import leaguegamefinder, playbyplayv2, ShotChartDetail
from nba_api.stats.static import teams
import pandas as pd
import time
import numpy as np


1. Gather team ids

In [5]:
nba_teams = teams.get_teams()
print(len(nba_teams), nba_teams[0])


30 {'id': 1610612737, 'full_name': 'Atlanta Hawks', 'abbreviation': 'ATL', 'nickname': 'Hawks', 'city': 'Atlanta', 'state': 'Georgia', 'year_founded': 1949}


In [7]:
tIDs={}
for tm in nba_teams:
    tIDs[tm['id']]=tm['full_name']
print(tIDs)

{1610612737: 'Atlanta Hawks', 1610612738: 'Boston Celtics', 1610612739: 'Cleveland Cavaliers', 1610612740: 'New Orleans Pelicans', 1610612741: 'Chicago Bulls', 1610612742: 'Dallas Mavericks', 1610612743: 'Denver Nuggets', 1610612744: 'Golden State Warriors', 1610612745: 'Houston Rockets', 1610612746: 'Los Angeles Clippers', 1610612747: 'Los Angeles Lakers', 1610612748: 'Miami Heat', 1610612749: 'Milwaukee Bucks', 1610612750: 'Minnesota Timberwolves', 1610612751: 'Brooklyn Nets', 1610612752: 'New York Knicks', 1610612753: 'Orlando Magic', 1610612754: 'Indiana Pacers', 1610612755: 'Philadelphia 76ers', 1610612756: 'Phoenix Suns', 1610612757: 'Portland Trail Blazers', 1610612758: 'Sacramento Kings', 1610612759: 'San Antonio Spurs', 1610612760: 'Oklahoma City Thunder', 1610612761: 'Toronto Raptors', 1610612762: 'Utah Jazz', 1610612763: 'Memphis Grizzlies', 1610612764: 'Washington Wizards', 1610612765: 'Detroit Pistons', 1610612766: 'Charlotte Hornets'}


In [None]:
2. Now we will get the shot data for each team in the 23-24 season

In [9]:
def get_shot_chart_data(player_id, team_id, season_type):
    try:
        # Fetch data using nba_api's ShotChartDetail endpoint
        shotchart = ShotChartDetail(
            player_id=player_id,
            team_id=team_id,
            context_measure_simple='FGA',
            season_type_all_star=season_type,  # Adjust season type (e.g., 'Regular Season', 'Playoffs')
            season_nullable='2023-24'
        )
        
        # Get the shot chart data from the API response
        shot_data = shotchart.get_data_frames()[0]  # Data is returned as a list of dataframes
        return shot_data
    except Exception as e:
        print(f"An error occurred: {e}")
        return pd.DataFrame()  # Return an empty DataFrame on error

In [14]:
player_id = 0  # Replace with the actual Player ID
season_type = 'Regular Season'
season_data={}
for tid in tIDs:
    
# Fetch and display shot chart data
    season_data[tid]=get_shot_chart_data(player_id, tid, season_type)

# Display the first few rows of the shot data
print(season_data[i])
print(season_data.keys())

              GRID_TYPE     GAME_ID  GAME_EVENT_ID  PLAYER_ID  \
0     Shot Chart Detail  0022300009             11    1630163   
1     Shot Chart Detail  0022300009             15    1629023   
2     Shot Chart Detail  0022300009             19    1630163   
3     Shot Chart Detail  0022300009             24    1641706   
4     Shot Chart Detail  0022300009             28    1631109   
...                 ...         ...            ...        ...   
7128  Shot Chart Detail  0022301216            587     202330   
7129  Shot Chart Detail  0022301216            596    1628970   
7130  Shot Chart Detail  0022301216            598    1641706   
7131  Shot Chart Detail  0022301216            617    1641706   
7132  Shot Chart Detail  0022301216            621    1626179   

          PLAYER_NAME     TEAM_ID          TEAM_NAME  PERIOD  \
0         LaMelo Ball  1610612766  Charlotte Hornets       1   
1     P.J. Washington  1610612766  Charlotte Hornets       1   
2         LaMelo Ball  1610

In [8]:
for i in tIDs:
    print(i)

1610612737
1610612738
1610612739
1610612740
1610612741
1610612742
1610612743
1610612744
1610612745
1610612746
1610612747
1610612748
1610612749
1610612750
1610612751
1610612752
1610612753
1610612754
1610612755
1610612756
1610612757
1610612758
1610612759
1610612760
1610612761
1610612762
1610612763
1610612764
1610612765
1610612766


In [33]:
def get_season_games(season_year='2023-24'):
    gamefinder = leaguegamefinder.LeagueGameFinder(season_nullable=season_year)
    games_df = gamefinder.get_data_frames()[0]
    return games_df['GAME_ID'].unique()
game_ids = get_season_games('2023-24')
game_ids

array(['0042300405', '0042300404', '0042300403', ..., '0012300001',
       '2012300002', '2012300001'], dtype=object)

In [34]:
len(game_ids)

2191

In [35]:
check=season_data[i]['GAME_ID'].unique()
len(check)

check

array(['0022300009', '0022300017', '0022300027', '0022300057',
       '0022300063', '0022300077', '0022300101', '0022300122',
       '0022300133', '0022300143', '0022300157', '0022300177',
       '0022300202', '0022300217', '0022300225', '0022300248',
       '0022300266', '0022300281', '0022300292', '0022300312',
       '0022300326', '0022300335', '0022300354', '0022300365',
       '0022300390', '0022300414', '0022300428', '0022300436',
       '0022300455', '0022300464', '0022300484', '0022300506',
       '0022300518', '0022300541', '0022300551', '0022300575',
       '0022300584', '0022300593', '0022300610', '0022300619',
       '0022300635', '0022300647', '0022300657', '0022300674',
       '0022300695', '0022300707', '0022300713', '0022300726',
       '0022300745', '0022300751', '0022300761', '0022300777',
       '0022300801', '0022300811', '0022300826', '0022300840',
       '0022300849', '0022300858', '0022300875', '0022300885',
       '0022300910', '0022300918', '0022300933', '00223

In [22]:
len(check)

82

In [28]:
temp=[]
gameSet=set()
for tm in season_data:
    check=season_data[tm]['GAME_ID'].unique()
    temp.append(check)
    gameSet.update(check)
print(len(temp))
print(len(gameSet))

30
1230


In [31]:
def get_play_by_play_data(gameSet):
    all_play_by_play_data = []
    dloadProg=0
    for game_id in gameSet:
        dloadProg+=1
        if dloadProg%20 == 0:
            print(str(dloadProg/1230*100)+'%')
        try:
            pbp = playbyplayv2.PlayByPlayV2(game_id=game_id)
            pbp_df = pbp.get_data_frames()[0]
            pbp_df['GAME_ID'] = game_id
            all_play_by_play_data.append(pbp_df)
            time.sleep(1)  # To avoid hitting rate limits
        except Exception as e:
            print(f"Error fetching data for game {game_id}: {e}")
    
    combined_df = pd.concat(all_play_by_play_data, ignore_index=True)
    return combined_df

In [32]:
pbp_all = get_play_by_play_data(gameSet)

1.6260162601626018%
3.2520325203252036%
4.878048780487805%
6.504065040650407%
8.130081300813007%
9.75609756097561%
11.38211382113821%
13.008130081300814%
14.634146341463413%
16.260162601626014%
17.88617886178862%
19.51219512195122%
21.138211382113823%
22.76422764227642%
24.390243902439025%
26.01626016260163%
27.64227642276423%
29.268292682926827%
30.89430894308943%
32.52032520325203%
34.146341463414636%
35.77235772357724%
37.39837398373984%
39.02439024390244%
40.65040650406504%
42.27642276422765%
43.90243902439025%
45.52845528455284%
47.15447154471545%
48.78048780487805%
50.40650406504065%
52.03252032520326%
53.65853658536586%
55.28455284552846%
56.91056910569105%
58.536585365853654%
60.16260162601627%
61.78861788617886%
63.41463414634146%
65.04065040650406%
66.66666666666666%
68.29268292682927%
69.91869918699187%
Error fetching data for game 0022301219: Expecting value: line 1 column 1 (char 0)
71.54471544715447%
73.17073170731707%
74.79674796747967%
76.42276422764228%
78.048780487804

Now that we have all of our data. We need to merge what we want and omit
what we dont need. we have a dictionary of season data for each team. the contents
of each key is a df. we can merge the individual dfs i think. lets try

In [36]:
len(pbp_all)

567185

In [37]:
len(pbp_all['GAME_ID'].unique())

1229

In [40]:
#pick a random game to test
pbp_all.iloc[0]

GAME_ID                                              0022300583
EVENTNUM                                                      2
EVENTMSGTYPE                                                 12
EVENTMSGACTIONTYPE                                            0
PERIOD                                                        1
WCTIMESTRING                                           10:11 PM
PCTIMESTRING                                              12:00
HOMEDESCRIPTION                                            None
NEUTRALDESCRIPTION           Start of 1st Period (10:11 PM EST)
VISITORDESCRIPTION                                         None
SCORE                                                      None
SCOREMARGIN                                                None
PERSON1TYPE                                                   0
PLAYER1_ID                                                    0
PLAYER1_NAME                                               None
PLAYER1_TEAM_ID                         

In [44]:
pbp_shots = pbp_all[pbp_all['EVENTMSGTYPE'].isin([1, 2])]

In [45]:
len(pbp_shots)

218510

In [43]:
type(pbp_all)

pandas.core.frame.DataFrame

In [48]:
pbp_shots.iloc[0]

GAME_ID                                                             0022300583
EVENTNUM                                                                     7
EVENTMSGTYPE                                                                 1
EVENTMSGACTIONTYPE                                                          98
PERIOD                                                                       1
WCTIMESTRING                                                          10:12 PM
PCTIMESTRING                                                             11:44
HOMEDESCRIPTION                                                           None
NEUTRALDESCRIPTION                                                        None
VISITORDESCRIPTION           Hield 5' Cutting Layup Shot (2 PTS) (Turner 1 ...
SCORE                                                                    2 - 0
SCOREMARGIN                                                                 -2
PERSON1TYPE                                         

In [46]:
#i am going to assign game events to each unique id i think

testDF1=season_data[1610612740]
testDF2=pbp_shots

In [47]:
print(testDF1['GAME_ID'][:20])
print(testDF2['GAME_ID'][:20])

0     0022300011
1     0022300011
2     0022300011
3     0022300011
4     0022300011
5     0022300011
6     0022300011
7     0022300011
8     0022300011
9     0022300011
10    0022300011
11    0022300011
12    0022300011
13    0022300011
14    0022300011
15    0022300011
16    0022300011
17    0022300011
18    0022300011
19    0022300011
Name: GAME_ID, dtype: object
2     0022300583
3     0022300583
4     0022300583
6     0022300583
7     0022300583
8     0022300583
9     0022300583
10    0022300583
12    0022300583
14    0022300583
16    0022300583
17    0022300583
19    0022300583
20    0022300583
22    0022300583
24    0022300583
29    0022300583
31    0022300583
34    0022300583
35    0022300583
Name: GAME_ID, dtype: object


In [54]:
df1Filter=testDF1[['GAME_ID','GAME_EVENT_ID','SHOT_DISTANCE','SHOT_ATTEMPTED_FLAG','SHOT_MADE_FLAG']]
df2Filter=testDF2[['GAME_ID','EVENTNUM','EVENTMSGACTIONTYPE','PERIOD']]
print(df1Filter.iloc[:10])
# print(df2Filter.iloc[0])

      GAME_ID  GAME_EVENT_ID  SHOT_DISTANCE  SHOT_ATTEMPTED_FLAG  \
0  0022300011             18              1                    1   
1  0022300011             22             24                    1   
2  0022300011             40              9                    1   
3  0022300011             46              1                    1   
4  0022300011             53              7                    1   
5  0022300011             61              1                    1   
6  0022300011             66             11                    1   
7  0022300011             69             26                    1   
8  0022300011             73             11                    1   
9  0022300011             75              1                    1   

   SHOT_MADE_FLAG  
0               1  
1               1  
2               1  
3               1  
4               1  
5               1  
6               1  
7               0  
8               0  
9               1  


In [55]:
merged_df = pd.merge(df1Filter, df2Filter, 
                     left_on=['GAME_ID', 'GAME_EVENT_ID'],  # Columns from df1
                     right_on=['GAME_ID', 'EVENTNUM'])  # Columns from df2

print(merged_df.iloc[:10])

      GAME_ID  GAME_EVENT_ID  SHOT_DISTANCE  SHOT_ATTEMPTED_FLAG  \
0  0022300011             18              1                    1   
1  0022300011             22             24                    1   
2  0022300011             40              9                    1   
3  0022300011             46              1                    1   
4  0022300011             53              7                    1   
5  0022300011             61              1                    1   
6  0022300011             66             11                    1   
7  0022300011             69             26                    1   
8  0022300011             73             11                    1   
9  0022300011             75              1                    1   

   SHOT_MADE_FLAG  EVENTNUM  EVENTMSGACTIONTYPE  PERIOD  
0               1        18                 108       1  
1               1        22                   1       1  
2               1        40                  86       1  
3               1  

In [56]:
season_data[1610612740]

Unnamed: 0,GRID_TYPE,GAME_ID,GAME_EVENT_ID,PLAYER_ID,PLAYER_NAME,TEAM_ID,TEAM_NAME,PERIOD,MINUTES_REMAINING,SECONDS_REMAINING,...,SHOT_ZONE_AREA,SHOT_ZONE_RANGE,SHOT_DISTANCE,LOC_X,LOC_Y,SHOT_ATTEMPTED_FLAG,SHOT_MADE_FLAG,GAME_DATE,HTM,VTM
0,Shot Chart Detail,0022300011,18,202685,Jonas Valanciunas,1610612740,New Orleans Pelicans,1,11,13,...,Center(C),Less Than 8 ft.,1,-13,7,1,1,20231110,HOU,NOP
1,Shot Chart Detail,0022300011,22,202685,Jonas Valanciunas,1610612740,New Orleans Pelicans,1,10,43,...,Center(C),24+ ft.,24,76,235,1,1,20231110,HOU,NOP
2,Shot Chart Detail,0022300011,40,1627742,Brandon Ingram,1610612740,New Orleans Pelicans,1,9,37,...,Right Side(R),8-16 ft.,9,87,29,1,1,20231110,HOU,NOP
3,Shot Chart Detail,0022300011,46,1629627,Zion Williamson,1610612740,New Orleans Pelicans,1,9,2,...,Center(C),Less Than 8 ft.,1,12,2,1,1,20231110,HOU,NOP
4,Shot Chart Detail,0022300011,53,1629627,Zion Williamson,1610612740,New Orleans Pelicans,1,8,21,...,Center(C),Less Than 8 ft.,7,-11,78,1,1,20231110,HOU,NOP
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7160,Shot Chart Detail,0022301230,636,1641722,Jordan Hawkins,1610612740,New Orleans Pelicans,4,2,35,...,Right Side(R),16-24 ft.,16,152,64,1,1,20231207,LAL,NOP
7161,Shot Chart Detail,0022301230,642,1630526,Jeremiah Robinson-Earl,1610612740,New Orleans Pelicans,4,1,58,...,Left Side Center(LC),16-24 ft.,18,-98,151,1,0,20231207,LAL,NOP
7162,Shot Chart Detail,0022301230,646,1641722,Jordan Hawkins,1610612740,New Orleans Pelicans,4,1,31,...,Left Side(L),16-24 ft.,21,-206,71,1,0,20231207,LAL,NOP
7163,Shot Chart Detail,0022301230,649,1641722,Jordan Hawkins,1610612740,New Orleans Pelicans,4,0,59,...,Center(C),16-24 ft.,18,-32,179,1,0,20231207,LAL,NOP


In [58]:
print(merged_df.iloc[-10:])

         GAME_ID  GAME_EVENT_ID  SHOT_DISTANCE  SHOT_ATTEMPTED_FLAG  \
7155  0022301230            603             26                    1   
7156  0022301230            607              2                    1   
7157  0022301230            613             26                    1   
7158  0022301230            627              2                    1   
7159  0022301230            631              9                    1   
7160  0022301230            636             16                    1   
7161  0022301230            642             18                    1   
7162  0022301230            646             21                    1   
7163  0022301230            649             18                    1   
7164  0022301230            652             28                    1   

      SHOT_MADE_FLAG  EVENTNUM  EVENTMSGACTIONTYPE  PERIOD  
7155               0       603                 103       4  
7156               1       607                   6       4  
7157               0       613     

ok we can now do this for every team

In [62]:
pbpShotFILT=pbp_shots[['GAME_ID','EVENTNUM','EVENTMSGACTIONTYPE','PERIOD']]
seasonFILT={}

In [74]:
for tm in season_data:
    tmFILT=season_data[tm][['GAME_ID','GAME_EVENT_ID','SHOT_DISTANCE','SHOT_MADE_FLAG']]
    mf=pd.merge(tmFILT, pbpShotFILT, 
                     left_on=['GAME_ID', 'GAME_EVENT_ID'],  # Columns from df1: pbpShotFILT
                     right_on=['GAME_ID', 'EVENTNUM'])
    seasonFILT[tm]=mf.drop(columns='GAME_EVENT_ID')

In [75]:
print(seasonFILT[list(seasonFILT.keys())[0]])

         GAME_ID  SHOT_DISTANCE  SHOT_MADE_FLAG  EVENTNUM  EVENTMSGACTIONTYPE  \
0     0022300018              6               1        12                  78   
1     0022300018             26               1        16                   1   
2     0022300018             15               1        25                  79   
3     0022300018              3               1        28                   5   
4     0022300018              2               1        32                   5   
...          ...            ...             ...       ...                 ...   
7579  0022301218             27               0       662                   1   
7580  0022301218              9               0       665                  97   
7581  0022301218              4               1       667                   1   
7582  0022301218             29               1       676                  79   
7583  0022301218             28               0       683                  79   

      PERIOD  
0          1

In [91]:
type(seasonFILT[list(seasonFILT.keys())[0]]['PERIOD'][0])

numpy.int64

In [68]:
# I could add a win/loss ratio to this to help rank instead of just categorize. 
    #there is also more positional data i could incorporate...

seasonFILT now has every shot from the season separated by each team. At this point we may be able to train our model

At first I was considering giving a stat line or some ratio/expected value of shot selections. Instead I will elect to classify which team has the shot profile that most accurately aligns with the input/tested/selected shot/parameters (in a specific matchup?).

In [97]:
dfs = []

# Iterate through the dictionary and add a new column with the dictionary keys
for key, df in seasonFILT.items():
    df['TEAM_ID'] = key  # Add the key as a new column
    dfs.append(df)      # Add the DataFrame to the list

# Concatenate all DataFrames in the list into a single DataFrame
sznFILTdf = pd.concat(dfs, ignore_index=True)

# Display the resulting DataFrame
print(sznFILTdf)

           GAME_ID  SHOT_DISTANCE  SHOT_MADE_FLAG  EVENTNUM  \
0       0022300018              6               1        12   
1       0022300018             26               1        16   
2       0022300018             15               1        25   
3       0022300018              3               1        28   
4       0022300018              2               1        32   
...            ...            ...             ...       ...   
218503  0022301216              2               1       587   
218504  0022301216             23               1       596   
218505  0022301216             26               1       598   
218506  0022301216              7               0       617   
218507  0022301216             22               1       621   

        EVENTMSGACTIONTYPE  PERIOD     TEAM_ID  FITS_PROFILE  
0                       78       1  1610612737             0  
1                        1       1  1610612737             0  
2                       79       1  1610612737        

In [98]:

# Example DataFrame setup
# Replace this with your actual `sznFILTdf` DataFrame
# sznFILTdf = pd.read_csv('path_to_your_file.csv')

# Step 1: Create a team shot profile based on shot distance and shot type
team_profiles = {}

# Group data by team_id
for team_id, df in sznFILTdf.groupby('TEAM_ID'):
    # Get shot distance and shot type stats per team
    shot_stats = df.groupby(['SHOT_DISTANCE', 'EVENTMSGACTIONTYPE']).agg({
        'SHOT_MADE_FLAG': ['mean', 'count']  # Mean shot success rate and count of attempts
    }).reset_index()
    
    # Rename columns for clarity
    shot_stats.columns = ['SHOT_DISTANCE', 'EVENTMSGACTIONTYPE', 'MEAN_SHOT_SUCCESS_RATE', 'COUNT']
    
    # Store the profile in the dictionary
    team_profiles[team_id] = shot_stats

# Example to show the output
for team_id, profile in team_profiles.items():
    print(f"Team: {team_id}")
    print(profile.head())  # Displaying only the first few rows for brevity


Team: 1610612737
   SHOT_DISTANCE  EVENTMSGACTIONTYPE  MEAN_SHOT_SUCCESS_RATE  COUNT
0              0                   3                0.500000      2
1              0                   5                0.800000     20
2              0                   6                0.603774     53
3              0                   7                0.923077     26
4              0                   9                0.950000     20
Team: 1610612738
   SHOT_DISTANCE  EVENTMSGACTIONTYPE  MEAN_SHOT_SUCCESS_RATE  COUNT
0              0                   3                1.000000      1
1              0                   5                0.909091     11
2              0                   6                0.769231     39
3              0                   7                1.000000      5
4              0                   9                0.945946     37
Team: 1610612739
   SHOT_DISTANCE  EVENTMSGACTIONTYPE  MEAN_SHOT_SUCCESS_RATE  COUNT
0              0                   3                0.750000     

In [99]:
# Step 2: Define a function to compare a shot to a team's profile (including shot type)
def does_shot_fit_profile(shot, team_id, team_profiles):
    shot_distance = shot['SHOT_DISTANCE']
    shot_type = shot['EVENTMSGACTIONTYPE']
    
    # Get the team's profile for shot distance and type
    team_profile = team_profiles.get(team_id)
    
    if team_profile is not None:
        # Match both shot distance and shot type
        closest_match = team_profile[
            (team_profile['SHOT_DISTANCE'] == shot_distance) & 
            (team_profile['EVENTMSGACTIONTYPE'] == shot_type)
        ]
        
        if not closest_match.empty:
            # Define the rule for fitting the profile (e.g., success rate > 50%)
            if closest_match['MEAN_SHOT_SUCCESS_RATE'].values[0] > 0.5:  # Example rule
                return 1  # Shot fits the profile
            else:
                return 0  # Shot does not fit
    return 0  # Default: Does not fit

# Step 3: Apply the function to generate the target column 'FITS_PROFILE'
def generate_profile_labels(df, team_profiles):
    df['FITS_PROFILE'] = df.apply(lambda row: does_shot_fit_profile(row, row['TEAM_ID'], team_profiles), axis=1)
    return df

# Apply to the whole dataset (sznFILTdf)
# Group data by team_id
team_profiles = {}

for team_id, df in sznFILTdf.groupby('TEAM_ID'):
    # Get shot distance and shot type stats per team
    shot_stats = df.groupby(['SHOT_DISTANCE', 'EVENTMSGACTIONTYPE']).agg({
        'SHOT_MADE_FLAG': ['mean', 'count']  # Mean shot success rate and count of attempts
    }).reset_index()
    
    # Rename columns for clarity
    shot_stats.columns = ['SHOT_DISTANCE', 'EVENTMSGACTIONTYPE', 'MEAN_SHOT_SUCCESS_RATE', 'COUNT']
    
    # Store the profile in the dictionary
    team_profiles[team_id] = shot_stats

# Apply to the whole dataset (sznFILTdf)
sznFILTdf = generate_profile_labels(sznFILTdf, team_profiles)

# Check the results
print(sznFILTdf.head())


      GAME_ID  SHOT_DISTANCE  SHOT_MADE_FLAG  EVENTNUM  EVENTMSGACTIONTYPE  \
0  0022300018              6               1        12                  78   
1  0022300018             26               1        16                   1   
2  0022300018             15               1        25                  79   
3  0022300018              3               1        28                   5   
4  0022300018              2               1        32                   5   

   PERIOD     TEAM_ID  FITS_PROFILE  
0       1  1610612737             0  
1       1  1610612737             0  
2       1  1610612737             0  
3       1  1610612737             0  
4       1  1610612737             0  


In [100]:
# Check the results
print(sznFILTdf)

           GAME_ID  SHOT_DISTANCE  SHOT_MADE_FLAG  EVENTNUM  \
0       0022300018              6               1        12   
1       0022300018             26               1        16   
2       0022300018             15               1        25   
3       0022300018              3               1        28   
4       0022300018              2               1        32   
...            ...            ...             ...       ...   
218503  0022301216              2               1       587   
218504  0022301216             23               1       596   
218505  0022301216             26               1       598   
218506  0022301216              7               0       617   
218507  0022301216             22               1       621   

        EVENTMSGACTIONTYPE  PERIOD     TEAM_ID  FITS_PROFILE  
0                       78       1  1610612737             0  
1                        1       1  1610612737             0  
2                       79       1  1610612737        

In [101]:
# Check the number of instances where shots fit the profile
fit_counts = sznFILTdf['FITS_PROFILE'].value_counts()

# Print the counts
print(fit_counts)


FITS_PROFILE
0    147702
1     70806
Name: count, dtype: int64


In [103]:
cont_names = ['SHOT_DISTANCE']  # Continuous variables
cat_names = ['EVENTMSGACTIONTYPE', 'PERIOD', 'TEAM_ID']  # Categorical variables including team_id
y_names = 'FITS_PROFILE'  # Target variable

# Create DataLoaders
dataloaders = TabularDataLoaders.from_df(train_df, 
                                         valid_df=valid_df, 
                                         path='.', 
                                         procs=[Categorify, FillMissing, Normalize], 
                                         cat_names=cat_names, 
                                         cont_names=cont_names, 
                                         y_names=y_names, 
                                         bs=64)  # Adjust batch size as needed



0.3240430556318304

In [124]:
cont_names = ['SHOT_DISTANCE']  # Continuous variables
cat_names = ['EVENTMSGACTIONTYPE', 'PERIOD', 'TEAM_ID']  # Categorical variables including team_id
y_names = 'FITS_PROFILE'  # Target variable


# Create DataLoaders
dataloaders = TabularDataLoaders.from_df(sznFILTdf, 
                                         path='.', 
                                         procs=[Categorify,FillMissing], 
                                         cat_names=cat_names, 
                                         cont_names=cont_names, 
                                         y_names=y_names,
                                         valid_pct=0.2,
                                         bs=128)  # Adjust batch size as needed


In [125]:
# Create a TabularLearner
learn = tabular_learner(dataloaders, 
                        metrics=accuracy)  # You can use different metrics as needed

# Train the model
learn.fit_one_cycle(5)  # Number of epochs


epoch,train_loss,valid_loss,accuracy,time
0,0.092699,0.089324,0.675774,00:13
1,0.0866,0.085361,0.675774,00:13
2,0.079786,0.079189,0.675774,00:13
3,0.077913,0.07961,0.675774,00:13
4,0.076814,0.080899,0.675774,00:13


In [126]:
# Check accuracy
learn.show_results()

# Plot confusion matrix
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()


Unnamed: 0,EVENTMSGACTIONTYPE,PERIOD,TEAM_ID,SHOT_DISTANCE,FITS_PROFILE,FITS_PROFILE_pred
0,11.0,1.0,29.0,6.0,0.0,0.252126
1,1.0,1.0,25.0,25.0,0.0,0.001477
2,10.0,1.0,29.0,2.0,1.0,0.901876
3,27.0,2.0,10.0,15.0,0.0,0.027947
4,36.0,4.0,20.0,2.0,1.0,0.934489
5,4.0,3.0,29.0,4.0,0.0,0.165376
6,28.0,1.0,10.0,25.0,0.0,0.00684
7,5.0,2.0,23.0,1.0,1.0,1.008828
8,1.0,1.0,18.0,27.0,0.0,0.019377


AttributeError: vocab

In [132]:
# Example of making predictions
new_data = pd.DataFrame({
    'SHOT_DISTANCE': [4],
    'EVENTMSGACTIONTYPE': [7],
    'PERIOD': [2],
    'TEAM_ID':[5.0]
})

# Convert to DataLoader format
new_dl = learn.dls.test_dl(new_data)

# Make predictions
preds, _ = learn.get_preds(dl=new_dl)
print(preds)


tensor([[0.5698]])
