<a href="https://colab.research.google.com/github/SonicHedghog/CSE-6363-Final-Project/blob/main/Regular_Transformer_Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
# !pip install --upgrade transformers timm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class LocalGemmaChatBot:
    """
    A chatbot that plays Rock, Paper, Scissors using a locally-run Gemma 2 9B model.
    """
    def __init__(self):
        """
        Initializes the LocalGemmaChatBot.
        This will download the model (if not cached) and load it into memory.
        """
        print("Initializing local Gemma model...")
        print("This may take a while and will download gigabytes of data the first time.")

        # The model ID for the instruction-tuned Gemma 2 9B model
        self.model_id = "google/gemma-3-12b-it"
        self.dtype = torch.bfloat16 # Use bfloat16 for better performance

        # Load the tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

        # Load the model
        # device_map="auto" will automatically use a GPU if it's available
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_id,
            torch_dtype=self.dtype,
            device_map="auto",
        )

        # The system prompt that defines the bot's behavior
        system_prompt = """You are a rock-paper-scissors game.

Please do the following:
1. First, check if the user wants to exit the game (they might say "exit", "quit", "stop", "bye", "goodbye", or similar)
- If they want to exit, respond with exactly: "EXIT_GAME_TOKEN"
2. Validate the user's choice (rock, paper, or scissors).
- If the user enters an invalid choice (not rock/paper/scissors and not wanting to exit), respond with an error message asking them to choose rock, paper, scissors, or exit.
3. If it's a valid game choice, choose your own move randomly (rock, paper, or scissors)
4. Determine who wins based on the rules:
- Rock beats scissors
- Paper beats rock
- Scissors beats paper
- Same choice = tie

For valid game moves, format your response like this:
My choice: [your choice]
Result: [who won and why]
"""
        # We store the conversation history, starting with the system prompt
        self.chat_history = [{"role": "user", "content": system_prompt}, {"role": "model", "content": "I am ready to play Rock, Paper, Scissors! What is your choice?"}]
        print("\nModel ready!")


    def get_local_gemma_response(self, user_prompt):
        """
        Generates a response from the local Gemma model.

        Args:
            user_prompt (str): The user's input.

        Returns:
            str: The model's response.
        """
        try:
            # Add the user's new message to the history
            self.chat_history.append({"role": "user", "content": user_prompt})

            # Apply the chat template to format the history for the model
            # This turns the list of roles/content into a single string the model understands
            prompt_for_model = self.tokenizer.apply_chat_template(
                self.chat_history,
                tokenize=False,
                add_generation_prompt=True
            )

            # Tokenize the formatted prompt and move it to the model's device (CPU/GPU)
            inputs = self.tokenizer.encode(prompt_for_model, add_special_tokens=False, return_tensors="pt")
            inputs = inputs.to(self.model.device)

            # Generate a response
            outputs = self.model.generate(
                input_ids=inputs,
                max_new_tokens=150 # Limit the length of the response
            )

            # Decode the response, but only the new part
            response_text = self.tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)

            # Add the model's response to the history for the next turn
            self.chat_history.append({"role": "model", "content": response_text})

            return response_text.strip()

        except Exception as e:
            print(f"Error getting local Gemma response: {e}")
            return "I'm having trouble thinking. Let's try again!"

    def play(self):
        """
        Starts the Rock, Paper, Scissors game loop.
        """
        print("\nLet's play Rock, Paper, Scissors with a local Gemma!")
        print("You can say 'exit', 'quit', 'stop', or similar to end the game.\n")

        while True:
            user_choice = input("Enter your choice (rock, paper, scissors) or say when you want to quit: ")

            response = self.get_local_gemma_response(user_choice)

            # Check if Gemma returned the exit token
            if "EXIT_GAME_TOKEN" in response:
                print("Thanks for playing! Goodbye!")
                break

            print("\n" + response)
            print("-" * 50)

if __name__ == "__main__":
    # No API key needed! Just create the bot and play.
    bot = LocalGemmaChatBot()
    bot.play()

Initializing local Gemma model...
This may take a while and will download gigabytes of data the first time.


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]


Model ready!

Let's play Rock, Paper, Scissors with a local Gemma!
You can say 'exit', 'quit', 'stop', or similar to end the game.

Enter your choice (rock, paper, scissors) or say when you want to quit: quit
Thanks for playing! Goodbye!
