In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import typing

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from collections import defaultdict

Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



In [3]:
exploded_train_df = pd.read_csv("gs://pitch-sequencing/sequence_data/full_sequence_data/v2/kitchensink/exploded/large_cur_train.csv")

In [34]:
exploded_train_df.head(10)

Unnamed: 0,pitch_sequence,count_sequence,zone_sequence,p_throws,stand,pitcher_id,batter_id,at_bat_number,game_date,at_bat_pitch_number,...,setup_count,input_pitch_sequence,events,zone,outs_when_up,type,bb_type,on_3b,on_2b,on_1b
0,"FF,CB","0-0,0-1",214,R,L,572955,656976,54,2021-04-27,2,...,0-1,FF,,14.0,2.0,B,,,,
1,"FF,CB,FF","0-0,0-1,1-1",21413,R,L,572955,656976,54,2021-04-27,3,...,1-1,"FF,CB",,13.0,2.0,S,,,,
2,"FF,CB,FF,FF","0-0,0-1,1-1,1-2",2141312,R,L,572955,656976,54,2021-04-27,4,...,1-2,"FF,CB,FF",,12.0,2.0,B,,,,
3,"FF,CB,FF,FF,CB","0-0,0-1,1-1,1-2,2-2",214131214,R,L,572955,656976,54,2021-04-27,5,...,2-2,"FF,CB,FF,FF",,14.0,2.0,B,,,,
4,"FF,CB,FF,FF,CB,CB","0-0,0-1,1-1,1-2,2-2,3-2",2141312148,R,L,572955,656976,54,2021-04-27,6,...,3-2,"FF,CB,FF,FF,CB",strikeout,8.0,2.0,S,,,,
5,"CB,FF","0-0,1-0",142,R,L,592836,677595,42,2023-09-22,2,...,1-0,CB,,2.0,2.0,S,,,,
6,"CB,FF,FC","0-0,1-0,1-1",14211,R,L,592836,677595,42,2023-09-22,3,...,1-1,"CB,FF",,11.0,2.0,B,,,,
7,"CB,FF,FC,FS","0-0,1-0,1-1,2-1",1421113,R,L,592836,677595,42,2023-09-22,4,...,2-1,"CB,FF,FC",field_out,13.0,2.0,X,fly_ball,,,
8,"FF,FF","0-0,0-1",1211,L,R,663531,593643,10,2021-04-05,2,...,0-1,FF,field_out,11.0,0.0,X,fly_ball,,,
9,"CB,FF","0-0,0-1",45,R,L,677651,645277,15,2021-11-02,2,...,0-1,CB,,5.0,0.0,S,,,,


In [60]:
import pitch_sequencing.ml.tokenizers.vocab as vocab
import pitch_sequencing.ml.data.generators as gen

from pitch_sequencing.ml.data.sequences import PitchSequenceDataset, CSVSequenceDataGenPlan, collate_interleaved_and_target
from pitch_sequencing.ml.tokenizers.pitch_arsenal import PitchArsenalLookupTable
from pitch_sequencing.ml.tokenizers.pitch_sequence import PitchSequenceTokenizer, SequenceInfo, SequenceID
from pitch_sequencing.ml.models.last_pitch import LastPitchTransformerModel
from pitch_sequencing.io.join import join_paths
from pitch_sequencing.io.gcs import save_model_to_gcs

arsenal_df = pd.read_csv("gs://pitch-sequencing/arsenal_data/pitch_arsenal_data.csv")
arsenal_lookup_table = PitchArsenalLookupTable(arsenal_df)

sequential_sequence_infos = [
    SequenceInfo(SequenceID.ARSENAL, arsenal_lookup_table.max_arsenal_size, vocab_ids=[vocab.VocabID.PITCHES]),
    SequenceInfo(SequenceID.HANDEDNESS, 2, vocab_ids=[vocab.VocabID.HANDEDNESS]),
    SequenceInfo(SequenceID.ON_BASE, 3, vocab_ids=[vocab.VocabID.BOOLEAN]),
]
sequential_sequence_gen_plans = [
    CSVSequenceDataGenPlan(SequenceID.ARSENAL, gen.ArsenalCSVGenerator(arsenal_lookup_table)),
    CSVSequenceDataGenPlan(SequenceID.HANDEDNESS, gen.HandednessCSVGenerator()),
    CSVSequenceDataGenPlan(SequenceID.ON_BASE, gen.OnBaseCSVGenerator()),
]

# Hardcode 63 for now.
interleaved_sequence_infos = SequenceInfo(SequenceID.INTERLEAVED, 63, vocab_ids=[vocab.VocabID.PITCHES, vocab.VocabID.COUNTS])
interleaved_sequence_gen_plans = [
    CSVSequenceDataGenPlan(SequenceID.INTERLEAVED, gen.DirectCSVLookupGenerator('count_sequence')),
    CSVSequenceDataGenPlan(SequenceID.PITCHES, gen.DirectCSVLookupGenerator('input_pitch_sequence')),
]

tokenizer = PitchSequenceTokenizer(sequential_sequence_infos, interleaved_sequence_infos, [vocab.PITCH_VOCAB, vocab.HANDEDNESS_VOCAB, vocab.BOOLEAN_VOCAB, vocab.COUNT_VOCAB])
train_dataset = PitchSequenceDataset(exploded_train_df, tokenizer, sequential_sequence_gen_plans, interleaved_sequence_gen_plans, target_df_key='target_pitch')

In [64]:
input, target = train_dataset[0]

[tokenizer.get_token_for_id(i) for i in input.src]

572955


['<start>',
 '<arsenal_start>',
 'FF',
 'CB',
 'FC',
 'SI',
 'SL',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<handedness_start>',
 'R',
 'L',
 '<on_base_start>',
 'F',
 'F',
 'F',
 '<interleaved_start>',
 '0-0',
 'FF',
 '0-1',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>']

In [65]:
model = LastPitchTransformerModel(tokenizer.vocab_size(), d_model=64, nhead=4, num_layers=2)
collate_fn = collate_interleaved_and_target
loss = nn.CrossEntropyLoss()

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)