In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score
import argparse


def clean_prediction(text):
    """Clean the prediction text by removing <end_of_turn> and <pad> tags."""
    text = text.replace("<end_of_turn>\n<pad>", "").strip()
    text = text.replace("<end_of_turn>", "").strip()
    text = text.replace(" chips", "").strip()
    text = text.replace(".", "").strip()

    return text.lower()

def extract_action(move):
    """Extract the action (call, fold, check, raise) from a move."""
    if pd.isna(move):
        return np.nan
    if "raise" in move:
        return "raise"
    elif "call" in move:
        return "call"
    elif "check" in move:
        return "check"
    elif "fold" in move:
        return "fold"
    elif "bet" in move:
        return "bet"
    else:
        return move


def extract_amount(move):
    """Extract the bet amount from a raise or bet move."""
    if pd.isna(move):
        return np.nan
    if "raise" in move:
        try:
            return float(move.split("raise")[1].strip())
        except:
            return 0
    elif "bet" in move:
        try:
            return float(move.split("bet")[1].strip())
        except:
            return 0
    else:
        return 0

def main(model_name, testing_type):
    # Load the CSV file based on the model name
    file_path = f'/Users/weber/Github/COMP0258/testing-results/{model_name}_predictions.csv'
    if file_path is None:
        print(f"File not found: {file_path}")
        return
    else:
        df = pd.read_csv(file_path)

    # Clean predictions
    df['Prediction_Clean'] = df['Prediction'].apply(clean_prediction)
    df['Ground_Truth_Clean'] = df['Ground Truth']

    # Extract action and amount
    df['Pred_Action'] = df['Prediction_Clean'].apply(extract_action)
    df['True_Action'] = df['Ground_Truth_Clean'].apply(extract_action)
    df['Pred_Amount'] = df['Prediction_Clean'].apply(extract_amount)
    df['True_Amount'] = df['Ground_Truth_Clean'].apply(extract_amount)

    # save the cleaned dataframe
    cleaned_file_path = f'/Users/weber/Github/COMP0258/testing-results/{model_name}_cleaned.csv'
    df[['Pred_Action', 'True_Action', 'Pred_Amount', 'True_Amount']].to_csv(
        cleaned_file_path, index=False)

    # Calculate overall accuracy for the action type
    action_accuracy = accuracy_score(df['True_Action'], df['Pred_Action'])
    print(f"Action prediction accuracy: {action_accuracy:.4f}")


    # Confusion matrix for actions
    if testing_type == 'preflop':
        actions = ['fold', 'check', 'call', 'raise']
    else:
        actions = ['fold', 'check', 'call', 'raise', 'bet']
    action_df = df[df['True_Action'].isin(
        actions) & df['Pred_Action'].isin(actions)]
    cm = confusion_matrix(
        action_df['True_Action'], action_df['Pred_Action'], labels=actions)
    cm_df = pd.DataFrame(cm, index=actions, columns=actions)

    # Calculate per-action accuracy
    action_specific_accuracy = {}
    for action in actions:
        action_rows = df[df['True_Action'] == action]
        if len(action_rows) > 0:
            correct = sum(action_rows['Pred_Action'] == action)
            accuracy = correct / len(action_rows)
            action_specific_accuracy[action] = accuracy
            print(
                f"Accuracy for {action}: {accuracy:.4f} ({correct}/{len(action_rows)})")

    # Plot confusion matrix
    plt.figure(figsize=(20, 8))
    # adding title for the confusion matrix
    plt.suptitle(f'{model_name}', fontsize=20)


    plt.subplot(1, 2, 1)
    sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix for Poker Actions')
    plt.ylabel('True Action')
    plt.xlabel('Predicted Action')

    # Analyze raise and bet amount accuracy
    # For raises
    raise_df = df[(df['True_Action'] == 'raise') &
                  (df['Pred_Action'] == 'raise')]
    raise_df = raise_df.dropna(subset=['True_Amount', 'Pred_Amount'])

    # For bets
    bet_df = df[(df['True_Action'] == 'bet') &
                (df['Pred_Action'] == 'bet')]
    bet_df = bet_df.dropna(subset=['True_Amount', 'Pred_Amount'])

    plt.subplot(1, 2, 2)

    # Calculate metrics and plot for raises
    if len(raise_df) > 0:
        # Calculate RMSE for raise amounts
        raise_rmse = np.sqrt(
            ((raise_df['True_Amount'] - raise_df['Pred_Amount']) ** 2).mean())
        print(f"RMSE for raise amounts: {raise_rmse:.4f}")

        # Calculate NRMSE for raise amounts
        if raise_df['True_Amount'].max() != raise_df['True_Amount'].min():
            true_amount_range = raise_df['True_Amount'].max(
            ) - raise_df['True_Amount'].min()
            raise_nrmse = raise_rmse / true_amount_range
            print(f"NRMSE for raise amounts: {raise_nrmse:.4f}")

        # Plot raise amount comparison
        plt.scatter(raise_df['True_Amount'],
                    raise_df['Pred_Amount'], alpha=0.6, color='blue', label='Raise')

    # Calculate metrics and plot for bets
    if len(bet_df) > 0 and testing_type != 'preflop':
        # Calculate RMSE for bet amounts
        bet_rmse = np.sqrt(
            ((bet_df['True_Amount'] - bet_df['Pred_Amount']) ** 2).mean())
        print(f"RMSE for bet amounts: {bet_rmse:.4f}")

        # Calculate NRMSE for bet amounts
        if bet_df['True_Amount'].max() != bet_df['True_Amount'].min():
            true_amount_range = bet_df['True_Amount'].max(
            ) - bet_df['True_Amount'].min()
            bet_nrmse = bet_rmse / true_amount_range
            print(f"NRMSE for bet amounts: {bet_nrmse:.4f}")

        # Plot bet amount comparison
        plt.scatter(bet_df['True_Amount'],
                    bet_df['Pred_Amount'], alpha=0.6, color='green', label='Bet')

    # Add perfect prediction line if either raises or bets exist
    if len(raise_df) > 0 or len(bet_df) > 0:
        # Combine dataframes for determining overall min and max
        combined_df = pd.concat([raise_df, bet_df])

        if len(combined_df) > 0:
            max_val = max(combined_df['True_Amount'].max(),
                          combined_df['Pred_Amount'].max())
            min_val = min(combined_df['True_Amount'].min(),
                          combined_df['Pred_Amount'].min())
            plt.plot([min_val, max_val], [min_val, max_val],
                     'r--', label='Perfect Prediction')

            plt.title('Amount Prediction: Predicted vs True')
            plt.xlabel('True Amount')
            plt.ylabel('Predicted Amount')
            plt.grid(True, alpha=0.3)
            plt.legend()

    # Save the combined image
    combined_image_path = f'/Users/weber/Github/COMP0258/testing-results/{model_name}_confusion-matrix.png'

    # Add accuracy information to the plot
    accuracy_text = f"Action accuracy: {action_accuracy:.2f}"

    # Add per-action accuracy to the text
    for action in actions:
        if action in action_specific_accuracy:
            accuracy_text += f" | {action.capitalize()} accuracy: {action_specific_accuracy[action]:.2f}"

    # Add RMSE to the text if available
    if len(raise_df) > 0:
        accuracy_text += f" | RMSE for raise amounts: {raise_rmse:.2f}"
    if testing_type != 'preflop':
        accuracy_text += f" | RMSE for bet amounts: {bet_rmse:.2f}"
    plt.figtext(0.5, 0.01, accuracy_text, ha="center", fontsize=12)

    plt.savefig(combined_image_path)
    plt.close()

    


if __name__ == "__main__":
    # parser = argparse.ArgumentParser(
    #     description='Evaluate poker move predictions.')
    # parser.add_argument('model_name', type=str,
    #                     help='The name of the model to evaluate (e.g., gemma-2-9b-it, lora-gemma-2-9b-it)')
    # args = parser.parse_args()
    # main(args.model_name)
    models = [
        # 'gpt-4o-preflop', 'gpt-4o-postflop',
        # 'gemma-2-9b-it-preflop', 'gemma-2-9b-it-postflop'
        # 'lora-gemma-2-9b-it-preflop', 'lora-gemma-2-9b-it-postflop',
        # 'Qwen2.5-7B-Instruct-1M-preflop', 'Qwen2.5-7B-Instruct-1M-postflop',
        # 'lora-Qwen2.5-7B-Instruct-1M-preflop', 'lora-Qwen2.5-7B-Instruct-1M-postflop',
        # 'lora_Qwen2.5_14B_model-5000-preflop', 'lora_Qwen2.5_14B_model-5000-postflop',
        # "qwen2.5-14b-instruct-bnb-4bit-preflop", "qwen2.5-14b-instruct-bnb-4bit-postflop",
        'Meta-Llama-3-8B-Instruct-preflop', 'Meta-Llama-3-8B-Instruct-postflop', 'Meta-Llama-3.1-8B-Instruct-total',
        # 'Meta-Llama-3.1-8B-Instruct-preflop', 'Meta-Llama-3.1-8B-Instruct-postflop', 'Meta-Llama-3.1-8B-Instruct-total',
        # 'lora-Meta-Llama-3.1-8B-Instruct-preflop', 'lora-Meta-Llama-3.1-8B-Instruct-postflop'
        # 'Meta-Llama-3-8B-Instruct-preflop', 'Meta-Llama-3-8B-Instruct-postflop',
        # 'lora-Meta-Llama-3-8B-Instruct-lr-6-preflop', 'lora-Meta-Llama-3-8B-Instruct-lr-6-postflop',
        # 'lora-Llama-3.2-3B-Instruct-lr-5-preflop', 'lora-Llama-3.2-3B-Instruct-lr-5-postflop', 'lora-Llama-3.2-3B-Instruct-lr-5-total',
        # 'lora-Llama-3.2-3B-Instruct-lr-6-preflop', 'lora-Llama-3.2-3B-Instruct-lr-6-postflop', 
        # 'Llama-3.2-3B-Instruct-preflop', 'Llama-3.2-3B-Instruct-postflop',
        # 'lora_llama_3.1_8B_model-1600', 'lora_llama_3.1_8B_model-2400', 'lora_llama_3.1_8B_model-3200', 'lora_llama_3.1_8B_model-4000',
        # 'lora_llama_3.2-3B_model-500', 'lora_llama_3.2-3B_model-1000',
        ]
    for model in models:
        main(model, model.split('-')[-1])

Action prediction accuracy: 0.2810
Accuracy for fold: 0.1080 (27/250)
Accuracy for check: 0.0000 (0/250)
Accuracy for call: 0.9600 (240/250)
Accuracy for raise: 0.0560 (14/250)
RMSE for raise amounts: 44.4610
NRMSE for raise amounts: 1.4820
Action prediction accuracy: 0.3706
Accuracy for fold: 0.4284 (1071/2500)
Accuracy for check: 0.1348 (337/2500)
Accuracy for call: 0.5364 (1341/2500)
Accuracy for raise: 0.3703 (574/1550)
Accuracy for bet: 0.4032 (383/950)
RMSE for raise amounts: 90.7761
NRMSE for raise amounts: 0.9867
RMSE for bet amounts: 60.0036
NRMSE for bet amounts: 1.5790
Action prediction accuracy: 0.3936
Accuracy for fold: 0.0193 (53/2750)
Accuracy for check: 0.4007 (1102/2750)
Accuracy for call: 0.8767 (2411/2750)
Accuracy for raise: 0.1300 (234/1800)
Accuracy for bet: 0.5579 (530/950)
RMSE for raise amounts: 89.4485
NRMSE for raise amounts: 1.0165
RMSE for bet amounts: 37.3509
NRMSE for bet amounts: 0.9829
