In [12]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(),".."))

import model_utils

import polars as pl
# import data
def load_data():
    balltoball = pl.read_csv(os.path.join(os.path.join( '..',"data", "filtered_data" , "balltoball.csv")))
    team_stats = pl.read_csv(os.path.join(os.path.join( '..',"data", "filtered_data" , "team12_stats.csv")))
    players_stats = pl.read_csv(os.path.join(os.path.join( '..',"data", "filtered_data" , "players_stats.csv")))
    return balltoball, team_stats, players_stats
balltoball,team_stats,players_stats = load_data()
print(balltoball.columns)
print(team_stats.columns)
print(players_stats.columns)
print(balltoball.head(1),team_stats.head(1),players_stats.head(1))

['match_id', 'innings', 'ball', 'runs', 'wickets', 'curr_score', 'curr_wickets', 'overs', 'run_rate', 'required_run_rate', 'target', 'won']
['match_id', 'gender', 'Cumulative Won team1', 'Cumulative Lost team1', 'Cumulative Tied team1', 'Cumulative W/L team1', 'Cumulative AveRPW team1', 'Cumulative AveRPO team1', 'Cumulative Won team2', 'Cumulative Lost team2', 'Cumulative Tied team2', 'Cumulative W/L team2', 'Cumulative AveRPW team2', 'Cumulative AveRPO team2']
['match_id', 'Cum Mat Total', 'Cum Runs Total', 'Cum SR', 'Cumulative Overs', 'Cumulative Bowling Runs', 'Cumulative Wkts', 'Cumulative Econ', 'Cumulative Dis', 'Cumulative Ct', 'Cumulative St', 'Cumulative D/I']
shape: (1, 12)
┌──────────┬─────────┬──────┬──────┬───┬──────────┬───────────────────┬────────┬─────┐
│ match_id ┆ innings ┆ ball ┆ runs ┆ … ┆ run_rate ┆ required_run_rate ┆ target ┆ won │
│ ---      ┆ ---     ┆ ---  ┆ ---  ┆   ┆ ---      ┆ ---               ┆ ---    ┆ --- │
│ i64      ┆ i64     ┆ f64  ┆ i64  ┆   ┆ f64

In [None]:
def partition_data_with_keys(df, group_keys):
    partitions = df.partition_by(group_keys)
    keys = [tuple(partition.select(group_keys).unique().to_numpy()[0]) for partition in partitions]
    partitions = [partition.drop(group_keys).to_numpy() for partition in partitions]
    # partitions = [partition for partition in partitions]                  # for testing
    return keys, partitions

# Use the updated partition_data_with_keys function
balltoball_keys, balltoball_partitions = partition_data_with_keys(balltoball, ["match_id"])
team_stats_keys, team_stats_partitions = partition_data_with_keys(team_stats, ["match_id"])
players_stats_keys, players_stats_partitions = partition_data_with_keys(players_stats, ["match_id"])

# Align the partitions using common keys
common_keys = set(balltoball_keys) & set(team_stats_keys) & set(players_stats_keys)

balltoball_dict = dict(zip(balltoball_keys, balltoball_partitions))
team_stats_dict = dict(zip(team_stats_keys, team_stats_partitions))
players_stats_dict = dict(zip(players_stats_keys, players_stats_partitions))

aligned_balltoball_partitions = []
aligned_team_stats_partitions = []
aligned_players_stats_partitions = []
labels = []

for key in common_keys:
    balltoball_partition = balltoball_dict[key]
    team_stats_partition = team_stats_dict[key]
    players_stats_partition = players_stats_dict[key]

    label = balltoball_partition[:, -1][0]
    aligned_balltoball_partitions.append(balltoball_partition[:, :-1])
    aligned_team_stats_partitions.append(team_stats_partition)
    aligned_players_stats_partitions.append(players_stats_partition)
    labels.append(label)

import numpy as np
labels = np.array(labels)

In [29]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset

team_data = [team.to_numpy() if isinstance(team, pl.DataFrame) else team for team in aligned_team_stats_partitions]
player_data = [players.to_numpy() if isinstance(players, pl.DataFrame) else players for players in aligned_players_stats_partitions]
ball_data = [ball.to_numpy() if isinstance(ball, pl.DataFrame) else ball for ball in aligned_balltoball_partitions]

train_indices, val_indices = train_test_split(np.arange(len(labels)), test_size=0.2, random_state=42)
val_indices, test_indices = train_test_split(val_indices, test_size=0.5, random_state=42)

dataset = model_utils.CricketDataset(
    team_data,
    player_data,
    ball_data,
    labels
)

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)


train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=model_utils.collate_fn_with_padding)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, collate_fn=model_utils.collate_fn_with_padding)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=model_utils.collate_fn_with_padding)

for team_input, player_input, ball_input, labels, mask in train_dataloader:
    print(f"Team input shape: {team_input.shape}")  # [batch_size, team_feature_dim]
    print(f"Player input shape: {player_input.shape}")  # [batch_size, player_feature_dim]
    print(f"Padded ball input shape: {ball_input.shape}")  # [batch_size, max_seq_len, ball_feature_dim]
    print(f"Mask shape: {mask.shape}")  # [batch_size, max_seq_len]
    print(f"Labels shape: {labels.shape}")  # [batch_size]
    break

Team input shape: torch.Size([32, 14])
Player input shape: torch.Size([32, 22, 12])
Padded ball input shape: torch.Size([32, 256, 11])
Mask shape: torch.Size([32, 256])
Labels shape: torch.Size([32])


In [15]:
import pickle
with open(os.path.join( '..',"data", "pytorch_data" , "train_dataloader.pkl"), "wb") as f:
    pickle.dump(train_dataloader, f)

with open(os.path.join( '..',"data", "pytorch_data" , "val_dataloader.pkl"), "wb") as f:
    pickle.dump(val_dataloader, f)

with open(os.path.join( '..',"data", "pytorch_data" , "test_dataloader.pkl"), "wb") as f:
    pickle.dump(test_dataloader, f)

In [16]:
train_dataloader = pickle.load(open(os.path.join( '..',"data", "pytorch_data" , "train_dataloader.pkl"), "rb"))
val_dataloader = pickle.load(open(os.path.join( '..',"data", "pytorch_data" , "val_dataloader.pkl"), "rb"))
test_dataloader = pickle.load(open(os.path.join( '..',"data", "pytorch_data" , "test_dataloader.pkl"), "rb"))

In [17]:
j=0
for i in train_dataloader:
    j+=1
print(j)

26


In [None]:
for i in train_dataloader:
    j = np.random.randint(32)
    data0 = np.array(i[0][j]).reshape(1,-1)
    data1 = np.array(i[1][j])
    data2 = np.array(i[2][j])
    print(data0.shape, data1.shape, data2.shape)
    # team0 = pl.DataFrame(data0,schema=team_stats.columns)
    # players1 = pl.DataFrame(data1,schema=players_stats.columns)
    # balltoball2 = pl.DataFrame(data2,schema=balltoball.columns[:-1])
    team0 = pl.DataFrame(data0,schema=team_stats.columns[1:])
    players1 = pl.DataFrame(data1,schema=players_stats.columns[1:])
    balltoball2 = pl.DataFrame(data2,schema=balltoball.columns[1:-1])
    break
team0

(1, 14) (22, 12) (258, 11)


match_id,gender,Cumulative Won team1,Cumulative Lost team1,Cumulative Tied team1,Cumulative W/L team1,Cumulative AveRPW team1,Cumulative AveRPO team1,Cumulative Won team2,Cumulative Lost team2,Cumulative Tied team2,Cumulative W/L team2,Cumulative AveRPW team2,Cumulative AveRPO team2
f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
533289.0,0.0,24.0,17.0,0.0,1.41,23.5,7.67,16.0,20.0,2.0,0.8,22.99,7.66


In [31]:
players1

match_id,Cum Mat Total,Cum Runs Total,Cum SR,Cumulative Overs,Cumulative Bowling Runs,Cumulative Wkts,Cumulative Econ,Cumulative Dis,Cumulative Ct,Cumulative St,Cumulative D/I
f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
533289.0,37.0,981.0,134.350006,1.0,8.0,0.0,8.0,11.0,11.0,0.0,0.3
533289.0,38.0,917.0,114.610001,24.6,185.0,5.0,8.06,20.0,18.0,2.0,0.53
533289.0,35.0,910.0,117.089996,0.0,0.0,0.0,0.0,28.0,17.0,11.0,0.8
533289.0,28.0,363.0,123.669998,56.099998,386.0,17.0,6.78,9.0,9.0,0.0,0.32
533289.0,4.0,41.0,111.400002,6.0,36.0,1.0,6.13,1.0,1.0,0.0,0.25
…,…,…,…,…,…,…,…,…,…,…,…
533289.0,29.0,184.0,111.360001,83.0,542.0,31.0,6.56,14.0,14.0,0.0,0.48
533289.0,26.0,229.0,120.790001,0.0,0.0,0.0,0.0,26.0,21.0,5.0,1.0
533289.0,5.0,2.0,50.0,20.0,128.0,7.0,6.4,0.0,0.0,0.0,0.0
533289.0,12.0,11.0,33.330002,45.5,408.0,16.0,9.0,0.0,0.0,0.0,0.0


In [32]:
balltoball2

match_id,innings,ball,runs,wickets,curr_score,curr_wickets,overs,run_rate,required_run_rate,target
f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
533289.0,1.0,0.1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
533289.0,1.0,0.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
533289.0,1.0,0.3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
533289.0,1.0,0.4,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
533289.0,1.0,0.5,1.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0
…,…,…,…,…,…,…,…,…,…,…
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
