In [2]:
import wandb

# Initialize your WandB API
api = wandb.Api()

# Define your project and sweep details
entity = "ravikumarchavva-org"  # Replace with your WandB organization or username
project = "T20I-CRICKET-WINNER-PREDICTION"
sweep_id = "qqakx1g3"  # The specific sweep ID

# Fetch the sweep
sweep = api.sweep(f"{entity}/{project}/{sweep_id}")

# Retrieve all runs in the sweep
runs = sweep.runs

# Sort runs by a specific metric, e.g., validation accuracy (replace with your metric name)
# Use the metric key name you logged in your WandB runs
best_run = sorted(
    runs, key=lambda run: run.summary.get("val_accuracy", float("-inf")), reverse=True
)[1]

# Print details of the best run
print(f"Best run ID: {best_run.id}")
print(f"Validation Accuracy: {best_run.summary.get('val_accuracy')}")
config = best_run.config
config

[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.


Best run ID: yjilby5q
Validation Accuracy: 83.49514563106796


{'lr': 0.0001,
 'dropout': 0.6882847989323215,
 'batch_size': 32,
 'num_epochs': 100,
 'num_layers': 3,
 'hidden_size': 256,
 'enable_plots': False,
 'weight_decay': 5.754095511533712e-06,
 'learning_rate': 0.0009693209823947022}

In [3]:
# Get the model from run id
run_path = f"ravikumarchavva-org/T20I-CRICKET-WINNER-PREDICTION/{best_run.id}"

# Get the specific run
run = api.run(run_path)

# List and download output artifacts
for artifact in run.logged_artifacts():
    if artifact.name.startswith("best_model"):
        print(f"Downloading artifact: {artifact.name}")
        artifact_dir = artifact.download()
        print(f"Artifact downloaded to: {artifact_dir}")

Downloading artifact: best_model_val_loss_0.3524:v0


wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb:   1 of 1 files downloaded.  


Artifact downloaded to: d:\github\Cricket-Prediction\ml_modeling\5_selecting_best_model_to_onnx\artifacts\best_model_val_loss_0.3524-v0


In [4]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(),"..",".."))

from utils.data_utils import collate_fn_with_padding, load_datasets, augument_data
from utils.model_utils import set_seed
from torch.utils.data import DataLoader

set_seed()
# Load the Datasets
train_dataset, val_dataset, test_dataset = load_datasets()

# Step 2: Augment Data
train_dataset, val_dataset, test_dataset = augument_data(train_dataset, val_dataset, test_dataset)

# Step 3: Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_fn_with_padding)
val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn_with_padding)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn_with_padding)

wandb: Currently logged in as: ravikumarchavva (ravikumarchavva-org). Use `wandb login --relogin` to force relogin


In [5]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import os

# Load the model
model = torch.load(os.path.join(artifact_dir, 'best_model.pth'),weights_only=False)

# Evaluating on `test_dataset`

In [6]:
from utils.model_utils import evaluate_model

config['enable_plots'] = False

save_dir = os.path.dirname(os.getcwd())

# Define window sizes
window_sizes = [5,10,15,20, 25, 30, 35, 40, 45]

# Evaluate the model
metrics, all_labels, all_predictions, all_probs = evaluate_model(
    model, test_dataloader, device, window_sizes, config, save_dir=os.getcwd()
)
overall_metrics = metrics["overall_metrics"]
stage_metrics = metrics["stage_metrics"]
import pandas as pd
# Convert metrics to pandas DataFrames
stage_df = pd.DataFrame(stage_metrics).T
stage_df.index.name = "Stage"
stage_df.reset_index(inplace=True)

overall_df = pd.DataFrame(overall_metrics, index=["Overall"]).reset_index()
overall_df.rename(columns={"index": "Stage"}, inplace=True)

print("\nTest Data Metrics:")
# Print metrics in DataFrame format
print("Overall Metrics:")
print(overall_df.to_string(index=False))

print("\nStage Metrics:")
stage_df

Accuracy: 85.58 %

Test Data Metrics:
Overall Metrics:
  Stage  accuracy  precision   recall      f1
Overall  0.855769   0.924138 0.797619 0.85623

Stage Metrics:


Unnamed: 0,Stage,accuracy,precision,recall,f1
0,5 overs,0.7,0.666667,0.5,0.571429
1,10 overs,0.783333,0.869565,0.666667,0.754717
2,15 overs,0.811111,0.878049,0.75,0.808989
3,20 overs,0.85,0.9,0.818182,0.857143
4,25 overs,0.86,0.909091,0.833333,0.869565
5,30 overs,0.855556,0.913043,0.823529,0.865979
6,35 overs,0.861905,0.925234,0.825,0.872247
7,40 overs,0.875,0.928571,0.847826,0.886364
8,45 overs,0.877778,0.933333,0.84,0.884211


In [7]:
# i = torch.randint(0, len(test_dataset), (1,)).item()
i = 0
train_dataset[i][0].unsqueeze(0).shape, train_dataset[i][1].unsqueeze(0).unsqueeze(0).shape, train_dataset[i][2].unsqueeze(0).shape

(torch.Size([1, 13]), torch.Size([1, 1, 22, 12]), torch.Size([1, 120, 10]))

In [8]:
print(train_dataset[i][0].shape[-1])

13


In [9]:
# Move the model to the device
model.to(device)

# Set the model to evaluation mode
model.eval()

# Visualize the model architecture
from torchinfo import summary  # Replace torchsummary with torchinfo
from torchviz import make_dot  # Add import for torchviz

# Visualize the model architecture using torchinfo
summary(model, input_size=[(1, 13), (1, 1, 22, 12), (1, 10)])

# Create a dummy input to visualize the graph
team_dummy = train_dataset[i][0].unsqueeze(0).to(device)
player_dummy = train_dataset[i][1].unsqueeze(0).to(device)
ball_dummy = train_dataset[i][2].unsqueeze(0).to(device)
# Forward pass to get the output
output = model(team_dummy, player_dummy, ball_dummy)

# Generate and save the model visualization
dot = make_dot(output, params=None)
dot.format = 'png'
dot.render('model_visualization')  # Saves as model_visualization.png

'model_visualization.png'

# Saving to ONNX format

In [11]:
# Export the model to ONNX

from utils.model_utils import export_model_to_onnx

export_path = os.path.join(os.getcwd(), "model.onnx")
export_model_to_onnx(model, export_path, (team_dummy, player_dummy, ball_dummy))



Model has been exported to d:\github\Cricket-Prediction\ml_modeling\5_selecting_best_model_to_onnx\model.onnx


# Predict using onnx

In [12]:
#import from onnx
import onnx
import onnxruntime as ort

# Load the ONNX model
onnx_model = onnx.load(export_path)

# Check the ONNX model
onnx.checker.check_model(onnx_model)

# Initialize the ONNX runtime session
ort_session = ort.InferenceSession(export_path)

# Get the input names
input_names = [input.name for input in ort_session.get_inputs()]

# Get the output names
output_names = [output.name for output in ort_session.get_outputs()]

# Print the input and output names
print(f"Input Names: {input_names}")
print(f"Output Names: {output_names}")

Input Names: ['team_input', 'player_input', 'ball_input']
Output Names: ['output']


In [13]:
i= torch.randint(0, len(val_dataset), (1,)).item()
team_input, player_input, ball_input,label = train_dataset[i]
team_input = team_input.unsqueeze(0).to(device)
player_input = player_input.unsqueeze(0).to(device)
ball_input = ball_input.unsqueeze(0).to(device)
label

tensor(0.)

In [14]:
onnx_input = {
    "team_input": team_input.cpu().numpy(),
    "player_input": player_input.cpu().numpy(),
    "ball_input": ball_input.cpu().numpy(),
}


In [15]:
# Run the ONNX model using the ONNX runtime session
outputs = ort_session.run(None, onnx_input)

# Print the outputs
print(outputs)

[array([[0.07113281]], dtype=float32)]


In [18]:
# select 2 random samples from the test dataset where 1 is a win and 1 is a loss
win_indices = [i for i, label in enumerate(all_labels) if label == 1]
loss_indices = [i for i, label in enumerate(all_labels) if label == 0]

win_index = torch.randint(0, len(win_indices), (1,)).item()
loss_index = torch.randint(0, len(loss_indices), (1,)).item()

win_index, loss_index

(208, 250)

In [19]:
win_team_input, win_player_input, win_ball_input, win_label = test_dataset[win_indices[win_index]]
loss_team_input, loss_player_input, loss_ball_input, loss_label = test_dataset[loss_indices[loss_index]]

win_team_input = win_team_input.unsqueeze(0).to(device)
win_player_input = win_player_input.unsqueeze(0).to(device)
win_ball_input = win_ball_input.unsqueeze(0).to(device)

loss_team_input = loss_team_input.unsqueeze(0).to(device)
loss_player_input = loss_player_input.unsqueeze(0).to(device)
loss_ball_input = loss_ball_input.unsqueeze(0).to(device)

win_label, loss_label

(tensor(1.), tensor(0.))

In [20]:
onnx_input_win = {
    "team_input": win_team_input.cpu().numpy(),
    "player_input": win_player_input.cpu().numpy(),
    "ball_input": win_ball_input.cpu().numpy(),
}

onnx_input_loss = {
    "team_input": loss_team_input.cpu().numpy(),
    "player_input": loss_player_input.cpu().numpy(),
    "ball_input": loss_ball_input.cpu().numpy(),
}

# Run the ONNX model using the ONNX runtime session
win_outputs = ort_session.run(None, onnx_input_win)
loss_outputs = ort_session.run(None, onnx_input_loss)

# Print the outputs
print("Win Outputs:", win_outputs)
print("Loss Outputs:", loss_outputs)

Win Outputs: [array([[0.8578379]], dtype=float32)]
Loss Outputs: [array([[0.01378292]], dtype=float32)]


# Printing from taking image

In [None]:
# printing for taking image
import pandas as pd
columns = ['innings', 'ball', 'runs', 'wickets', 'total_runs','total_wickets', 'overs', 'run_rate', 'req_run_rate', 'target']
print(f"Actual:{win_label}, Predicted:{win_outputs[0][0]}")
pd.DataFrame(win_ball_input.cpu().numpy().reshape(-1,10),columns=columns).tail(5)

Actual:1.0, Predicted:[0.8578379]


Unnamed: 0,innings,ball,runs,wickets,total_runs,total_wickets,overs,run_rate,req_run_rate,target
235,2.0,17.5,1.0,0.0,179.0,5.0,17.0,10.529411,5.0,194.0
236,2.0,17.6,2.0,0.0,181.0,5.0,17.0,10.647058,4.333333,194.0
237,2.0,18.1,1.0,0.0,182.0,5.0,18.0,10.111111,6.0,194.0
238,2.0,18.200001,6.0,0.0,188.0,5.0,18.0,10.444445,3.0,194.0
239,2.0,18.299999,1.0,0.0,189.0,5.0,18.0,10.5,2.5,194.0


In [35]:
print(f"Actual:{loss_label}, Predicted:{loss_outputs[0][0]}")
pd.DataFrame(loss_ball_input.cpu().numpy().reshape(-1,10),columns=columns).tail(5)

Actual:0.0, Predicted:[0.01378292]


Unnamed: 0,innings,ball,runs,wickets,total_runs,total_wickets,overs,run_rate,req_run_rate,target
222,2.0,15.2,0.0,1.0,99.0,8.0,15.0,6.6,21.0,204.0
223,2.0,15.3,1.0,0.0,100.0,8.0,15.0,6.666667,20.799999,204.0
224,2.0,15.4,0.0,0.0,100.0,8.0,15.0,6.666667,20.799999,204.0
225,2.0,15.5,1.0,0.0,101.0,8.0,15.0,6.733333,20.6,204.0
226,2.0,15.6,0.0,1.0,101.0,9.0,15.0,6.733333,20.6,204.0
