In [18]:
import os
import numpy as np
import pandas as pd
import polars as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.nn.functional as F  # Import F module
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# Load the data using polars
directory = r'D:\github\Cricket-prediction\data\4_filteredData'
balltoball = pl.read_csv(os.path.join(directory, 'balltoball.csv'))
teamStats = pl.read_csv(os.path.join(directory, 'team12Stats.csv'))
playersStats = pl.read_csv(os.path.join(directory, 'playersStats.csv'))

# Preprocess the data
def partition_data(df, group_keys):
    partitions = df.partition_by(group_keys)
    partition_list = [partition.drop(group_keys).to_numpy() for partition in partitions]
    return partition_list

team_stats_partitions = partition_data(teamStats, ['match_id', 'flip'])
player_stats_partitions = partition_data(playersStats, ['match_id', 'flip'])
ball_stats_partitions = partition_data(balltoball, ['match_id', 'flip'])

In [19]:

# Augment the data by creating new samples with different combinations of overs
def augment_data(team_stats_list, player_stats_list, ball_stats_list, over_segments=np.arange(7, 40)): 
    augmented_team_stats = []
    augmented_player_stats = []
    augmented_ball_stats = []
    
    for team_stats, player_stats, ball_stats in zip(team_stats_list, player_stats_list, ball_stats_list):
        total_overs = ball_stats.shape[0] // 6  # Assuming 6 balls per over
        for segment in over_segments:
            if total_overs >= segment:
                end_idx = segment * 6
                augmented_team_stats.append(team_stats)
                augmented_player_stats.append(player_stats)
                augmented_ball_stats.append(ball_stats[:end_idx])
    
    return augmented_team_stats, augmented_player_stats, augmented_ball_stats

augmented_team_stats, augmented_player_stats, augmented_ball_stats = augment_data(
    team_stats_partitions, player_stats_partitions, ball_stats_partitions)

# Split the dataset into training and validation sets
train_team_stats, val_team_stats, train_player_stats, val_player_stats, train_ball_stats, val_ball_stats = train_test_split(
    augmented_team_stats, augmented_player_stats, augmented_ball_stats, test_size=0.25, random_state=42)
# Split the dataset into validation and test sets
val_team_stats, test_team_stats, val_player_stats, test_player_stats, val_ball_stats, test_ball_stats = train_test_split(
    val_team_stats, val_player_stats, val_ball_stats, test_size=0.5, random_state=42)

# Create a custom Dataset
class CricketDataset(Dataset):
    def __init__(self, team_stats_list, player_stats_list, ball_stats_list):
        self.team_stats_list = team_stats_list
        self.player_stats_list = player_stats_list
        self.ball_stats_list = ball_stats_list

    def __len__(self):
        return len(self.team_stats_list)

    def __getitem__(self, idx):
        team_input = torch.tensor(self.team_stats_list[idx], dtype=torch.float32)
        team_input = team_input.squeeze()  # Remove extra dimensions
        player_input = torch.tensor(self.player_stats_list[idx], dtype=torch.float32)
        ball_stats = torch.tensor(self.ball_stats_list[idx], dtype=torch.float32)
        # Assuming the last column is the label
        ball_input = ball_stats[:, :-1]
        label = ball_stats[0, -1]
        return team_input, player_input, ball_input, label

# Define a collate function to handle variable-length sequences
def collate_fn(batch):
    team_inputs = []
    player_inputs = []
    ball_inputs = []
    labels = []
    ball_lengths = []

    for team_input, player_input, ball_input, label in batch:
        team_inputs.append(team_input)
        player_inputs.append(player_input)
        ball_inputs.append(ball_input)
        labels.append(label)
        ball_lengths.append(ball_input.shape[0])

    # Pad ball_inputs to the maximum sequence length in the batch
    max_seq_len = max(ball_lengths)
    padded_ball_inputs = torch.zeros(len(ball_inputs), max_seq_len, ball_inputs[0].shape[1])
    for i, ball_input in enumerate(ball_inputs):
        seq_len = ball_input.shape[0]
        padded_ball_inputs[i, :seq_len, :] = ball_input

    team_inputs = torch.stack(team_inputs)
    player_inputs = torch.stack(player_inputs)
    labels = torch.tensor(labels, dtype=torch.float32)
    return team_inputs, player_inputs, padded_ball_inputs, labels, ball_lengths

# Create the training and validation datasets and dataloaders
train_dataset = CricketDataset(train_team_stats, train_player_stats, train_ball_stats)
val_dataset = CricketDataset(val_team_stats, val_player_stats, val_ball_stats)
test_dataset = CricketDataset(test_team_stats, test_player_stats, test_ball_stats)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [20]:
i = 0
for team_inputs, player_inputs, ball_inputs, labels, ball_lengths in train_dataloader:
    print(team_inputs.shape, player_inputs.shape, ball_inputs.shape, labels.shape)
    i += 1
    if i >= 5:
        break

torch.Size([32, 15]) torch.Size([32, 22, 13]) torch.Size([32, 228, 4]) torch.Size([32])
torch.Size([32, 15]) torch.Size([32, 22, 13]) torch.Size([32, 222, 4]) torch.Size([32])
torch.Size([32, 15]) torch.Size([32, 22, 13]) torch.Size([32, 228, 4]) torch.Size([32])
torch.Size([32, 15]) torch.Size([32, 22, 13]) torch.Size([32, 234, 4]) torch.Size([32])
torch.Size([32, 15]) torch.Size([32, 22, 13]) torch.Size([32, 234, 4]) torch.Size([32])


In [21]:
# Save dataloaders
import pickle

directory = r'D:\github\Cricket-prediction\data\5_pytorchData'
# Save dataloaders
with open(os.path.join(directory, 'train_dataloader.pkl'), 'wb') as f:
    pickle.dump(train_dataloader, f)
with open(os.path.join(directory, 'val_dataloader.pkl'), 'wb') as f:
    pickle.dump(val_dataloader, f)
with open(os.path.join(directory, 'test_dataloader.pkl'), 'wb') as f:
    pickle.dump(test_dataloader, f)

In [22]:
balltoball

match_id,flip,innings,ball,curr_score,curr_wickets,won
i64,i64,i64,f64,i64,i64,i64
211028,0,1,0.1,0,0,1
211028,0,1,0.2,1,0,1
211028,0,1,0.3,1,0,1
211028,0,1,0.4,1,0,1
211028,0,1,0.5,1,0,1
…,…,…,…,…,…,…
1450765,1,2,19.3,100,9,0
1450765,1,2,19.4,100,9,0
1450765,1,2,19.5,104,9,0
1450765,1,2,19.6,106,9,0


In [23]:
teamStats

match_id,flip,gender,Cumulative Won team1,Cumulative Lost team1,Cumulative Tied team1,Cumulative NR team1,Cumulative W/L team1,Cumulative AveRPW team1,Cumulative AveRPO team1,Cumulative Won team2,Cumulative Lost team2,Cumulative Tied team2,Cumulative NR team2,Cumulative W/L team2,Cumulative AveRPW team2,Cumulative AveRPO team2
i64,i64,i64,i64,i64,i64,i64,f64,f64,f64,i64,i64,i64,i64,f64,f64,f64
211028,0,0,0,0,0,0,0.0,0.0,0.0,1,0,0,0,0.0,42.8,0.0
211028,1,0,1,0,0,0,0.0,42.8,0.0,0,0,0,0,0.0,0.0,0.0
211048,0,0,0,0,0,0,0.0,0.0,0.0,0,0,0,0,0.0,0.0,0.0
211048,1,0,0,0,0,0,0.0,0.0,0.0,0,0,0,0,0.0,0.0,0.0
225263,0,0,1,0,0,0,0.0,22.37,0.0,0,0,0,0,0.0,0.0,0.0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1450753,1,0,36,24,0,3,1.5,19.59,7.49,0,8,0,2,0.0,0.0,3.98
1450759,0,0,0,8,0,2,0.0,0.0,3.98,18,10,0,3,1.8,24.87,6.69
1450759,1,0,18,10,0,3,1.8,24.87,6.69,0,8,0,2,0.0,0.0,3.98
1450765,0,0,36,24,0,3,1.5,19.59,7.49,18,10,0,3,1.8,24.87,6.69


In [24]:
playersStats

match_id,flip,Cum Mat Total,Cum Inns Total,Cum Runs Total,Cum Batting Ave,Cum SR,Cumulative Overs,Cumulative Runs,Cumulative Wkts,Cumulative Econ,Cumulative Dis,Cumulative Ct,Cumulative St,Cumulative D/I
i64,i64,i64,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
211028,0,1,1,1,1.0,33.33,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
211028,0,1,1,31,31.0,206.66,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0
211028,0,1,1,3,3.0,60.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
211028,0,0,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
211028,0,1,1,98,98.0,178.18,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1450765,1,22,20,216,16.36,107.44,56.799999,384.0,19.0,6.7,5.0,5.0,0.0,0.24
1450765,1,0,0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1450765,1,62,39,294,15.42,104.92,157.499999,1109.0,62.0,6.97,21.0,21.0,0.0,0.35
1450765,1,27,18,94,7.92,145.24,75.0,485.0,23.0,6.46,2.0,2.0,0.0,0.07
