- toolformer datageneration test
- [toolformer-conceptofmind](https://github.com/conceptofmind/toolformer)를 clone 해 놓은 뒤 그 폴더 안에서 실행하시오.

In [1]:
import os

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
from datasets import load_dataset
from data_generation.retrieval import RetrievalPostprocessing
from data_generation.calendar import CalendarPostprocessing
from data_generation.calculator import CalculatorPostprocessing
import json
import time
import argparse

import torch
print(torch.cuda.is_available())
device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


True


[nltk_data] Downloading package punkt to /home/jinhakai2/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
import sys
# ipynb에서 args 사용하려면 이렇게 해야 함.
sys.argv = ['script_name.py', '--device_id', '0', '--num_devices', '1']

parser = argparse.ArgumentParser(description='do some continuations')
parser.add_argument('--device_id', type=int, default=0)
parser.add_argument("--num_devices", type=int, default=8)
args = parser.parse_args()

In [3]:
model_id = "beomi/Llama-3-Open-Ko-8B-Instruct-preview"
tokenizer = AutoTokenizer.from_pretrained(model_id)

start_tokens = [
        tokenizer("[")["input_ids"][0],
        tokenizer(" [")["input_ids"][0],
    ]
end_tokens = [
        tokenizer("]")["input_ids"][0],
        tokenizer(" ]")["input_ids"][0],
    ]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
import torch
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
)
import nltk
from nltk import tokenize
from tools import Retriever
from prompts import retrieval_prompt
from typing import List

nltk.download("punkt")

MAX_BATCH_SIZE = 1  # My 3090 is weak 😔
N = 128  # SEQ Len
MAX_LEN = 1024  # Maximum retrieval length
M = 16  # Min Loss Span To Consider

import json
from typing import List
import torch
from transformers import (
    PreTrainedTokenizerBase,
    pipeline,
    PreTrainedModel,
    TextGenerationPipeline,
)
from torch import nn

MAX_BATCH_SIZE = 1  # My 3090 is weak 😔
N = 64  # SEQ Len
M = 16  # Min Loss Span To Consider


class APICallPostprocessing:
    def __init__(
        self,
        start_tokens: List[int],
        end_tokens: List[int],
        minimum_percentage: float = 0.1,
    ):
        """
        Base API Postprocesing class

        :param start_tokens: token representation for [ or other tokens
        :param end_tokens:  token representation for ] or other tokens
        :param minimum_percentage: pass percentage for candidate generation, less than this are ignored.
        """
        self.start_tokens = start_tokens
        self.end_tokens = end_tokens
        self.minimum_percentage = minimum_percentage
        self.api_text = ""  # API text, might be better to pass it in
        self.k_values = 5  # Default topk generation, might be better to pass it in

    def filter_continuations(
        self,
        input_tokens: torch.Tensor,
        input_logits: torch.Tensor,
        labels: torch.Tensor,
        input_start: int,
        tokenizer: PreTrainedTokenizerBase,
    ) -> (torch.Tensor, torch.Tensor):
        """
        Grab continuations that are valid

        :param input_tokens: tokenized inputs
        :param input_logits: input logits
        :param labels: labels for input logits
        :param input_start: start of real input
        :param tokenizer:
        :return: Values and Indices
        """
        # First, figure out locations...
        probs = torch.softmax(input_logits, dim=-1)
        # Make sure we don't keep any tokens that are supposed to be [
        remove_tokens = 1.0 - torch.sum(
            torch.stack([labels == start_token for start_token in self.start_tokens]),
            dim=0,
        )
        # Get maximum probability... Should be sufficient. Maybe switch to sum if there's issues later
        max_start_tokens = torch.amax(
            torch.stack(
                [probs[:, :, start_token] for start_token in self.start_tokens]
            ),
            dim=0,
        )
        max_start_tokens = max_start_tokens * remove_tokens
        return torch.topk(max_start_tokens[:, : -(M + 1)], k=self.k_values, dim=1)

    def create_candidates(
        self,
        indices: torch.Tensor,
        values: torch.Tensor,
        input_tokens: torch.Tensor,
        labels: torch.Tensor,
        input_start: int,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizerBase,
        generator: TextGenerationPipeline,
        criterion: nn.CrossEntropyLoss,
    ):
        """
        Generates continuations of valid API calls

        :param indices: index to start
        :param values: values for filtering
        :param input_tokens: tokenized input
        :param labels: labels for input
        :param input_start: real start for base loss calculation
        :param model:
        :param tokenizer:
        :param generator: pipeline for text generation
        :param criterion: Should just be CE loss
        :return:
        """
        # Setup lists...
        outputs = list()
        num_to_keeps = list()
        texts_to_test = list()
        max_index = 0
        for i, batch in enumerate(indices):
            for j, index in enumerate(batch):
                if values[i][j] < self.minimum_percentage:
                    continue
                # Get base output
                base_outputs = model(input_tokens[:, input_start:].cuda()).logits[
                    :, index : index + M
                ]
                # Find starting location...
                num_keep = int(input_tokens[:, input_start:].shape[1] - index)
                # Calculate loss without API
                base_loss = criterion(
                    base_outputs.view(-1, base_outputs.size(-1)),
                    labels[:, index : index + M].cuda().view(-1),
                )
                # For padding later
                max_index = max(max_index, index)
                # API Text
                texts_to_test.append(
                    tokenizer.decode(input_tokens[:, : input_start + index][i])
                    + f" [{self.api_text}"
                )
                # grab 5 generations
                outputs.append(
                    generator(
                        texts_to_test[-1], max_new_tokens=28, num_return_sequences=5
                    )
                )
                # Add additional items to generation outputs...
                for k in range(5):
                    outputs[-1][k]["index"] = int(index)
                    outputs[-1][k]["base_loss"] = float(base_loss.item())
                    outputs[-1][k]["base_outputs"] = base_outputs
                # So we know where to look
                num_to_keeps.append(num_keep)
        return outputs, num_to_keeps, texts_to_test, max_index

    def add_api_calls(
        self,
        candidate: int,
        outputs: dict,
        texts_to_test: List[str],
        tokenizer: PreTrainedTokenizerBase,
        input_tokens: torch.Tensor,
        input_start: int,
        nums_to_keep: List[int],
        base_loss: float,
        *args,
        **kwargs,
    ):
        """
        Add API calls here.

        :param candidate: which candidate is being parsed
        :param outputs: individual candidate outputs
        :param texts_to_test: text for candidates
        :param tokenizer:
        :param input_tokens:
        :param input_start:
        :param nums_to_keep: values kept after generation
        :param base_loss: base loss value for candidate
        :param args: args to pass to subclass
        :param kwargs: kwargs to pass to subclass
        :return:
        """
        raise NotImplementedError("Fill this in with your API code please!")

    def generate_continuations(
        self,
        input_tokens: torch.Tensor,
        input_logits: torch.Tensor,
        labels: torch.Tensor,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizerBase,
        *args,
        **kwargs,
    ):
        """
        Generate continuations

        :param input_tokens: input to model
        :param input_logits: output from model
        :param labels: labels for logits
        :param model:
        :param tokenizer:
        :param args: args to pass to add_api_calls
        :param kwargs: kwargs to pass to add_api_calls
        :return: individual candidate outputs
        """
        # Setup token stuff...
        input_start = input_tokens.shape[1] - input_logits.shape[1]
        start_str = tokenizer.decode(input_tokens[:, :input_start][0])
        # Find top tokens...
        values, indices = self.filter_continuations(
            input_tokens, input_logits, labels, input_start, tokenizer
        )
        # setup generation calls...
        # generator = pipeline(
        #     "text-generation", model=model, tokenizer=tokenizer, device=0
        # )  # type: TextGenerationPipeline
        generator = pipeline(
            "text-generation", model=model, tokenizer=tokenizer
        )  # type: TextGenerationPipeline
        criterion = nn.CrossEntropyLoss()
        with torch.no_grad():
            outputs, num_to_keeps, texts_to_test, max_index = self.create_candidates(
                indices,
                values,
                input_tokens,
                labels,
                input_start,
                model,
                tokenizer,
                generator,
                criterion,
            )
            for i in range(len(outputs)):
                generated_texts, max_token_len, max_token_len_base = self.add_api_calls(
                    i,
                    outputs[i],
                    texts_to_test,
                    tokenizer,
                    input_tokens,
                    input_start,
                    num_to_keeps,
                    outputs[i][0]["base_loss"],
                    *args,
                    **kwargs,
                )
                if len(generated_texts) == 0:
                    outputs[i] = None
                    continue
                # shape the batches...
                for j in range(len(generated_texts)):
                    generated_texts[j].append(
                        max_token_len - generated_texts[j][0].shape[1]
                    )
                    if generated_texts[j][-1] != 0:
                        generated_texts[j][0] = torch.cat(
                            (
                                generated_texts[j][0],
                                torch.zeros(
                                    (1, generated_texts[j][-1]),
                                    dtype=generated_texts[j][0].dtype,
                                    device=generated_texts[j][0].device,
                                ),
                            ),
                            dim=1,
                        )
                    generated_texts[j].append(
                        max_token_len_base - generated_texts[j][1].shape[1]
                    )
                    if generated_texts[j][-1] != 0:
                        generated_texts[j][1] = torch.cat(
                            (
                                generated_texts[j][1],
                                torch.zeros(
                                    (1, generated_texts[j][-1]),
                                    dtype=generated_texts[j][1].dtype,
                                    device=generated_texts[j][1].device,
                                ),
                            ),
                            dim=1,
                        )

                test_outputs = model(
                    torch.cat(
                        list(generated_text[0] for generated_text in generated_texts),
                        dim=0,
                    )
                ).logits
                base_outputs = model(
                    torch.cat(
                        list(generated_text[1] for generated_text in generated_texts),
                        dim=0,
                    )
                ).logits
                best_loss = -99.0
                best_output = outputs[i][0]
                for j in range(len(generated_texts)):
                    num_to_keep = generated_texts[j][2]
                    if generated_texts[j][-2] != 0:
                        test = test_outputs[j][: -generated_texts[j][-2]]
                        test_loss = criterion(
                            test[-num_to_keep : -(num_to_keep - M)].view(
                                -1, generated_texts[j][-3]["base_outputs"].size(-1)
                            ),
                            labels[:, -num_to_keep : -(num_to_keep - M)]
                            .cuda()
                            .view(-1),
                        )
                    else:
                        test_loss = criterion(
                            test_outputs[j][-num_to_keep : -(num_to_keep - M)].view(
                                -1, generated_texts[j][-3]["base_outputs"].size(-1)
                            ),
                            labels[:, -num_to_keep : -(num_to_keep - M)]
                            .cuda()
                            .view(-1),
                        )
                    if generated_texts[j][-1] != 0:
                        base = base_outputs[j][: -generated_texts[j][-1]]
                        base_loss = criterion(
                            base[-num_to_keep : -(num_to_keep - M)].view(
                                -1, generated_texts[j][-3]["base_outputs"].size(-1)
                            ),
                            labels[:, -num_to_keep : -(num_to_keep - M)]
                            .cuda()
                            .view(-1),
                        )
                    else:
                        base_loss = criterion(
                            base_outputs[j][-num_to_keep : -(num_to_keep - M)].view(
                                -1, generated_texts[j][-3]["base_outputs"].size(-1)
                            ),
                            labels[:, -num_to_keep : -(num_to_keep - M)]
                            .cuda()
                            .view(-1),
                        )
                    generated_texts[j][-3]["generated_text"] = generated_texts[j][-3][
                        "generated_text"
                    ].replace(start_str, "")
                    if (
                        min(base_loss.item(), generated_texts[j][-3]["base_loss"])
                        - test_loss
                        > best_loss
                    ):
                        best_output = generated_texts[j][-3]
                        best_loss = generated_texts[j][-3]["base_loss"] - test_loss
                if len(generated_texts) > 0:
                    outputs[i] = best_output
                    outputs[i]["Score"] = float(best_loss.item())
                    outputs[i]["base_api_loss"] = float(base_loss.item())
                    del outputs[i]["base_outputs"]
                else:
                    outputs[i] = None
        # print(json.dumps(outputs, indent=2))
        return outputs

    def parse_article(
        self, data: dict, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase
    ):
        """
        Takes in data dict and parses it into API continuations
        :param data: data, assuming it's from load_dataset and has a text field
        :param model:
        :param tokenizer:
        :return: outputs for the input data, should have index of API call insertion, API, and score value at minimum.
        """
        raise NotImplementedError("Fill this in for what you need to do please!")

class RetrievalPostprocessing(APICallPostprocessing):
    def __init__(
        self,
        start_tokens: List[int],
        end_tokens: List[int],
        minimum_percentage: float = 0.1,
    ):
        self.retriever = Retriever()
        self.api_text = "Retrieval("
        super().__init__(start_tokens, end_tokens, minimum_percentage)

    def add_api_calls(
        self,
        candidate: int,
        outputs: dict,
        texts_to_test: List[str],
        tokenizer: PreTrainedTokenizerBase,
        input_tokens: torch.Tensor,
        input_start: int,
        nums_to_keep: List[int],
        base_loss: float,
        *args,
        **kwargs
    ):
        retrieval_strings = args[0]
        generated_texts = list()
        max_token_len = N
        max_token_len_base = N
        for j in range(len(outputs)):
            outputs[j]["Retrieval"] = outputs[j]["generated_text"].replace(
                texts_to_test[candidate], ""
            )
            outputs[j]["Generated"] = outputs[j]["generated_text"].split("Output:")[-1]
            if "]" in outputs[j]["Retrieval"]:
                outputs[j]["Retrieval"] = (
                    outputs[j]["Retrieval"].replace("Retrieval(", "").split("]")[0]
                )
                if ")" in outputs[j]["Retrieval"]:
                    outputs[j]["Retrieval"] = outputs[j]["Retrieval"].split(")")[0]
                outputs[j]["Retrieval_text"] = (
                    "[Retrieval(" + outputs[j]["Retrieval"] + ")"
                )
                base_inputs = tokenizer(
                    outputs[j]["Retrieval_text"] + "]" + "\n",
                    return_tensors="pt",
                )["input_ids"].cuda()
                outputs[j]["Retrieval"] = self.retriever.retrieval(
                    retrieval_strings, outputs[j]["Retrieval"], 3
                )
                outputs[j]["Retrieval_output"] = [outputs[j]["Retrieval_text"][1:], ", ".join(outputs[j]["Retrieval"])]
                outputs[j]["Retrieval_text"] = (
                    outputs[j]["Retrieval_text"]
                    + "->"
                    + ", ".join(outputs[j]["Retrieval"])
                    + "]"
                )
                test_inputs = tokenizer(
                    outputs[j]["Retrieval_text"] + "\n",
                    return_tensors="pt",
                )["input_ids"].cuda()
                test_inputs = torch.concat(
                    [
                        test_inputs.cuda(),
                        input_tokens[:, input_start:].cuda(),
                    ],
                    dim=1,
                )
                if test_inputs.shape[1] > MAX_LEN:
                    continue
                base_inputs = torch.concat(
                    [
                        base_inputs.cuda(),
                        input_tokens[:, input_start:].cuda(),
                    ],
                    dim=1,
                )
                max_token_len = max(max_token_len, test_inputs.shape[1])
                max_token_len_base = max(max_token_len_base, test_inputs.shape[1])
                generated_texts.append(
                    [
                        test_inputs,
                        base_inputs,
                        nums_to_keep[candidate],
                        base_loss,
                        outputs[j],
                    ]
                )
        return generated_texts, max_token_len, max_token_len_base

    def parse_article(
        self, data: dict, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase
    ):
        outputs = list()
        # tokens = tokenizer(data["text"], return_tensors="pt")["input_ids"]
        tokens = tokenizer(data[text_field], return_tensors="pt")["input_ids"]
        print(tokens)
        start_step = 2048//N
        ret_skip = 1024//N  # naively assuming the model should be able to look back if it's less than this.
        total_steps = tokens.shape[1]//N
        for i in range(start_step, total_steps):
            input_tokens = tokens[:, (-N * (i + 1) - 1) : (-N * (i) - 1)]
            labels = tokens[
                :,
                int(tokens.shape[1] + (-N * (i + 1))) : int(tokens.shape[1] + (-N * i)),
            ]
            ret_tokens = tokens[:, : (-(N) * ((i - ret_skip) + 1) - 1)]
            # print(tokens.shape)
            string = tokenizer.decode(input_tokens[0])
            print(string)
            ret_strings = tokenize.sent_tokenize(tokenizer.decode(ret_tokens[0]))
            print(ret_strings)
            # print(ret_strings)
            model_input = tokenizer(
                retrieval_prompt.replace("<REPLACEGPT>", string) + string,
                return_tensors="pt",
            )["input_ids"]
            # print(string)
            # print(model_input.shape)
            with torch.no_grad():
                output = model(model_input.cuda()).logits.cpu()[:, -N:]
            new_outputs = self.generate_continuations(
                model_input,
                output,
                labels,
                model,
                tokenizer,
                ret_strings,
            )
            for output in new_outputs:
                if output is None:
                    continue
                output["index"] += int(tokens.shape[1] + (-N * (i + 1)))
                # filter by score
                if output["Score"] > 1.0:
                    outputs.append([output["Score"], output["index"]] + output["Retrieval_output"])
        return outputs



api_handler = RetrievalPostprocessing(start_tokens, end_tokens)

[nltk_data] Downloading package punkt to /home/jinhakai2/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [5]:
# beomi/Llama-3-Open-Ko-8B-Instruct-preview 모델을 4bit로 load
bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=False,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype="float16",
            )

model = AutoModelForCausalLM.from_pretrained(model_id,quantization_config=bnb_config)

`low_cpu_mem_usage` was None, now set to True since model is quantized.
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.17it/s]


In [6]:
# 한국어 데이터셋을 가져옴. dp는 문장 내에서 단어 간의 의존 관계를 분석하는 데이터셋.
# 일반 평문을 가져오기 위해 dp로 선택
# dataset = load_dataset("klue", "dp", split="train", streaming=True)
# dataset = load_dataset("squad_kor_v1", split="train", streaming=True)
dataset = load_dataset("c4", "en", split="train", streaming=True)
iter_data = iter(dataset)
test = False
counter = 0
file_counter = 0
found_examples = 0
output_dataset = list()
start_time = time.process_time()

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [7]:
data = next(iter_data)
text_field = "text"
data[text_field]

'Beginners BBQ Class Taking Place in Missoula!\nDo you want to get better at making delicious BBQ? You will have the opportunity, put this on your calendar now. Thursday, September 22nd join World Class BBQ Champion, Tony Balay from Lonestar Smoke Rangers. He will be teaching a beginner level class for everyone who wants to get better with their culinary skills.\nHe will teach you everything you need to know to compete in a KCBS BBQ competition, including techniques, recipes, timelines, meat selection and trimming, plus smoker and fire information.\nThe cost to be in the class is $35 per person, and for spectators it is free. Included in the cost will be either a t-shirt or apron and you will be tasting samples of each meat that is prepared.'

In [8]:
# 생성할 개수 목표

num_examples = int(3/float(args.num_devices))
start_count = -1
if os.path.isfile(f"retrieval_data_{args.device_id}.json"):
    with open(f"retrieval_data_{args.device_id}.json") as f:
        output_dataset = json.load(f)
        start_count = output_dataset[-1]['file_index']
        for item in output_dataset:
            num_examples -= len(item['retrieval_outputs'])

In [9]:
from dataclasses import dataclass
from transformers import PreTrainedTokenizerBase
import dateutil.parser as dparser
import random
import re

@dataclass
class AvailableAPIs:
    """Keeps track of available APIs"""

    retrieval: bool = True
    calendar: bool = True
    calculator: bool = True
    llmchain: bool = True

    def check_any_available(self):
        return any([self.retrieval, self.calendar, self.calculator])

def check_apis_available(
    data: dict, tokenizer: PreTrainedTokenizerBase
) -> AvailableAPIs:
    """
    Returns available APIs with boolean flags

    :param data: from load_dataset, assumes ['text'] is available
    :param tokenizer: Tokenizer to tokenize data
    :return: AvailableAPIs
    """
    # tokenized_data = tokenizer(data["text"])["input_ids"]
    tokenized_data = tokenizer(data[text_field])["input_ids"]
    available = AvailableAPIs()
    # In case we need a different version, found this here:
    # https://stackoverflow.com/questions/28198370/regex-for-validating-correct-input-for-calculator
    calc_pattern = re.compile("^(\d+[\+\-\*\/]{1})+\d+$")
    # if len(tokenized_data) < 4096:
    #     available.retrieval = False
    # try:
        # date = dparser.parse(data["url"], fuzzy=True)
    # except (ValueError, OverflowError):
    #     available.calendar = False
    available.retrieval = True
    available.calendar = False
    available.calculator = False
    tried_rand = False
    for i in range(len(tokenized_data) // 100):
        text = tokenizer.decode(tokenized_data[i * 100 : (i + 1) * 100])

        operators = bool(re.search(calc_pattern, text))
        equals = any(
            ["=" in text, "equal to" in text, "total of" in text, "average of" in text]
        )
        if not (operators and equals) and not tried_rand:
            tried_rand = True
            text = text.replace("\n", " ")
            text = text.split(" ")
            text = [item for item in text if item.replace(".", "", 1).isnumeric()]
            if len(text) >= 3:
                if random.randint(0, 99) == 0:
                    available.calculator = True
        else:
            available.calculator = True

    return available

In [None]:
import pytz
import datetime
iter_count = 0
max_iter = 300
korea_time_zone = pytz.timezone('Asia/Seoul')
current_time_kst = datetime.datetime.now(korea_time_zone)
filename = f"retrieval_data_{args.device_id}_{current_time_kst.strftime('%Y-%m-%d_%H-%M-%S')}.json"

while found_examples < num_examples:
    if iter_count == max_iter:
        break
    iter_count +=1
    try:
        data = next(iter_data)
    except StopIteration:
        break  # 더 이상 데이터가 없으므로 루프를 종료합니다.
    if file_counter < start_count: # 0 < -1
        file_counter += 1
        continue
    if file_counter % args.num_devices != args.device_id: # 0 != 0
        file_counter += 1
        continue
    available = check_apis_available(data, tokenizer)
    test = available.retrieval
    if test:
        data_outputs = api_handler.parse_article(data, model, tokenizer)
        output_dataset.append(
            {
                "file_index": file_counter,
                "text": data[text_field],
                "retrieval_outputs": data_outputs
            }
        )
        prev_found = found_examples
        found_examples += len(output_dataset[-1]["retrieval_outputs"])
        eta_s = (num_examples - found_examples) * (time.process_time()-start_time) / max(1, found_examples)
        eta_m = eta_s // 60
        eta_h = eta_m // 60
        eta_m = eta_m - (eta_h*60)
        eta_s = eta_s - ((eta_m*60) + (eta_h*60*60))
        print(f"Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s")
        if found_examples//100 > prev_found//100:
            with open(filename, 'w') as f:
                json.dump(output_dataset, f, indent=2)
        counter += 1
    file_counter += 1
with open(filename, 'w', encoding='utf-8') as f:
    json.dump(output_dataset, f, indent=2, ensure_ascii=False)