In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
import ast
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Model, GPT2Config

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = np.load('tensor.npy')

In [12]:
games = []
current_game = []
current_quarter = []

for i, row in enumerate(data):
    if np.array_equal(row, [0, 1, 0, 0, 0]):
        current_quarter.append(row)
        current_game.append(np.array(current_quarter))
        current_quarter = []
    if np.all(row == 0):
        games.append(current_game)
        current_game = []
    else:
        current_quarter.append(row)

In [18]:
flattened_quarters = []
TARGET_LENGTH = 705

for game in games:
    for quarter in game:
        current_quarter = np.array(quarter).flatten()

        if len(current_quarter) < TARGET_LENGTH:
            padded = np.pad(current_quarter, (0, TARGET_LENGTH - len(current_quarter)), mode='constant', constant_values=0)
            
        flattened_quarters.append(padded)

In [34]:
class QuarterDataset(Dataset):
    def __init__(self, flattened_quarters, pad_token=[0, 0, 0, 0, 0]):
        self.flattened_quarters = flattened_quarters
        self.pad_token = pad_token
    
    def __len__(self):
        return len(self.flattened_quarters)
    
    def __getitem__(self, idx):
        sequence = torch.tensor(self.flattened_quarters[idx], dtype=torch.long)
        plays = sequence.view(-1, 5)
        targets = plays.roll(-1, dims=0)
        targets[-1] = torch.tensor(self.pad_token)
        attention_mask = ~(plays == torch.tensor(self.pad_token)).all(dim=1).long()
        return plays, targets, attention_mask

In [21]:
class TokenEmbeddings(nn.Module):
    def __init__(self, vocab_sizes, embedding_dims, embedding_output_dim):
        super(TokenEmbeddings, self).__init__()

        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_sizes[0], embedding_dims[0]),
            nn.Embedding(vocab_sizes[1], embedding_dims[1]),
            nn.Embedding(vocab_sizes[2], embedding_dims[2]),
            nn.Embedding(vocab_sizes[3], embedding_dims[3]),
        ])
        self.embeddings.append(self.embeddings[2])

        self.projection = nn.Linear(sum(embedding_dims), embedding_output_dim)
    
    def forward(self, tokens):
        # tokens: (batch_size, sequence_length, 5)
        embedded = [
            self.embeddings[i](tokens[..., i]) for i in range(len(self.embeddings))
        ]

        concat_embeddings = torch.cat(embedded, dim=-1) # (batch_size, sequence_length, sum(embedding_dims))
        projected_embedding = self.projection(concat_embeddings) # (batch_size, sequence_length, embedding_output_dim)
        return projected_embedding

In [31]:
class ComponentSubModelFF(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, context_embedding, local_context):
        if local_context is None or local_context.size(1) == 0:
            combined_input = context_embedding
        else:
            combined_input = torch.cat([context_embedding, local_context], dim=-1)
        return self.layers(combined_input)

In [85]:
shot_components = [2, 3]
shots = [1, 2, 3, 4]
rebounds = [5, 6]

class PlaySubModel(torch.nn.Module):
    def __init__(self, context_embedding_dim, local_embedding_dims, hidden_dims, output_dims, token_embeddings):
        super().__init__()
        self.token_embeddings = token_embeddings
        
        self.submodels = torch.nn.ModuleList([
            ComponentSubModelFF(context_embedding_dim + sum(local_embedding_dims[:i]), hidden_dim, output_dim) for i, (hidden_dim, output_dim) in enumerate(zip(hidden_dims, output_dims))
        ])
    
    def forward(self, context_embedding):
        batch_size = context_embedding.size(0)
        play_context = torch.zeros(batch_size, 0, device=context_embedding.device)
        generated_play = []

        for component_idx, submodel in enumerate(self.submodels):
            logits = submodel(context_embedding, play_context)  # Filter the batch
            probs = torch.softmax(logits, dim=-1)
            generated_token = torch.multinomial(probs, num_samples=1).squeeze(-1)

            token_embedding = self.token_embeddings.embeddings[component_idx](generated_token)
            play_context = torch.cat([play_context, token_embedding], dim=-1)

            generated_play.append(generated_token)
        
        return torch.stack(generated_play, dim=1)

In [193]:
class PlayAttentionTransformer(nn.Module):
    def __init__(self, play_embedding_dim, transformer_hidden_dim, num_layers):
        super().__init__()

        self.config = GPT2Config(
            n_embd=transformer_hidden_dim,
            n_layer=num_layers,
            n_head=8,
            n_positions=141
        )

        self.transformer = GPT2Model(self.config)
        self.context_projection = nn.Linear(transformer_hidden_dim, play_embedding_dim)
    
    def forward(self, play_embeddings, attention_mask=None):
        print(f"play_embeddings shape: {play_embeddings.shape}")  # Expected: [batch_size, seq_length, embed_dim]
        print(f"attention_mask shape: {attention_mask.shape}")    # Expected: [batch_size, seq_length]
        transformer_outputs = self.transformer(
            inputs_embeds=play_embeddings, attention_mask=attention_mask
        )

        hidden_states = transformer_outputs.last_hidden_state # (batch_size, num_plays, transformer_hidden_dim)
        context_embedding = self.context_projection(hidden_states) # (batch_size, play_embedding_dim)
        return context_embedding


In [194]:
class NBAAutoregressiveModel(nn.Module):
    def __init__(self, vocab_sizes, embedding_dims, play_embedding_dim, transformer_hidden_dim, num_layers, hidden_dims):
        super().__init__()
        self.token_embeddings = TokenEmbeddings(vocab_sizes, embedding_dims, play_embedding_dim)
        self.play_transformer = PlayAttentionTransformer(play_embedding_dim, transformer_hidden_dim, num_layers)
        self.play_submodel = PlaySubModel(play_embedding_dim, embedding_dims, hidden_dims, vocab_sizes, self.token_embeddings)
    
    def forward(self, tokens, attention_mask=None):
        play_embeddings = self.token_embeddings(tokens) # (batch_size, num_plays, play_embedding_dim)
        context_embedding = self.play_transformer(play_embeddings, attention_mask) # (batch_size, play_embedding_dim)
        next_play = self.play_submodel(context_embedding) # (batch_size, 5)
        return next_play


In [195]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [196]:
vocab_sizes = [14, 2194, 9, 3, 2194]
embedding_dims = [8, 32, 4, 1, 32]
play_embedding_dim = 64
transformer_hidden_dim = 64
num_layers = 4
hidden_dims = [32, 32, 16, 16, 32]

model = NBAAutoregressiveModel(vocab_sizes, embedding_dims, play_embedding_dim, transformer_hidden_dim, num_layers, hidden_dims).to(device)
optimizer = optim.Adam(model.parameters(), lr = 1e-4)
criterion = nn.CrossEntropyLoss()

In [197]:
batch_size = 16

dataset = QuarterDataset(flattened_quarters)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [198]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch_idx, (plays, targets, attention_mask) in enumerate(dataloader):
        plays, targets, attention_mask = plays.to(device), targets.to(device), attention_mask.to(device)

        play_outputs = []
        cur_context = torch.zeros((batch_size, 5), dtype=torch.long).to(device)
        cur_attention_mask = torch.ones((plays.size(0), 5), dtype=torch.long).to(device)  # Match cur_context size

        for i in range(plays.size(1)):
            play_outputs.append(model(cur_context, cur_attention_mask))
            cur_context = torch.cat([cur_context[:, 1:], plays[:, i]], dim=1)
            cur_attention_mask = torch.cat([cur_attention_mask[:, 1:], attention_mask[:, i].unsqueeze(1)], dim=1)

        play_outputs = torch.stack(play_outputs, dim=1)

        loss = 0
        for i in range(5):
            loss += criterion(play_outputs[i].view(-1, play_outputs[i].size(-1)), targets[:, :, i].view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}")

play_embeddings shape: torch.Size([16, 64])
attention_mask shape: torch.Size([16, 5])
play_embeddings shape: torch.Size([16, 64])
attention_mask shape: torch.Size([16, 5])


RuntimeError: The size of tensor a (5) must match the size of tensor b (16) at non-singleton dimension 3