In [1]:
from duckduckgo_search import DDGS
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import pickle
import textwrap
from tqdm import tqdm
import time
import numpy as np
import re
import torch.nn.functional as F
from collections import Counter

In [9]:
ds=load_dataset("squad")

In [49]:
torch.cuda.is_available()

True

In [2]:
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id=tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,device_map="cuda")#attn_implementation="flash_attention_2" -> I can't do this as it requires to be at least on a A100

In [45]:
def make_web_search(query,max_results=1) :
    """
    Perform a web search using the DuckDuckGo Search API.

    Args:
        query (str): The search query string to be used for the web search.

    Returns:
        str: A string containing the body of the first search result, 
              concatenated with a separator " | " if there are multiple results.
    """
    with DDGS() as ddgs :
        results= ddgs.text(query,max_results=max_results)
    return " | ".join([result["body"] for result in results])


def generate(prompt, model, tokenizer, make_api_calls=False):
    """
    Generates a response from the prompt using model. 
    If make_api_calls=True, it will stop whenever
    """
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        add_special_tokens=True,
    ).to(model.device)
    with torch.no_grad():
      if not make_api_calls:
          outputs = model.generate(**inputs, max_new_tokens=500, tokenizer=tokenizer, do_sample=False,use_cache=True,pad_token_id=tokenizer.eos_token_id)
      else:
          continue_generation=True
          current_inputs=inputs
          while continue_generation:
            stop_strings=["<|eot_id|>","→"]
            outputs = model.generate(**current_inputs, max_new_tokens=500, tokenizer=tokenizer, do_sample=True,use_cache=True,pad_token_id=tokenizer.eos_token_id,stop_strings=stop_strings)
            text=tokenizer.decode(outputs[0])
            if  text[-1]=="→":
              api_call_text=text.split("[[")[-1][:-1] #This should have the form WS(question)
              query_text=api_call_text[3:-1]
              print(query_text)
              answer=make_web_search(query_text)
              text=text.replace(api_call_text+"→",api_call_text+"→"+answer+"]]") #This will be of the form WS(question)→ response1 | response2 | response3...
              current_inputs = tokenizer(
                            text,
                            return_tensors="pt",
                            add_special_tokens=True,
                        ).to(model.device)
            else :
              continue_generation=False


    result = outputs[0][inputs["input_ids"].shape[-1]:]
    return tokenizer.decode(result, skip_special_tokens=True)


def format(system_prompt,user_prompt) :
    """Format the LLM prompt in the appropriat format"""
    return f"""
<|start_header_id|>system<|end_header_id|>

{system_prompt}<|eot_id|>
<|start_header_id|>user<|end_header_id|>

{user_prompt}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""

In [52]:
# Test web search
make_web_search("What is the capital of Italy?")

'Rome is the capital city of Italy and the centre of the Metropolitan City of Rome Capital. It is also the cradle of Western civilization and Western Christian culture, and the seat of the Catholic Church and several UN agencies.'

In [53]:
DATASET_GENERATION_PROMPT="""Your task is to add Web Search API calls to a
piece of text. The web searches should help you get
information required to complete the text. You can call the
API by writing " [[WS(query)]] " where "query" is the
question you want to ask. You should only add in the middle of the text some API calls but should not modify in any way the text.
Here are some examples of API calls:

INPUT:
'The ownership of the Spectre organisation—originally stylised "SPECTRE" as an acronym of SPecial Executive for Counter-intelligence, Terrorism, Revenge and Extortion—and its characters, had been at the centre of long-standing litigation starting in 1961 between Ian Fleming and Kevin McClory over the film rights to the novel Thunderball.'

OUTPUT:
'The ownership of the Spectre organisation—originally stylised "SPECTRE" as an acronym of [[WS(What is the acronym for the SPECTRE organization?)]] SPecial Executive for Counter-intelligence, Terrorism, Revenge and Extortion—and its characters, had been at the centre of long-standing litigation starting in 1961 between [[WS(Who was involved in the 1961 litigation about spectre ownership?)]] Ian Fleming and Kevin McClory over the film rights to the novel Thunderball.'

INPUT:
'In December 2015, West released a song titled "Facts". He announced in January 2016 on Twitter that SWISH would be released on February 11, after releasing new song "Real Friends" and a snippet of "No More Parties in L.A." with Kendrick Lamar.'

OUTPUT:
In December 2015, West released a song titled [[WS(What song did West release in October 2015?)]] "Facts". He announced in January 2016 on Twitter that SWISH would be released on February 11, after releasing new song "Real Friends" and a snippet of "No More Parties in L.A." with [[WS(Who did west sing "Real Friends" and "No More Parties in L.A." with?)]] Kendrick Lamar.'

INPUT:
The capital of France is Paris.

OUTPUT:
The capital of France is [[WS(What is the capital of France?)]] Paris.
"""

In [10]:
train_ds_size=400
# We use the wikipedia context as our training set and make sur to not have any duplicates
train_texts=[]
already_seen=set()
for text in ds["train"]["context"] :
    if len(train_texts)==train_ds_size :
        break
    if text not in already_seen :
        train_texts.append(text)
        already_seen.add(text)

In [55]:
# Investigate answers
answer=generate(format(DATASET_GENERATION_PROMPT,f"INPUT:\n{train_texts[0]}\n\nOUTPUT:\n"),model,tokenizer,make_api_calls=False)
print(textwrap.fill(train_texts[0],width=100))
print("----------------------------")
print(textwrap.fill(answer,width=100))



Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden
statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper
statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building
is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place
of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary
reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a
direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of
Mary.
----------------------------
Architecturally, the school has a [[WS(What is the Catholic character of the school?)]] Catholic
character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in
front of the Main Building and facing it, is a copper statue of Christ with ar

 Example of bad queries:
 WS(What is the Catholic character of the school?)

 The Web Search query in this case is not self contained and depends on the context making it bad. We hope the model will learn to not do this type of query as they will be filtered out in the next step.

# Step 1: Sampling API Calls

In [56]:
raw_tool_augmented_text=[]
for text  in tqdm(train_texts) :
    raw_tool_augmented_text.append(generate(format(DATASET_GENERATION_PROMPT,f"INPUT:\n{text}\n\nOUTPUT:\n"),model,tokenizer,make_api_calls=False))

100%|██████████| 400/400 [1:50:33<00:00, 16.58s/it]  


In [59]:
with open("raw_tool_augmented_text.pkl","wb") as f :
    pickle.dump(raw_tool_augmented_text,f)

In [7]:
with open("raw_tool_augmented_text.pkl","rb") as f :
    raw_tool_augmented_text=pickle.load(f)

In [5]:
api_call_pattern=r"\[\[.*?\]\]"

In [11]:
# We only keep valid annotations (e.g the annotation should be the same as the initial text if we remove API calls)
valid_raw_tool_augmented_texts=[]
for i in range(len(raw_tool_augmented_text)) :
    cleaned_tool_text=re.sub(api_call_pattern, "", raw_tool_augmented_text[i]).replace("  "," ").replace("<|eot_id|>","")
    # print(cleaned_tool_text==train_text[i])
    # print(cleaned_tool_text)
    # print("--------")
    # print(train_text[i])

    if cleaned_tool_text==train_texts[i] :
        valid_raw_tool_augmented_texts.append((raw_tool_augmented_text[i],train_texts[i]))
print(len(valid_raw_tool_augmented_texts))

204


When using sampling as our decoding strategy we get around 33% of valid annotation whereas we get closer to 50% with greedy decoding. We choose to use greedy decoding at this step.
- Sample -> about 1/3 valid
- Greedy decoding -> about 1/2 valid

In [63]:
valid_raw_tool_augmented_texts[:5]

[('Architecturally, the school has a [[WS(What is the Catholic character of the school?)]] Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
  'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with th

# Step 2: Executing API Calls

In [125]:
valid_parsed_tool_augmented_texts=[]

In [128]:
for raw_tool_augmented_text, train_text in tqdm(valid_raw_tool_augmented_texts[len(valid_parsed_tool_augmented_texts):]) :
    api_calls=re.findall(api_call_pattern,raw_tool_augmented_text)
    for api_call in api_calls :
        query=api_call[5:-3]
        # We need to sleep between each web search otherwise we reach duckduckgo rate limits
        # We also implement retries to not get blocked if receiving a rate limit exception
        num_retry=10
        for i in range(num_retry) :
            try :
                answer=make_web_search(query)
                break
            except Exception as e:
                if i==num_retry-1 :
                    raise(e)
                time.sleep(30)
                continue

        raw_tool_augmented_text=raw_tool_augmented_text.replace(api_call,f"[[WS({query})→{answer}]]")
    valid_parsed_tool_augmented_texts.append((raw_tool_augmented_text,train_text))

0it [00:00, ?it/s]


In [123]:
with open("valid_parsed_tool_augmented_texts.pkl","wb") as f:
    pickle.dump(valid_parsed_tool_augmented_texts,f)

In [127]:
with open("valid_parsed_tool_augmented_texts.pkl","rb") as f:
    valid_parsed_tool_augmented_texts=pickle.load(f)


In [12]:
#Counter of the number of proposed API calls per text
Counter([len(re.findall(api_call_pattern,valid_raw_tool_augmented_text)) for valid_raw_tool_augmented_text,_ in valid_raw_tool_augmented_texts])

Counter({1: 79, 0: 65, 2: 30, 3: 20, 4: 7, 6: 2, 5: 1})

# Step 3: Filtering API calls

In [None]:
# The threshold for the increase of performance by an API call to consider it useful
# We choose it by trial and error to keep good precision (the API calls kept are indeed of good quality) and recall(we do not throw too many API calls) 
tau=0.3
begin_token=tokenizer(" [[")["input_ids"][-1] #4416
end_token=tokenizer("]]")["input_ids"][-1] #5163

def  get_cross_entropy(logits,input_ids) :
    log_probs=F.log_softmax(logits,dim=1)
    return -torch.sum(log_probs.gather(-1,input_ids.unsqueeze(-1)))

def get_logits_input_ids(text) :
    input=tokenizer(text,return_tensors="pt",add_special_tokens=True).to(model.device)
    input_ids=input["input_ids"][0][1:] #We need to shift input_ids by 1 such that it is aligned with the logits. 
    with torch.no_grad() :
        logits=model(**input).logits[0]
    return logits,input_ids

def get_tool_call_to_remove(text_tool,text_no_tool):
    """
    This function determines which tool calls can be removed from the text based on their contribution to
    the overall performance (decrease in cross entropy). 
    
    When an API call is made at index i and the next call is at index j (j=len(text) if there is not another call)),
    we evaluate the gain in cross entropy loss for the segment [i,j] of the text. 
    This approach is different from the one used in the paper as we do not compute the loss on the rest
    of the text but only on the part up to the next API call. This enables to have multiple API calls per text
    in our trainset while the original paper allowed only 1 call per text.   

    We consider the loss with the API call(L1), without the API call(L2), and with the API call where we replace the
    response by ...(L3)
    If L1<min(L2,L3)-tau we consider the API call helpful enough and keep it, otherwise we discard it

    Parameters:
    - text_tool (str): The text containing tool calls.
    - text_no_tool (str): The text without any tool calls.
    
    Returns:
    - tool_call_to_remove_idx (list): A list of indices of tool calls that can be removed.
    """
    logits_no_tool,input_id_no_tool=get_logits_input_ids(text_no_tool)
    logits_tool,input_id_tool=get_logits_input_ids(text_tool)
    text_tool_no_resp=re.sub(r"→.*?]]","→...]]",text_tool)
    logits_tool_no_resp,input_id_tool_no_resp=get_logits_input_ids(text_tool_no_resp)   

    num_calls=(input_id_tool==begin_token).sum()
    tool_call_to_remove_idx=[]
    for i in range(num_calls) :
        next_api_call_start_idx=(input_id_tool==begin_token).nonzero()[1] if i<num_calls-1 else len(input_id_tool)
        api_call_start_idx=(input_id_tool==begin_token).nonzero()[0]
        api_call_end_idx=(input_id_tool==end_token).nonzero()[0]
        api_call_end_idx_no_resp=(input_id_tool_no_resp==end_token).nonzero()[0]
        section_length=next_api_call_start_idx-api_call_end_idx-1
        
        input_id_tool=input_id_tool[api_call_end_idx+1:]
        logits_tool=logits_tool[api_call_end_idx+1:]
        input_id_tool_no_resp=input_id_tool_no_resp[api_call_end_idx_no_resp+1:]
        logits_tool_no_resp=logits_tool_no_resp[api_call_end_idx_no_resp+1:]
        input_id_no_tool=input_id_no_tool[api_call_start_idx:]
        logits_no_tool=logits_no_tool[api_call_start_idx:]
        
        loss_no_tool=get_cross_entropy(logits_no_tool[:section_length],input_id_no_tool[:section_length])
        loss_tool=get_cross_entropy(logits_tool[:section_length],input_id_tool[:section_length])
        loss_tool_no_resp=get_cross_entropy(logits_tool_no_resp[:section_length],input_id_tool_no_resp[:section_length])
        
        if loss_tool>min(loss_no_tool,loss_tool_no_resp)-tau :
            tool_call_to_remove_idx.append(i)
    return tool_call_to_remove_idx


In [142]:
#We remove every tool call that is not "helpful" enough.
final_text_tool=[]
removed_tool_calls=[]
for text_tool,text_no_tool in valid_parsed_tool_augmented_texts:
    tool_calls=re.findall(api_call_pattern,text_tool)
    for tool_call_to_remove_idx in get_tool_call_to_remove(text_tool,text_no_tool):
        removed_tool_calls.append(tool_calls[tool_call_to_remove_idx])
        text_tool=text_tool.replace(tool_calls[tool_call_to_remove_idx],"")
    final_text_tool.append(text_tool)

In [149]:
# Save final_text_tool to a pickle file
with open('final_text_tool.pkl', 'wb') as f:
    pickle.dump(final_text_tool, f)

In [3]:
# Load final_text_tool from the pickle file
with open('final_text_tool.pkl', 'rb') as f:
    final_text_tool = pickle.load(f)

In [6]:
#Counter of the number of proposed API calls per text
Counter([len(re.findall(api_call_pattern,text)) for text in final_text_tool])

Counter({0: 169, 1: 27, 2: 6, 3: 2})

In [7]:
#We only keep examples that contain some API calls.
final_filtered_text_tool=[text for text in final_text_tool if "[[" in text]

In [145]:
for text in final_filtered_text_tool[:5] :
    if "[[" in text :
        print(textwrap.fill(text))
        print("------------")

The Lobund Institute grew out of pioneering research in [[WS(What is
the field of research that Lobund Institute is associated
with?)→Lobund Institute. The Lobund Institute grew out of pioneering
research in germ-free-life which began in 1928. This area of research
originated in a question posed by Pasteur as to whether animal life
was possible without bacteria. ... Lobund was the first research
organization to answer definitively, that such life is possible and
that it can ...]] germ-free-life which began in 1928. This area of
research originated in a question posed by Pasteur as to whether
animal life was possible without bacteria. Though others had taken up
this idea, their research was short lived and inconclusive. Lobund was
the first research organization to answer definitively, that such life
is possible and that it can be prolonged through generations. But the
objective was not merely to answer Pasteur's question but also to
produce the germ free animal as a new tool for biolog

In [147]:
removed_tool_calls[:10]

['[[WS(What is the Catholic character of the school?)→Catholic education is rooted in the conviction that Jesus Christ provides the most comprehensive and compelling example of the realization of full human potential (The Catholic School, 34, 35). In every aspect of programs, life, and activities, Catholic schools should foster a personal relationship with Jesus Christ and communal witness to the ...]]',
 '[[WS(What type of degree did the university first offer in 1854-1855?)→The 1855 diploma of William Gouverneur Morris, who earned an LL.B. degree from Harvard Law School. It was signed by Harvard President James Walker, under whose regime (1853-1860) Harvard constructed its first sciences building, offered its first music course, and hired its first black staffer, boxing instructor A. Molyneaux Hewlett.]]',
 '[[WS(What type of degree was developed in 1924?)→How did the first nursing doctorate offered by Columbia University in 1924 differ from the various practice-based degrees that we

We can see here that our methods seems to work ! 
All the web search that were made indeed make the generation easier. However we can see one issue arise: Sometimes bad question can lead to helpful results and so they will be accepted. On the otherhand some good questions do not lead to useful results and are thus discarded. This problem is more linked to the websearch part which we will not dive further on (the way it is done in the paper is by using RAG with a BM25 sparse embedding on wikipedia dumps). 

# Step 4: Finetuning on the filtered dataset obtained

In [None]:
from transformers import Trainer, TrainingArguments,DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from datasets import Dataset

# Create a dataset from final_filtered_text_tool
dataset = Dataset.from_dict({"text": final_filtered_text_tool})

# 2. Define LoRA (Low Rank Adaptation) configuration
lora_config = LoraConfig(
    r=16,  # rank
    lora_alpha=32,
    target_modules=["gate_proj", "down_proj", "up_proj","q_proj", "v_proj", "k_proj", "o_proj"], #All the linear modules 
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

peft_model=get_peft_model(model,lora_config)
peft_model.print_trainable_parameters()

def tokenize_function(examples):
    tokenized= tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=512,  
        return_tensors=None, 
        return_special_tokens_mask=True
    )
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized


tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=dataset.column_names  # This removes the original text column
)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=20,
    per_device_train_batch_size=4,
    learning_rate=2e-4,
    fp16=True,
    weight_decay=0.99,
    lr_scheduler_type="cosine",
    warmup_steps=10,
    save_total_limit=2,
    logging_dir="./logs",  # Directory for storing logs
    save_strategy="epoch",
    logging_strategy="epoch",  # Log every epoch
    # logging_steps=10,
    report_to=["tensorboard"],
    optim="adamw_torch",
    label_names=["input_ids"]
)

# Initialize trainer
trainer = Trainer(
    model=peft_model,  # Using the model defined earlier
    args=training_args,
    train_dataset=tokenized_dataset,
    # data_collator=data_collator
)

# Train the model
trainer.train()

# Save the model
trainer.save_model("./toolformer_finetuned")


trainable params: 11,272,192 || all params: 1,247,086,592 || trainable%: 0.9039


Map:   0%|          | 0/35 [00:00<?, ? examples/s]

Step,Training Loss
9,4.5176
18,1.439
27,1.2288
36,1.0513
45,0.8422
54,0.6079
63,0.3982
72,0.2198
81,0.1126
90,0.0632


In [9]:
trainer.save_model("./toolformer_finetuned")

# Results

In [23]:
trained_model=AutoModelForCausalLM.from_pretrained("./toolformer_finetuned",device_map="cuda")

In [24]:
trained_model.device

device(type='cuda', index=0)

In [47]:
SYSTEM_PROMPT="You are a nice and helpful AI assistant. Your goal is to provide accurate, respectful, and engaging responses to the user's queries. "
USER_PROMPT="What is the capital of France?"
generate(format(SYSTEM_PROMPT,USER_PROMPT),trained_model,tokenizer,make_api_calls=True)

'The capital of France is Paris.'

In [42]:
SYSTEM_PROMPT="You are a nice and helpful AI assistant. Your goal is to provide accurate, respectful, and engaging responses to the user's queries. "
USER_PROMPT="What is the capital of France?"
generate(format(SYSTEM_PROMPT,USER_PROMPT)+"The capital of France is [[WS",trained_model,tokenizer,make_api_calls=True)

'WS(What is the capital of France?)→Paris is a global city of culture, finance, diplomacy, and tourism, with an estimated population of 2 million residents in 2025. It is the centre of the Île-de-France region and has many famous landmarks, museums, and historical districts.]] Paris.'

In [44]:
SYSTEM_PROMPT="You are a nice and helpful AI assistant. Your goal is to provide accurate, respectful, and engaging responses to the user's queries. "
USER_PROMPT="Describe me the history of the second world war?"
generate(format(SYSTEM_PROMPT,USER_PROMPT)+"The second world war starts in [[",trained_model,tokenizer,make_api_calls=True)

" [[WS(What is the start date of the second world war?)→World War II [b] or the Second World War (1 September 1939 - 2 September 1945) was a global conflict between two coalitions: the Allies and the Axis powers. Nearly all of the world's countries participated, with many nations mobilising all resources in pursuit of total war. Tanks and aircraft played major roles, enabling the strategic bombing of cities and delivery of the first and only...]] September 1939, when Nazi Germany, led by Adolf Hitler, invaded Poland. The war spread quickly across Europe, with Germany advancing into France, Belgium, the Netherlands, and into England. In Asia, Japan, under the leadership of Emperor Hirohito, invaded China in 1937 and continued to expand into Southeast Asia and India. The war ended with the surrender of Japan to the Allied Powers in 1945, and the death of Adolf Hitler in 1945."

We see that the model knows how to make queries and how to use them to generate better answers. However we see that we still need to "push" the model to make such WebSearch calls and will not do them on its own. This limitation is most likely a consequence of the very small scale at which our experiments were done (We trained on about 40 examples while the origianl paper used 15.000 of them for each tool).

Better results could also have been obtained by having a more complex websearch part (for instance by doing RAG style retrieval on wikipedia dumps).  