# Temporal Fusion Transformer

This notebook will outline how we are applying the Temporal Fusion Transformer (TFT) model to predict MLB attendance. 

## Data Preparation

The first steps is to do some light data cleaning and then configure the data into a `TimeSeriesDataSet` format. 

In [2]:
!pip install pytorch_forecasting pytorch_lightning -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m141.4/141.4 KB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m715.6/715.6 KB[0m [31m57.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m827.8/827.8 KB[0m [31m66.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.8/30.8 MB[0m [31m54.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m308.2/308.2 KB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 KB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m60.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.2/212.2 KB[0m [31m25.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━

In [19]:
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import re
from pytorch_forecasting.metrics import QuantileLoss
import torch

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
data = pd.read_csv('/content/drive/MyDrive/MinneMUDAC DS Challenge/data/merged_features_gsr_v0.1.csv', parse_dates=[1], index_col=0, low_memory=False)
data.head()

Unnamed: 0,Date,NumberofGames,DayofWeek,VisitingTeam,VisitingTeamLeague,VisitingTeamGameNumber,HomeTeam,HomeTeamLeague,HomeTeamGameNumber,VistingTeamScore,...,visiting_div,league_match_flag,Team_home,total_wins_home,win_rate_home,day_league_rank_home,Team_visiting,total_wins_visiting,win_rate_visiting,day_league_rank_visiting
0,2000-04-03,0,Mon,NYA,AL,1,ANA,AL,1,3,...,AL_EAST,0,ANA,0,0.0,2.5,NYA,1,0.0,2.0
1,2000-04-04,0,Tue,NYA,AL,2,ANA,AL,2,5,...,AL_EAST,0,ANA,0,0.0,4.5,NYA,2,2.0,1.5
2,2000-04-05,0,Wed,NYA,AL,3,ANA,AL,3,6,...,AL_EAST,0,ANA,1,0.5,4.5,NYA,2,1.0,1.5
3,2000-04-07,0,Fri,BOS,AL,4,ANA,AL,4,3,...,AL_EAST,0,ANA,2,0.5,4.0,BOS,1,0.333333,5.0
4,2000-04-08,0,Sat,BOS,AL,5,ANA,AL,5,5,...,AL_EAST,0,ANA,3,0.6,2.5,BOS,1,0.25,5.0


In [6]:
# clean one of the attendance values that is -1
fix_tm = data[data['Attendance'] < 0]['HomeTeam'].iloc[0]
fix_idx = data[data['Attendance'] < 0].index
data.loc[fix_idx, 'Attendance'] = np.nan
data['Attendance'] = data.sort_values('Date')['Attendance'].interpolate()


The TFT model needs one column to be an increasing timestep for each time series. For this application each `HomeTeam` will be their own time series so we need to construct a timestep that increases for each game that is played. 

In [7]:
# add increasing time_idx
data['time_idx'] = 1
data['time_idx'] = data.sort_values('Date').groupby('HomeTeam')['time_idx'].cumsum()

TFT can process string values directly so they do not need to be one-hot encoded. However, all string variables need to be converted into categorical types. 

In [8]:
# create and convert columns to categoricals
data['month'] = data['Date'].dt.month_name().astype('category')
categories = [
    'DayofWeek',
    'VisitingTeam',
    'VisitingTeamLeague',
    'HomeTeam',
    'HomeTeamLeague',
    'DayNight',
    'VisitingTeam_StartingPitcher_ID',
    'HomeTeam_StartingPitcher_ID',
    'home_div',
    'visiting_div']
data[categories] = data[categories].astype('category')


Each feature needs to be of one of the following types:  
- `static_categoricals`: categorical features that are known and do not change over time for each time series. These can be different for different time series. For example `HomeTeam` is static for every time series because that is how we are grouping them. 

- `static_reals`: continous features that do not change over time. If all teams never changed stadiums then a feature like stadium capacity could fit this category. However, some teams have built new stadiums so we do not have any static real features.

- `time_varying_known_categoricals`: categorical features that are known into the future but do change with time. The most common example is the day of the week. We always know the day of the week but the value will change over time. 

-  `time_varying_known_reals`: continous features that are known but do change with time. The year or age of a stadium are features that change with time and are known into the future. While stadium age could be theoretically unknown say 100 years into the future because it is likely that a new stadium would be built but we don't know when, however, for the timespan that this model would be reasonably be expected to be used, these values will be known. 

- `time_varying_unknown_categorical`: these are categorical features that we measure but we do not know the value at future timesteps. This could be something like the starting lineup for a game. While we may know it for a single game we would not know the values for an entire season ahead of time. For this application none of our features fall within this calss

- `time_varying_unknown_reals`: these are continous features that we measure but are not known into the future. This category covers most of our variables. All the stats from a game fall within this category. 

In [9]:
# get names of all the statistic columns
pattern = '(V.*|H.*)Team(Offense|Defense|Pitchers)_'
stat_mask = [bool(re.match(pattern, col)) for col in data.columns]
stat_cols = data.columns[stat_mask].to_list()

As part of the `TimeSeriesDataSet` we have to define other important features:  
- `max_prediction_length`: how far into the future do we want to predict? We want to predict an entire season. While every team plays the same number of games, they do not play the same number of home games. The max number of home games in our dataset is 84 so that is our max prediction length.

- `max_encoder_length`: how big of a time series do we want to consider for predictions. This is a hyper-parameter to tune but to start off we want to predict the next season using the current season's information so we will also set this to 84. 

- `training_cutoff`: We want to have a validation set so we will only consider games through the 2019 season as training and use the 2021 and 2022 seasons as validation. 

In [10]:
# limit to MN Twins initially while building
data = data[data['HomeTeam'] == 'MIN']

In [22]:
# original
# build TFT
max_prediction_length = 84  # predict 1 season
max_encoder_length = 1540  # 18 seasons
training_cutoff = data['year'] <= 2019

training = TimeSeriesDataSet(
    data=data[training_cutoff],
    time_idx='time_idx',
    target='Attendance',
    group_ids=['HomeTeam'],
    max_encoder_length=max_encoder_length,
    min_encoder_length=78,
    max_prediction_length=max_prediction_length,
    min_prediction_length=78,
    min_prediction_idx=max_prediction_length,
    static_categoricals=['HomeTeam', 'HomeTeamLeague'],
    time_varying_known_categoricals=['DayofWeek', 'month', 'VisitingTeam', 'DayNight',
                                     'home_div', 'visiting_div'],
    time_varying_known_reals=['HomeTeamGameNumber', 'VisitingTeamGameNumber', 'year',
                              'park_age', 'league_match_flag'],
    # time_varying_unknown_categoricals=['pitchers'],
    # variable_groups={'pitchers':['VisitingTeam_StartingPitcher_ID',
    #                              'HomeTeam_StartingPitcher_ID']},
    time_varying_unknown_reals=['VistingTeamScore', 'HomeTeamScore', 'NumberofOuts',
                                'Attendance', 'LengthofGame', 'total_wins_home',
                                'win_rate_home', 'day_league_rank_home', 
                                'total_wins_visiting', 'win_rate_visiting', 'day_league_rank_visiting'],
    lags={'Attendance': [78]},
    target_normalizer=GroupNormalizer(
        groups=['HomeTeam'], transformation='softplus'),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True
)

In [23]:
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size*10, num_workers=0)

## Modeling  Training

With our models defined, the next step is to start model training. 

In [24]:
torch.cuda.device_count()

1

In [27]:
early_stop_callback = EarlyStopping(monitor='val_loss', min_delta=1e-4,
                                    patience=10, verbose=False, mode='min')
lr_logger = LearningRateMonitor()  # log the learning rate
# logger = TensorBoardLogger('lightning_logs')  # logging results to a tensorboard

trainer = pl.Trainer(
    max_epochs=10,
    gpus=1,
    enable_model_summary=True,
    gradient_clip_val=0.1,
    limit_train_batches=30,
    # fast_dev_run=True,
    callbacks=[early_stop_callback],
    # logger=logger,
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=7,
    loss=QuantileLoss(),
    log_interval=10,
    reduce_on_plateau_patience=4,
)
print(f'Number of parameters in network: {tft.size()/1e3:.1f}k')

  rank_zero_deprecation(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn(


Number of parameters in network: 35.6k


In [None]:
# fit network
test = trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1  | logging_metrics                    | ModuleList                      | 0     
2  | input_embeddings                   | MultiEmbedding                  | 423   
3  | prescalers                         | ModuleDict                      | 336   
4  | static_variable_selection          | VariableSelectionNetwork        | 1.9 K 
5  | encoder_variable_selection         | VariableSelectionNetwork        | 14.0 K
6  | decoder_variable_selection         | VariableSelectionNetwork        | 5.1 K 
7  | static_context_variable_selection  | GatedResidualNetwork            | 1.1 K 
8  | static_context_initial_hidde

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


AttributeError: ignored

In [None]:
# read columns to keep
keep = []
with open('column_names.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        keep.append(line.split(',')[0])

# write columns to file
with open('column_names.txt', 'w') as f:
    for k, col in zip(keep, data.columns):
        f.writelines(k + ',' + col + '\n')

1107.0

In [None]:
data['HomeTeam'].unique().shape

(30,)