In [16]:
pip install torch torch-geometric wandb pandas





In [39]:
import os
import torch
import torch.nn.functional as F
import glob
import numpy as np
import pandas as pd
import wandb
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv



In [18]:
# Initialize Weights & Biases
wandb.init(project="rocket-league-gnn")





In [49]:
class RocketLeagueDataset(Dataset):
    def __init__(self, root_dir):
        super().__init__()
        self.file_paths = glob.glob(os.path.join(root_dir, "**", "*.csv"), recursive=True)
    
    def len(self):
        return len(self.file_paths)
    
    def get(self, idx):
        df = pd.read_csv(self.file_paths[idx])
        data_list = []
        
        for _, row in df.iterrows():
            # Extract player features (nodes)
            player_features = []
            for i in range(6):  # 6 players
                start_idx = i * 9
                player_features.append(row[start_idx:start_idx + 9].values)
            x = torch.tensor(player_features, dtype=torch.float)
            
            # Compute edges (player distances)
            edge_index, edge_attr = self.compute_edges(x[:, :3])
            
            # State vector (global feature)
            state_vector = torch.tensor([row['ball_x'], row['ball_y'], row['ball_z'], row['ball_dist_net']], dtype=torch.float)
            
            # Labels
            y = torch.tensor([row['blue_score'], row['orange_score']], dtype=torch.float)
            
            data_list.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, state=state_vector))
        
        return data_list
    
    def compute_edges(self, positions):
        num_players = positions.shape[0]
        edge_index = []
        edge_attr = []
        
        for i in range(num_players):
            for j in range(i + 1, num_players):
                edge_index.append([i, j])
                edge_index.append([j, i])
                distance = np.linalg.norm(positions[i] - positions[j])
                edge_attr.append([distance])
                edge_attr.append([distance])
                
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
        return edge_index, edge_attr

In [51]:
# Define GCN model
class RocketLeagueGCN(torch.nn.Module):
    def __init__(self, input_dim, state_dim, hidden_dim, output_dim):
        super().__init__()
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = torch.nn.Linear(hidden_dim + state_dim, output_dim)

    def forward(self, data):
        x, edge_index, state = data.x, data.edge_index, data.state
        x = F.relu(self.gcn1(x, edge_index))
        x = F.relu(self.gcn2(x, edge_index))
        x = torch.cat([x.mean(dim=0), state], dim=-1)  # Aggregate graph + state vector
        return torch.sigmoid(self.fc(x))  # Output probability

In [53]:
# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_root = "E:\\RL Esports Replays"
dataset = RocketLeagueDataset(dataset_root)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

model = RocketLeagueGCN(input_dim=6, state_dim=3, hidden_dim=32, output_dim=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.BCELoss()

In [55]:
# Training loop
epochs = 50
for epoch in range(epochs):
    model.train()
    total_loss = 0

    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = loss_fn(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    wandb.log({"epoch": epoch, "loss": total_loss / len(train_loader)})
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")

KeyError: 'boost'

In [None]:
# Save model
torch.save(model.state_dict(), "rocket_league_gcn.pth")
wandb.save("rocket_league_gcn.pth")