This notebook illustrates the agent creation process for the **LLM 20 Questions**. Running this notebook produces a `submission.tar.gz` file. You may submit this file directly from the **Submit to competition** heading to the right. Alternatively, from the notebook viewer, click the *Output* tab then find and download `submission.tar.gz`. Click **Submit Agent** at the upper-left of the competition homepage to upload your file and make your submission. 

In [None]:
# Setup the environment
!pip install -q -U immutabledict sentencepiece 

import os
import shutil
import torch
import contextlib

# Remove existing directories if they exist
if os.path.exists('/kaggle/working/gemma_pytorch'):
    shutil.rmtree('/kaggle/working/gemma_pytorch')
if os.path.exists('/kaggle/working/submission/lib/gemma'):
    shutil.rmtree('/kaggle/working/submission/lib/gemma')

# Clone the repository
!git clone https://github.com/google/gemma_pytorch.git > /dev/null

# Create the gemma directory
os.makedirs('/kaggle/working/submission/lib/gemma', exist_ok=True)

# Move the necessary files
shutil.move('/kaggle/working/gemma_pytorch/gemma', '/kaggle/working/submission/lib/gemma')

# Verify the contents of the directory
print("Files in /kaggle/working/submission/lib/gemma:", os.listdir('/kaggle/working/submission/lib/gemma'))

# Set weights directory
weights_dir = '/kaggle/input/gemma/pytorch/7b-it-quant/2/'
print("Files in the weights directory:", os.listdir(weights_dir))

import sys

sys.path.append("/kaggle/working/submission/lib/gemma")

from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

# Load the model
VARIANT = "7b-it-quant"
MACHINE_TYPE = "cpu"

model_config = get_config_for_7b() if "7b" in VARIANT else get_config_for_2b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")

# Verify if tokenizer model file exists
if not os.path.isfile(model_config.tokenizer):
    raise FileNotFoundError(f"Tokenizer model not found at {model_config.tokenizer}")

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = GemmaForCausalLM(model_config)
    ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
    model.load_weights(ckpt_path)
    model = model.to(device).eval()

# Use the model
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"

prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt="What is a good place for travel in the US?"
    )
    + MODEL_CHAT_TEMPLATE.format(prompt="California.")
    + USER_CHAT_TEMPLATE.format(prompt="What can I do in California?")
    + "<start_of_turn>model\n"
)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)

In [None]:
!rm -rf /kaggle/working/gemma_pytorch
!rm -rf /kaggle/working/submission/lib/gemma

In [None]:
%%bash
cd /kaggle/working
pip install -q -U -t /kaggle/working/submission/lib immutabledict sentencepiece

# Re-clone the repository
git clone https://github.com/google/gemma_pytorch.git

# Create the gemma directory
mkdir -p /kaggle/working/submission/lib/gemma

# Move files into the gemma directory
mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/submission/lib/gemma/

In [None]:
import os
import shutil

# Remove existing directories if they exist
if os.path.exists('/kaggle/working/gemma_pytorch'):
    shutil.rmtree('/kaggle/working/gemma_pytorch')
if os.path.exists('/kaggle/working/submission/lib/gemma'):
    shutil.rmtree('/kaggle/working/submission/lib/gemma')

# Clone the repository
!git clone https://github.com/google/gemma_pytorch.git > /dev/null

# Create the gemma directory
os.makedirs('/kaggle/working/submission/lib/gemma', exist_ok=True)

# Move the necessary files
shutil.move('/kaggle/working/gemma_pytorch/gemma', '/kaggle/working/submission/lib/gemma')

# Verify the contents of the directory
print("Files in /kaggle/working/submission/lib/gemma:", os.listdir('/kaggle/working/submission/lib/gemma'))

In [None]:
# Verify the weights directory and its contents
weights_dir = '/kaggle/input/gemma/pytorch/1.1-2b-it/1/'
if not os.path.exists(weights_dir):
    print(f"Directory {weights_dir} does not exist.")
else:
    print("Files in the weights directory:", os.listdir(weights_dir))

In [None]:
import os

# List directories in /kaggle/input/ to find the correct path
input_dir = '/kaggle/input/'
print("Directories in /kaggle/input/:", os.listdir(input_dir))

In [None]:
# List directories in /kaggle/input/gemma to find the correct path
gemma_dir = '/kaggle/input/gemma/'
print("Directories in /kaggle/input/gemma/:", os.listdir(gemma_dir))

In [None]:
# List directories in /kaggle/input/gemma/pytorch to find the correct path
pytorch_dir = '/kaggle/input/gemma/pytorch/'
print("Directories in /kaggle/input/gemma/pytorch/:", os.listdir(pytorch_dir))

In [None]:
# List directories in /kaggle/input/gemma/pytorch/7b-it-quant
quant_dir = '/kaggle/input/gemma/pytorch/7b-it-quant/'
print("Files in /kaggle/input/gemma/pytorch/7b-it-quant/:", os.listdir(quant_dir))

# List directories in /kaggle/input/gemma/pytorch/default
default_dir = '/kaggle/input/gemma/pytorch/default/'
print("Files in /kaggle/input/gemma/pytorch/default/:", os.listdir(default_dir))

In [None]:
# List directories in /kaggle/input/gemma/pytorch/7b-it-quant/2
quant_subdir = '/kaggle/input/gemma/pytorch/7b-it-quant/2/'
print("Files in /kaggle/input/gemma/pytorch/7b-it-quant/2:", os.listdir(quant_subdir))

In [11]:
%%writefile submission/main.py
# Setup
import os
import sys
import contextlib
import itertools
import re
from pathlib import Path
from typing import Iterable

import torch

# **IMPORTANT:** Set up your system path for both notebooks and simulations environment.
KAGGLE_AGENT_PATH = "/kaggle_simulations/agent/"
LIB_PATH = os.path.join(KAGGLE_AGENT_PATH, 'lib') if os.path.exists(KAGGLE_AGENT_PATH) else "/kaggle/working/submission/lib"
sys.path.insert(0, LIB_PATH)

# Import gemma.config and gemma.model with error handling
try:
    from gemma.config import get_config_for_7b, get_config_for_2b
    from gemma.model import GemmaForCausalLM
except ImportError as e:
    print(f"ImportError: {e}")
    sys.exit("Required modules from gemma are not available. Exiting.")

# Define WEIGHTS_PATH based on the environment
WEIGHTS_PATH = os.path.join(KAGGLE_AGENT_PATH, "gemma/pytorch/7b-it-quant/2") if os.path.exists(KAGGLE_AGENT_PATH) else "/kaggle/input/gemma/pytorch/7b-it-quant/2"

# Prompt Formatting
class GemmaFormatter:
    _start_token = '<start_of_turn>'
    _end_token = '<end_of_turn>'

    def __init__(self, system_prompt: str = None, few_shot_examples: Iterable = None):
        self._system_prompt = system_prompt
        self._few_shot_examples = few_shot_examples
        self._turn_user = f"{self._start_token}user\n{{}}{self._end_token}\n"
        self._turn_model = f"{self._start_token}model\n{{}}{self._end_token}\n"
        self.reset()

    def __repr__(self):
        return self._state

    def user(self, prompt):
        self._state += self._turn_user.format(prompt)
        return self

    def model(self, prompt):
        self._state += self._turn_model.format(prompt)
        return self

    def start_user_turn(self):
        self._state += f"{self._start_token}user\n"
        return self

    def start_model_turn(self):
        self._state += f"{self._start_token}model\n"
        return self

    def end_turn(self):
        self._state += f"{self._end_token}\n"
        return self

    def reset(self):
        self._state = ""
        if self._system_prompt:
            self.user(self._system_prompt)
        if self._few_shot_examples:
            self.apply_turns(self._few_shot_examples, start_agent='user')
        return self

    def apply_turns(self, turns: Iterable, start_agent: str):
        formatters = [self.model, self.user] if start_agent == 'model' else [self.user, self.model]
        formatters = itertools.cycle(formatters)
        for fmt, turn in zip(formatters, turns):
            fmt(turn)
        return self

# Agent Definitions
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Set the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

class GemmaAgent:
    def __init__(self, variant='7b-it-quant', device='cuda:0', system_prompt=None, few_shot_examples=None):
        self._variant = variant
        self._device = torch.device(device)
        self.formatter = GemmaFormatter(system_prompt=system_prompt, few_shot_examples=few_shot_examples)

        print("Initializing model")
        model_config = get_config_for_2b() if "2b" in variant else get_config_for_7b()
        model_config.tokenizer = os.path.join(WEIGHTS_PATH, "tokenizer.model")
        model_config.quant = "quant" in variant

        with _set_default_tensor_type(model_config.get_dtype()):
            model = GemmaForCausalLM(model_config)
            ckpt_path = os.path.join(WEIGHTS_PATH, f'gemma-{variant}.ckpt')
            model.load_weights(ckpt_path)
            self.model = model.to(self._device).eval()

    def __call__(self, obs, *args):
        self._start_session(obs)
        prompt = str(self.formatter)
        response = self._call_llm(prompt)
        response = self._parse_response(response, obs)
        print(f"{response=}")
        return response

    def _start_session(self, obs: dict):
        raise NotImplementedError

    def _call_llm(self, prompt, max_new_tokens=32, **sampler_kwargs):
        if sampler_kwargs is None:
            sampler_kwargs = {
                'temperature': 0.01,
                'top_p': 0.1,
                'top_k': 1,
            }
        response = self.model.generate(
            prompt,
            device=self._device,
            output_len=max_new_tokens,
            **sampler_kwargs,
        )
        return response

    def _parse_keyword(self, response: str):
        match = re.search(r"(?<=\*\*)([^*]+)(?=\*\*)", response)
        return match.group().lower() if match else ''

    def _parse_response(self, response: str, obs: dict):
        raise NotImplementedError

def interleave_unequal(x, y):
    return [item for pair in itertools.zip_longest(x, y) for item in pair if item is not None]

class GemmaQuestionerAgent(GemmaAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _start_session(self, obs):
        self.formatter.reset()
        self.formatter.user("Let's play 20 Questions. You are playing the role of the Questioner.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='model')
        if obs.turnType == 'ask':
            self.formatter.user("Please ask a yes-or-no question.")
        elif obs.turnType == 'guess':
            self.formatter.user("Now guess the keyword. Surround your guess with double asterisks.")
        self.formatter.start_model_turn()

    def _parse_response(self, response: str, obs: dict):
        if obs.turnType == 'ask':
            match = re.search(".+?\?", response.replace('*', ''))
            return match.group() if match else "Is it a person?"
        elif obs.turnType == 'guess':
            return self._parse_keyword(response)
        else:
            raise ValueError("Unknown turn type:", obs.turnType)

class GemmaAnswererAgent(GemmaAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _start_session(self, obs):
        self.formatter.reset()
        self.formatter.user(f"Let's play 20 Questions. You are playing the role of the Answerer. The keyword is {obs.keyword} in the category {obs.category}.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='user')
        self.formatter.user(f"The question is about the keyword {obs.keyword} in the category {obs.category}. Give yes-or-no answer and surround your answer with double asterisks, like **yes** or **no**.")
        self.formatter.start_model_turn()

    def _parse_response(self, response: str, obs: dict):
        answer = self._parse_keyword(response)
        return 'yes' if 'yes' in answer else 'no'

# Agent Creation
system_prompt = "You are an AI assistant designed to play the 20 Questions game. In this game, the Answerer thinks of a keyword and responds to yes-or-no questions by the Questioner. The keyword is a specific person, place, or thing."

few_shot_examples = [
    "Let's play 20 Questions. You are playing the role of the Questioner. Please ask your first question.",
    "Is it a person?", "**no**",
    "Is it a place?", "**yes**",
    "Is it a country?", "**yes** Now guess the keyword.",
    "**France**", "Correct!",
]

# **IMPORTANT:** Define agent as a global so you only have to load
# the agent you need. Loading both will likely lead to OOM.
agent = None

def get_agent(name: str):
    global agent
    
    if agent is None:
        if name == 'questioner':
            agent = GemmaQuestionerAgent(
                device='cuda:0',
                system_prompt=system_prompt,
                few_shot_examples=few_shot_examples,
            )
        elif name == 'answerer':
            agent = GemmaAnswererAgent(
                device='cuda:0',
                system_prompt=system_prompt,
                few_shot_examples=few_shot_examples,
            )
        else:
            raise ValueError("Unknown agent name:", name)
    
    assert agent is not None, "Agent not initialized."

    return agent

def agent_fn(obs, cfg):
    if obs.turnType == "ask":
        response = get_agent('questioner')(obs)
    elif obs.turnType == "guess":
        response = get_agent('questioner')(obs)
    elif obs.turnType == "answer":
        response = get_agent('answerer')(obs)
    else:
        raise ValueError("Unknown turn type:", obs.turnType)
    
    return "yes" if response is None or len(response) <= 1 else response

Overwriting submission/main.py


In [5]:
import sys
import os

# Add the repository to the system path
sys.path.append('/kaggle/working/gemma_pytorch')

# Try importing a module from gemma_pytorch
try:
    from gemma.config import get_config_for_7b
    print("gemma_pytorch is correctly installed and accessible.")
except ImportError as e:
    print(f"ImportError: {e}")

gemma_pytorch is correctly installed and accessible.


In [6]:
pip list

Package                                  Version
---------------------------------------- -------------------
absl-py                                  1.4.0
accelerate                               0.29.3
access                                   1.1.9
affine                                   2.4.0
aiobotocore                              2.12.3
aiofiles                                 22.1.0
aiohttp                                  3.9.1
aiohttp-cors                             0.7.0
aioitertools                             0.11.0
aiorwlock                                1.3.0
aiosignal                                1.3.1
aiosqlite                                0.19.0
albumentations                           1.4.0
alembic                                  1.13.1
altair                                   5.3.0
annotated-types                          0.6.0
annoy                                    1.17.3
anyio                                    4.2.0
apache-beam                          

In [7]:
import torch
print(torch.__version__)

2.1.2+cpu


In [8]:
import os

weights_path = '/kaggle/input/gemma/pytorch/7b-it-quant/2'
print("Files in weights directory:", os.listdir(weights_path))

Files in weights directory: ['config.json', 'gemma-7b-it-quant.ckpt', 'tokenizer.model']


In [9]:
from gemma.config import get_config_for_7b
from gemma.model import GemmaForCausalLM
import torch

# Path to the model weights
weights_path = '/kaggle/input/gemma/pytorch/7b-it-quant/2'

# Initialize model configuration
model_config = get_config_for_7b()
model_config.tokenizer = os.path.join(weights_path, "tokenizer.model")
model_config.quant = True

# Load model
model = GemmaForCausalLM(model_config)
ckpt_path = os.path.join(weights_path, 'gemma-7b-it-quant.ckpt')
model.load_weights(ckpt_path)

# Check if model is loaded properly
print("Model loaded successfully.")

Model loaded successfully.


In [None]:
!apt install pigz pv > /dev/null

In [None]:
!tar --use-compress-program='pigz --fast --recursive | pv' -cf submission.tar.gz -C /kaggle/working/submission . -C /kaggle/input/ gemma/pytorch/7b-it-quant/2