In [None]:
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import csv
import glob
import random
import cv2
from PIL import Image

## Initialize models
### Define constants
Set the width and height of images that will be input to the models. Define the timesteps for the recurrent neural network. The timesteps is equal the number of images in the sequence of frames that you want the AI to analyze at a time.

In [None]:
WIDTH = 96
HEIGHT = 96
TIMESTEPS = 5

### Convolutional neural network

This project uses a simple architecture for the convolutional neural network with 3 convoltuions in order to optimize the model for speed. [This tutorial](https://developers.google.com/machine-learning/practica/image-classification) is a good hands on example of how convolutional neural networks work.

In [None]:
cnn_model = keras.models.Sequential()

# First convolution
cnn_model.add(keras.layers.Conv2D(16, 3, activation='relu', input_shape=(WIDTH, HEIGHT, 1)))
cnn_model.add(keras.layers.MaxPooling2D(2))

# Second convolution
cnn_model.add(keras.layers.Conv2D(32, 3, activation='relu'))
cnn_model.add(keras.layers.MaxPooling2D(2))

# Third convolution
cnn_model.add(keras.layers.Conv2D(64, 3, activation='relu'))
cnn_model.add(keras.layers.MaxPooling2D(2))

# Flaten so that we can pass the output to the recurrent neural network
cnn_model.add(keras.layers.Flatten())

cnn_model.summary()

### Recurrent neural network
The recurrent neural network used in this project is a simple [long short-term memory (LSTM)](https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM) network. The [TimeDistributed](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TimeDistributed) layer allows us to connect the convolutional neural network to the recurrent neural network. The `TIMESTEPS` variable defines how many images the recurrent neural network will analyze at a time.

In [None]:
model = keras.models.Sequential()
model.add(keras.layers.TimeDistributed(cnn_model, input_shape=(TIMESTEPS, WIDTH, HEIGHT, 1)))
model.add(keras.layers.LSTM(256, return_sequences=False))
model.add(keras.layers.Dense(64, activation='relu'))
model.add(keras.layers.Dropout(0.2))
model.add(keras.layers.Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['acc', keras.metrics.Recall(), keras.metrics.Precision()])

model.summary()

### Configure checkpoints
We will use TensorFlow's [ModelCheckpoint](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint) to save the weights of the model at each epoch so that we select the weights from the best epoch to use for the final model once training is complete.

In [None]:
cp_filepath = 'swing_checkpoints/cp-{epoch:04d}.ckpt'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=cp_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=False)

## Load annotations
While collecting the data, I separated the data for pitches where I hit the ball, and pitches where I purposely didn't hit the ball, into different directories. Each pitch has it's own csv file associated with it. The csv file is formated like `filename,label`, where `filename` is the file name of an image for a single frame in the pitch. And `label` is 0 if a swing wasn't initiated during that frame, or 1 if a swing was initiated during that frame. The vast majority of data will be labeled 0, so we will separate the data to deal with some imbalanced data.
### Load annotations for hits
The data for hits will always end as soon as a swing is initiated, so the last image in each csv file will be for the frame where the swing was initiated, and therefore will be labeled as 1. I categorized data for hits into three differnent catagories:
- Images where the ball is a significant distance away from the batter (more than two frames away from when the swing was initiated). This data falls into `csv_labels_other`
- Images where the ball is one frame before the swing is initiated. This data falls into `csv_labels_hits_0`
- Images where the swing is initiated `csv_labels_1`

In [None]:
CSV_HITS_DIR = '../examples/csvs_hits/'
IMAGE_HITS_DIR = '../examples/ims_hits/'
csv_hits_paths = glob.glob(f'{CSV_HITS_DIR}*.csv')

csv_labels_other = []
csv_labels_strikes_0 = []
csv_labels_1 = []

# read each csv file line by line
for csv_path in csv_hits_paths:
    csv_data = []
    with open(csv_path, newline='') as csv_file:
        reader = csv.reader(csv_file, delimiter=',')
        for row in reader:
            csv_data.append(row)
         
    # if the number of images is less than the number of timesteps required then skip this pitch
    if len(csv_data) < TIMESTEPS:
        continue
        
    # create all possible time windows for the pitch
    all_windows = []
    for i in range(len(csv_data) - (TIMESTEPS - 1)):
        window = [csv_data[i][0]]
        for j in range(1, TIMESTEPS):
            window.append(csv_data[i+j][0])
        all_windows.append(window)
    
    # these are time windows where the ball is still far away from the batter
    for i in range(len(all_windows) - 2):
        csv_labels_other.append([0, all_windows[i], IMAGE_HITS_DIR])
    
    # this is the time window one frame before the swing is initiated
    csv_labels_strikes_0.append([0, all_windows[-2], IMAGE_HITS_DIR])
    
    # this is the time window where the swing is initiated
    csv_labels_1.append([1, all_windows[-1], IMAGE_HITS_DIR])

### Load annotations for balls
Any pitches that were balls were pitches that I purposely didn't swing at. The data for balls ends approximately once the ball has passed the batter and is off the screen. I categorized data for balls into two differnent catagories:
- Images where the ball is 3, 6, or 9 frames before the last frame. This data falls into `csv_labels_balls`. This data is meant to train the AI to recognize images of balls when they are in approximately similar positions to when the AI would need to initiate a swing
- All other images fall into `csv_labels_other`

In [None]:
CSV_BALLS_DIR = '../examples/csvs_balls/'
IMAGE_BALLS_DIR = '../examples/ims_balls/'
csv_balls_paths = glob.glob(f'{CSV_BALLS_DIR}*.csv')

csv_labels_balls = []

# read each csv file line by line
for csv_path in csv_balls_paths:
    csv_data = []
    with open(csv_path, newline='') as csv_file:
        reader = csv.reader(csv_file, delimiter=',')
        for row in reader:
            csv_data.append(row)
    
    # if the number of images is less than the number of timesteps required then skip this pitch
    if len(csv_data) < 9 or len(csv_data) < TIMESTEPS:
        continue
    
    # create all possible time windows for the pitch
    all_windows = []
    for i in range(len(csv_data) - (TIMESTEPS - 1)):
        window = [csv_data[i][0]]
        for j in range(1, TIMESTEPS):
            window.append(csv_data[i+j][0])
        all_windows.append(window)
    
    save_idxs = [3, 6, 9]
    for i in range(len(all_windows) - 2):
        if (len(all_windows) - i) in save_idxs:
            csv_labels_balls.append([0, all_windows[i], IMAGE_BALLS_DIR])
        csv_labels_other.append([0, all_windows[i], IMAGE_BALLS_DIR])

### Combine annotations for hits and balls
Combines all of the annotations for hits and balls and separates the annotations into training and validation labels

In [None]:
TRAIN_SPLIT = 0.8
VALIDATION_SPLIT = 0.2

# shuffle all of the annotations
random.shuffle(csv_labels_other)
random.shuffle(csv_labels_balls)
random.shuffle(csv_labels_strikes_0)
random.shuffle(csv_labels_1)

# number of "other" images will be equivalent to number of images where a swing was initiated in order
# to prevent a largely unbalanced dataset
csv_labels_other = csv_labels_other[:len(csv_labels_1)]

# combine all 0 labels together and split validation and training
csv_labels_0_train = csv_labels_balls[:int(len(csv_labels_balls)*TRAIN_SPLIT)] + csv_labels_strikes_0[:int(len(csv_labels_strikes_0)*TRAIN_SPLIT)] + csv_labels_other[:int(len(csv_labels_other)*TRAIN_SPLIT)]
csv_labels_0_val = csv_labels_balls[int(len(csv_labels_balls)*TRAIN_SPLIT):] + csv_labels_strikes_0[int(len(csv_labels_strikes_0)*TRAIN_SPLIT):] + csv_labels_other[int(len(csv_labels_other)*TRAIN_SPLIT):]

# split validation and training for 1 labels
csv_labels_1_train = csv_labels_1[:int(len(csv_labels_1)*TRAIN_SPLIT)]
csv_labels_1_val = csv_labels_1[int(len(csv_labels_1)*TRAIN_SPLIT):]

# combine all training and validation labels together and shuffle
csv_labels_train = csv_labels_0_train + csv_labels_1_train
csv_labels_val = csv_labels_0_val + csv_labels_1_val
random.shuffle(csv_labels_train)
random.shuffle(csv_labels_val)

csv_labels = csv_labels_train + csv_labels_val
print(len(csv_labels))

## Load training data
### Preprocess images
Function to preprocess images. The process is the mask the image, convert the image to grayscale, resize the image, and then normalize the data.

In [None]:
def preprocess_image(image):
    # convert from BGR to HSV image in order to apply color mask
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    masked_im = cv2.inRange(hsv, (0, 0, 0), (179, 57, 255))
    masked_im = cv2.bitwise_and(image, image, mask=masked_im)
    
    # conver image to grayscale
    masked_im = cv2.cvtColor(masked_im, cv2.COLOR_BGR2GRAY)

    # resize image
    masked_im = cv2.resize(masked_im, (WIDTH, HEIGHT))
    
    # normalize values to be between 0 and 1
    masked_im = masked_im / 255
    
    masked_im = masked_im.reshape(WIDTH, HEIGHT, 1)
    return masked_im

### Load images
Use the previously made annotations (csv_labels) to create the final data and labels set for the AI

In [None]:
data = []
labels = []

count = 0
total = len(csv_labels)
for label in csv_labels:
    labels.append(label[0])
    image_dir = label[2]
    
    window = []
    for im_name in label[1]:
        image = cv2.imread(f'{image_dir}{im_name}')
        image = preprocess_image(image)
        window.append(image)
    data.append(window)
    
    count += 1
    # print progress of loading images
    print(f'{int((count/total)*100)}%', end='\r', flush=True)
    
labels = np.array(labels)
data = np.array(data)
data = data.reshape(len(data), TIMESTEPS, WIDTH, HEIGHT, 1)

## Train model
Train the model on the data collected. You can adjust the `BATCH_SIZE` and `EPOCHS` paramaters as you like, however, these values worked best for me. The model does not need to train for very long before it will start to overfit.

In [None]:
BATCH_SIZE = 64
EPOCHS = 50

history = model.fit(
    x=data,
    y=labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_split=VALIDATION_SPLIT,
    callbacks=[model_checkpoint_callback],
)

## Evaluate model
Evaluate the loss, accuracy, [precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) of the model to determine the best checkpoint to load the weights from. The most important metrics here are precision and recall. Generally speaking, a model with low precision will swing at too many "bad" pitches. And a model with low recall will not swing at enough "good" pitches. Therefore, you should select a model that has decent and similar recall and precision.

In [None]:
# Retrieve a list of loss results on training and validation data sets for each training epoch
loss = history.history['loss']
val_loss = history.history['val_loss']
print(f'Checkpoint with lowest loss: {np.argmin(np.array(val_loss))}')

# Retrieve a list of accuracy results on training and validation data sets for each training epoch
acc = history.history['acc']
val_acc = history.history['val_acc']
print(f'Checkpoint with highest accuracy: {np.argmax(np.array(val_acc))}')

# Retrieve a list of recall results on training and validation data sets for each training epoch
recall = history.history['recall']
val_recall = history.history['val_recall']
print(f'Checkpoint with highest recall: {np.argmax(np.array(val_recall))}')

# Retrieve a list of precision results on training and validation data sets for each training epoch
precision = history.history['precision']
val_precision = history.history['val_precision']
print(f'Checkpoint with highest precision: {np.argmax(np.array(val_precision))}')

# Get range of epochs
epochs_range = range(EPOCHS)

# Plot training and validation accuracy per epoch
plt.plot(epochs_range, acc, label='Training')
plt.plot(epochs_range, val_acc, label='Validation')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.legend()

plt.figure()

# Plot training and validation recall per epoch
plt.plot(epochs_range, recall, label='Training')
plt.plot(epochs_range, val_recall, label='Validation')
plt.title('Training and validation recall')
plt.xlabel('Epochs')
plt.legend()

plt.figure()

# Plot training and validation precision per epoch
plt.plot(epochs_range, precision, label='Training')
plt.plot(epochs_range, val_precision, label='Validation')
plt.title('Training and validation precision')
plt.xlabel('Epochs')
plt.legend()

plt.figure()

# Plot training and validation loss per epoch
plt.plot(epochs_range, loss, label='Training')
plt.plot(epochs_range, val_loss, label='Validation')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.legend()

## Save Model
Set the `best_cp` variable to the epoch that had the best metrics and save the model as an [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) file.

In [None]:
best_cp = 30
best_cp_filepath = cp_filepath.format(epoch=best_cp)
model.load_weights(best_cp_filepath)
model.save('swing_model.h5')