- [lucidrains toolformer](https://github.com/lucidrains/toolformer-pytorch)를 활용한 toolformer 사용 코드
- 기본으로 설정된 Palm 대신 microsoft phi-3를 사용

In [None]:
!pip install bitsandbytes
!pip install accelerate
!git clone https://github.com/lucidrains/toolformer-pytorch.git
!pip install -r /content/toolformer-pytorch/tools-requirements.txt
!pip install x-clip

In [None]:
import torch
import torch.nn as nn
import sys
sys.path.append('/content/toolformer-pytorch')
from toolformer_pytorch import Toolformer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [None]:
bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=False,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype="float16",
        )

In [None]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True, quantization_config=bnb_config)
print(len(tokenizer))

In [None]:
prompt = f"""
Your task is to add calls to a Calendar API to a piece of text.
The API calls should help you get information required to complete the text.
You can call the API by writing "[Calendar()]"
Here are some examples of API calls:
Input: Today is the first Friday of the year.
Output: Today is the first [Calendar()] Friday of the year.
Input: The president of the United States is Joe Biden.
Output: The president of the United States is [Calendar()] Joe Biden.
Input: [input]
Output:
"""

In [None]:
class CustomModel(nn.Module):
    def __init__(self, base_model, tokenizer):
        super(CustomModel, self).__init__()
        self.base_model = base_model
        self.tokenizer = tokenizer

    def forward(self, messages):
        encoded_input = self.tokenizer(messages, return_tensors="pt", padding=True)
        input_ids = encoded_input['input_ids'].to(self.base_model.device)

        with torch.no_grad():
            outputs = self.base_model(input_ids)
            logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)
        return probabilities

In [None]:
mymodel = CustomModel(model, tokenizer)

In [None]:
mymodel("안녕")

In [None]:
def Calendar():
    import datetime
    from calendar import day_name, month_name
    now = datetime.datetime.now()
    return f'Today is {day_name[now.weekday()]}, {month_name[now.month]} {now.day}, {now.year}.'

In [None]:
toolformer = Toolformer(
    model = mymodel,
    model_seq_len = 256,
    teach_tool_prompt = prompt,
    tool_id = 'Calendar',
    tool = Calendar,
    finetune = False
)

In [None]:
data = [
    "The store is never open on the weekend, so today it is closed.",
    "The number of days from now until Christmas is 30",
    "The current day of the week is Wednesday."
]

In [None]:
filtered_stats = toolformer(data)