#  Data Preparation: Converting the Policy Tree to a Dataset

In [5]:
%pip install --quiet transformers accelerate bitsandbytes sentence-transformers 
%pip install -U --quiet outlines  peft  trl

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [10]:
import json
from huggingface_hub import hf_hub_download
from datasets import load_dataset, Dataset, concatenate_datasets

# --- Load your optimal policy tree from the Hub ---
repo_id = "shahriar7/optimal-20-question-game-policy-tree"
downloaded_file_path = hf_hub_download(
    repo_id=repo_id,
    filename="optimal-20-question-game-policy-tree.json",
    repo_type="dataset"
)

with open(downloaded_file_path, 'r') as f:
    policy_tree = json.load(f)

print("Successfully loaded the policy tree.")

# --- Define System Prompts for both tasks ---
SYSTEM_PROMPT_QUESTION = "You are a 20 Questions player. Your task is to ask the most informative next yes/no question based on the history. The keyword is a simple single-word noun. Start broad and narrow down systematically."
SYSTEM_PROMPT_GUESS = "You are a 20 Questions player. Based on the history, your task is to make the best guess for the secret word. Respond with only the single-word noun."

# --- Modified function to create datasets for both questions and guesses ---
def traverse_and_create_datasets(node, history, question_rows, guess_rows, max_depth, current_depth=0):
    # Stop if the current depth exceeds the maximum depth
    if current_depth >= max_depth:
        return

    # 1. Create data for ASKING A QUESTION (if there's a question at this node)
    if 'question' in node:
        question = node['question']
        
        # Create prompt for asking a question
        prompt_ask = [
            {"role": "system", "content": SYSTEM_PROMPT_QUESTION},
            {"role": "user", "content": f"Game history:\n{''.join(history) if history else 'No history yet.'}\n\nOnly output a single yes/no question."}
        ]
        question_rows.append({"prompt": prompt_ask, "completion": question})

        # 2. Recurse and create data for MAKING A GUESS
        # For the 'yes' branch
        if 'answer_yes' in node and 'guess' in node['answer_yes']:
            guess_yes = node['answer_yes']['guess']
            if guess_yes: # Ensure there is a guess to be made
                new_history_yes = history + [f"Q: {question}\nA: yes\n"]
                prompt_guess_yes = [
                    {"role": "system", "content": SYSTEM_PROMPT_GUESS},
                    {"role": "user", "content": f"Game history:\n{''.join(new_history_yes)}\n\nWhat is your single-word guess?"}
                ]
                guess_rows.append({"prompt": prompt_guess_yes, "completion": guess_yes})
                # Continue traversal
                traverse_and_create_datasets(node['answer_yes'], new_history_yes, question_rows, guess_rows, max_depth, current_depth + 1)
        
        # For the 'no' branch
        if 'answer_no' in node and 'guess' in node['answer_no']:
            guess_no = node['answer_no']['guess']
            if guess_no: # Ensure there is a guess to be made
                new_history_no = history + [f"Q: {question}\nA: no\n"]
                prompt_guess_no = [
                    {"role": "system", "content": SYSTEM_PROMPT_GUESS},
                    {"role": "user", "content": f"Game history:\n{''.join(new_history_no)}\n\nWhat is your single-word guess?"}
                ]
                guess_rows.append({"prompt": prompt_guess_no, "completion": guess_no})
                # Continue traversal
                traverse_and_create_datasets(node['answer_no'], new_history_no, question_rows, guess_rows, max_depth, current_depth + 1)

# --- Start the traversal ---
question_sft_rows = []
guess_sft_rows = []
MAX_TRAVERSAL_DEPTH = 11 # Set your desired depth here
traverse_and_create_datasets(policy_tree, [], question_sft_rows, guess_sft_rows, max_depth=MAX_TRAVERSAL_DEPTH)

# --- Convert lists to Hugging Face Datasets ---
question_dataset = Dataset.from_list(question_sft_rows)
guess_dataset = Dataset.from_list(guess_sft_rows)

print(f"Generated {len(question_dataset)} samples for asking questions.")
print(f"Generated {len(guess_dataset)} samples for making guesses.")

# --- Combine both datasets for multi-task fine-tuning ---
combined_dataset = concatenate_datasets([question_dataset, guess_dataset]).shuffle(seed=42)

print(f"\nTotal combined samples: {len(combined_dataset)}.")

# --- Split the final dataset for training and evaluation ---
dataset_splits = combined_dataset.train_test_split(test_size=0.05, seed=42)
train_ds = dataset_splits["train"]
eval_ds = dataset_splits["test"]

print(f"Final Training samples: {len(train_ds)}, Final Evaluation samples: {len(eval_ds)}")

(…)ptimal-20-question-game-policy-tree.json:   0%|          | 0.00/10.2M [00:00<?, ?B/s]

Successfully loaded the policy tree.
Generated 2046 samples for asking questions.
Generated 4092 samples for making guesses.

Total combined samples: 6138.
Final Training samples: 5831, Final Evaluation samples: 307


# Supervised Fine-Tuning (SFT) Setup

In [11]:
import torch
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainerCallback
from trl import SFTTrainer, SFTConfig
from huggingface_hub import login

# --- Login and Configuration ---
login(token="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") # Replace with your Hugging Face token

model_name = "mistralai/Mistral-7B-Instruct-v0.3"

# --- QLoRA Configuration (4-bit quantization) ---
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# --- LoRA Configuration ---
peft_config = LoraConfig(
    lora_alpha=8,
    lora_dropout=0.05,
    r=16,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]# ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] 
)

# --- Load Base Model ---
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)
model.config.use_cache = False

# --- Load Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# --- Formatting Function for the Trainer ---
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['prompt'])):
        # Apply the chat template to the prompt messages
        text = tokenizer.apply_chat_template(example['prompt'][i], tokenize=False)
        # Append the completion
        text += example['completion'][i]
        output_texts.append(text)
    return output_texts

2025-08-26 09:49:33.040030: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756201773.063808     113 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756201773.070810     113 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [12]:
# --- Corrected data formatting to support completion_only_loss ---

def format_for_sft(dataset):
    """
    Applies the chat template to the 'prompt' column, converting it from a list
    of dictionaries to a single string, which is the format SFTTrainer expects
    when using separate prompt and completion columns.
    """
    def process_batch(batch):
        # Apply the chat template to each list of messages in the 'prompt' column
        formatted_prompts = [
            tokenizer.apply_chat_template(
                prompt_messages, 
                tokenize=False,
                add_generation_prompt=False  # The trainer will add the generation prompt
            ) 
            for prompt_messages in batch['prompt']
        ]
        # Return the batch with the 'prompt' column now as a list of strings
        return {"prompt": formatted_prompts, "completion": batch["completion"]}

    # Apply the processing function to the entire dataset
    return dataset.map(
        process_batch, 
        batched=True,
    )

# Apply the formatting to your train and eval datasets
train_ds = format_for_sft(train_ds)
eval_ds = format_for_sft(eval_ds)



Map:   0%|          | 0/5831 [00:00<?, ? examples/s]

Map:   0%|          | 0/307 [00:00<?, ? examples/s]

# Training the Model

In [22]:
# ignore the warnings in the output.

import wandb
wandb.login(key="b3f46c193418884a60db8ae0148c0eed62f957e6") 

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [14]:
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, get_peft_model, PeftModel
hub_model_id = "shahriar7/mistral-7b-20q-SFT-finetuned-on-Optimal-Policy-tree"

training_arguments = SFTConfig(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=1,   # starting guess
    auto_find_batch_size=True,       # enable automatic shrinking on OOM
    gradient_accumulation_steps=1,
    optim="paged_adamw_8bit",  
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=True,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="cosine",
    report_to=["tensorboard", "wandb"], 
    eval_strategy="epoch", 
    save_strategy="epoch",      
    save_total_limit=3,          # keep a few recents locally during the run
    logging_strategy="epoch",   
    push_to_hub=True,
    hub_model_id=hub_model_id,
    hub_strategy="checkpoint", # "every_save",
    hub_private_repo=False,
    
    assistant_only_loss=False,  
    completion_only_loss=True,
    
    max_length=512,            
    packing=False,               
    # dataset_text_field="text", # set this if your JSONL column name isn't "text"
    dataloader_num_workers=0,   
    dataset_num_proc=1,         
)

# optional VRAM saver
# model.gradient_checkpointing_enable()

# 2. Create a custom callback to print the loss
class LossMonitorCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        """
        Event called after logging the evaluation results.
        """
        if logs is not None and "loss" in logs:
            print(f"Step {state.global_step}: Training Loss = {logs['loss']:.4f}")  


wandb.init(
    project="mistral-20q-finetune", 
    config=training_arguments.to_dict(),
)
run_id = wandb.run.id 

if isinstance(model, PeftModel):
    model = model.unload()
    
trainer = SFTTrainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    peft_config=peft_config,
    args=training_arguments,
    processing_class=tokenizer,
    callbacks=[LossMonitorCallback()],
)


Adding EOS to train dataset:   0%|          | 0/5831 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/5831 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/5831 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/307 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/307 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/307 [00:00<?, ? examples/s]

In [None]:
trainer.train()

In [15]:
# Now let's perform SFT:
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint
from huggingface_hub import snapshot_download

def prepare_resume(output_dir: str, repo_id: str):
    # 1) try local (works during a single Kaggle session)
    local = get_last_checkpoint(output_dir)
    if local:
        return local

    # 2) otherwise pull from the Hub (needs "Internet" ON in Kaggle settings)
    snapshot_download(
        repo_id=repo_id,
        local_dir=output_dir, local_dir_use_symlinks=False,
        allow_patterns=["last-checkpoint/*", "checkpoint-*/*"],  
    )
    last = Path(output_dir) / "last-checkpoint"
    if last.exists():
        return str(last)
    ckpts = sorted(Path(output_dir).glob("checkpoint-*"),
                   key=lambda p: int(p.name.split("-")[-1]))
    return str(ckpts[-1]) if ckpts else None


print("--- Starting from last checkpoing ---")
trainer.args.num_train_epochs=4
resume_path = prepare_resume("./results", hub_model_id)
trainer.train(resume_from_checkpoint=resume_path or None)

wandb.finish()

--- Starting from last checkpoing ---


For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]

last-checkpoint/scaler.pt:   0%|          | 0.00/988 [00:00<?, ?B/s]

last-checkpoint/scheduler.pt:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

last-checkpoint/rng_state.pth:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

last-checkpoint/optimizer.pt:   0%|          | 0.00/28.0M [00:00<?, ?B/s]

last-checkpoint/adapter_model.safetensor(…):   0%|          | 0.00/54.6M [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

adapter_config.json:   0%|          | 0.00/896 [00:00<?, ?B/s]

last-checkpoint/tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/437 [00:00<?, ?B/s]

last-checkpoint/training_args.bin:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

trainer_state.json: 0.00B [00:00, ?B/s]

Epoch,Training Loss,Validation Loss
3,1.3116,1.469142
4,0.8079,1.786985


Step 17493: Training Loss = 1.3116
Step 23324: Training Loss = 0.8079


0,1
eval/loss,▁█
eval/mean_token_accuracy,█▁
eval/num_tokens,▁█
eval/runtime,█▁
eval/samples_per_second,▁█
eval/steps_per_second,▁█
train/epoch,▁▁███
train/global_step,▁▁███
train/grad_norm,▁█
train/learning_rate,█▁

0,1
eval/loss,1.78699
eval/mean_token_accuracy,0.68895
eval/num_tokens,2277480.0
eval/runtime,61.3996
eval/samples_per_second,5.0
eval/steps_per_second,0.635
total_flos,1.9479115115200512e+17
train/epoch,4.0
train/global_step,23324.0
train/grad_norm,31.50794


In [None]:
# Now let's perform SFT:
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint
from huggingface_hub import snapshot_download
wandb.init(
    project="mistral-20q-finetune", 
    config=training_arguments.to_dict(),
)
def prepare_resume(output_dir: str, repo_id: str):
    # 1) try local (works during a single Kaggle session)
    local = get_last_checkpoint(output_dir)
    if local:
        return local

    # 2) otherwise pull from the Hub (needs "Internet" ON in Kaggle settings)
    snapshot_download(
        repo_id=repo_id,
        local_dir=output_dir, local_dir_use_symlinks=False,
        allow_patterns=["last-checkpoint/*", "checkpoint-*/*"],  
    )
    last = Path(output_dir) / "last-checkpoint"
    if last.exists():
        return str(last)
    ckpts = sorted(Path(output_dir).glob("checkpoint-*"),
                   key=lambda p: int(p.name.split("-")[-1]))
    return str(ckpts[-1]) if ckpts else None


print("--- Starting from last checkpoing ---")
trainer.args.num_train_epochs=6
resume_path = prepare_resume("./results", hub_model_id)
trainer.train(resume_from_checkpoint=resume_path or None)

wandb.finish()

--- Starting from last checkpoing ---


Epoch,Training Loss,Validation Loss
5,0.805,1.834578


Step 29155: Training Loss = 0.8050


# Evaluattion

In [6]:
%%writefile test.py

import random
import torch
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import login
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)

GAME_WORDS = [
    "cat", "dog", "cow", "horse", "rabbit", "lion", "bear", "shark", "eagle", "ant",
    "apple", "banana", "orange", "carrot", "bread", "cheese", "pizza", "cookie", "egg", "ice-cream",
    "chair", "table", "sofa", "bed", "lamp", "clock", "mirror", "door", "window", "carpet",
    "car", "bicycle", "bus", "train", "airplane", "boat", "rocket", "helmet", "engine", "wheel",
    "pencil", "pen", "book", "paper", "scissors", "ruler", "eraser", "backpack", "laptop", "phone",
    "ball", "doll", "puzzle", "kite", "yo-yo", "drum", "guitar", "camera", "radio", "television",
    "shirt", "pants", "jacket", "hat", "shoes", "gloves", "umbrella", "watch", "glasses",
    "moon", "sun", "star", "cloud", "rain", "snow", "mountain", "river", "ocean", "island",
    "doctor", "teacher", "chef", "farmer", "artist", "pilot", "police", "firefighter", "singer", "dancer",
    "gold", "silver", "iron", "sand", "water", "oil", "soap", "sugar", "salt", "honey"
]

SYSTEM_PROMPT = "You are a precise answering engine. Based on the keyword and question, provide a 'Yes' or 'No' answer."
FEW_SHOT_EXAMPLES = """
[EXAMPLE 1]
Keyword: car
Question: Is it a living thing?
Answer: No
[EXAMPLE 2]
Keyword: water
Question: Is it used for cleaning?
Answer: Yes
[EXAMPLE 3]
Keyword: tree
Question: Is it man-made?
Answer: No
"""

class ValidatorModel:
    
    _model = None
    _tokenizer = None
    _words_dataset = None
    
    def __init__(self):

                
        if ValidatorModel._model is None: # Check if the model has already been loaded (by a previous instantiation)
            print("--- Loading model and tokenizer first time ---")
            login("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", add_to_git_credential=False)
            model_id = "mistralai/Mistral-7B-Instruct-v0.3"

            # Load and assign to the CLASS attributes
            ValidatorModel._tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
            
            bnb_cfg = BitsAndBytesConfig(load_in_4bit=True,
                                         bnb_4bit_compute_dtype=torch.float16,
                                         bnb_4bit_use_double_quant=True,
                                         bnb_4bit_quant_type="nf4")

            ValidatorModel._model = AutoModelForCausalLM.from_pretrained(
                model_id,
                device_map="auto",
                quantization_config=bnb_cfg,
                trust_remote_code=True,
            ).eval()
            ValidatorModel._words_dataset = GAME_WORDS.copy()
            print("--- Model and tokenizer loaded and cached. ---")
        
        # Assign the cached model/tokenizer to this specific instance
        self.model = ValidatorModel._model
        self.tokenizer = ValidatorModel._tokenizer
        self.words_dataset = ValidatorModel._words_dataset
        
        self.keyword = random.choice(self.words_dataset)
        random.choice(self.words_dataset)
        print(f"\n--- Validator secret word: {self.keyword} ---")
        self.words_dataset.remove(self.keyword) # to not ask same word twice

    def _ask_ai(self, messages: List[Dict], max_new=8, temp=0.01) -> str:
        prompt = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        out = self.model.generate(
            **inputs,
            max_new_tokens=max_new,
            temperature=temp,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.eos_token_id
        )
        return self.tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()

    def validate_question(self, question: str) -> str:
        user_content = f"{FEW_SHOT_EXAMPLES}\n[FINAL TASK]\nKeyword: {self.keyword}\nQuestion: {question}\nAnswer:"
        messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_content}]
        model_ans = self._ask_ai(messages)
        return "yes" if "yes" in model_ans.lower() else "no"

    def validate_guess(self, guess: str) -> str:
        return 'Yes' if guess and self.keyword and guess.lower() == self.keyword.lower() else 'No'

Writing test.py


In [9]:
%%writefile evaluate_20Q.py

import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from test import ValidatorModel
from huggingface_hub import login
import re 

# System prompt for the question-asking model
SYSTEM_PROMPT = "You play 20 Questions. Ask the most informative next yes/no question. The keyword is a simple single-word noun known by most people. Start Broad And divide possibilities in half. Then Narrow Systematically. Do not ask the same question twice"

class Agent:
    """
    An AI agent that uses a single fine-tuned model and disables the LoRA
    adapter to revert to base model behavior for guessing.
    """
    def __init__(self, model_id="shahriar7/mistral-7b-20q-SFT-finetuned-on-Optimal-Policy-tree"):
        print("Loading agent model (base + LoRA adapter)...")
        
        # Authenticate with Hugging Face
        login(token="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", add_to_git_credential=False)
        
        # 4-bit quantization configuration
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )
        
        # Load the single model. Transformers will load the base and apply the adapter.
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=bnb_config,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        print("Agent model loaded successfully.")

    def _generate_response(self, model, messages, temperature=0.7, max_tokens=20):
        """Helper function to generate a response from a specific model."""
        prompt = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.tokenizer(prompt, return_tensors="pt").to(model.device)
        
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=0.9,
            pad_token_id=self.tokenizer.eos_token_id
        )
        
        response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        return response.strip()

    def ask_question(self, history):
        """Uses the FINE-TUNED model with an improved strategic prompt."""
        history_str = "\n".join(history) if history else "No history yet. Start with a broad question."

        system_prompt = (
            "You are a 20 Questions player. Your task is to ask the most informative next yes/no question based on the history. The keyword is a simple single-word noun. Start broad and narrow down systematically."
        )
        user_prompt = f"Game history:\n{history_str}\n\nOnly output a single yes/no question."
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
        
        # This call is correct, we'll make make_guess match it
        question = self._generate_response(self.model, messages, temperature=0.7, max_tokens=25)
        question = question.split('\n')[0].strip()

        temperature = 0.7
        for retry in range(4):
            if question in history_str:
                print("repeated question, highering temperature")
                temperature += 0.5
                messages[1]["content"]+= "\n\nATTENTION: Your last question was a repeat. You MUST generate a new, unique question."
                question = self._generate_response(self.model, messages, temperature=temperature, max_tokens=25)
                question = question.split('\n')[0].strip()
            else:
                break
            
        return question

    def make_guess(self, history):
        """Temporarily disables the adapter to use the BASE model for guessing."""
        history_str = "\n".join(history)
        
        system_prompt = (
            "You are a 20 Questions player. Based on the history, your task is to make the best guess for the secret word. Respond with only the single-word noun."
        )
        user_prompt = (
            f"Game history:\n{history_str}\n\nWhat is your single-word guess?"
        )
        
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
        
        try:
            self.model.enable_adapters()
            # --- FIX IS HERE: Pass self.model to the function call ---
            guess = self._generate_response(self.model, messages, temperature=0.1, max_tokens=10)
        finally:
            self.model.enable_adapters()

        guess = guess.strip().split()[0]
        guess = re.sub(r"[^a-zA-Z]", "", guess).lower()
        return guess



def play_game(agent, validator):
    """
    Plays a single game of 20 Questions.
    Returns True if the agent wins, False otherwise.
    """
    history = []
    for turn in range(1, 21):
        print(f"\n--- Turn {turn} ---")
        
        # 1. Agent asks a question
        question = agent.ask_question(history)
        print(f"Agent Question: {question}")

        # 2. Validator provides an answer
        answer = validator.validate_question(question)
        print(f"Validator Answer: {answer}")
        
        # 3. Update history with the question and answer
        history.append(f"Q: {question}\nA: {answer}")

        # 4. Agent makes a guess
        guess = agent.make_guess(history)
        print(f"Agent Guess: {guess}")

        # 5. Validator checks the guess
        is_correct_str = validator.validate_guess(guess)
        
        if "yes" in is_correct_str.lower():
            print(f"Agent guessed correctly! The word was '{guess}'.")
            return True
        else:
            print("Agent's guess was incorrect.")
            
    print(f"Agent failed to guess the word in 20 turns. The word was '{validator.keyword}'.")
    return False

def run_evaluation(num_games):
    """
    Runs the full evaluation for a specified number of games.
    """
    agent = Agent()
    wins = 0

    for i in range(num_games):
        print(f"\n====================\n  Starting Game {i + 1}/{num_games}\n====================")
        # A new validator is created for each game to get a new random word [cite: 78]
        validator = ValidatorModel() 
        if play_game(agent, validator):
            wins += 1
    
    print("\n--- Evaluation Finished ---")
    print(f"Final Score: {wins} / {num_games} wins ({wins/num_games:.2%})")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate a 20 Questions agent.")
    login("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", add_to_git_credential=False)
    parser.add_argument(
        "-N",
        "--num_games",
        type=int,
        required=True,
        help="The number of games to play for the evaluation."
    )
    args = parser.parse_args()
    
    run_evaluation(args.num_games)

Overwriting evaluate_20Q.py


In [10]:
!python evaluate_20Q.py -N 10

Loading agent model (base + LoRA adapter)...
2025-08-26 18:20:12.321231: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756232412.347693     239 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756232412.357011     239 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Loading checkpoint shards: 100%|██████████████████| 3/3 [00:16<00:00,  5.48s/it]
Agent model loaded successfully.

  Starting Game 1/10
--- Loading model and tokenizer first time ---
Loading checkpoint shards: 100%|██████████████████| 3/3 [00:15<00:00,  5.06s/it]
--- Model and tokenizer loaded and cached. ---

--- Validator secret word: eagle ---

--- Turn 1 ---
Agent Question: Is it used 