In [1]:
%load_ext autoreload
%autoreload 2

import os

import numpy as np
import pandas as pd

from strangefish.models.model_training_utils import GameHistorySequence
from strangefish.models.uncertainty_transformer import uncertainty_transformer_1



In [2]:
model = uncertainty_transformer_1()
model_name = 'uncertainty_transformer_1'
# os.mkdir()

model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input (InputLayer)          [(None, None, 8, 8, 37)]  0         
                                                                 
 masking (Masking)           (None, None, 8, 8, 37)    0         
                                                                 
 reshape_1 (Reshape)         (None, None, 2368)        0         
                                                                 
 transformer_block (Transfor  (None, None, 2368)       45190976  
 merBlock)                                                       
                                                                 
 global_average_pooling1d (G  (None, 2368)             0         
 lobalAveragePooling1D)                                          
                                                                 
 dropout_2 (Dropout)         (None, 2368)              0     

In [10]:
from sklearn.model_selection import train_test_split

data_path = 'game_logs/historical_games_extended'

files = os.listdir(data_path)
files.sort()

train_data, test_data = train_test_split(files, test_size=0.2, random_state=42)

pd.DataFrame({'files': train_data}).to_csv('uncertainty_model/train_data.csv')
pd.DataFrame({'files': test_data}).to_csv('uncertainty_model/test_data.csv')

In [11]:
training_sequence = GameHistorySequence(train_data, data_path, 16)
test_sequence = GameHistorySequence(test_data, data_path, 16, shuffle=False)

In [12]:
from tensorflow.python.keras.callbacks import ModelCheckpoint

checkpoint = ModelCheckpoint(f'uncertainty_model/{model_name}', 'val_loss', verbose=1, mode='max')

In [None]:
hist = model.fit(training_sequence, validation_data=test_sequence, epochs=5, callbacks=checkpoint)

Epoch 1/5

In [None]:
model.save(f'uncertainty_model/{model_name}')

In [None]:
model.save_weights(f'uncertainty_model/{model_name}/weights')

In [None]:
import pickle
with open(f'uncertainty_model/{model_name}/train_hist', 'wb') as file_pi:
    pickle.dump(hist.history, file_pi)

In [16]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input (InputLayer)           [(None, None, 8, 8, 37)]  0         
_________________________________________________________________
masking (Masking)            (None, None, 8, 8, 37)    0         
_________________________________________________________________
conv_1 (Conv2D)              (None, None, 6, 6, 128)   42752     
_________________________________________________________________
tf.cast (TFOpLambda)         (None, None, 6, 6, 128)   0         
_________________________________________________________________
tf.compat.v1.nn.fused_batch_ ((None, None, 6, 6, 128), 0         
_________________________________________________________________
activation_1 (Activation)    (None, None, 6, 6, 128)   0         
_________________________________________________________________
dropout_1 (Dropout)          (None, None, 6, 6, 128)   0     

In [41]:
model.predict(np.array([game_map[0][0:1]]))

array([[[0.09206896]]], dtype=float32)

In [56]:
model.predict(test_sample)

array([[[9.65302661e-02],
        [1.84438210e-02],
        [1.08479137e-04],
        ...,
        [0.00000000e+00],
        [0.00000000e+00],
        [0.00000000e+00]],

       [[1.46613687e-01],
        [1.66606791e-02],
        [3.14770907e-04],
        ...,
        [0.00000000e+00],
        [0.00000000e+00],
        [0.00000000e+00]],

       [[9.01013464e-02],
        [1.72635075e-02],
        [8.06811586e-05],
        ...,
        [0.00000000e+00],
        [0.00000000e+00],
        [0.00000000e+00]],

       ...,

       [[1.16691105e-01],
        [1.57225449e-02],
        [8.28637785e-05],
        ...,
        [0.00000000e+00],
        [0.00000000e+00],
        [0.00000000e+00]],

       [[8.83309990e-02],
        [1.63214374e-02],
        [6.22796797e-05],
        ...,
        [0.00000000e+00],
        [0.00000000e+00],
        [0.00000000e+00]],

       [[1.09717838e-01],
        [1.31376991e-02],
        [2.62407644e-04],
        ...,
        [0.00000000e+00],
        [0.0000