# Initialisation

The following cells are for initialisation (loading the model and dataset) and should always be run at the start of a session.

In [1]:
from google.colab import drive
drive.mount('/content/drive')
cwd = '/content/drive/MyDrive/0258_poker_project_personal'

Mounted at /content/drive


In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [None]:
import random
import torch
import csv
from tqdm import tqdm
import pandas as pd
import json

In [None]:
from unsloth import FastLanguageModel

max_seq_length = 8192  # was 4096

model_name = "Meta-Llama-3.2-3B-Instruct"
# model_name = "Meta-Llama-3.1-8B"

model, tokenizer = FastLanguageModel.from_pretrained(
    # model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit",
    # model_name = "unsloth/Meta-Llama-3.1-8B",  # change this for different models
    model_name = "unsloth/Llama-3.2-3B-Instruct",
    max_seq_length = max_seq_length,
    dtype = None,  # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
    load_in_4bit = True,  # to reduce memory usage
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

EOS_TOKEN = tokenizer.eos_token

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.50.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.35G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.7k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

In [None]:
# add LoRA adapters - we aren't finetuning but it seems like this is still needed for inference

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth 2025.3.19 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


In [None]:
# load data from https://huggingface.co/datasets/RZ412/PokerBench. format is {'instruction': ..., 'output': ...}.

from datasets import load_dataset

train_ds = load_dataset("RZ412/PokerBench", split="train")

# separate out a validation set to test few-shot prompts on
train_val_split = train_ds.train_test_split(test_size=0.0001)

# read test set as pre-flop and post-flop
dataset_dir = f"{cwd}/datasets"

with open(f'{dataset_dir}/postflop_10k_test_set_prompt_and_label.json', 'r') as f:
  postflop_test_set = json.load(f)

with open(f'{dataset_dir}/preflop_1k_test_set_prompt_and_label.json', 'r') as f:
  preflop_test_set = json.load(f)

README.md:   0%|          | 0.00/5.16k [00:00<?, ?B/s]

(…)lop_500k_train_set_prompt_and_label.json:   0%|          | 0.00/561M [00:00<?, ?B/s]

(…)flop_60k_train_set_prompt_and_label.json:   0%|          | 0.00/59.2M [00:00<?, ?B/s]

(…)tflop_10k_test_set_prompt_and_label.json:   0%|          | 0.00/11.2M [00:00<?, ?B/s]

(…)reflop_1k_test_set_prompt_and_label.json:   0%|          | 0.00/921k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/563200 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11000 [00:00<?, ? examples/s]

In [None]:
# split datasets by action

fold_examples = train_val_split["train"].filter(lambda example: example['output'].startswith('fold'))
raise_examples = train_val_split["train"].filter(lambda example: example['output'].startswith('raise'))
bet_examples = train_val_split["train"].filter(lambda example: example['output'].startswith('bet'))
check_examples = train_val_split["train"].filter(lambda example: example['output'].startswith('check'))
call_examples = train_val_split["train"].filter(lambda example: example['output'].startswith('call'))

Filter:   0%|          | 0/563143 [00:00<?, ? examples/s]

Filter:   0%|          | 0/563143 [00:00<?, ? examples/s]

Filter:   0%|          | 0/563143 [00:00<?, ? examples/s]

Filter:   0%|          | 0/563143 [00:00<?, ? examples/s]

Filter:   0%|          | 0/563143 [00:00<?, ? examples/s]

# Utility functions for running the model and constructing few-shot prompts

The following cells define utility functions for sampling outputs from the model and constructing a few-shot prompt using examples sampled from the training set.

In [None]:
# functions to run inference on one prompt

from unsloth.chat_templates import get_chat_template

FastLanguageModel.for_inference(model) # Enable native 2x faster inference

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)

def clean_outputs(output):
  # calling generate on the llama model seems to always return a certain format
  # that includes the prompt so we have to extract the output.
  output = output.replace(EOS_TOKEN, "")

  output_only = output.split("<|start_header_id|>assistant<|end_header_id|>")[-1]  # get back prompt response only
  output_only = output_only.split("optimal action is:")[-1]  # try to extract the answered action
  output_only = output_only.strip()
  return output_only


def run_inference_on_one_prompt(prompt):
  messages = [
      {"role": "user", "content": prompt},
  ]
  inputs = tokenizer.apply_chat_template(
      messages,
      tokenize = True,
      add_generation_prompt = True, # Must add for generation
      return_tensors = "pt",
  ).to("cuda")

  outputs = tokenizer.batch_decode(
      model.generate(input_ids = inputs, max_new_tokens = 8, use_cache = True,
                     temperature = 1.5, min_p = 0.1)
  )

  cleaned_outputs = map(clean_outputs, outputs)
  return list(cleaned_outputs)[0]

In [None]:
# prompt composition and formatting, where a fixed ordering is used.

alpaca_prompt = "### Scenario:\n{} {}"

EOS_TOKEN = tokenizer.eos_token

# sample a set of examples to use in all few-shot prompts

fold_example = fold_examples[random.randrange(len(fold_examples))]
raise_example = raise_examples[random.randrange(len(raise_examples))]
bet_example = bet_examples[random.randrange(len(bet_examples))]
check_example = check_examples[random.randrange(len(check_examples))]
call_example = call_examples[random.randrange(len(call_examples))]

examples = [fold_example, raise_example, bet_example, check_example, call_example]

prompt_context = "You are a specialist in playing 6-handed No Limit Texas Holdem. Given a game scenario, you need to make the optimal decision."

for example in examples:
  prompt_context += "\n\n" + alpaca_prompt.format(example['instruction'].split("Here is a game summary:")[1].strip(), example['output']) + EOS_TOKEN


def compose_fewshot_prompt(query):
  return prompt_context + "\n\n" + alpaca_prompt.format(query['instruction'].split("Here is a game summary:")[1].strip(), "")


In [None]:
prompt = compose_fewshot_prompt(query=train_ds[0])

run_inference_on_one_prompt(prompt)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


'check'

# Evaluation functionality

The following functions are for evaluating the model on a test set. The first (`run_testset_inference_no_csv`) returns the predictions and ground truths as a dataframe, to facilitate integrated evaluation when selecting the most performant prompt ordering. The second (`run_testset_inference`) saves the predictions to a csv so that inference can be resumed if interrupted. Evaluation of results produced by this function should be evaluated using the separate `output_modify.ipynb` notebook.

In [None]:
# function to run inference on a test set and store the results in a dataframe for direct evaluation.

def run_testset_inference_no_csv(test_set):
  results = []
  for i in tqdm(range(len(test_set))):
    query = test_set[i]
    prompt = compose_fewshot_prompt(query)
    pred = run_inference_on_one_prompt(prompt)

    results.append({"Ground Truth": query["output"], "Prediction": pred})

  return pd.DataFrame(results)

In [None]:
# function to run inference on a test set and store the results in a csv for later evaluation.

def run_testset_inference(test_set, division):
  # division = "preflop" or "postflop"

  csv_name = f"{cwd}/fewshot_{model_name}_{division}_predictions.csv"

  with open(csv_name, 'a') as preds_file:
    writer = csv.writer(preds_file)
    writer.writerow(["Ground Truth", "Prediction"])  # name the columns this for compatibility w evaluation code

    for i in tqdm(range(4312, len(test_set))):
      query = test_set[i]
      prompt = compose_fewshot_prompt(query)
      pred = run_inference_on_one_prompt(prompt)

      writer.writerow([query["output"], pred])

In [None]:
# FUNCTIONS FOR DIRECT EVALUATION - adapted from the output_modify notebook

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score
import argparse
import os


def clean_prediction(text):
    """Clean the prediction text by removing <end_of_turn> and <pad> tags."""
    text = str(text)
    text = text.replace("<end_of_turn>\n<pad>", "").strip()
    text = text.replace("<end_of_turn>", "").strip()
    text = text.replace(" chips", "").strip()
    text = text.replace(".", "").strip()

    return text.lower()

def extract_action(move):
    """Extract the action (call, fold, check, raise) from a move."""
    if pd.isna(move):
        return np.nan
    if "raise" in move:
        return "raise"
    elif "call" in move:
        return "call"
    elif "check" in move:
        return "check"
    elif "fold" in move:
        return "fold"
    elif "bet" in move:
        return "bet"
    else:
        return move


def extract_amount(move):
    """Extract the bet amount from a raise or bet move."""
    if pd.isna(move):
        return np.nan
    if "raise" in move:
        try:
            return float(move.split("raise")[1].strip())
        except:
            return 0
    elif "bet" in move:
        try:
            return float(move.split("bet")[1].strip())
        except:
            return 0
    else:
        return 0


def evaluate(df):
    # df: validation set as a dataframe

    # Clean predictions
    df['Prediction_Clean'] = df['Prediction'].apply(clean_prediction)
    df['Ground_Truth_Clean'] = df['Ground Truth']

    # Extract action and amount
    df['Pred_Action'] = df['Prediction_Clean'].apply(extract_action)
    df['True_Action'] = df['Ground_Truth_Clean'].apply(extract_action)
    df['Pred_Amount'] = df['Prediction_Clean'].apply(extract_amount)
    df['True_Amount'] = df['Ground_Truth_Clean'].apply(extract_amount)

    # Calculate overall accuracy for the action type
    action_accuracy = accuracy_score(df['True_Action'], df['Pred_Action'])
    print(f"Action prediction accuracy: {action_accuracy:.4f}")

    # define actions
    actions = ['fold', 'check', 'call', 'raise', 'bet']

    # Calculate per-action accuracy
    action_specific_accuracy = {}
    for action in actions:
        action_rows = df[df['True_Action'] == action]
        if len(action_rows) > 0:
            correct = sum(action_rows['Pred_Action'] == action)
            accuracy = correct / len(action_rows)
            action_specific_accuracy[action] = accuracy
            # print(
            #     f"Accuracy for {action}: {accuracy:.4f} ({correct}/{len(action_rows)})")

    # Analyze raise and bet amount accuracy
    # For raises
    raise_df = df[(df['True_Action'] == 'raise') &
                  (df['Pred_Action'] == 'raise')]
    raise_df = raise_df.dropna(subset=['True_Amount', 'Pred_Amount'])

    # For bets
    bet_df = df[(df['True_Action'] == 'bet') &
                (df['Pred_Action'] == 'bet')]
    bet_df = bet_df.dropna(subset=['True_Amount', 'Pred_Amount'])

    return action_accuracy

# Prompt selection: test all possible example orderings

The following cell tests all possible permutations of the training set examples, using their performance on a separate validation set, and selects the best-performing one.

In [None]:
# select an ordering for the few-shot prompt based on performance on a validation set.

import itertools

# # sample a set of examples to use in all few-shot prompts
# fold_example = fold_examples[random.randrange(len(fold_examples))]
# raise_example = raise_examples[random.randrange(len(raise_examples))]
# bet_example = bet_examples[random.randrange(len(bet_examples))]
# check_example = check_examples[random.randrange(len(check_examples))]
# call_example = call_examples[random.randrange(len(call_examples))]

examples = [fold_example, raise_example, bet_example, check_example, call_example]

best_ordering, best_accuracy = None, -float('inf')

# since testing all 150 permutations is quite slow, we test a random sample of 20 permutations.
# sample_permutations = random.sample(list(itertools.permutations(examples)), 20)

for i, ordering in enumerate(itertools.permutations(examples)):
  print("Running inference with ordering", i)
  prompt_context = "You are a specialist in playing 6-handed No Limit Texas Holdem. Given a game scenario, you need to make the optimal decision."

  for example in ordering:
    prompt_context += "\n\n" + alpaca_prompt.format(example['instruction'].split("Here is a game summary:")[1].strip(), example['output']) + EOS_TOKEN

  results = run_testset_inference_no_csv(train_val_split["test"])
  accuracy = evaluate(results)

  if accuracy > best_accuracy:
    best_ordering = ordering
    best_accuracy = accuracy


Running inference with ordering 0


100%|██████████| 57/57 [00:39<00:00,  1.43it/s]


Action prediction accuracy: 0.3158
Running inference with ordering 1


100%|██████████| 57/57 [00:39<00:00,  1.46it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 2


100%|██████████| 57/57 [00:37<00:00,  1.50it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 3


100%|██████████| 57/57 [00:37<00:00,  1.53it/s]


Action prediction accuracy: 0.4211
Running inference with ordering 4


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.4035
Running inference with ordering 5


100%|██████████| 57/57 [00:37<00:00,  1.53it/s]


Action prediction accuracy: 0.4035
Running inference with ordering 6


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.4035
Running inference with ordering 7


100%|██████████| 57/57 [00:38<00:00,  1.49it/s]


Action prediction accuracy: 0.2807
Running inference with ordering 8


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 9


100%|██████████| 57/57 [00:38<00:00,  1.49it/s]


Action prediction accuracy: 0.4035
Running inference with ordering 10


100%|██████████| 57/57 [00:39<00:00,  1.46it/s]


Action prediction accuracy: 0.3158
Running inference with ordering 11


100%|██████████| 57/57 [00:37<00:00,  1.52it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 12


100%|██████████| 57/57 [00:37<00:00,  1.51it/s]


Action prediction accuracy: 0.4386
Running inference with ordering 13


100%|██████████| 57/57 [00:37<00:00,  1.52it/s]


Action prediction accuracy: 0.4737
Running inference with ordering 14


100%|██████████| 57/57 [00:37<00:00,  1.53it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 15


100%|██████████| 57/57 [00:37<00:00,  1.53it/s]


Action prediction accuracy: 0.4211
Running inference with ordering 16


100%|██████████| 57/57 [00:37<00:00,  1.51it/s]


Action prediction accuracy: 0.3860
Running inference with ordering 17


100%|██████████| 57/57 [00:37<00:00,  1.50it/s]


Action prediction accuracy: 0.4035
Running inference with ordering 18


100%|██████████| 57/57 [00:38<00:00,  1.49it/s]


Action prediction accuracy: 0.2632
Running inference with ordering 19


100%|██████████| 57/57 [00:37<00:00,  1.52it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 20


100%|██████████| 57/57 [00:37<00:00,  1.53it/s]


Action prediction accuracy: 0.2456
Running inference with ordering 21


100%|██████████| 57/57 [00:37<00:00,  1.51it/s]


Action prediction accuracy: 0.3509
Running inference with ordering 22


100%|██████████| 57/57 [00:37<00:00,  1.53it/s]


Action prediction accuracy: 0.3860
Running inference with ordering 23


100%|██████████| 57/57 [00:37<00:00,  1.52it/s]


Action prediction accuracy: 0.3860
Running inference with ordering 24


100%|██████████| 57/57 [00:43<00:00,  1.30it/s]


Action prediction accuracy: 0.2807
Running inference with ordering 25


100%|██████████| 57/57 [00:42<00:00,  1.34it/s]


Action prediction accuracy: 0.2632
Running inference with ordering 26


100%|██████████| 57/57 [00:42<00:00,  1.35it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 27


100%|██████████| 57/57 [00:42<00:00,  1.33it/s]


Action prediction accuracy: 0.3509
Running inference with ordering 28


100%|██████████| 57/57 [00:42<00:00,  1.35it/s]


Action prediction accuracy: 0.3158
Running inference with ordering 29


100%|██████████| 57/57 [00:43<00:00,  1.32it/s]


Action prediction accuracy: 0.3684
Running inference with ordering 30


100%|██████████| 57/57 [00:41<00:00,  1.36it/s]


Action prediction accuracy: 0.4035
Running inference with ordering 31


100%|██████████| 57/57 [00:42<00:00,  1.36it/s]


Action prediction accuracy: 0.3158
Running inference with ordering 32


100%|██████████| 57/57 [00:42<00:00,  1.33it/s]


Action prediction accuracy: 0.2982
Running inference with ordering 33


100%|██████████| 57/57 [00:42<00:00,  1.34it/s]


Action prediction accuracy: 0.2807
Running inference with ordering 34


100%|██████████| 57/57 [00:41<00:00,  1.37it/s]


Action prediction accuracy: 0.3158
Running inference with ordering 35


100%|██████████| 57/57 [00:42<00:00,  1.34it/s]


Action prediction accuracy: 0.3509
Running inference with ordering 36


100%|██████████| 57/57 [00:41<00:00,  1.37it/s]


Action prediction accuracy: 0.3509
Running inference with ordering 37


100%|██████████| 57/57 [00:40<00:00,  1.41it/s]


Action prediction accuracy: 0.3509
Running inference with ordering 38


100%|██████████| 57/57 [00:41<00:00,  1.36it/s]


Action prediction accuracy: 0.3684
Running inference with ordering 39


100%|██████████| 57/57 [00:41<00:00,  1.36it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 40


100%|██████████| 57/57 [00:40<00:00,  1.42it/s]


Action prediction accuracy: 0.4561
Running inference with ordering 41


100%|██████████| 57/57 [00:41<00:00,  1.38it/s]


Action prediction accuracy: 0.3158
Running inference with ordering 42


100%|██████████| 57/57 [00:41<00:00,  1.37it/s]


Action prediction accuracy: 0.3158
Running inference with ordering 43


100%|██████████| 57/57 [00:41<00:00,  1.38it/s]


Action prediction accuracy: 0.4035
Running inference with ordering 44


100%|██████████| 57/57 [00:42<00:00,  1.35it/s]


Action prediction accuracy: 0.1930
Running inference with ordering 45


100%|██████████| 57/57 [00:42<00:00,  1.33it/s]


Action prediction accuracy: 0.2456
Running inference with ordering 46


100%|██████████| 57/57 [00:40<00:00,  1.40it/s]


Action prediction accuracy: 0.4912
Running inference with ordering 47


100%|██████████| 57/57 [00:41<00:00,  1.38it/s]


Action prediction accuracy: 0.4561
Running inference with ordering 48


100%|██████████| 57/57 [00:40<00:00,  1.42it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 49


100%|██████████| 57/57 [00:41<00:00,  1.39it/s]


Action prediction accuracy: 0.4386
Running inference with ordering 50


100%|██████████| 57/57 [00:41<00:00,  1.37it/s]


Action prediction accuracy: 0.2281
Running inference with ordering 51


100%|██████████| 57/57 [00:40<00:00,  1.40it/s]


Action prediction accuracy: 0.2982
Running inference with ordering 52


100%|██████████| 57/57 [00:41<00:00,  1.37it/s]


Action prediction accuracy: 0.2807
Running inference with ordering 53


100%|██████████| 57/57 [00:41<00:00,  1.38it/s]


Action prediction accuracy: 0.4211
Running inference with ordering 54


100%|██████████| 57/57 [00:41<00:00,  1.38it/s]


Action prediction accuracy: 0.1930
Running inference with ordering 55


100%|██████████| 57/57 [00:41<00:00,  1.37it/s]


Action prediction accuracy: 0.2281
Running inference with ordering 56


100%|██████████| 57/57 [00:41<00:00,  1.38it/s]


Action prediction accuracy: 0.3860
Running inference with ordering 57


100%|██████████| 57/57 [00:42<00:00,  1.34it/s]


Action prediction accuracy: 0.3509
Running inference with ordering 58


100%|██████████| 57/57 [00:41<00:00,  1.36it/s]


Action prediction accuracy: 0.2807
Running inference with ordering 59


100%|██████████| 57/57 [00:41<00:00,  1.37it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 60


100%|██████████| 57/57 [00:41<00:00,  1.39it/s]


Action prediction accuracy: 0.3158
Running inference with ordering 61


100%|██████████| 57/57 [00:40<00:00,  1.40it/s]


Action prediction accuracy: 0.2632
Running inference with ordering 62


100%|██████████| 57/57 [00:42<00:00,  1.35it/s]


Action prediction accuracy: 0.3158
Running inference with ordering 63


100%|██████████| 57/57 [00:41<00:00,  1.38it/s]


Action prediction accuracy: 0.3684
Running inference with ordering 64


100%|██████████| 57/57 [00:40<00:00,  1.42it/s]


Action prediction accuracy: 0.4035
Running inference with ordering 65


100%|██████████| 57/57 [00:41<00:00,  1.36it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 66


100%|██████████| 57/57 [00:40<00:00,  1.41it/s]


Action prediction accuracy: 0.2982
Running inference with ordering 67


100%|██████████| 57/57 [00:39<00:00,  1.43it/s]


Action prediction accuracy: 0.2632
Running inference with ordering 68


100%|██████████| 57/57 [00:41<00:00,  1.37it/s]


Action prediction accuracy: 0.3158
Running inference with ordering 69


100%|██████████| 57/57 [00:41<00:00,  1.37it/s]


Action prediction accuracy: 0.2807
Running inference with ordering 70


100%|██████████| 57/57 [00:39<00:00,  1.44it/s]


Action prediction accuracy: 0.4211
Running inference with ordering 71


100%|██████████| 57/57 [00:41<00:00,  1.38it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 72


100%|██████████| 57/57 [00:38<00:00,  1.49it/s]


Action prediction accuracy: 0.4912
Running inference with ordering 73


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.4912
Running inference with ordering 74


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.5263
Running inference with ordering 75


100%|██████████| 57/57 [00:38<00:00,  1.50it/s]


Action prediction accuracy: 0.4386
Running inference with ordering 76


100%|██████████| 57/57 [00:37<00:00,  1.51it/s]


Action prediction accuracy: 0.4737
Running inference with ordering 77


100%|██████████| 57/57 [00:37<00:00,  1.51it/s]


Action prediction accuracy: 0.4737
Running inference with ordering 78


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.4035
Running inference with ordering 79


100%|██████████| 57/57 [00:38<00:00,  1.50it/s]


Action prediction accuracy: 0.4386
Running inference with ordering 80


100%|██████████| 57/57 [00:39<00:00,  1.46it/s]


Action prediction accuracy: 0.4737
Running inference with ordering 81


100%|██████████| 57/57 [00:38<00:00,  1.50it/s]


Action prediction accuracy: 0.5088
Running inference with ordering 82


100%|██████████| 57/57 [00:37<00:00,  1.52it/s]


Action prediction accuracy: 0.4561
Running inference with ordering 83


100%|██████████| 57/57 [00:39<00:00,  1.46it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 84


100%|██████████| 57/57 [00:38<00:00,  1.47it/s]


Action prediction accuracy: 0.3860
Running inference with ordering 85


100%|██████████| 57/57 [00:38<00:00,  1.49it/s]


Action prediction accuracy: 0.4737
Running inference with ordering 86


100%|██████████| 57/57 [00:38<00:00,  1.49it/s]


Action prediction accuracy: 0.3860
Running inference with ordering 87


100%|██████████| 57/57 [00:38<00:00,  1.49it/s]


Action prediction accuracy: 0.5088
Running inference with ordering 88


100%|██████████| 57/57 [00:38<00:00,  1.49it/s]


Action prediction accuracy: 0.5614
Running inference with ordering 89


100%|██████████| 57/57 [00:38<00:00,  1.47it/s]


Action prediction accuracy: 0.4737
Running inference with ordering 90


100%|██████████| 57/57 [00:37<00:00,  1.51it/s]


Action prediction accuracy: 0.5614
Running inference with ordering 91


100%|██████████| 57/57 [00:37<00:00,  1.52it/s]


Action prediction accuracy: 0.5263
Running inference with ordering 92


100%|██████████| 57/57 [00:37<00:00,  1.53it/s]


Action prediction accuracy: 0.4386
Running inference with ordering 93


100%|██████████| 57/57 [00:38<00:00,  1.50it/s]


Action prediction accuracy: 0.3509
Running inference with ordering 94


100%|██████████| 57/57 [00:37<00:00,  1.51it/s]


Action prediction accuracy: 0.4737
Running inference with ordering 95


100%|██████████| 57/57 [00:38<00:00,  1.49it/s]


Action prediction accuracy: 0.3860
Running inference with ordering 96


100%|██████████| 57/57 [00:39<00:00,  1.45it/s]


Action prediction accuracy: 0.3860
Running inference with ordering 97


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.4386
Running inference with ordering 98


100%|██████████| 57/57 [00:39<00:00,  1.46it/s]


Action prediction accuracy: 0.3158
Running inference with ordering 99


100%|██████████| 57/57 [00:41<00:00,  1.37it/s]


Action prediction accuracy: 0.2807
Running inference with ordering 100


100%|██████████| 57/57 [00:39<00:00,  1.45it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 101


100%|██████████| 57/57 [00:40<00:00,  1.42it/s]


Action prediction accuracy: 0.3509
Running inference with ordering 102


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.4912
Running inference with ordering 103


100%|██████████| 57/57 [00:37<00:00,  1.51it/s]


Action prediction accuracy: 0.4737
Running inference with ordering 104


100%|██████████| 57/57 [00:38<00:00,  1.49it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 105


100%|██████████| 57/57 [00:37<00:00,  1.50it/s]


Action prediction accuracy: 0.3333
Running inference with ordering 106


100%|██████████| 57/57 [00:36<00:00,  1.54it/s]


Action prediction accuracy: 0.5263
Running inference with ordering 107


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.3860
Running inference with ordering 108


100%|██████████| 57/57 [00:38<00:00,  1.49it/s]


Action prediction accuracy: 0.2807
Running inference with ordering 109


100%|██████████| 57/57 [00:39<00:00,  1.46it/s]


Action prediction accuracy: 0.3509
Running inference with ordering 110


100%|██████████| 57/57 [00:37<00:00,  1.52it/s]


Action prediction accuracy: 0.2982
Running inference with ordering 111


100%|██████████| 57/57 [00:39<00:00,  1.44it/s]


Action prediction accuracy: 0.4211
Running inference with ordering 112


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.4035
Running inference with ordering 113


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.3860
Running inference with ordering 114


100%|██████████| 57/57 [00:37<00:00,  1.51it/s]


Action prediction accuracy: 0.4912
Running inference with ordering 115


100%|██████████| 57/57 [00:36<00:00,  1.54it/s]


Action prediction accuracy: 0.4561
Running inference with ordering 116


100%|██████████| 57/57 [00:37<00:00,  1.51it/s]


Action prediction accuracy: 0.4211
Running inference with ordering 117


100%|██████████| 57/57 [00:38<00:00,  1.48it/s]


Action prediction accuracy: 0.2807
Running inference with ordering 118


100%|██████████| 57/57 [00:37<00:00,  1.53it/s]


Action prediction accuracy: 0.3509
Running inference with ordering 119


100%|██████████| 57/57 [00:37<00:00,  1.52it/s]

Action prediction accuracy: 0.2456





The following three cells are for setting the prompt context to an ordering that was empirically found to be the best. If the preceding cell was previously run, then run the cell that saves the prompt; otherwise, run the cell that loads a previously selected prompt. The third cell then sets the `prompt_context` variable to the desired prompt.

In [None]:
# save the found prompt

with open(os.path.join(cwd, f"best_fewshot_prompt_{model_name}.json"), "w") as f:
  json.dump(best_ordering, f)

In [10]:
# alternatively, load a previously-saved prompt

with open(os.path.join(cwd, f"best_fewshot_prompt_{model_name}.json"), "r") as f:
  # print(f.readline())
  best_ordering = json.load(f)

In [11]:
# set the prompting context to the one that was found to perform best.

prompt_context = "You are a specialist in playing 6-handed No Limit Texas Holdem. Given a game scenario, you need to make the optimal decision."

for example in best_ordering:
  prompt_context += "\n\n" + alpaca_prompt.format(example['instruction'].split("Here is a game summary:")[1].strip(), example['output']) + EOS_TOKEN

# Run inference

The following function calls run inference on a given test set (post-flop or pre-flop). Note that the prompt context must have been set elsewhere in the code first, whether it is one with a fixed ordering or a strategically-selected ordering.

In [None]:
run_testset_inference(postflop_test_set, "postflop")

  0%|          | 0/5688 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 5688/5688 [1:02:30<00:00,  1.52it/s]


In [None]:
run_testset_inference(preflop_test_set, "preflop")

100%|██████████| 1000/1000 [10:58<00:00,  1.52it/s]
