In [15]:
import os
import polars as pl


# Load the data using polars
directory = r'D:\github\Cricket-Prediction\data\filteredData'
balltoball = pl.read_csv(os.path.join(directory, 'balltoball.csv'))
teamStats = pl.read_csv(os.path.join(directory, 'team12Stats.csv'))
playersStats = pl.read_csv(os.path.join(directory, 'playersStats.csv'))

In [16]:
# Preprocess the data
def partition_data(df, group_keys):
    partitions = df.partition_by(group_keys, as_dict=True)
    partition_dict = {}
    for key, partition in partitions.items():
        match_id = partition[group_keys[0]][0]
        flip = partition[group_keys[1]][0]
        key = (match_id, flip)
        partition_dict[key] = partition.drop(group_keys).to_numpy()
    return partition_dict

ball_stats_partitions = partition_data(balltoball, ['match_id', 'flip'])
team_stats_partitions = partition_data(teamStats, ['match_id', 'flip'])
player_stats_partitions = partition_data(playersStats, ['match_id', 'flip'])

len(ball_stats_partitions.keys()), len(team_stats_partitions.keys()), len(player_stats_partitions.keys())

(2312, 2312, 2312)

In [17]:
player_stats_partitions_reshaped = {}
for key, value in player_stats_partitions.items():
    player_stats_partitions_reshaped[key] = value.reshape(1,22,22)

player_stats_partitions_reshaped[(211028,0)].shape

(1, 22, 22)

In [18]:
player_stats_partitions = player_stats_partitions_reshaped

In [19]:
from sklearn.model_selection import train_test_split
# Split match IDs into training and validation sets
match_ids = list(set([key[0] for key in team_stats_partitions.keys()]))
train_matches, val_matches = train_test_split(match_ids, test_size=0.3, random_state=42)
val_matches, test_matches = train_test_split(val_matches, test_size=0.5, random_state=42)
len(train_matches), len(val_matches), len(test_matches)

(809, 173, 174)

# Data Augumentation

In [20]:
# Function to prepare data with overs limit
def prepare_data(matches, is_test=False):
    team_stats_list = []
    player_stats_list = []
    ball_stats_list = []

    for match_id in matches:
        for flip in [0, 1]:  # Team perspectives
            key = (match_id, flip)
            if key in team_stats_partitions and key in player_stats_partitions and key in ball_stats_partitions:
                team_stats = team_stats_partitions[key]
                player_stats = player_stats_partitions[key]
                ball_stats = ball_stats_partitions[key]

                # For any match min overs should be 5
                min_overs = 6

                # Determine max overs
                if is_test:
                    max_overs = 35  # Up to 2nd innings 15 overs
                else:
                    max_overs = 40  # Up to 2nd innings 20 overs

                total_overs = ball_stats.shape[0] // 6  # Assuming 6 balls per over
                max_overs = min(total_overs, max_overs)
                start_idx = min_overs * 6
                end_idx = max_overs * 6
                
                # Append the data
                for idx in range(start_idx, end_idx):
                    ball_stats_list.append(ball_stats[:idx])
                    team_stats_list.append(team_stats)
                    player_stats_list.append(player_stats)

    return team_stats_list, player_stats_list, ball_stats_list

In [21]:
# Prepare the data
X_train_team, X_train_player, X_train_ball = prepare_data(train_matches, is_test=False)
X_val_team, X_val_player, X_val_ball = prepare_data(val_matches, is_test=False)
X_test_team, X_test_player, X_test_ball = prepare_data(test_matches, is_test=True)

len(X_train_team), len(X_train_player), len(X_train_ball), len(X_val_team), len(X_val_player), len(X_val_ball), len(X_test_team), len(X_test_player), len(X_test_ball)

(306888, 306888, 306888, 65796, 65796, 65796, 58452, 58452, 58452)

In [22]:
from torch.utils.data import Dataset
import torch

class CricketDataset(Dataset):
    def __init__(self, teamStats, playersStats, balltoball):
        self.teamStats = teamStats
        self.playersStats = playersStats
        self.balltoball = balltoball

    def __len__(self):
        return len(self.teamStats)
    
    def __getitem__(self, idx):
        team_input = torch.tensor(self.teamStats[idx], dtype=torch.float32)
        player_input = torch.tensor(self.playersStats[idx], dtype=torch.float32)
        ball_stats = torch.tensor(self.balltoball[idx], dtype=torch.float32)
        label = ball_stats[:,-1].mean()
        ball_input = ball_stats[:,:-1]
        return team_input, player_input, ball_input, label
train_dataset = CricketDataset(X_train_team, X_train_player, X_train_ball)
val_dataset = CricketDataset(X_val_team, X_val_player, X_val_ball)
test_dataset = CricketDataset(X_test_team, X_test_player, X_test_ball)

In [23]:
max_len = float('-inf')
for i in ball_stats_partitions.keys():
    if ball_stats_partitions[i].shape[0] > max_len:
        max_len = ball_stats_partitions[i].shape[0]
max_len

271

In [24]:
# collate_fn to pad the sequences
def collate_fn(batch):
    max_len = 280
    # pad balls upto max_len
    team, player, balls, labels = zip(*batch)
    padded_balls = torch.zeros(len(balls), max_len, balls[0].shape[1])
    for i, ball in enumerate(balls):
        padded_balls[i, :ball.shape[0], :] = ball
    return torch.stack(team), torch.stack(player), padded_balls, torch.tensor(labels, dtype=torch.float32)

from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, collate_fn=collate_fn)

In [25]:
for team_input, player_input, ball_input, label in train_loader:
    print(team_input.shape, player_input.shape, ball_input.shape, label.shape)
    break

torch.Size([64, 1, 23]) torch.Size([64, 1, 22, 22]) torch.Size([64, 280, 4]) torch.Size([64])


In [26]:
# Save loaders
import pickle

# Save loaders
with open(os.path.join(directory, '../pytorchData\\train_loader.pkl'), 'wb') as f:
    pickle.dump(train_loader, f)
with open(os.path.join(directory, '../pytorchData\\val_loader.pkl'), 'wb') as f:
    pickle.dump(val_loader, f)
with open(os.path.join(directory, '../pytorchData\\test_loader.pkl'), 'wb') as f:
    pickle.dump(test_loader, f)

In [27]:
i = 0
for team_input, player_input, ball_input, label in train_loader:
    i += 1
    print(team_input.shape, player_input.shape, ball_input.shape, label.shape)
    if i>5:
        break

torch.Size([64, 1, 23]) torch.Size([64, 1, 22, 22]) torch.Size([64, 280, 4]) torch.Size([64])
torch.Size([64, 1, 23]) torch.Size([64, 1, 22, 22]) torch.Size([64, 280, 4]) torch.Size([64])
torch.Size([64, 1, 23]) torch.Size([64, 1, 22, 22]) torch.Size([64, 280, 4]) torch.Size([64])
torch.Size([64, 1, 23]) torch.Size([64, 1, 22, 22]) torch.Size([64, 280, 4]) torch.Size([64])
torch.Size([64, 1, 23]) torch.Size([64, 1, 22, 22]) torch.Size([64, 280, 4]) torch.Size([64])
torch.Size([64, 1, 23]) torch.Size([64, 1, 22, 22]) torch.Size([64, 280, 4]) torch.Size([64])
