In [93]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from nba_api.stats.static import players
from nba_api.stats.endpoints import playergamelog
from nba_api.stats.library.parameters import SeasonAll

In [297]:
player_dict = players.get_active_players()
butler = [player for player in player_dict if player['full_name'] == 'Jimmy Butler'][0]

In [298]:
butler_games = playergamelog.PlayerGameLog(player_id=butler['id'], season=SeasonAll.all).get_data_frames()[0]
butler_games['PTS_L5'] = butler_games['PTS'].rolling(window=5).mean().shift(-5)
butler_games = butler_games.dropna(subset=['PTS_L5'])
butler_games.head(10)

Unnamed: 0,SEASON_ID,Player_ID,Game_ID,GAME_DATE,MATCHUP,WL,MIN,FGM,FGA,FG_PCT,...,REB,AST,STL,BLK,TOV,PF,PTS,PLUS_MINUS,VIDEO_AVAILABLE,PTS_L5
0,22023,202710,22301189,"APR 14, 2024",MIA vs. TOR,W,24,6,9,0.667,...,5,4,0,0,1,1,15,6,1,20.0
1,22023,202710,22301176,"APR 12, 2024",MIA vs. TOR,W,29,5,7,0.714,...,3,7,0,0,0,0,14,7,1,21.2
2,22023,202710,22301161,"APR 10, 2024",MIA vs. DAL,L,36,5,8,0.625,...,4,3,2,0,5,0,12,-9,1,22.2
3,22023,202710,22301147,"APR 09, 2024",MIA @ ATL,W,44,7,14,0.5,...,8,9,3,0,3,1,25,2,1,20.6
4,22023,202710,22301133,"APR 07, 2024",MIA @ IND,L,39,7,16,0.438,...,7,8,1,0,0,1,27,-6,1,16.8
5,22023,202710,22301120,"APR 05, 2024",MIA @ HOU,W,29,6,14,0.429,...,3,3,1,0,2,1,22,1,1,15.4
6,22023,202710,22301111,"APR 04, 2024",MIA vs. PHI,L,40,7,17,0.412,...,4,5,0,0,1,0,20,10,1,14.8
7,22023,202710,22301096,"APR 02, 2024",MIA vs. NYK,W,39,5,12,0.417,...,5,6,0,0,4,3,17,10,1,17.4
8,22023,202710,22301081,"MAR 31, 2024",MIA @ WAS,W,35,5,11,0.455,...,7,4,0,0,0,1,17,-2,1,16.8
9,22023,202710,22301068,"MAR 29, 2024",MIA vs. POR,W,25,2,4,0.5,...,4,8,1,0,0,1,8,42,1,18.2


In [299]:
butler_games_tr = butler_games[butler_games['SEASON_ID'] != '22023']
butler_games_val = butler_games[butler_games['SEASON_ID'] == '22023']

In [300]:
pts_l5_tr = torch.tensor(butler_games_tr['PTS_L5'].values)
lines_tr = torch.randint(low=10, high=31, size=(pts_l5_tr.shape[0],))
Xtr = torch.stack((lines_tr, pts_l5_tr), dim=1).to(torch.float32)
pts_tr = torch.tensor(butler_games_tr['PTS'].values)
Ytr = (pts_tr > lines_tr).to(torch.float32)
Xtr.shape, Ytr.shape

(torch.Size([749, 2]), torch.Size([749]))

In [301]:
pts_l5_val = torch.tensor(butler_games_val['PTS_L5'].values)
lines_val = torch.randint(low=10, high=31, size=(pts_l5_val.shape[0],))
Xval = torch.stack((lines_val, pts_l5_val), dim=1).to(torch.float32)
pts_val = torch.tensor(butler_games_val['PTS'].values)
Yval = (pts_val > lines_val).to(torch.float32)
Xval.shape, Yval.shape

(torch.Size([60, 2]), torch.Size([60]))

In [178]:
class BinaryClassificationModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(in_features=2, out_features=1)
        self.sig = nn.Sigmoid()
    
    def forward(self, x):
        x = self.l1(x)
        x = torch.squeeze(x, dim=1)
        x = self.sig(x)
        return x

In [374]:
model = BinaryClassificationModel()
criterion = nn.BCELoss()
lr = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

In [376]:
num_epochs = 10000
for epoch in range(num_epochs):
    pred = model(Xtr)
    loss = criterion(pred, Ytr)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(loss)

tensor(0.4447, grad_fn=<BinaryCrossEntropyBackward0>)


In [438]:
all_games = pd.read_csv('all_games.csv')

In [457]:
def l5_points_player(df):
    df['PTS_L5'] = df['PTS'].rolling(window=5, min_periods=1).mean().shift(-5)
    return df

all_games_plus = all_games.groupby('Player_ID').apply(l5_points_player)
all_games_plus = all_games_plus.dropna(subset=['PTS_L5'])
all_games_plus.info()

<class 'pandas.core.frame.DataFrame'>
MultiIndex: 159731 entries, (2544, 0) to (1641931, 17)
Data columns (total 28 columns):
 #   Column           Non-Null Count   Dtype  
---  ------           --------------   -----  
 0   SEASON_ID        159731 non-null  object 
 1   Player_ID        159731 non-null  object 
 2   Game_ID          159731 non-null  object 
 3   GAME_DATE        159731 non-null  object 
 4   MATCHUP          159731 non-null  object 
 5   WL               159731 non-null  object 
 6   MIN              159731 non-null  object 
 7   FGM              159731 non-null  object 
 8   FGA              159731 non-null  object 
 9   FG_PCT           159731 non-null  float64
 10  FG3M             159731 non-null  object 
 11  FG3A             159731 non-null  object 
 12  FG3_PCT          159731 non-null  float64
 13  FTM              159731 non-null  object 
 14  FTA              159731 non-null  object 
 15  FT_PCT           159731 non-null  float64
 16  OREB             159731

  all_games_plus = all_games.groupby('Player_ID').apply(l5_points_player)
