In [6]:
import pandas as pd
import numpy as np

from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Embedding, Flatten

In [2]:
games_season = pd.read_csv('inputs/games_season.csv')
games_tourney = pd.read_csv('inputs/games_tourney.csv')
games_season.head()

Unnamed: 0,season,team_1,team_2,home,score_diff,score_1,score_2,won
0,1985,3745,6664,0,17,81,64,1
1,1985,126,7493,1,7,77,70,1
2,1985,288,3593,1,7,63,56,1
3,1985,1846,9881,1,16,70,54,1
4,1985,2675,10298,1,12,86,74,1


In [3]:
#Count the number of teams.
n_teams = np.unique(games_season[['team_1', 'team_2']]).shape[0]
n_teams

10888

In [4]:
#Create an embedding layer.
#The embedding layer maps each team ID to a single number representing the team's strngth.

team_lookup = Embedding(input_dim = n_teams, output_dim = 1, 
                       input_length = 1, name = 'Team-Strength')


### Define the team model.

In [10]:
#Create an input layer from the team ID.
teamid_in = Input(shape = (1, ))

#Look for the input in the team strength embedding layer.
strength_lookup = team_lookup(teamid_in)

#Flatten the output.
strength_lookup_flat = Flatten()(strength_lookup)

#Combine the operations into a single, re-usable model.
team_strength_model = Model(teamid_in, strength_lookup_flat, name = 'Team-Strength-Model')

In [12]:
team_strength_model.compile(loss = 'mae', 
                           optimizer = 'adam')

team_strength_model.summary()

Model: "Team-Strength-Model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 1)]               0         
                                                                 
 Team-Strength (Embedding)   (None, 1, 1)              10888     
                                                                 
 flatten (Flatten)           (None, 1)                 0         
                                                                 
Total params: 10,888
Trainable params: 10,888
Non-trainable params: 0
_________________________________________________________________
