# Training FormationExpertCNN (Expert)
This notebook demonstrates the training of the CNN expert model for NFL play prediction. The CNN (Convolutional Neural Network) expert focuses on pre-snap team formations to predict whether the offensive team will make a pass or run.

## Imports

In [1]:
import sys
sys.path.append('../')

In [2]:
import pandas as pd
import matplotlib.pyplot as plt

In [3]:
from experts.cnn import FormationExpertCNN
from etl.dataloader import ExpertDataset

## Loading Dataset

See `etl` module for more details.

In [4]:
# Make test-train-split
from sklearn.model_selection import train_test_split


df = pd.read_csv('features/target.csv')

# Split into train (80%) and test (20%) sets
df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)

# Reset indices
df_train = df_train.reset_index(drop=True)
df_test = df_test.reset_index(drop=True)


In [5]:
# create the dataset
MLP_train_dataset = ExpertDataset(df_train, expert_name="cnn")
MLP_val_dataset = ExpertDataset(df_test, expert_name="cnn")

In [6]:
# Get features and targets from dataset
X_train, y_train = MLP_train_dataset.features, MLP_train_dataset.targets
X_val, y_val = MLP_val_dataset.features, MLP_val_dataset.targets

In [7]:
X_train.shape, y_train.shape

((10284, 158, 300, 4), (10284,))

## Training expert
See `experts` module for more details. Specifically, see `experts/cnn.py` for more details.

**Architecture**:
- Input: Grid-based representation of player positions. Image of dimensions 158x300x4
- 3 conv layers (16→32→64 channels)
- Batch norm + max pooling
- 3 FC layers (128→64→1)
- Output: binary classification (1 for pass, 0 for run)
 

In [8]:
cnn_model = FormationExpertCNN()

In [9]:
# Training parameters
n_epochs = 5
batch_size = 32
learning_rate = 0.001

# Train the model and get metrics
train_metrics = cnn_model.train(X_train=X_train, y_train=y_train, 
                                X_val=None, y_val=None,
                                num_epochs=n_epochs,
                                batch_size=batch_size, 
                                alpha=learning_rate)

Epoch [1/5], Loss: 0.7090, Accuracy: 59.52%
Epoch [2/5], Loss: 0.6702, Accuracy: 60.68%
Epoch [3/5], Loss: 0.6703, Accuracy: 60.68%
Epoch [4/5], Loss: 0.6699, Accuracy: 60.68%
Epoch [5/5], Loss: 0.6703, Accuracy: 60.68%
