In [159]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(),".."))

import model_utils

import polars as pl
# import data
def load_data():
    balltoball = pl.read_csv(os.path.join(os.path.join( '..',"data", "filtered_data" , "balltoball.csv")))
    team_stats = pl.read_csv(os.path.join(os.path.join( '..',"data", "filtered_data" , "team12_stats.csv")))
    players_stats = pl.read_csv(os.path.join(os.path.join( '..',"data", "filtered_data" , "players_stats.csv")))
    return balltoball, team_stats, players_stats
balltoball,team_stats,players_stats = load_data()
print(balltoball.columns)
print(team_stats.columns)
print(players_stats.columns)
print(balltoball.head(1),team_stats.head(1),players_stats.head(1))

['match_id', 'innings', 'ball', 'runs', 'wickets', 'curr_score', 'curr_wickets', 'overs', 'run_rate', 'required_run_rate', 'target', 'won']
['match_id', 'gender', 'Cumulative Won team1', 'Cumulative Lost team1', 'Cumulative Tied team1', 'Cumulative W/L team1', 'Cumulative AveRPW team1', 'Cumulative AveRPO team1', 'Cumulative Won team2', 'Cumulative Lost team2', 'Cumulative Tied team2', 'Cumulative W/L team2', 'Cumulative AveRPW team2', 'Cumulative AveRPO team2']
['match_id', 'Cum Mat Total', 'Cum Inns Total', 'Cum Runs Total', 'Cum Batting Ave', 'Cum SR', 'Cumulative Overs', 'Cumulative Bowling Runs', 'Cumulative Wkts', 'Cumulative Econ', 'Cumulative Dis', 'Cumulative Ct', 'Cumulative St', 'Cumulative D/I']
shape: (1, 12)
┌──────────┬─────────┬──────┬──────┬───┬──────────┬───────────────────┬────────┬─────┐
│ match_id ┆ innings ┆ ball ┆ runs ┆ … ┆ run_rate ┆ required_run_rate ┆ target ┆ won │
│ ---      ┆ ---     ┆ ---  ┆ ---  ┆   ┆ ---      ┆ ---               ┆ ---    ┆ --- │
│ i64  

In [160]:
def partition_data_with_keys(df, group_keys):
    partitions = df.partition_by(group_keys)
    keys = [tuple(partition.select(group_keys).unique().to_numpy()[0]) for partition in partitions]
    partitions = [partition.drop(group_keys).to_numpy() for partition in partitions]
    # partitions = [partition for partition in partitions]                  # for testing
    return keys, partitions

# Use the updated partition_data_with_keys function
balltoball_keys, balltoball_partitions = partition_data_with_keys(balltoball, ["match_id"])
team_stats_keys, team_stats_partitions = partition_data_with_keys(team_stats, ["match_id"])
players_stats_keys, players_stats_partitions = partition_data_with_keys(players_stats, ["match_id"])

# Align the partitions using common keys
common_keys = set(balltoball_keys) & set(team_stats_keys) & set(players_stats_keys)

balltoball_dict = dict(zip(balltoball_keys, balltoball_partitions))
team_stats_dict = dict(zip(team_stats_keys, team_stats_partitions))
players_stats_dict = dict(zip(players_stats_keys, players_stats_partitions))

aligned_balltoball_partitions = []
aligned_team_stats_partitions = []
aligned_players_stats_partitions = []
labels = []

for key in common_keys:
    balltoball_partition = balltoball_dict[key]
    team_stats_partition = team_stats_dict[key]
    players_stats_partition = players_stats_dict[key]

    label = balltoball_partition[:, -1][0]
    aligned_balltoball_partitions.append(balltoball_partition[:, :-1])
    aligned_team_stats_partitions.append(team_stats_partition)
    aligned_players_stats_partitions.append(players_stats_partition)
    labels.append(label)

import numpy as np
labels = np.array(labels)

In [161]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset

team_data = [team.to_numpy() if isinstance(team, pl.DataFrame) else team for team in aligned_team_stats_partitions]
player_data = [players.to_numpy() if isinstance(players, pl.DataFrame) else players for players in aligned_players_stats_partitions]
ball_data = [ball.to_numpy() if isinstance(ball, pl.DataFrame) else ball for ball in aligned_balltoball_partitions]

train_indices, val_indices = train_test_split(np.arange(len(labels)), test_size=0.2, random_state=42)
val_indices, test_indices = train_test_split(val_indices, test_size=0.5, random_state=42)

dataset = model_utils.CricketDataset(
    team_data,
    player_data,
    ball_data,
    labels
)

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)


train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=model_utils.collate_fn_with_packing)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, collate_fn=model_utils.collate_fn_with_packing)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=model_utils.collate_fn_with_packing)

In [162]:
import pickle
with open(os.path.join( '..',"data", "pytorch_data" , "train_dataloader.pkl"), "wb") as f:
    pickle.dump(train_dataloader, f)

with open(os.path.join( '..',"data", "pytorch_data" , "val_dataloader.pkl"), "wb") as f:
    pickle.dump(val_dataloader, f)

with open(os.path.join( '..',"data", "pytorch_data" , "test_dataloader.pkl"), "wb") as f:
    pickle.dump(test_dataloader, f)

In [163]:
train_dataloader = pickle.load(open(os.path.join( '..',"data", "pytorch_data" , "train_dataloader.pkl"), "rb"))
val_dataloader = pickle.load(open(os.path.join( '..',"data", "pytorch_data" , "val_dataloader.pkl"), "rb"))
test_dataloader = pickle.load(open(os.path.join( '..',"data", "pytorch_data" , "test_dataloader.pkl"), "rb"))

In [170]:
# Print a single match's data
single_match_index = np.random.randint(0,32)  # Index of the match to print
single_match_data = dataset[single_match_index]

team_input, player_input, ball_input, label = single_match_data

# team_df = pl.DataFrame(team_input.numpy().astype(float).reshape(1, -1), schema=team_stats.columns)
# player_df = pl.DataFrame(player_input.numpy().astype(float), schema=players_stats.columns)
# ball_df = pl.DataFrame(ball_input.numpy().astype(float), schema=balltoball.columns[:-1])
team_df = pl.DataFrame(team_input.numpy().astype(float).reshape(1, -1), schema=team_stats.columns[1:])
player_df = pl.DataFrame(player_input.numpy().astype(float), schema=players_stats.columns[1:])
ball_df = pl.DataFrame(ball_input.numpy().astype(float), schema=balltoball.columns[1:-1])

print("Label:")
print(label)
print("Ball-to-Ball Stats:")
ball_df.with_row_index()

Label:
tensor(1.)
Ball-to-Ball Stats:


index,innings,ball,runs,wickets,curr_score,curr_wickets,overs,run_rate,required_run_rate,target
u32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
0,1.0,0.1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,1.0,0.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,1.0,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,1.0,0.4,4.0,0.0,4.0,0.0,0.0,0.0,0.0,0.0
4,1.0,0.5,0.0,0.0,4.0,0.0,0.0,0.0,0.0,0.0
…,…,…,…,…,…,…,…,…,…,…
232,2.0,17.4,0.0,1.0,102.0,4.0,17.0,6.0,2.666667,110.0
233,2.0,17.5,2.0,0.0,104.0,4.0,17.0,6.117647,2.0,110.0
234,2.0,17.6,4.0,0.0,108.0,4.0,17.0,6.352941,0.666667,110.0
235,2.0,18.1,0.0,0.0,108.0,4.0,18.0,6.0,1.0,110.0


In [166]:
print("Player Stats:")
player_df

Player Stats:


Cum Mat Total,Cum Inns Total,Cum Runs Total,Cum Batting Ave,Cum SR,Cumulative Overs,Cumulative Bowling Runs,Cumulative Wkts,Cumulative Econ,Cumulative Dis,Cumulative Ct,Cumulative St,Cumulative D/I
f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
42.0,35.0,739.0,36.349998,121.209999,11.2,113.0,1.0,10.98,19.0,19.0,0.0,0.45
6.0,6.0,101.0,16.83,102.220001,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.17
13.0,13.0,273.0,21.0,110.959999,0.0,0.0,0.0,0.0,12.0,12.0,0.0,0.92
44.0,38.0,947.0,33.5,128.509995,29.0,235.0,6.0,8.13,20.0,20.0,0.0,0.45
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
…,…,…,…,…,…,…,…,…,…,…,…,…
1.0,1.0,7.0,7.0,175.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
40.0,33.0,469.0,36.009998,149.919998,79.400002,661.0,26.0,8.62,13.0,13.0,0.0,0.33
17.0,7.0,32.0,17.57,90.910004,55.0,300.0,18.0,5.56,2.0,2.0,0.0,0.12
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [167]:
print("Single Match Data:")
print("Team Stats:")
team_df

Single Match Data:
Team Stats:


gender,Cumulative Won team1,Cumulative Lost team1,Cumulative Tied team1,Cumulative W/L team1,Cumulative AveRPW team1,Cumulative AveRPO team1,Cumulative Won team2,Cumulative Lost team2,Cumulative Tied team2,Cumulative W/L team2,Cumulative AveRPW team2,Cumulative AveRPO team2
f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
0.0,31.0,22.0,1.0,1.41,28.76,7.99,42.0,25.0,1.0,1.68,25.0,7.75
