In [4]:
!pip install torch transformers datasets accelerate evaluate trl protobuf sentencepiece huggingface-hub


Collecting datasets
  Downloading datasets-4.4.2-py3-none-any.whl.metadata (19 kB)
Collecting accelerate
  Downloading accelerate-1.12.0-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting trl
  Downloading trl-0.26.2-py3-none-any.whl.metadata (11 kB)
Collecting sentencepiece
  Downloading sentencepiece-0.2.1-cp313-cp313-win_amd64.whl.metadata (10 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-22.0.0-cp313-cp313-win_amd64.whl.metadata (3.3 kB)
Collecting dill<0.4.1,>=0.3.0 (from datasets)
  Downloading dill-0.4.0-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess<0.70.19 (from datasets)
  Downloading multiprocess-0.70.18-py313-none-any.whl.metadata (7.2 kB)
Downloading datasets-4.4.2-py3-none-any.whl (512 kB)
Downloading dill-0.4.0-py3-none-any.whl (119 kB)
Downloading multiprocess-0.70.18-py313-none-any.whl (151 kB)
Downloading accelerate-1.12.0-py3-none-any.whl (380 kB)
Downloadi

In [5]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
base_model = "google/functiongemma-270m-it"
learning_rate = 5e-5

In [8]:
qdrant_tool_calling = [
    {
        "user_content": (
            "Create a new collection called documents "
            "with 384 dimensional vectors using cosine distance"
        ),
        "tool_name": "create_collection",
        "tool_arguments": {
            "collection_name": "documents",
            "vector_size": 384,
            "distance": "Cosine"
        }
    },
    {
        "user_content": (
            "Add a point with id 1 to the images collection "
            "with vector [0.5, 0.3] and metadata category cat"
        ),
        "tool_name": "upsert_points",
        "tool_arguments": {
            "collection_name": "images",
            "points": [
                {
                    "id": 1,
                    "vector": [0.5, 0.3],
                    "payload": {"category": "cat"}
                }
            ]
        }
    }
]


In [12]:
import json
from datasets import Dataset
from transformers.utils import get_json_schema

# ----------------------------
# Tool function definitions
# ----------------------------

def create_collection(collection_name: str, vector_size: int, distance: str) -> str:
    """
    Creates a new collection in Qdrant.

    Args:
        collection_name (str): Name of the collection to create.
        vector_size (int): Dimensionality of the vectors.
        distance (str): Distance metric (e.g. Cosine, Dot, Euclid).

    Returns:
        str: Status message.
    """
    return "collection_created"


def upsert_points(collection_name: str, points: list) -> str:
    """
    Inserts or updates points in a Qdrant collection.

    Args:
        collection_name (str): Name of the collection.
        points (list): List of points with vectors and payloads.

    Returns:
        str: Status message.
    """
    return "points_upserted"


def get_point(collection_name: str, point_id: int) -> dict:
    """
    Retrieves a point from a Qdrant collection.

    Args:
        collection_name (str): Name of the collection.
        point_id (int): ID of the point.

    Returns:
        dict: Retrieved point data.
    """
    return {"id": point_id}


def search_points(collection_name: str, vector: list, limit: int) -> list:
    """
    Searches for similar vectors in a Qdrant collection.

    Args:
        collection_name (str): Name of the collection.
        vector (list): Query vector.
        limit (int): Maximum number of results.

    Returns:
        list: Search results.
    """
    return []


def scroll_points(collection_name: str, limit: int) -> list:
    """
    Scrolls through points in a Qdrant collection.

    Args:
        collection_name (str): Name of the collection.
        limit (int): Number of points to retrieve.

    Returns:
        list: Retrieved points.
    """
    return []


def delete_points(collection_name: str, point_ids: list) -> str:
    """
    Deletes points from a Qdrant collection.

    Args:
        collection_name (str): Name of the collection.
        point_ids (list): IDs of points to delete.

    Returns:
        str: Status message.
    """
    return "points_deleted"


# ----------------------------
# Generate tool schemas
# ----------------------------

TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "create_collection",
            "description": "Create a new collection in Qdrant",
            "parameters": {
                "type": "object",
                "properties": {
                    "collection_name": {"type": "string"},
                    "vector_size": {"type": "integer"},
                    "distance": {"type": "string"},
                },
                "required": ["collection_name", "vector_size", "distance"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "upsert_points",
            "description": "Insert or update points in a Qdrant collection",
            "parameters": {
                "type": "object",
                "properties": {
                    "collection_name": {"type": "string"},
                    "points": {"type": "array"},
                },
                "required": ["collection_name", "points"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "get_point",
            "description": "Retrieve a point by ID",
            "parameters": {
                "type": "object",
                "properties": {
                    "collection_name": {"type": "string"},
                    "point_id": {"type": "integer"},
                },
                "required": ["collection_name", "point_id"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "search_points",
            "description": "Search for similar vectors",
            "parameters": {
                "type": "object",
                "properties": {
                    "collection_name": {"type": "string"},
                    "vector": {"type": "array"},
                    "limit": {"type": "integer"},
                },
                "required": ["collection_name", "vector", "limit"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "scroll_points",
            "description": "Scroll points in a collection",
            "parameters": {
                "type": "object",
                "properties": {
                    "collection_name": {"type": "string"},
                    "limit": {"type": "integer"},
                },
                "required": ["collection_name", "limit"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "delete_points",
            "description": "Delete points from a collection",
            "parameters": {
                "type": "object",
                "properties": {
                    "collection_name": {"type": "string"},
                    "point_ids": {"type": "array"},
                },
                "required": ["collection_name", "point_ids"],
            },
        },
    },
]



# ----------------------------
# Example tool-calling data
# ----------------------------

qdrant_tool_calling = [
    {
        "user_content": (
            "Create a new collection called documents with "
            "384 dimensional vectors using cosine distance"
        ),
        "tool_name": "create_collection",
        "tool_arguments": {
            "collection_name": "documents",
            "vector_size": 384,
            "distance": "Cosine",
        },
    },
    {
        "user_content": (
            "Add a point with id 1 to the images collection "
            "with vector [0.5, 0.3] and metadata category cat"
        ),
        "tool_name": "upsert_points",
        "tool_arguments": {
            "collection_name": "images",
            "points": [
                {
                    "id": 1,
                    "vector": [0.5, 0.3],
                    "payload": {"category": "cat"},
                }
            ],
        },
    },
]

# ----------------------------
# Format training data
# ----------------------------

formatted_data = []

for item in qdrant_tool_calling:
    messages = [
        {
            "role": "user",
            "content": item["user_content"],
        },
        {
            "role": "assistant",
            "tool_calls": [
                {
                    "id": "1",
                    "type": "function",
                    "function": {
                        "name": item["tool_name"],
                        "arguments": item["tool_arguments"],
                    },
                }
            ],
        },
    ]

    formatted_data.append({"messages": messages})

# ----------------------------
# Create HuggingFace dataset
# ----------------------------

dataset = Dataset.from_list(formatted_data).train_test_split(test_size=0.2)

# ----------------------------
# Debug / sanity checks
# ----------------------------

print("Tool schemas:")
print(json.dumps(TOOLS, indent=2))

print("\nSample training example:")
print(json.dumps(formatted_data[0], indent=2))

print("\nDataset split:")
print(dataset)


Tool schemas:
[
  {
    "type": "function",
    "function": {
      "name": "create_collection",
      "description": "Create a new collection in Qdrant",
      "parameters": {
        "type": "object",
        "properties": {
          "collection_name": {
            "type": "string"
          },
          "vector_size": {
            "type": "integer"
          },
          "distance": {
            "type": "string"
          }
        },
        "required": [
          "collection_name",
          "vector_size",
          "distance"
        ]
      }
    }
  },
  {
    "type": "function",
    "function": {
      "name": "upsert_points",
      "description": "Insert or update points in a Qdrant collection",
      "parameters": {
        "type": "object",
        "properties": {
          "collection_name": {
            "type": "string"
          },
          "points": {
            "type": "array"
          }
        },
        "required": [
          "collection_name",
          "

In [13]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="eager",
)


config.json:   0%|          | 0.00/1.32k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
`torch_dtype` is deprecated! Use `dtype` instead!
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/536M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/176 [00:00<?, ?B/s]

In [14]:
tokenizer = AutoTokenizer.from_pretrained(base_model)

print(f"Device: {model.device}")
print(f"DType: {model.dtype}")

print("--- dataset input ---")
print(json.dumps(dataset["train"][0], indent=2))

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/63.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/706 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/13.8k [00:00<?, ?B/s]

Device: cpu
DType: torch.bfloat16
--- dataset input ---
{
  "messages": [
    {
      "content": "Create a new collection called documents with 384 dimensional vectors using cosine distance",
      "role": "user",
      "tool_calls": null
    },
    {
      "content": null,
      "role": "assistant",
      "tool_calls": [
        {
          "function": {
            "arguments": {
              "collection_name": "documents",
              "distance": "Cosine",
              "points": null,
              "vector_size": 384
            },
            "name": "create_collection"
          },
          "id": "1",
          "type": "function"
        }
      ]
    }
  ]
}


In [15]:
debug_msg = tokenizer.apply_chat_template(
    dataset["train"][0]["messages"],
    tools=TOOLS,
    add_generation_prompt=False,
    tokenize=False,
)

print(debug_msg)

<bos><start_of_turn>developer
<start_function_declaration>declaration:create_collection{description:<escape>Create a new collection in Qdrant<escape>,parameters:{properties:{collection_name:{description:<escape><escape>,type:<escape>STRING<escape>},distance:{description:<escape><escape>,type:<escape>STRING<escape>},vector_size:{description:<escape><escape>,type:<escape>INTEGER<escape>}},required:[<escape>collection_name<escape>,<escape>vector_size<escape>,<escape>distance<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:upsert_points{description:<escape>Insert or update points in a Qdrant collection<escape>,parameters:{properties:{collection_name:{description:<escape><escape>,type:<escape>STRING<escape>},points:{description:<escape><escape>,type:<escape>ARRAY<escape>}},required:[<escape>collection_name<escape>,<escape>points<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:get_

In [16]:
def check_success_rate():
    success_count = 0

    for idx, item in enumerate(dataset["test"]):
        # Use user + assistant ground-truth tool call
        messages = [
            item["messages"][0],
            item["messages"][1],
        ]

        inputs = tokenizer.apply_chat_template(
            messages,
            tools=TOOLS,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt",
        )

        outputs = model.generate(
            **inputs.to(model.device),
            pad_token_id=tokenizer.eos_token_id,
            max_new_tokens=128,
        )

        output_text = tokenizer.decode(
            outputs[0][len(inputs["input_ids"][0]):],
            skip_special_tokens=True,
        )

        try:
            parsed = json.loads(output_text)

            gt_tool = item["messages"][1]["tool_calls"][0]["function"]
            pred_tool = parsed[0]

            if (
                pred_tool["name"] == gt_tool["name"]
                and json.loads(pred_tool["arguments"])
                == json.loads(gt_tool["arguments"])
            ):
                success_count += 1

        except Exception:
            pass

    total = len(dataset["test"])
    print(
        f"Success rate: {success_count}/{total} = "
        f"{(success_count / total) * 100:.2f}%"
    )


In [17]:
from trl import SFTConfig

In [18]:
torch_dtype = model.dtype

In [21]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="./functiongemma-270m-it-qdrant-ft",
    max_length=512,
    packing=False,
    num_train_epochs=8,
    per_device_train_batch_size=1,   # CPU-friendly
    gradient_accumulation_steps=8,    # simulate larger batch
    gradient_checkpointing=False,     # not useful on CPU
    optim="adamw_torch",              # ❗ NOT fused
    learning_rate=learning_rate,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    fp16=False,                       # ❌
    bf16=False,                       # ❌
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    push_to_hub=False,
    report_to="none",
    use_cpu=True,                     # ✅
)
