In [36]:
import os
import numpy as np
import pandas as pd
import polars as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.utils.class_weight import compute_class_weight
import logging
import yaml

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Load configuration
def load_config(config_path='config.yaml'):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    # Ensure numerical values are correctly interpreted
    config['batch_size'] = int(config['batch_size'])
    config['learning_rate'] = float(config['learning_rate'])
    config['weight_decay'] = float(config['weight_decay'])
    config['num_epochs'] = int(config['num_epochs'])
    config['patience'] = int(config['patience'])
    return config

# Check if GPU is available
def get_device():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device: {device}')
    return device

# Load the data using polars
def load_data(directory):
    teamStats = pl.read_csv(os.path.join(directory, 'team12Stats.csv'))
    playersStats = pl.read_csv(os.path.join(directory, 'playersStats.csv'))
    balltoball = pl.read_csv(os.path.join(directory, 'balltoball.csv'))
    return teamStats, playersStats, balltoball

# Preprocess the data
def partition_data(df, group_keys):
    partitions = df.partition_by(group_keys)
    partition_list = [partition.drop(group_keys).to_numpy() for partition in partitions]
    return partition_list

# Data augmentation function
def augment_data(team_stats_list, player_stats_list, ball_stats_list, over_segments=np.arange(7, 40,5)):
    augmented_team_stats = []
    augmented_player_stats = []
    augmented_ball_stats = []
    augmented_labels = []
    
    for team_stats, player_stats, ball_stats in zip(team_stats_list, player_stats_list, ball_stats_list):
        total_overs = ball_stats.shape[0] // 6  # Assuming 6 balls per over
        
        for segment in over_segments:
            if segment <= total_overs:
                end_idx = segment * 6
                truncated_ball_stats = ball_stats[:end_idx]
                augmented_team_stats.append(team_stats)
                augmented_player_stats.append(player_stats)
                augmented_ball_stats.append(truncated_ball_stats)
                label = ball_stats[0, -1]
                augmented_labels.append(label)
    
    return augmented_team_stats, augmented_player_stats, augmented_ball_stats, augmented_labels

# Create a custom Dataset
class CricketDataset(Dataset):
    def __init__(self, team_stats_list, player_stats_list, ball_stats_list, labels):
        self.team_stats_list = team_stats_list
        self.player_stats_list = player_stats_list
        self.ball_stats_list = ball_stats_list
        self.labels = labels

    def __len__(self):
        return len(self.team_stats_list)

    def __getitem__(self, idx):
        team_input = torch.tensor(self.team_stats_list[idx], dtype=torch.float32)
        team_input = team_input.squeeze()
        player_input = torch.tensor(self.player_stats_list[idx], dtype=torch.float32)
        ball_stats = torch.tensor(self.ball_stats_list[idx], dtype=torch.float32)
        ball_input = ball_stats
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return team_input, player_input, ball_input, label

# Define a collate function to handle variable-length sequences
def collate_fn(batch):
    team_inputs = []
    player_inputs = []
    ball_inputs = []
    labels = []
    ball_lengths = []

    for team_input, player_input, ball_input, label in batch:
        team_inputs.append(team_input)
        player_inputs.append(player_input)
        ball_inputs.append(ball_input)
        labels.append(label)
        ball_lengths.append(ball_input.shape[0])

    max_seq_len = max(ball_lengths)
    padded_ball_inputs = torch.zeros(len(ball_inputs), max_seq_len, ball_inputs[0].shape[1])
    for i, ball_input in enumerate(ball_inputs):
        seq_len = ball_input.shape[0]
        padded_ball_inputs[i, :seq_len, :] = ball_input

    team_inputs = torch.stack(team_inputs)
    player_inputs = torch.stack(player_inputs)
    labels = torch.tensor(labels, dtype=torch.float32)
    return team_inputs, player_inputs, padded_ball_inputs, labels, ball_lengths

# Define the models
class TeamStatsModel(nn.Module):
    def __init__(self, input_size):
        super(TeamStatsModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.5),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Dropout(0.5),
            nn.Linear(32, 16),
            nn.ReLU()
        )

    def forward(self, x):
        return self.model(x)

class PlayerStatsModel(nn.Module):
    def __init__(self, input_size, seq_len):
        super(PlayerStatsModel, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_size, out_channels=32, kernel_size=3)
        self.bn1 = nn.BatchNorm1d(32)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(64)
        self.pool2 = nn.MaxPool1d(2)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64 * ((seq_len - 4) // 4), 16)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = self.flatten(x)
        x = F.relu(self.fc(x))
        return x

class BallToBallModel(nn.Module):
    def __init__(self, input_dim):
        super(BallToBallModel, self).__init__()
        self.lstm = nn.LSTM(input_dim, 128, num_layers=2, batch_first=True, bidirectional=False)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(128, 16)

    def forward(self, x, lengths):
        x_packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        output_packed, (hn, cn) = self.lstm(x_packed)
        hn = hn[-1,:,:]
        x = self.dropout(hn)
        x = F.relu(self.fc(x))
        return x

class CombinedModel(nn.Module):
    def __init__(self, team_input_size, player_input_size, player_seq_len, ball_input_dim):
        super(CombinedModel, self).__init__()
        self.team_model = TeamStatsModel(team_input_size)
        self.player_model = PlayerStatsModel(player_input_size, player_seq_len)
        self.ball_model = BallToBallModel(ball_input_dim)
        self.fc = nn.Sequential(
            nn.Linear(16+16+16, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(32, 1)
        )

    def forward(self, team_input, player_input, ball_input, ball_lengths):
        team_output = self.team_model(team_input)
        player_output = self.player_model(player_input)
        ball_output = self.ball_model(ball_input, ball_lengths)
        combined = torch.cat((team_output, player_output, ball_output), dim=1)
        output = self.fc(combined)
        return output.squeeze()

# Function to print a sample of the training and validation data
def print_data_sample(train_dataloader, val_dataloader):
    print("Sample from training data:")
    for i, (team_input, player_input, ball_input, labels, ball_lengths) in enumerate(train_dataloader):
        if i >= 1:  # Print only the first batch
            break
        print(f"Team input: {team_input[0]}")
        print(f"Player input: {player_input[0]}")
        print(f"Ball input: {ball_input[0]}")
        print(f"Label: {labels[0]}")
        print(f"Ball lengths: {ball_lengths[0]}")
    
    print("\nSample from validation data:")
    for i, (team_input, player_input, ball_input, labels, ball_lengths) in enumerate(val_dataloader):
        if i >= 1:  # Print only the first batch
            break
        print(f"Team input: {team_input[0]}")
        print(f"Player input: {player_input[0]}")
        print(f"Ball input: {ball_input[0]}")
        print(f"Label: {labels[0]}")
        print(f"Ball lengths: {ball_lengths[0]}")

In [37]:
config = load_config()
device = get_device()

# Load data
teamStats, playersStats, balltoball = load_data(config['data_directory'])

# Preprocess data
team_stats_partitions = partition_data(teamStats, ['match_id', 'flip'])
player_stats_partitions = partition_data(playersStats, ['match_id', 'flip'])
ball_stats_partitions = partition_data(balltoball, ['match_id', 'flip'])

# Apply data augmentation
augmented_team_stats, augmented_player_stats, augmented_ball_stats, augmented_labels = augment_data(
    team_stats_partitions, player_stats_partitions, ball_stats_partitions)

# Convert labels to integers
augmented_labels = [int(label) for label in augmented_labels]

2024-11-08 09:32:41,172 - INFO - Using device: cuda


In [38]:
print_data_sample(train_dataloader, val_dataloader)

Sample from training data:
Team input: tensor([ 0.0000, 70.0000, 88.0000,  3.0000, 10.0000,  0.8000, 23.7100,  7.9100,
        91.0000, 63.0000,  1.0000,  2.0000,  1.4400, 28.0300,  8.2300])
Player input: tensor([[3.6000e+01, 3.2000e+01, 6.5100e+02, 2.4200e+01, 1.4130e+02, 1.0000e+00,
         1.4000e+01, 0.0000e+00, 1.4000e+01, 2.7000e+01, 2.4000e+01, 3.0000e+00,
         7.5000e-01],
        [6.9000e+01, 6.8000e+01, 1.8940e+03, 3.2070e+01, 1.2744e+02, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 7.7000e+01, 6.2000e+01, 1.5000e+01,
         1.1200e+00],
        [4.5000e+01, 1.6000e+01, 1.3900e+02, 2.2000e+01, 1.1197e+02, 1.4330e+02,
         1.1910e+03, 5.0000e+01, 8.4700e+00, 6.0000e+00, 6.0000e+00, 0.0000e+00,
         1.3000e-01],
        [2.3000e+01, 2.1000e+01, 7.2200e+02, 4.4660e+01, 1.4963e+02, 1.9000e+01,
         1.3900e+02, 6.0000e+00, 7.1800e+00, 2.1000e+01, 2.1000e+01, 0.0000e+00,
         9.1000e-01],
        [1.3000e+01, 6.0000e+00, 4.0000e+01, 1.1170e+01, 8.