# Setup

Please ensure you have imported a Gemini API key from AI Studio.
You can do this directly in the Secrets tab on the left.

After doing so, please run the setup cell below.

In [1]:
!pip install git+https://github.com/huggingface/transformers

Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-mcfvshy9
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-mcfvshy9
  Resolved https://github.com/huggingface/transformers to commit b7fc2daf8b3fe783173c270d592073aabfb426cb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


# Generated Code

In [2]:
!pip install -U "bitsandbytes"
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig




In [3]:
!huggingface-cli login -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


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) y
Token is valid (permission: fineG

In [7]:
import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM, BitsAndBytesConfig

ckpt = "google/gemma-3-4b-it"
tokenizer = AutoTokenizer.from_pretrained(ckpt)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16 # Or torch.float16
)

model = Gemma3ForCausalLM.from_pretrained(
    ckpt,
    torch_dtype = torch.bfloat16,
    device_map="auto",
    quantization_config=quantization_config
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Initialization

In [22]:
import json
setup_instructions = """
You are a highly competent and professional assistant with access to a set of specialized tools. Your task is to respond to user requests.
If a user request requires invoking one or more tools, you MUST respond exclusively by putting function calls in the following JSON format:

{ "function_calls": [ {"name": "function_name", "parameters": { ... }}, {"name": "function_name", "parameters": { ... }} ] }

WHILE ANSWERING TO FUNCTION CALLS DO NOT REPLY WITH ANY OTHER TEXT.
Only include valid JSON function calls as defined in the available tools.

If no tools are required simply answer the query in a normal way.
"""
function_definitions_list = [
    {
        "name": "get_current_weather",
        "description": "Get the current weather conditions for a specific location.",
        "parameters": { # Add parameters schema - crucial for the model!
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The city and state, e.g., San Francisco, CA"
                }
            },
            "required": ["location"]
        }
    },
    {
        "name": "send_message",
        "description": "Sends a message to a recipient.",
        "parameters": { # Add parameters schema
             "type": "object",
            "properties": {
                "recipient": {
                    "type": "string",
                    "description": "The email address or identifier of the recipient."
                },
                "body": {
                    "type": "string",
                    "description": "The content of the message."
                }
            },
            "required": ["recipient", "body"]
        }
    }
]

function_definitions_json_string = json.dumps(function_definitions_list, indent=2)

user_query = "Can you check the current weather in odisha"

final_user_content = f"""{setup_instructions}

Available Tools:
{function_definitions_json_string}

User Query: {user_query}"""



## Parser

In [11]:
import json
import re

def parse_tool_calls(decoded_output: str) -> list | None:
    """
    Parses the specific '{"function_calls": [...] }' JSON format,
    handling potential markdown fences. Returns a list of tool call dicts
    or None if no valid structure is found.
    """
    cleaned_output = decoded_output.strip()
    match = re.search(r"```json\s*(\{.*?\})\s*```", cleaned_output, re.DOTALL | re.IGNORECASE)
    if match:
        json_string = match.group(1)
    else:
        json_string = cleaned_output.replace("<end_of_turn>", "").strip()

    if not json_string.startswith('{') or not json_string.endswith('}'):
        return None

    try:
        parsed_json = json.loads(json_string)
    except json.JSONDecodeError as e:
        # Keep error logging during development, can be replaced by proper logging later
        print(f"Parser Error: Failed to decode JSON: {e}")
        return None

    if isinstance(parsed_json, dict) and "function_calls" in parsed_json:
        function_calls_list = parsed_json["function_calls"]
        if isinstance(function_calls_list, list):
            validated_calls = []
            for call in function_calls_list:
                if isinstance(call, dict) and "name" in call and "parameters" in call:
                    validated_calls.append(call)
                else:
                    # Invalid item found in the list, treat as failure for strictness
                    print(f"Parser Error: Invalid item structure in function_calls list: {call}")
                    return None # Or decide to skip invalid items and return partial list
            return validated_calls # Return the list of validated calls
        else:
            print("Parser Error: 'function_calls' key did not contain a list.")
            return None
    else:
        # JSON was valid but didn't match the expected top-level structure
        return None

## Parsing function calls

In [17]:

import re

parsed_calls = parse_tool_calls(decoded)

if parsed_calls is not None:
  print("Successfully parsed function calls:")
  for call in parsed_calls:
    print(f" Name: {call['name']}, Parameters: {call['parameters']}")
else:
  print("Failed to parse function calls.")

Successfully parsed function calls:
 Name: send_message, Parameters: {'recipient': 'mom', 'body': 'Anand will be home tonight'}


## Defining functions

In [18]:
def get_current_weather(location: str) -> str:
    """Gets the current weather for a location."""
    print(f"--- TOOL EXECUTING: get_current_weather(location='{location}') ---")
    if "new york" in location.lower():
        return json.dumps({"location": location, "temperature": "95F", "condition": "Sunny"})
    elif "san francisco" in location.lower():
         return json.dumps({"location": location, "temperature": "60F", "condition": "Foggy"})
    else:
        return json.dumps({"location": location, "temperature": "80F", "condition": "Rainy"})

def send_message(recipient: str, body: str) -> str:
    """Sends a message to a recipient."""
    print(f"--- TOOL EXECUTING: send_message(recipient='{recipient}', body='{body}') ---")
    # Simulate success
    return json.dumps({"status": "Message sent successfully", "to": recipient})

In [19]:
available_tools = {
    "get_current_weather": get_current_weather,
    "send_message": send_message,
    # Add other function names and their corresponding Python function objects here
}

## Executing

In [20]:
# --- Executor Logic ---

def execute_tool_calls(parsed_tool_calls: list, tool_registry: dict) -> list:
    """
    Executes a list of tool calls based on the parsed output and a registry.

    Args:
        parsed_tool_calls: The list of dicts from the parser, e.g.,
                           [{"name": "...", "parameters": {...}}, ...].
                           Assumes this is not None (checked before calling).
        tool_registry: A dictionary mapping tool names (str) to callable functions.

    Returns:
        A list of results from executing each tool call. Each result is often
        stored as a dict containing the original call info and the output.
        Returns an empty list if the input list was empty.
    """
    execution_results = []

    if not parsed_tool_calls:
        return execution_results

    for tool_call in parsed_tool_calls:
        function_name = tool_call.get("name")
        parameters = tool_call.get("parameters", {})

        if not function_name:
            print("Executor Error: Tool call missing 'name'. Skipping.")
            execution_results.append({
                "call": tool_call,
                "error": "Missing function name"
            })
            continue

        if function_name not in tool_registry:
            print(f"Executor Error: Tool '{function_name}' not found in registry. Skipping.")
            execution_results.append({
                "call": tool_call,
                "error": f"Function '{function_name}' not registered."
            })
            continue

        function_to_call = tool_registry[function_name]

        try:
            result = function_to_call(**parameters)
            execution_results.append({
                "call": tool_call, # Store the original call for context
                "output": result   # Store the function's return value
            })
            print(f"Executor: Call to {function_name} succeeded. Result: {result}")

        except TypeError as e:
            print(f"Executor Error: TypeError calling {function_name}: {e}. Check parameters.")
            execution_results.append({
                "call": tool_call,
                "error": f"Parameter mismatch for '{function_name}': {e}"
            })
        except Exception as e:
            print(f"Executor Error: Exception during execution of {function_name}: {e}")
            execution_results.append({
                "call": tool_call,
                "error": f"Execution error in '{function_name}': {e}"
            })

    return execution_results

# Testing

In [23]:
MAX_TURNS = 5 # Prevent infinite loops
turn_count = 0


chat_history = [
    {"role": "user", "content": final_user_content}
]
model_inputs = tokenizer.apply_chat_template(
      chat_history,
      add_generation_prompt=True, # Adds the prompt for the model turn, e.g., <start_of_turn>model\n
      tokenize=True,
      return_tensors="pt"
    ).to(model.device)


input_len = model_inputs.shape[1]
with torch.inference_mode():
  generation = model.generate(input_ids = model_inputs, max_new_tokens=150, do_sample=False)
  generation = generation[0][input_len:]
decoded = tokenizer.decode(generation, skip_special_tokens = True)
chat_history.append({"role": "assistant", "content": decoded})


while turn_count < MAX_TURNS:
    turn_count += 1
    print(f"\n--- Agent Turn {turn_count} ---")

    model_inputs = tokenizer.apply_chat_template(
        chat_history,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt"
    ).to(model.device)

    # Clear input length calculation (needed inside loop if history changes)
    # Note: This might need adjustment depending on how padding/truncation is handled
    input_len = model_inputs.shape[1]

    with torch.inference_mode():
        generation = model.generate(input_ids=model_inputs, max_new_tokens=200, do_sample=False)
        # Ensure we decode only the newly generated tokens
        # Be careful with slicing if using padding
        new_generation_ids = generation[0][input_len:]
        decoded = tokenizer.decode(new_generation_ids, skip_special_tokens=True) # Skip special tokens for cleaner parsing/output

    print(f"Model Output (Turn {turn_count}):\n{decoded}")

    # 3. Append model's response to history IMMEDIATELY
    # We store the raw response that *might* contain function calls
    chat_history.append({"role": "assistant", "content": decoded})

    # 4. Parse for tool calls in the *latest* model response
    parsed_calls = parse_tool_calls(decoded)

    if parsed_calls:
        print("\n--- Function Calls Identified ---")
        print(json.dumps(parsed_calls, indent=2))
        print("\n--- Executing Tools ---")
        tool_results = execute_tool_calls(parsed_calls, available_tools)
        print("\n--- Tool Execution Results ---")
        print(json.dumps(tool_results, indent=2))

        # --- CRUCIAL: Feed results back correctly ---
        # This is the part that needs verification/experimentation based
        # on Gemma 3's specific chat template capabilities in transformers.
        # Option A (Conceptual - requires checking actual template support):
        # Try to use role: "tool" if supported. Might need tool_call_id.
        # for i, result in enumerate(tool_results):
        #     # Assuming parser gives unique IDs or we can generate them
        #     tool_call_id = parsed_calls[i].get("id", f"call_{turn_count}_{i}")
        #     content = result.get("output", json.dumps({"error": result.get("error")}))
        #     chat_history.append({
        #         "role": "tool",
        #         "tool_call_id": tool_call_id, # Might be needed
        #         "content": content
        #     })


        # Option C (Simpler Workaround - closer to your code):
        # Directly append the results JSON, maybe with a preamble.
        # Still using 'user' role here as in your example, BUT BE AWARE this might confuse the model.
        # Testing is needed!
        results_content = json.dumps(tool_results, indent=2)
        chat_history.append({"role": "user", "content": f"Tool results:\n{results_content}"})
        # --- End Crucial Section ---


    else:
        print("\n--- No Function Calls Identified: Assuming Final Answer ---")
        final_response_to_user = decoded
        break

if turn_count >= MAX_TURNS:
    print("\n--- Reached Max Turns ---")
    final_response_to_user = chat_history[-1]['content'] if chat_history[-1]['role'] == 'assistant' else "Reached max turns without final answer."

# --- Output the final result ---
print("\n--- Final Response to User ---")
print(final_response_to_user)


--- Agent Turn 1 ---
Model Output (Turn 1):


--- No Function Calls Identified: Assuming Final Answer ---

--- Final Response to User ---

