In [21]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
from deuces import Card, Evaluator

import sys
sys.path.append('../src')
from calc_preflop_rank import evaluate2


data = pd.read_csv('../data/pluribus_parsed.csv')

data.head()

Unnamed: 0,hand_id,flop_cards,turn_card,river_card,players,hole_cards,preflop_actions,flop_actions,turn_actions,river_actions,showdown_actions,winners
0,100000,7d 5h 9d,,,"MrBlue, MrBlonde, MrWhite, MrPink, MrBrown, Pl...","MrBlue: Tc Qc, MrBlonde: 8s 4c, MrWhite: 9c 3d...","MrBlue: posts small blind 50, MrBlonde: posts ...","MrBlue: checks, MrPink: checks","MrBlue: checks, MrPink: checks","MrBlue: bets 230, MrPink: folds, MrBlue: uncal...",,MrBlue: collected 520.0 from pot
1,100001,7s 9c Tc,,,"MrBlonde, MrWhite, MrPink, MrBrown, Pluribus, ...","MrBlonde: Qh 5c, MrWhite: 9h 6h, MrPink: Kc Jh...","MrBlonde: posts small blind 50, MrWhite: posts...","MrWhite: checks, MrPink: bets 235, MrWhite: ca...","MrWhite: checks, MrPink: bets 600, MrWhite: fo...",,,MrPink: collected 940.0 from pot
2,100002,,,,"MrWhite, MrPink, MrBrown, Pluribus, MrBlue, Mr...","MrWhite: Jc 2c, MrPink: 2d Qh, MrBrown: 9d Jh,...","MrWhite: posts small blind 50, MrPink: posts b...",,,,,MrBlonde: collected 250.0 from pot
3,100003,3d 6h 9d,,,"MrPink, MrBrown, Pluribus, MrBlue, MrBlonde, M...","MrPink: 8d 8s, MrBrown: 2h Kc, Pluribus: 4s 9s...","MrPink: posts small blind 50, MrBrown: posts b...","MrPink: bets 170, MrBrown: folds, MrBlue: call...","MrPink: bets 600, MrBlue: folds, MrPink: uncal...",,,MrPink: collected 1015.0 from pot
4,100004,7c Ah Th,,,"MrBrown, Pluribus, MrBlue, MrBlonde, MrWhite, ...","MrBrown: Ts Ac, Pluribus: 2c 5c, MrBlue: 7d 3c...","MrBrown: posts small blind 50, Pluribus: posts...","MrBrown: checks, MrWhite: checks","MrBrown: checks, MrWhite: bets 400, MrBrown: r...","MrBrown: bets 3500, MrWhite: folds, MrBrown: u...",,MrBrown: collected 3500.0 from pot


In [22]:
preflop_features = []
preflop_labels = []

features = []
labels = []

evaluator = Evaluator()

def convert_to_deuces_format(card_str):
    return Card.new(card_str)

for i, row in data.iterrows():
    winners = row['winners']
    
    flop_cards = str(row['flop_cards']).split() if pd.notna(row['flop_cards']) else []
    turn_card = [str(row['turn_card'])] if pd.notna(row['turn_card']) else []
    river_card = [str(row['river_card'])] if pd.notna(row['river_card']) else []
    community_cards = [convert_to_deuces_format(card) for card in flop_cards + turn_card + river_card]
    hole_cards_dict = dict([player.split(': ') for player in row['hole_cards'].split(', ')])
    for player, cards in hole_cards_dict.items():
        player_hole_cards = [convert_to_deuces_format(card) for card in cards.split()]
        print(player_hole_cards)
        if len(player_hole_cards)+ len(community_cards) not in [2, 5, 6, 7]:
            continue
        
        preflop_strength = evaluate2(cards)
        preflop_features.append(preflop_strength)
        preflop_labels.append(1 if player in winners else 0)
        
        if(len(community_cards) == 0):
            continue
        
        hand_strength = evaluator.evaluate(player_hole_cards, community_cards)
        features.append(hand_strength)
        labels.append(1 if player in winners else 0)

[16812055, 67144223]
[4199953, 295429]
[8423187, 147715]
[268446761, 270853]
[16787479, 529159]
[1082379, 2102541]
[67119647, 557831]
[8398611, 1057803]
[134253349, 33564957]
[4204049, 67144223]
[2106637, 134228773]
[134224677, 164099]
[33589533, 98306]
[81922, 67119647]
[8406803, 33564957]
[2131213, 4204049]
[69634, 295429]
[1053707, 134224677]
[4212241, 4199953]
[73730, 134253349]
[266757, 8394515]
[134228773, 67119647]
[69634, 2114829]
[2106637, 33564957]
[16783383, 268471337]
[98306, 557831]
[2114829, 164099]
[2102541, 4212241]
[81922, 268454953]
[266757, 1053707]
[98306, 268471337]
[4204049, 69634]
[134253349, 33573149]
[33560861, 2102541]
[1082379, 268446761]
[268454953, 557831]
[8423187, 139523]
[2131213, 4204049]
[2106637, 16812055]
[33589533, 164099]
[135427, 98306]
[268442665, 147715]
[134224677, 134253349]
[2102541, 295429]
[16787479, 33564957]
[164099, 268471337]
[8398611, 2106637]
[134236965, 4228625]
[69634, 147715]
[67144223, 33589533]
[2114829, 295429]
[134253349, 10659

In [3]:
X = np.array(preflop_features).reshape(-1, 1)
y = np.array(preflop_labels)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
model = LogisticRegression(class_weight='balanced')
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
print(classification_report(y_test, y_pred, zero_division=0))
print(confusion_matrix(y_test, y_pred))

Accuracy: 0.7404444444444445
              precision    recall  f1-score   support

           0       0.92      0.75      0.83     15008
           1       0.35      0.67      0.46      2992

    accuracy                           0.74     18000
   macro avg       0.64      0.71      0.65     18000
weighted avg       0.83      0.74      0.77     18000

[[11310  3698]
 [  974  2018]]


In [4]:
X = np.array(features).reshape(-1, 1)
y = np.array(labels)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
model = LogisticRegression(class_weight='balanced')
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
print(classification_report(y_test, y_pred, zero_division=0))
print(confusion_matrix(y_test, y_pred))

Accuracy: 0.6219169528566968
              precision    recall  f1-score   support

           0       0.90      0.62      0.73      8012
           1       0.25      0.64      0.36      1597

    accuracy                           0.62      9609
   macro avg       0.57      0.63      0.55      9609
weighted avg       0.79      0.62      0.67      9609

[[4959 3053]
 [ 580 1017]]
