In [171]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(),".."))

import model_utils

import polars as pl
# import data
def load_data():
    balltoball = pl.read_csv(os.path.join(os.path.join( '..',"data", "filtered_data" , "balltoball.csv")))
    team_stats = pl.read_csv(os.path.join(os.path.join( '..',"data", "filtered_data" , "team12_stats.csv")))
    players_stats = pl.read_csv(os.path.join(os.path.join( '..',"data", "filtered_data" , "players_stats.csv")))
    return balltoball, team_stats, players_stats
balltoball,team_stats,players_stats = load_data()
print(balltoball.columns)
print(team_stats.columns)
print(players_stats.columns)
print(balltoball.head(1),team_stats.head(1),players_stats.head(1))

['match_id', 'innings', 'ball', 'runs', 'wickets', 'curr_score', 'curr_wickets', 'overs', 'run_rate', 'required_run_rate', 'target', 'won']
['match_id', 'gender', 'Cumulative Won team1', 'Cumulative Lost team1', 'Cumulative Tied team1', 'Cumulative W/L team1', 'Cumulative AveRPW team1', 'Cumulative AveRPO team1', 'Cumulative Won team2', 'Cumulative Lost team2', 'Cumulative Tied team2', 'Cumulative W/L team2', 'Cumulative AveRPW team2', 'Cumulative AveRPO team2']
['match_id', 'Cum Mat Total', 'Cum Inns Total', 'Cum Runs Total', 'Cum Batting Ave', 'Cum SR', 'Cumulative Overs', 'Cumulative Bowling Runs', 'Cumulative Wkts', 'Cumulative Econ', 'Cumulative Dis', 'Cumulative Ct', 'Cumulative St', 'Cumulative D/I']
shape: (1, 12)
┌──────────┬─────────┬──────┬──────┬───┬──────────┬───────────────────┬────────┬─────┐
│ match_id ┆ innings ┆ ball ┆ runs ┆ … ┆ run_rate ┆ required_run_rate ┆ target ┆ won │
│ ---      ┆ ---     ┆ ---  ┆ ---  ┆   ┆ ---      ┆ ---               ┆ ---    ┆ --- │
│ i64  

In [None]:
def partition_data_with_keys(df, group_keys):
    partitions = df.partition_by(group_keys)
    keys = [tuple(partition.select(group_keys).unique().to_numpy()[0]) for partition in partitions]
    partitions = [partition.drop(group_keys).to_numpy() for partition in partitions]
    # partitions = [partition for partition in partitions]                  # for testing
    return keys, partitions

# Use the updated partition_data_with_keys function
balltoball_keys, balltoball_partitions = partition_data_with_keys(balltoball, ["match_id"])
team_stats_keys, team_stats_partitions = partition_data_with_keys(team_stats, ["match_id"])
players_stats_keys, players_stats_partitions = partition_data_with_keys(players_stats, ["match_id"])

# Align the partitions using common keys
common_keys = set(balltoball_keys) & set(team_stats_keys) & set(players_stats_keys)

balltoball_dict = dict(zip(balltoball_keys, balltoball_partitions))
team_stats_dict = dict(zip(team_stats_keys, team_stats_partitions))
players_stats_dict = dict(zip(players_stats_keys, players_stats_partitions))

aligned_balltoball_partitions = []
aligned_team_stats_partitions = []
aligned_players_stats_partitions = []
labels = []

for key in common_keys:
    balltoball_partition = balltoball_dict[key]
    team_stats_partition = team_stats_dict[key]
    players_stats_partition = players_stats_dict[key]

    label = balltoball_partition[:, -1][0]
    aligned_balltoball_partitions.append(balltoball_partition[:-30, :-1]) # remove the last 30 rows or balls
    aligned_team_stats_partitions.append(team_stats_partition)
    aligned_players_stats_partitions.append(players_stats_partition)
    labels.append(label)

import numpy as np
labels = np.array(labels)

In [173]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset

team_data = [team.to_numpy() if isinstance(team, pl.DataFrame) else team for team in aligned_team_stats_partitions]
player_data = [players.to_numpy() if isinstance(players, pl.DataFrame) else players for players in aligned_players_stats_partitions]
ball_data = [ball.to_numpy() if isinstance(ball, pl.DataFrame) else ball for ball in aligned_balltoball_partitions]

train_indices, val_indices = train_test_split(np.arange(len(labels)), test_size=0.2, random_state=42)
val_indices, test_indices = train_test_split(val_indices, test_size=0.5, random_state=42)

dataset = model_utils.CricketDataset(
    team_data,
    player_data,
    ball_data,
    labels
)

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)


train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=model_utils.collate_fn_with_packing)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, collate_fn=model_utils.collate_fn_with_packing)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=model_utils.collate_fn_with_packing)

In [174]:
import pickle
with open(os.path.join( '..',"data", "pytorch_data" , "train_dataloader.pkl"), "wb") as f:
    pickle.dump(train_dataloader, f)

with open(os.path.join( '..',"data", "pytorch_data" , "val_dataloader.pkl"), "wb") as f:
    pickle.dump(val_dataloader, f)

with open(os.path.join( '..',"data", "pytorch_data" , "test_dataloader.pkl"), "wb") as f:
    pickle.dump(test_dataloader, f)

In [175]:
train_dataloader = pickle.load(open(os.path.join( '..',"data", "pytorch_data" , "train_dataloader.pkl"), "rb"))
val_dataloader = pickle.load(open(os.path.join( '..',"data", "pytorch_data" , "val_dataloader.pkl"), "rb"))
test_dataloader = pickle.load(open(os.path.join( '..',"data", "pytorch_data" , "test_dataloader.pkl"), "rb"))

In [188]:
# Print a single match's data
single_match_index = np.random.randint(0,32)  # Index of the match to print
single_match_data = dataset[single_match_index]

team_input, player_input, ball_input, label = single_match_data

# team_df = pl.DataFrame(team_input.numpy().astype(float).reshape(1, -1), schema=team_stats.columns)
# player_df = pl.DataFrame(player_input.numpy().astype(float), schema=players_stats.columns)
# ball_df = pl.DataFrame(ball_input.numpy().astype(float), schema=balltoball.columns[:-1])
team_df = pl.DataFrame(team_input.numpy().astype(float).reshape(1, -1), schema=team_stats.columns[1:])
player_df = pl.DataFrame(player_input.numpy().astype(float), schema=players_stats.columns[1:])
ball_df = pl.DataFrame(ball_input.numpy().astype(float), schema=balltoball.columns[1:-1])

print("Label:")
print(label)
print("Ball-to-Ball Stats:")
ball_df.with_row_index()

Label:
tensor(0.)
Ball-to-Ball Stats:


index,innings,ball,runs,wickets,curr_score,curr_wickets,overs,run_rate,required_run_rate,target
u32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
0,1.0,0.1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,1.0,0.2,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
2,1.0,0.3,1.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0
3,1.0,0.4,4.0,0.0,6.0,0.0,0.0,0.0,0.0,0.0
4,1.0,0.5,0.0,0.0,6.0,0.0,0.0,0.0,0.0,0.0
…,…,…,…,…,…,…,…,…,…,…
215,2.0,14.6,4.0,0.0,139.0,3.0,14.0,9.928572,4.333333,165.0
216,2.0,14.7,2.0,0.0,141.0,3.0,14.0,10.071428,4.0,165.0
217,2.0,15.1,0.0,0.0,141.0,3.0,15.0,9.4,4.8,165.0
218,2.0,15.2,1.0,0.0,142.0,3.0,15.0,9.466666,4.6,165.0


In [177]:
print("Player Stats:")
player_df

Player Stats:


Cum Mat Total,Cum Inns Total,Cum Runs Total,Cum Batting Ave,Cum SR,Cumulative Overs,Cumulative Bowling Runs,Cumulative Wkts,Cumulative Econ,Cumulative Dis,Cumulative Ct,Cumulative St,Cumulative D/I
f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
…,…,…,…,…,…,…,…,…,…,…,…,…
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4.0,4.0,24.0,6.0,64.860001,7.0,69.0,4.0,9.85,0.0,0.0,0.0,0.0
1.0,1.0,4.0,4.0,50.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4.0,3.0,19.0,6.33,135.710007,14.0,107.0,3.0,7.64,0.0,0.0,0.0,0.0


In [178]:
print("Single Match Data:")
print("Team Stats:")
team_df

Single Match Data:
Team Stats:


gender,Cumulative Won team1,Cumulative Lost team1,Cumulative Tied team1,Cumulative W/L team1,Cumulative AveRPW team1,Cumulative AveRPO team1,Cumulative Won team2,Cumulative Lost team2,Cumulative Tied team2,Cumulative W/L team2,Cumulative AveRPW team2,Cumulative AveRPO team2
f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,0.0,0.0,0.0,5.14
