In [None]:
class PositionDecoder(nn.Module):
    def __init__(self, max_targets=9, dropout=0.15):
        super(PositionDecoder, self).__init__()
        self.max_targets = max_targets
        # attention pooling, learn the importances of diffrent frames
        self.global_pool = AdditiveAttentionPooling(embed_size=64)
        # fcnn in funnel shape, prob need to adjust this when actually training
        self.multi_targ_decoders = nn.ModuleList([
            nn.Sequential(
                nn.Linear(64, 256),
                nn.BatchNorm1d(256),
                nn.GELU(), # modern choice of non linearity
                nn.Dropout(dropout * 1.5),

                nn.Linear(256, 128),
                nn.BatchNorm1d(128),
                nn.GELU(),
                nn.Dropout(dropout),

                nn.Linear(128, 64),
                nn.BatchNorm1d(64),
                nn.GELU(),
                nn.Dropout(dropout * .5),

                nn.Linear(64, 2)
            ) for _ in range(max_targets)
        ])
    def forward(self, encoded_sequence, target_mask):
        # pooling across all steps in input sequence
        pooled, _ = self.global_pool(encoded_sequence) 
        # predict final position
        predictions = []
        # predict for every target
        for i, decoder in enumerate(self.multi_targ_decoders):
            pred = decoder(pooled)  
            predictions.append(pred)
        
       # predicted outputs
        predictions = torch.stack(predictions, dim=1)
        
        # mask null targets
        target_mask = target_mask.unsqueeze(-1)  
        predictions = predictions * target_mask
        
        return predictions
        

In [None]:
class DJmoorePOS(nn.Module):
    def __init__(self, embed_size, num_layers, dropout, mask, dev='cuda'):
        super().__init__()
        # downsample cnn
        self.cnn = CNN_DownSample()
        # cnn output size
        cnn_output_size = 64 *14 * 31
        # transformer encoder module
        self.encoder = TransEncoder(
            input_dim=cnn_output_size,
            embed_size=embed_size,
            num_layers=num_layers,
            device=dev,
            mask=mask,
            dropout=dropout,
            max_length=100
        )
        # decoder, predict output
        self.decoder = PositionDecoder(dropout=dropout)
        
    def forward(self, heatmap_sequence, target_mask):
        # batch size, seq
        batch_size, seq_len = heatmap_sequence.shape[:2]
        # downsample each heatmap through the cnn
        cnn_features = []
        for t in range(seq_len):
            # select all rows and the t col
            frame = heatmap_sequence[:, t]
            # extract features from cnn
            features = self.cnn(frame)
            # keep 0dim, flatten all other features (mult) together
            features = features.flatten(1)
            # save extracted features
            cnn_features.append(features)

        # stack extracted features into a sequence
        sequence_features = torch.stack(cnn_features, dim=1)
        # attention mask, not needed for now because we are 
        # only prediciting the one frame
        mask = torch.ones(batch_size, seq_len, device=heatmap_sequence.device)
        # encoding sequence
        encoder = self.encoder(sequence_features, mask)
        # predict positon
        position = self.decoder(encoder, target_mask)
        # return position
        return position


In [None]:
def train_sequences(df_grids, max_seq, max_targets):
    sequences = []
    targets = []
    masks = []
    # get all offensive players
    player_to_predict = df[df['player_to_predict'] == True].groupby('play_id_n')['nfl_id'].unique()
    
    # get players to predict postions
    for play_id in df_grids['play_id_n'].unique():
        play_data = df_grids[df_grids['play_id_n'] == play_id].sort_values('frame_id')

    # loop through every play
    for play_id in df_grids['play_id_n'].unique():

        play_data = df_grids[df_grids['play_id_n'] == play_id].sort_values('frame_id')
        # padding with sequence length less than max
        if len(play_data) < max_seq:
            # get players to predict ids
            players = player_to_predict[play_id]

            # only take max targets
            players = players[:max_targets]
            num_receivers = len(players)
            
            # mask of players
            mask = torch.zeros(max_targets)
            mask[:num_receivers] = 1
            
            grids = [torch.from_numpy(grid).float() for grid in play_data['grid']]
            sequence = torch.stack(grids, dim=0)
            
            noise = torch.randn_like(sequence) * 0.0001
            sequence = sequence + noise
            
            # out of dist value
            padding_needed = max_seq - len(play_data)
            if padding_needed > 0:
                num_channels = 2 + max_targets + 1
                padding = torch.full((padding_needed, num_channels, 55, 121), -1.0)
                sequence = torch.cat([sequence, padding], dim=0)
        
            # get target frame for every player
            target_frame = play_data.iloc[-1]['frame_id']
            target_positions = torch.zeros(max_targets, 2)
            frame_data = df[(df['play_id_n'] == play_id) & (df['frame_id'] == target_frame)]

            valid_targets = 0
            for i, receiver_id in enumerate(players):
                receiver_data = frame_data[frame_data['nfl_id'] == receiver_id]
                if not receiver_data.empty:
                    x = float(receiver_data['x'].iloc[0]) / 120
                    y = float(receiver_data['y'].iloc[0]) / 53.3
                    target_positions[i] = torch.tensor([x, y])
                    valid_targets += 1
            
            if valid_targets > 0:
                sequences.append(sequence)
                targets.append(target_positions)
                masks.append(mask)


    if len(sequences) == 0:
        return None, None, None
    
    return (torch.stack(sequences, dim=0), 
            torch.stack(targets, dim=0), 
            torch.stack(masks, dim=0))

In [None]:
# training loop
def train(model, train_loader, val_loader, loss_func, optimizer, scheduler, epochs):
    # losses
    train_losses = []
    # early stopping
    es = np.inf
    es_count = 0
    # training rounds
    for epoch in range(epochs):
        # set to traing mode
        model.train()
        # loss and batch count
        epoch_loss = 0
        batches = 0
        # load training sequences and targets
        for batch_sequence, batch_targets, batch_masks in train_loader:
            batch_sequence = batch_sequence.to('cuda')
            batch_targets = batch_targets.to('cuda')
            batch_masks = batch_masks.to('cuda')
            # zero gradient
            optimizer.zero_grad()
            # forward pass
            preds = model(batch_sequence, batch_masks) # predicted value
            loss = loss_func(preds, batch_targets, batch_masks) # loss
            # backprop time (as an aside its so cool it works so effectively being a realtively simple concept)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) # prevent exploding gradients
            optimizer.step()

            # batch loss calcs
            epoch_loss += loss.item() * batch_sequence.size(0)
            batches += batch_sequence.size(0)

        # epoch loss on train set
        avg_loss = epoch_loss / batches

        val_loss = 0
        val_batches = 0

        # val loss
        model.eval() # set to eval mode
        with torch.no_grad(): 
            for val_seq, val_targ, val_mask in val_loader:
                # move to cuda
                val_seq = val_seq.to('cuda')
                val_targ = val_targ.to('cuda')
                val_mask = val_mask.to('cuda')
                # val_x, val_y
                val_preds = model(val_seq, val_mask)
                # val loss
                val_seq_loss = loss_func(val_preds, val_targ)
                val_loss += val_seq_loss.item() * val_seq.size(0) 
                val_batches += val_seq.size(0)
        # loss 
        val_loss = val_loss / val_batches

        # early stopping
        if val_loss < es:
            es = val_loss
            es_count = 0
        else:
            es_count += 1

        # save train loss, update LR 
        train_losses.append(avg_loss)
        scheduler.step(val_loss)

        if epoch % 5 == 0:
            print(f'epoch: {epoch} train_loss: {avg_loss} val_loss:{val_loss}')
        
        if es_count >= 25:
            break

    return train_losses

# predict final positon
def predict(model, sequence):
    # torch eval mode
    model.eval()
    # do not compute gradients
    with torch.no_grad():
        # adds dim if batch dim is not present
        if sequence.dim() == 4: 
            sequence = sequence.unsqueeze(0)
        # move to gpu
        sequence = sequence.to('cuda')
        # prediction
        prediction = model(sequence)
        x, y = prediction[0].cpu().numpy() # move to cpu, convert from tesnor to numpy
        
    return x, y

In [None]:
class VarianceRegularizedLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=0.1):
        super().__init__()
        self.alpha = alpha  
        self.beta = beta    
        
    def forward(self, predictions, targets):
        # mse loss
        mse_loss = F.mse_loss(predictions, targets)
        
        # match the varience between predection and target
        pred_var = torch.var(predictions, dim=0)
        target_var = torch.var(targets, dim=0)
        
        # penalize low varience in results 
        variance_penalty = F.mse_loss(pred_var, target_var)
        
        return self.alpha * mse_loss + self.beta * variance_penalty

class MaskedMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, predictions, targets, mask):
        # mask
        mask = mask.unsqueeze(-1)
        # mse
        mse = (predictions - targets) **2
        masked_mse = mse * mask
        # dont calculate 
        num_valid = mask.sum()
        if num_valid > 0:
            return masked_mse.sum() / num_valid
        else:
            return torch.tensor(0.0, device=predictions.device)

In [None]:
model = DJmoorePOS(embed_size=64, num_layers=8, mask=None, dropout=0.15, dev="cuda")
# move model to cuda
model = model.to("cuda")
# standard loss function
loss_fun = MaskedMSELoss()
# adamW > adam, bc of the proper application of weight decay
opti = torch.optim.AdamW(model.parameters(), lr = 0.0001, weight_decay=0.001)
# LR scheduler, 
lr_schedule =  get_cosine_schedule_with_warmup(opti, num_warmup_steps=20, num_training_steps=500)

In [None]:
seq, targ, mask = train_sequences(df_grids, max_seq=81, max_targets=9) 

X_train, X_test, y_train, y_test, mask_train, mask_test = train_test_split(seq, targ, mask, test_size=0.3, random_state=26, shuffle=True)
X_test, X_val, y_test, y_val, mask_test, mask_val = train_test_split(X_test, y_test, mask_test, test_size=0.5, random_state=26, shuffle=True)

train_dataset = TensorDataset(X_train, y_train, mask_train)
test_dataset = TensorDataset(X_test, y_test, mask_test)
val_dataset = TensorDataset(X_val, y_val, mask_val)

train_load = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

train(model=model, train_loader=train_load, val_loader = val_loader, 
      loss_func=loss_fun, optimizer=opti,scheduler=lr_schedule, epochs=500)

In [None]:
def eval_loop(model, test_loader, loss_fun):
    model.eval()
    total_loss = 0
    total_sample = 0
    all_pred = []
    all_target = []
    
    with torch.no_grad():
        for batch_seq, batch_targ in test_loader:
            batch_seq = batch_seq.to('cuda')
            batch_targ = batch_targ.to('cuda')
            batch_masks = batch_masks.to('cuda')

            preds = model(batch_seq)
            loss = loss_fun(preds, batch_targ)

            total_loss += loss.item() * batch_seq.size(0) 
            total_sample += batch_seq.size(0)
            
            all_pred.append(preds.cpu())
            all_target.append(batch_targ.cpu())

    avg_test_loss = total_loss / total_sample
    all_pred = torch.cat(all_pred, dim = 0)
    all_target = torch.cat(all_target, dim=0)

    return avg_test_loss, all_pred, all_target

loss, pred, target = eval_loop(model=model, test_loader=test_loader, loss_fun=loss_fun)
print(loss)