In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score
import argparse


def clean_prediction(text):
    """Clean the prediction text by removing <end_of_turn> and <pad> tags."""
    text = text.replace("<end_of_turn>\n<pad>", "").strip()
    text = text.replace("<end_of_turn>", "").strip()
    text = text.replace(" chips", "").strip()
    text = text.replace(".", "").strip()

    return text.lower()

def extract_action(move):
    """Extract the action (call, fold, check, raise) from a move."""
    if pd.isna(move):
        return np.nan
    if "raise" in move:
        return "raise"
    elif "call" in move:
        return "call"
    elif "check" in move:
        return "check"
    elif "fold" in move:
        return "fold"
    else:
        return move


def extract_amount(move):
    """Extract the bet amount from a raise move."""
    if pd.isna(move):
        return np.nan
    try:
        return float(move.split("raise")[1].strip())
    except:
        return 0


def main(model_name):
    # Load the CSV file based on the model name
    file_path = f'/Users/weber/Github/COMP0258/testing-results/{model_name}_predictions.csv'
    if file_path is None:
        print(f"File not found: {file_path}")
        return
    else:
        df = pd.read_csv(file_path)

    # Clean predictions
    df['Prediction_Clean'] = df['Prediction'].apply(clean_prediction)
    df['Ground_Truth_Clean'] = df['Ground Truth']

    # Extract action and amount
    df['Pred_Action'] = df['Prediction_Clean'].apply(extract_action)
    df['True_Action'] = df['Ground_Truth_Clean'].apply(extract_action)
    df['Pred_Amount'] = df['Prediction_Clean'].apply(extract_amount)
    df['True_Amount'] = df['Ground_Truth_Clean'].apply(extract_amount)

    # save the cleaned dataframe
    cleaned_file_path = f'/Users/weber/Github/COMP0258/testing-results/{model_name}_cleaned.csv'
    df[['Pred_Action', 'True_Action', 'Pred_Amount', 'True_Amount']].to_csv(
        cleaned_file_path, index=False)

    # Calculate overall accuracy for the action type
    action_accuracy = accuracy_score(df['True_Action'], df['Pred_Action'])
    print(f"Action prediction accuracy: {action_accuracy:.4f}")


    # Confusion matrix for actions
    actions = ['fold', 'check', 'call', 'raise']
    action_df = df[df['True_Action'].isin(
        actions) & df['Pred_Action'].isin(actions)]
    cm = confusion_matrix(
        action_df['True_Action'], action_df['Pred_Action'], labels=actions)
    cm_df = pd.DataFrame(cm, index=actions, columns=actions)

    # Calculate per-action accuracy
    action_specific_accuracy = {}
    for action in actions:
        action_rows = df[df['True_Action'] == action]
        if len(action_rows) > 0:
            correct = sum(action_rows['Pred_Action'] == action)
            accuracy = correct / len(action_rows)
            action_specific_accuracy[action] = accuracy
            print(
                f"Accuracy for {action}: {accuracy:.4f} ({correct}/{len(action_rows)})")

    # Plot confusion matrix
    plt.figure(figsize=(20, 8))
    # adding title for the confusion matrix
    plt.suptitle(f'{model_name}', fontsize=20)


    plt.subplot(1, 2, 1)
    sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix for Poker Actions')
    plt.ylabel('True Action')
    plt.xlabel('Predicted Action')

    # Analyze raise amount accuracy
    raise_df = df[(df['True_Action'] == 'raise') &
                  (df['Pred_Action'] == 'raise')]
    raise_df = raise_df.dropna(subset=['True_Amount', 'Pred_Amount'])

    if len(raise_df) > 0:
        # Calculate RMSE for raise amounts
        rmse = np.sqrt(
            ((raise_df['True_Amount'] - raise_df['Pred_Amount']) ** 2).mean())
        print(f"RMSE for raise amounts: {rmse:.4f}")

        # Calculate NRMSE for raise amounts
        true_amount_range = raise_df['True_Amount'].max(
        ) - raise_df['True_Amount'].min()
        nrmse = rmse / true_amount_range
        print(f"NRMSE for raise amounts: {nrmse:.4f}")

        # Plot raise amount comparison
        plt.subplot(1, 2, 2)
        plt.scatter(raise_df['True_Amount'],
                    raise_df['Pred_Amount'], alpha=0.6)

        # Add perfect prediction line
        max_val = max(raise_df['True_Amount'].max(),
                      raise_df['Pred_Amount'].max())
        min_val = min(raise_df['True_Amount'].min(),
                      raise_df['Pred_Amount'].min())
        plt.plot([min_val, max_val], [min_val, max_val], 'r--')

        plt.title('Raise Amount: Predicted vs True')
        plt.xlabel('True Raise Amount')
        plt.ylabel('Predicted Raise Amount')
        plt.grid(True, alpha=0.3)

    # Save the combined image
    combined_image_path = f'/Users/weber/Github/COMP0258/testing-results/{model_name}_confusion-matrix.png'

    # Add accuracy information to the plot
    accuracy_text = f"Action accuracy: {action_accuracy:.2f}"

    # Add per-action accuracy to the text
    for action in actions:
        if action in action_specific_accuracy:
            accuracy_text += f" | {action.capitalize()} accuracy: {action_specific_accuracy[action]:.2f}"

    # Add RMSE to the text if available
    if len(raise_df) > 0:
        accuracy_text += f" | RMSE for raise amounts: {rmse:.2f}"
    plt.figtext(0.5, 0.01, accuracy_text, ha="center", fontsize=12)

    plt.savefig(combined_image_path)
    plt.close()

    


if __name__ == "__main__":
    # parser = argparse.ArgumentParser(
    #     description='Evaluate poker move predictions.')
    # parser.add_argument('model_name', type=str,
    #                     help='The name of the model to evaluate (e.g., gemma-2-9b-it, lora-gemma-2-9b-it)')
    # args = parser.parse_args()
    # main(args.model_name)
    models = ['gemma-2-9b-it', 'lora-gemma-2-9b-it',
              'Qwen2.5-7B-Instruct-1M', 'lora-Qwen2.5-7B-Instruct-1M',
              'Meta-Llama-3.1-8B-Instruct', 'lora-Meta-Llama-3.1-8B-Instruct', 'lora-Meta-Llama-3.1-8B-Instruct-lr-5',
              'Meta-Llama-3-8B-Instruct', 'lora-Meta-Llama-3-8B-Instruct-lr-5*6', 'lora-Meta-Llama-3-8B-Instruct-lr-6',
              'lora-Llama-3.2-3B-Instruct-lr-5', 'lora-Llama-3.2-3B-Instruct-lr-5-postflop', 'lora-Llama-3.2-3B-Instruct-lr-5-total',
              'Llama-3.2-3B-Instruct', 'Llama-3.2-3B-Instruct-postflop',
              'lora_llama_3.1_8B_model-1600', 'lora_llama_3.1_8B_model-2400', 'lora_llama_3.1_8B_model-3200', 'lora_llama_3.1_8B_model-4000',
              'lora_llama_3.2-3B_model-500', 'lora_llama_3.2-3B_model-1000',
              'lora-Meta-Llama-3.1-8B-Instruct-total', 'Meta-Llama-3.1-8B-Instruct-total']
    for model in models:
        main(model)

Action prediction accuracy: 0.4120
Accuracy for fold: 0.4440 (111/250)
Accuracy for check: 0.0000 (0/250)
Accuracy for call: 0.8880 (222/250)
Accuracy for raise: 0.3160 (79/250)
RMSE for raise amounts: 5.4267
NRMSE for raise amounts: 0.4174
Action prediction accuracy: 0.5240
Accuracy for fold: 0.5960 (149/250)
Accuracy for check: 0.1880 (47/250)
Accuracy for call: 0.7680 (192/250)
Accuracy for raise: 0.5440 (136/250)
RMSE for raise amounts: 59.0208
NRMSE for raise amounts: 3.2789
Action prediction accuracy: 0.2510
Accuracy for fold: 0.0000 (0/250)
Accuracy for check: 0.0000 (0/250)
Accuracy for call: 1.0000 (250/250)
Accuracy for raise: 0.0040 (1/250)
RMSE for raise amounts: 2.0000
NRMSE for raise amounts: inf
Action prediction accuracy: 0.4750
Accuracy for fold: 0.7640 (191/250)
Accuracy for check: 0.4800 (120/250)
Accuracy for call: 0.4400 (110/250)
Accuracy for raise: 0.2160 (54/250)
RMSE for raise amounts: 26.6451
NRMSE for raise amounts: 1.2057
Action prediction accuracy: 0.2910
A

  nrmse = rmse / true_amount_range


RMSE for raise amounts: 18.5475
NRMSE for raise amounts: 18.5475
Action prediction accuracy: 0.6540
Accuracy for fold: 0.5120 (128/250)
Accuracy for check: 1.0000 (250/250)
Accuracy for call: 0.7840 (196/250)
Accuracy for raise: 0.3200 (80/250)
RMSE for raise amounts: 44.8373
NRMSE for raise amounts: 2.0288
Action prediction accuracy: 0.5930
Accuracy for fold: 0.8800 (220/250)
Accuracy for check: 0.0120 (3/250)
Accuracy for call: 0.6920 (173/250)
Accuracy for raise: 0.7880 (197/250)
RMSE for raise amounts: 106.5213
NRMSE for raise amounts: 3.2279
Action prediction accuracy: 0.2810
Accuracy for fold: 0.1080 (27/250)
Accuracy for check: 0.0000 (0/250)
Accuracy for call: 0.9600 (240/250)
Accuracy for raise: 0.0560 (14/250)
RMSE for raise amounts: 44.4610
NRMSE for raise amounts: 1.4820
Action prediction accuracy: 0.6070
Accuracy for fold: 0.8720 (218/250)
Accuracy for check: 0.0000 (0/250)
Accuracy for call: 0.7640 (191/250)
Accuracy for raise: 0.7920 (198/250)
RMSE for raise amounts: 115