In [17]:
import os
import logging
import sys
import traceback

from functools import partial

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import tensor as Tensor

from beartype import beartype
from beartype.typing import List, Callable, Union, Dict, Tuple

# Import parent class of AutoTokenizer
from transformers import LlamaTokenizer, AutoTokenizer, GPT2Tokenizer

from importlib import reload
logging.shutdown()
reload(logging)

pad_sequence = partial(pad_sequence, batch_first=True)
longTensor = partial(Tensor, dtype=torch.long)

PAD_TOKEN = "[PAD]"
PAD_ID = -100
ARROW_TOKEN = 39310
TOOL_TOKEN = "["
END_TOOL_TOKEN = "]"
TOOL_TOKEN_IDS = []
END_API_TOKEN = 50401
OPEN_PARENTHESIS = "("
OPEN_PARENTHESIS_ID = 7
CLOSE_PARENTHESIS = 8

LOGIT_DISPLACEMENT = 0 # This is for models where model at position i gives logits of prediction AFTER seeing i. For models that give logits of prediction BEFORE seeing i, this should be 1.


def log(t, eps=1e-20): return t.clamp(min=eps).log()


def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))


def gumbel_sample(t, temperature=1., dim=-1, eps=1e-10):
    # Returns flat vector
    if temperature == 0:
        return t.argmax(dim=dim)

    return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim=dim)

def renumerate(sequence, start=None):
    n = start
    if start is None:
        n = len(sequence) - 1
    for elem in sequence[::-1]:
        yield n, elem
        n -= 1


# Tools given to the toolmaster must have:
# 1. Name: str - Unique identifier for the tool
# 2. Arg parser: Callable - A function that takes a string and returns a list of arguments
# 3. Tool: Callable - A function that takes a list of argumets and returns a string
# 4. Explanation prompt: Union[torch.Tensor, str] - A string that explains how to use the tool
# 5. Short description: Optional[str] - A short description of the tool



@beartype
class ToolMaster(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        *,
        tool_specs: List[Dict],  # of the form {"name": str, "arg_parser": Callable, "tool": Callable, "explanation_prompt": Union[torch.Tensor, str], "short_desc": Optional[str]}
        tokenizer,
        tool_token_ids: List[int],
        free_generation_prompt: str = None,
        debug_level: int = 0,
        log_dir: str = "/vol/bitbucket/jg2619/augmenting_llms/model_training/models/logs",
        export_tool_execution: bool = False,
        max_new_tokens: int = 30,
        max_response_len: int = 100,
        temperature: float = 0.8,
        catch_answers: bool = False,
        answer_token_ids: List[int] = None,
        post_answer_token_ids: List[int] = None,
    ): 
        super().__init__()

        global PAD_ID, PAD_TOKEN, OPEN_PARENTHESIS_ID, TOOL_TOKEN_IDS
        
        self.model = model
        self.device = model.device
        self.encode = tokenizer.encode
        self.decode = tokenizer.decode

        PAD_ID = tokenizer.pad_token_id
        PAD_TOKEN = tokenizer.pad_token
        TOOL_TOKEN_IDS = longTensor(tool_token_ids, device=self.device).view(-1)
        print(f"Tool token ids device: {TOOL_TOKEN_IDS.device}")

        OPEN_PARENTHESIS_ID = tokenizer.encode(OPEN_PARENTHESIS)
        assert len(OPEN_PARENTHESIS_ID) == 1, "Open parenthesis token must be a single token"
        OPEN_PARENTHESIS_ID = OPEN_PARENTHESIS_ID[0]

        tool_names = [tool_spec["name"] for tool_spec in tool_specs]
        tokenized_tools = [tokenizer.encode(tool_name) for tool_name in tool_names]
        self.tokenized_tools = tokenized_tools
        self.tool_names = tool_names

        tool_name_desc = []
        for spec in tool_specs:
            name_desc = spec["name"]
            if 'tool_short_desc' in spec:
                name_desc += " (" + spec['short_desc'] +")"
            tool_name_desc.append(name_desc)

        if free_generation_prompt is None:
            free_generation_prompt = "You can use these tools to help you answer: [AVAILABLE TOOLS].\n\n"
        
        free_generation_prompt = free_generation_prompt.replace("[AVAILABLE TOOLS]", ", ".join(tool_name_desc))
        self.free_generation_prompt = free_generation_prompt.replace("[PROMPT]", "")
        
        if "[PROMPT]" in free_generation_prompt:
            self.free_gen_sub_idx = len(self.encode(free_generation_prompt.split("[PROMPT]")[0]))
        else:
            self.free_gen_sub_idx = len(self.encode(free_generation_prompt))
        self.tokenized_free_generation_prompt = longTensor(self.encode(self.free_generation_prompt)).to(self.device)

        tool_selection_dict = {}
        # This function creates a decision tree for the tool selection. The model chooses at each depth the token with the highest probability, until it reaches a tool id.
        def tree_maker(tree, token, id, depth):
            tokens = list(tree.keys())
            if token not in tokens:
                tree[token] = id 
            else:
                if token == OPEN_PARENTHESIS_ID:
                    print(f"Warning: tool {tokenized_tools[id]} is already in the tree")
                    return
                # Check if instance of dictionary:
                if not isinstance(tree[token], dict):
                    other_id = tree[token]
                    next_token = tokenized_tools[other_id][depth+1] if depth + 1 < len(tokenized_tools[other_id]) else OPEN_PARENTHESIS_ID
                    tree[token] = {next_token: other_id}
                next_token = tokenized_tools[id][depth+1] if depth + 1 < len(tokenized_tools[id]) else OPEN_PARENTHESIS_ID
                tree_maker(tree[token], next_token, id, depth + 1)

        for i, tool in enumerate(tokenized_tools):
            tree_maker(tool_selection_dict, tool[0], i, 0)

        self.tool_selection_dict = tool_selection_dict
        tool_explanation_prompts = [tool_spec["explanation_prompt"] for tool_spec in tool_specs]
        prepare_explan_for_gen = lambda x: longTensor(self.encode(x.replace("[PROMPT]", ""))).to(self.device)
        self.tool_explanation_prompts = list(map(prepare_explan_for_gen,  tool_explanation_prompts))

        self.tool_explan_sub_indices = []
        for explan in tool_explanation_prompts:
            if "[PROMPT]" in explan:
                self.tool_explan_sub_indices.append(len(self.encode(explan.split("[PROMPT]")[0])))
            else:
                self.tool_explan_sub_indices.append(len(self.encode(explan)))
    
        self.tools = [tool_spec["tool"] for tool_spec in tool_specs]
        self.arg_parsers = [tool_spec["arg_parser"] for tool_spec in tool_specs]
        self.tokenized_tools = tokenized_tools

        self.max_new_tokens = max_new_tokens
        self.max_response_len = max_response_len
        self.temperature = temperature

        self.debug_level = debug_level
        # Create log dir
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        # count files in log dir
        i = len(os.listdir(log_dir))
        logging.basicConfig(filename=f'{log_dir}/{i}.log', level=logging.DEBUG if debug_level>0 else logging.INFO, format='%(asctime)s:  %(message)s', datefmt='%m/%d/%Y %I:%M:%S  ')
        print(f"Logging to {log_dir}/{i}.log")

        self.catch_answers = catch_answers
        self.answer_token_ids = torch.tensor(answer_token_ids, dtype=torch.int32, device=self.device)
        if post_answer_token_ids is not None:
            post_answer_token_ids = longTensor(post_answer_token_ids).to(self.device)
        self.post_answer_token_ids = post_answer_token_ids
        
        handler = logging.StreamHandler(sys.stdout)
        handler.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s:  %(message)s')
        handler.setFormatter(formatter)

        logger_root = logging.getLogger()
        logger_root.addHandler(handler)

        # Tokens with →
        self.arg_gen_stoppers = []
        for key, value in tokenizer.get_vocab().items():
            if "→" in key or ")" in key:  # Remove close parenthesis TODO
                self.arg_gen_stoppers.append(value)
        self.arg_gen_stoppers = Tensor(self.arg_gen_stoppers).to(self.model.device)

        self.export_tool_execution = export_tool_execution

        # COPY PROMPT X TIMES
        # self.tokenized_free_generation_prompt.unsqueeze(0).repeat(batch_size,1)

    @torch.no_grad()
    def generate(self, 
                 user_prompts: list[torch.Tensor], 
                 explanation_prompts: Union[List[torch.Tensor], torch.Tensor],
                 generated_content: List[torch.Tensor] = None,
                 tool_history: List[List[Dict]] = None,
                 arg_selection_mode: bool = False,  # Arg selection mode VS free generation augmented with tool selection
                 max_new_tokens: int = 100,
                 temperature: float = 0.8, 
                 stop_tokens: Union[List[int],int,torch.Tensor] = [],
                 sub_indices: Union[int,List[int]] = None):
        
        global PAD_ID, PAD_TOKEN, OPEN_PARENTHESIS_ID, TOOL_TOKEN_IDS, LOGIT_DISPLACEMENT

        device = self.device

        # Each data point as it goes through the loop should have:
        # 1. An updated prime: The data point to be completed including the generated content
        # 2. The original prime: The prompt that was prepended to the prime
        # 3. Total generated content: The generated content for the data point
        # 4. Current generated content: The last generated content for the data point in the current loop section
        # 5. The tool history



        if isinstance(explanation_prompts, torch.Tensor) and explanation_prompts.dim() == 1:
            explanation_prompts = [explanation_prompts for _ in range(len(user_prompts))]

        assert user_prompts[0].dim() == 1, "Primes must be 1D tensor with the tokenized data"
        assert explanation_prompts[0].dim() == 1, "Prompt must be 1D tensor with the tokenized prompt"
        if isinstance(sub_indices, list):
            assert len(sub_indices) == len(explanation_prompts), "If Sub indices is a list, it must have the same length as the explanations"
        
        # Device assertions
        assert user_prompts[0].device == device, "Primes must be on the same device as the model"
        assert explanation_prompts[0].device == device, "Prompts must be on the same device as the model"
        if generated_content is not None:
            assert generated_content[0].device == device, "Generated content must be on the same device as the model"
        
        batch_size = len(user_prompts)                                                  # BATCH SIZE
        explan_prompt_lens = Tensor([prompt.shape[0] for prompt in explanation_prompts]).to(device)  # LENGTHS OF PREPENDED PROMPTS
        
        if tool_history is None:        # History of tools used for each row
            tool_history = [[] for _ in range(batch_size)]
        if generated_content is None:   # Generated content for each row
            generated_content = [longTensor([]).to(device) for _ in range(batch_size)]
        generated_content_lens = Tensor([content.shape[0] for content in generated_content]).to(device) 
        new_generated_content = [longTensor([]).to(device) for _ in range(batch_size)]
        if isinstance(stop_tokens, int):
            stop_tokens = [stop_tokens]
        if not isinstance(stop_tokens, torch.Tensor):
            stop_tokens = longTensor(stop_tokens).to(device).view(-1)

        # Position of where to start generating for each row
        positions = Tensor([user_prompt.shape[0]-1 for user_prompt in user_prompts]).to(device).unsqueeze(1)
        positions += explan_prompt_lens.unsqueeze(1) + generated_content_lens.unsqueeze(1) # Add prompt and generated content lengths
        initial_positions = positions.clone() # Initial positions of each row for data + prompt

        if sub_indices is None or sub_indices == -1:
            joined_prompts = [torch.cat([prompt, user_prompt]) for prompt, user_prompt in zip(explanation_prompts, user_prompts)]
        elif isinstance(sub_indices, int):
            # Insert user_prompt at sub_indiices
            joined_prompts = [torch.cat([prompt[:sub_indices], user_prompt, prompt[sub_indices:]]) for prompt, user_prompt in zip(explanation_prompts, user_prompts)]
        else:
            joined_prompts = [torch.cat([prompt[:sub_index], user_prompt, prompt[sub_index:]]) for prompt, user_prompt, sub_index in zip(explanation_prompts, user_prompts, sub_indices)]

        batch_input = [torch.cat([prompt, content]) for prompt, content in zip(joined_prompts, generated_content)]
        batch_lengths = Tensor([row.shape[0] for row in batch_input]).to(device)
        batch_input = pad_sequence(batch_input, padding_value=PAD_ID)
        extra_pad = (batch_lengths + max_new_tokens - generated_content_lens - batch_input.shape[1]).max().item()
        batch_input = F.pad(batch_input, (0, extra_pad,), value=PAD_ID)

        logging.debug(f"Extra pad: {extra_pad}")
        logging.debug(f"Batch lengths: {batch_lengths}")
        logging.debug(f"Max new tokens: {max_new_tokens}")
        logging.debug(f"Generated content lens: {generated_content_lens}")
        logging.debug(f"Batch input shape: {batch_input.shape}")

        logging.debug("Primes:")
        for row in batch_input:
            logging.debug(self.decode(row))
        
        # Indexing tensor utils
        loop_to_data_idx = torch.arange(batch_size).to(device)                 # Mapping from loop index to batch index
        batch_indices = torch.arange(batch_size).to(device).unsqueeze(1)       # ARANGE THAT ADJUSTS TO THE LOOP BATCH SIZE AS SAMPLES FINISH
        
        # Tool selection utils
        loop_selection_depth = torch.zeros(batch_size).int().to(device)        # Depth of the tool selection tree
        loop_is_selecting_tools = torch.zeros(batch_size).bool().to(device)    # Indices where we are selecting a tool
        current_opts = [self.tool_selection_dict for _ in range(batch_size)]               # Current tool options for row

        can_be_stopped = (torch.ones(batch_size).to(device)*(0 if stop_tokens == [] or self.catch_answers else 1)).bool()            # Indices where we can stop
        if self.catch_answers:
            stop_tokens = self.post_answer_token_ids

        logging.debug(f"can_be_stopped: {can_be_stopped}")

        batch_generated_count = generated_content_lens.clone() # Number of tokens generated for each row

        while batch_indices.shape[0] > 0:

            # Remove assertion: TODO
            assert loop_to_data_idx.shape[0] == batch_indices.shape[0], "Loop to data index and batch indices must have the same shape"

            # MODEL FORWARD CALL. MAINTAINS SHAPE EVEN AFTER INDEXING
            #print(f"Input shape {batch_input.shape}")
            #print(f"Positions shape {positions.shape}")
            #print(f"Loop to data idx shape {loop_to_data_idx.shape}")
            #print(f"Batch indices shape {batch_indices.shape}")
            #print(positions)
            output =self.model(batch_input[loop_to_data_idx], use_cache=False)
            loop_last_logits = output.logits[batch_indices, positions[loop_to_data_idx] + LOGIT_DISPLACEMENT]
            #loop_last_logits[:, :, TOOL_TOKEN_IDS[0]] += 10   #TOOL_TOKEN_ID , 13 fulstp

            positions[loop_to_data_idx] += 1

            if arg_selection_mode:   # Tool usage not available
                loop_last_logits[:, :, TOOL_TOKEN_IDS] = -1e10

            # Gumbel sample for rows not selecting a tool. Tool selection has different sampling procedure
            sample_ids = loop_to_data_idx[~loop_is_selecting_tools]
            loop_sampled = torch.ones(batch_indices.shape[0], 1).long().to(device)*-1
            loop_sampled[~loop_is_selecting_tools] = gumbel_sample(loop_last_logits[~loop_is_selecting_tools], temperature=temperature)
            batch_input[sample_ids.unsqueeze(1), positions[sample_ids]] = loop_sampled[~loop_is_selecting_tools]
            batch_generated_count[sample_ids] += 1

            # Catch answers
            if not arg_selection_mode and self.catch_answers:
                # Check if any of the tokens are answer tokens
                can_be_stopped += torch.isin(loop_sampled, self.answer_token_ids).view(-1).bool()
                print(f"Can be stopped: {can_be_stopped}")

            # Sampling procedure for rows selecting a tool
            if loop_is_selecting_tools.any():
                print("SELECTING TOOLS LOOP")
                for selecting_i in reversed(loop_is_selecting_tools.nonzero().squeeze(1)):
                    
                    data_i = loop_to_data_idx[selecting_i].item()
                        # Tool names are composed of tokens. ie. [CAL] [CUL] [ATOR]. We call each token a syllable
                        # Options for the next syllable. 
                    syllable_opts = Tensor(list(current_opts[data_i].keys())).to(device)
                    next_syllable_idx = loop_last_logits[selecting_i,0,syllable_opts].argmax(dim=-1)
                    next_syllable = syllable_opts[next_syllable_idx].item()
                    batch_input[data_i, positions[data_i]] = next_syllable
                    loop_selection_depth[selecting_i] += 1
                    current_opts[data_i] = current_opts[data_i][next_syllable]

                    batch_generated_count[data_i] += 1

                    # If current opts is a dict, there is a tie between possible tools. We need to keep selecting syllables.
                    if not isinstance(current_opts[data_i], dict):   # ELSE: We've reached a tool id
                        tool_id = current_opts[data_i]
                        depth = loop_selection_depth[selecting_i].item()-1   # Selection_depth = i means we've selected the ith syllable of tool name. -1 for indexing purposes.
                        tool_len = len(self.tokenized_tools[tool_id])

                        batch_generated_count[data_i] += tool_len + 1 - depth - 1 # +1 for open parenthesis

                        if batch_generated_count[data_i] >= max_new_tokens:
                            logging.warning(f"Stopping generation at row {data_i} that reached the generation limit")
                            logging.warning(f"Data: {self.decode(batch_input[data_i])}")
                            logging.warning(f"Data id {data_i}")
                            logging.warning(f"Tool history: {tool_history[data_i]}")
                            positions[data_i] = -1
                        else:
                            batch_input[data_i, positions[data_i]-depth:positions[data_i]-depth+tool_len] = Tensor(self.tokenized_tools[tool_id]).to(device)
                            batch_input[data_i, positions[data_i]-depth+tool_len] = OPEN_PARENTHESIS_ID

                            new_generated_content[data_i] = batch_input[data_i, initial_positions[data_i]+1:positions[data_i]-depth+tool_len+1]

                            tool_history[data_i].append({"id": tool_id})

                        # Remove index i
                        remove_index = torch.arange(loop_to_data_idx.shape[0]).to(device) != selecting_i
                        loop_is_selecting_tools = loop_is_selecting_tools[remove_index]
                        loop_selection_depth = loop_selection_depth[remove_index]
                        loop_to_data_idx = loop_to_data_idx[remove_index]
                        can_be_stopped = can_be_stopped[remove_index]
                        loop_sampled = loop_sampled[remove_index]
                        batch_indices = batch_indices[:-1]

                    

            print(f"Sampled: {', '.join([self.decode(sample) for sample in loop_sampled if sample != -1])}")

            # Check if any row wants to use a tool
            just_sampled_tool = torch.isin(loop_sampled, TOOL_TOKEN_IDS.to(device)).view(-1)
            if (just_sampled_tool).any():   # New rows selecting tools!
                loop_is_selecting_tools[just_sampled_tool] = True

            # Rows that reached the max number of tokens, we finish the call
            reached_limit = batch_generated_count[loop_to_data_idx] >= max_new_tokens
            logging.debug(f"Shape of reached limit: {reached_limit.shape}")
            # Sequence that reached the stop token
            finished = (can_be_stopped & torch.isin(loop_sampled.squeeze(1), stop_tokens)) | reached_limit
            if finished.any():
                print(f"{finished.sum()} FINISHED")
                print(f"Reached limit: {reached_limit.sum()}")
                print(f"Finished tensor: {finished}")
                print(f"reached limit tensor: {reached_limit}")
                print(f"Can be stopped: {can_be_stopped}")

                for finished_i in finished.nonzero().squeeze(1):
                    data_i = loop_to_data_idx[finished_i].item()
                    new_generated_content[data_i] = batch_input[data_i, initial_positions[data_i]+1:positions[data_i]+1]
                    if reached_limit[finished_i]:
                        logging.warning(f"Stopping generation at row {data_i} that reached the generation limit")
                        logging.warning(f"Data: {self.decode(batch_input[data_i])}")
                        if arg_selection_mode:
                            # model failed to generate arguments.
                            logging.warning(f"Model failed to generate arguments for: \ndata: {self.decode(batch_input[data_i])}")
                            logging.warning(f"Data id {data_i}")
                            logging.warning(f"Tool history: {tool_history[data_i]}")
                            tool_history[data_i][-1]["status"] = "Failed to generate arguments"
                            positions[data_i] = -1    # This marks tool use error - rectifies use and resumes generation
                            new_generated_content[data_i] = longTensor([]).to(device)


                if not arg_selection_mode:
                    # These rows are done generating. Mark them as finished
                    positions[loop_to_data_idx[finished]] = -1

                loop_is_selecting_tools = loop_is_selecting_tools[~finished]
                loop_selection_depth = loop_selection_depth[~finished]
                loop_to_data_idx = loop_to_data_idx[~finished]
                can_be_stopped = can_be_stopped[~finished]
                batch_indices = batch_indices[:-finished.sum().item()]


        output = {
            "user_prompts": user_prompts,
            "generated_content": [torch.cat([content, new_content]) for content, new_content in zip(generated_content, new_generated_content)],
            "tool_history": tool_history,
            "status": positions,
            "sampled_args": [arg[:-1] for arg in new_generated_content],
        }
        if not arg_selection_mode:
            del output["sampled_args"]

        return output

    


    def forward(self, 
                sentences: List[str],):

        # We receive a batch of texts. 
        logging.info("FORWARD TOOLMASTER")
        logging.info(f"Received batch of {len(sentences)} sentences")

        device = self.device

        # We tokenize the texts and store then in tuples with (tokenized_sentence, pos, count generation, tool_history)
        pending_completion = [(longTensor(self.encode(user_prompt)).to(device), longTensor([]).to(device), []) for user_prompt in sentences]
        finished_sentences = []
        ids = [i for i in range(len(sentences))]

        while len(pending_completion) > 0:

            ####################################################
            # FREE GENERATION MODE AUGMENTED WITH TOOL SELECTION
            ####################################################
            logging.info("STARTING FREE GENERATION MODE AUGMENTED WITH TOOL SELECTION")
            i = 0
            batch_size = 11
            pending_arg_sampling = []
            pending_count = len(pending_completion)

            while pending_count > 0:
                logging.debug(f"Processing batch {i+1}. Sentences processed: {len(pending_completion)-pending_count}/{len(pending_completion)}   ({(len(pending_completion)-pending_count)/len(pending_completion)*100:.2f}%))")
                start_idx = max(0, pending_count-batch_size)
                sentence_batch = pending_completion[start_idx:pending_count]

                try:
                    user_prompts, generated_content, tool_history = zip(*sentence_batch)
                    output_dict = self.generate(user_prompts = [prompt for prompt in user_prompts],
                                                generated_content=list(generated_content),
                                                explanation_prompts = self.tokenized_free_generation_prompt,
                                                tool_history=list(tool_history),
                                                max_new_tokens = self.max_new_tokens,
                                                arg_selection_mode = False,
                                                temperature=self.temperature,
                                                sub_indices=self.free_gen_sub_idx)            
                except torch.cuda.OutOfMemoryError as e: # type: ignore
                    batch_size-=5
                    sentence_batch = sentence_batch[5:]
                    logging.info(f"Out of memory error. Reducing batch size to {batch_size}")
                    continue
                
                pending_count -= len(sentence_batch)
                finished_count = 0
                tools_called = [0 for _ in range(len(self.tools))]
                for i, (user_prompt, generated_content, tool_history, status) in renumerate(list(zip(*output_dict.values()))):
                    if status == -1:
                        finished_sentences.append({"id":ids.pop(i), "user_prompt": self.decode(user_prompt.cpu()), "response":self.decode(generated_content.cpu()), "tool history":tool_history})
                        finished_count += 1
                    elif len(tool_history) > 0:
                        print(f"Tool use: {tool_history[-1]}")
                        tools_called[tool_history[-1]["id"]] += 1
                        pending_arg_sampling.append((user_prompt, generated_content, tool_history,))

                logging.info(f"Batch {i+1} processed. Finished sentences: {finished_count}/{len(sentence_batch)}, rest use tools.")
                logging.info(f"Tools were called the following number of times:")
                for tool_name, tool_count in zip(self.tool_names, tools_called):
                    logging.info(f"{tool_name}: {tool_count}")
                i+=1


            ####################################################
            # ARGUMENT GENERATION MODE
            ####################################################

            logging.info("STARTING ARGUMENT GENERATION MODE")
            batch_size = 11
            pending_completion = []
            pending_tool_execution = []
            total_pending_args = len(pending_arg_sampling)
            pending_count = total_pending_args
            logging.info(f"Pending: {pending_arg_sampling}")

            while pending_count > 0:
                logging.debug(f"Processing batch {i+1}. Sentences processed: {len(pending_arg_sampling)-pending_count}/{len(pending_arg_sampling)}   ({(len(pending_arg_sampling)-pending_count)/len(pending_arg_sampling)*100:.2f}%))")
                
                print(pending_arg_sampling)
                start_idx = max(0, pending_count-batch_size)
                sentence_batch = pending_arg_sampling[start_idx:pending_count]
                try:
                    user_prompt, generated_content, tool_histories = zip(*sentence_batch)

                    explanation_prompts = []
                    sub_idx = []
                    for hist in tool_histories:
                        tool_id = hist[-1]["id"]
                        explanation_prompts.append(self.tool_explanation_prompts[tool_id])
                        sub_idx.append(self.tool_explan_sub_indices[tool_id])

                    print("ARG GEN PROMPTS")
                    print(explanation_prompts)
                    output_dict = self.generate(user_prompts = list(user_prompt),
                                                generated_content=list(generated_content),
                                                explanation_prompts = explanation_prompts,
                                                tool_history=list(tool_histories),
                                                max_new_tokens = self.max_new_tokens,
                                                arg_selection_mode = True,
                                                stop_tokens=self.arg_gen_stoppers,
                                                temperature=self.temperature,
                                                sub_indices=sub_idx)
                except torch.cuda.OutOfMemoryError as e: # type: ignore
                    batch_size-=5
                    sentence_batch = sentence_batch[5:]
                    logging.info(f"Out of memory error. Reducing batch size to {batch_size}")
                    continue
                
                pending_count -= len(sentence_batch)
                finished_count = 0
                for i, (user_prompt, generated_content, tool_history, status, sampled_args) in enumerate(zip(*output_dict.values())):
                    if status == -1:
                        pending_completion.append((user_prompt, generated_content, tool_history))
                        finished_count += 1
                    else:
                        # TOOL SELECTION BABY
                        pending_tool_execution.append((user_prompt, generated_content, tool_history, sampled_args))

            ####################################################
            # TOOL EXECUTION
            ####################################################

            logging.info("STARTING TOOL EXECUTION")
            if not self.export_tool_execution:
                for i, (user_prompt, generated_content, tool_history, sampled_args) in renumerate(pending_tool_execution):
                    tool_id = tool_history[-1]["id"]
                    try:
                        parsed_args = self.arg_parsers[tool_id](self.decode(sampled_args))

                        logging.info(f"Executing tool {self.tool_names[tool_id]} with args {parsed_args}")
                        tool_output = self.tools[tool_id](*parsed_args)
                        tool_history[-1]["status"] = "Success"
                        tool_history[-1]["args"] = self.decode(sampled_args)
                        tool_history[-1]["parsed args"] = parsed_args
                        tool_history[-1]["output"] = tool_output

                    except Exception as e:
                        logging.warning(f"Error executing tool {self.tool_names[tool_id]} with args {parsed_args}")
                        # Print stack trace
                        logging.warning(traceback.format_exc())
                        logging.warning(f"Error: {e}")
                        tool_output = e
                        tool_history[-1]["status"] = "Error executing tool"

                        # Remove bad call from sentence
                        # sentence = sentence[:-sampled_args.shape[0]]
 
                    tool_output = self.encode(")→ " + str(tool_output), truncation=True, max_length=self.max_response_len)
                    tool_output = self.encode(self.decode(tool_output) + END_TOOL_TOKEN, return_tensors="pt")[0].to(device).long()
                    generated_content = torch.cat((generated_content[:-1], tool_output))

                    if generated_content.shape[0] < self.max_new_tokens:
                        pending_completion.append((user_prompt, generated_content, tool_history))
                    else:                        
                        finished_sentences.append({"id":ids.pop(i), "user_prompt":self.decode(user_prompt.cpu()), "response":self.decode(generated_content.cpu()), "tool_history":tool_history})

        finished_sentences.sort(key=lambda x: x["id"])

        if self.export_tool_execution:
            return finished_sentences, pending_tool_execution

        
        return finished_sentences
    

In [21]:

def print_dict(d, indent=0):
    for key, value in d.items():
        if isinstance(value, dict):
            print('  ' * indent + str(key))
            print_dict(value, indent+1)
        else:
            print('  ' * indent + str(key) + ": " + str(value))


a = {"a":{1:2, 3:4}, "b":{5:6, 7:8}, "c":2}

print_dict(a)

a
  1: 2
  3: 4
b
  5: 6
  7: 8
c: 2


In [7]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config



cache_dir = "/vol/bitbucket/jg2619/augmenting_llms/augmented_data_pipeline/toolformer/cache"
model = GPT2LMHeadModel.from_pretrained("gpt2", cache_dir=cache_dir)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", cache_dir=cache_dir)


tokenizer.pad_token = "!"

model.resize_token_embeddings(len(tokenizer))

Embedding(50257, 768)

In [18]:
# of the form {"name": str, "arg_parser": Callable, "tool": Callable, "explanation_prompt": Union[torch.Tensor, str], "short_desc": Optional[str]}

tool_specs =  [{
    "name": "pops", "arg_parser": lambda x: [x], "tool": lambda x: x, "explanation_prompt": "Use the pops tool to repeat a piece of text.", "short_desc": "(repeat a piece of text)"
}, {
    "name": "and she", "arg_parser": lambda x: [x], "tool": lambda x: x, "explanation_prompt": "tool"
}, {
    "name": "and he", "arg_parser": lambda x: [x*2], "tool": lambda x: x, "explanation_prompt": "lama"}]


toolmaster = ToolMaster(model,
                        tool_specs = tool_specs,
                        tokenizer = tokenizer,
                        tool_token_ids=tokenizer.encode("["),
                        log_dir = "/vol/bitbucket/jg2619/augmenting_llms/model_training/models/logs",
                        export_tool_execution = False,
                        max_new_tokens = 40,
                        debug_level=1,
                        catch_answers=True,
                        answer_token_ids=tokenizer.encode("Sentence"),
                        post_answer_token_ids=tokenizer.encode(":"),
                        )

#input = Tensor([tokenizer.encode(x) for x in ["This is a sentence", "This is another sentence"]]).long()

#print(input)
#model(input[Tensor([[0],[1]])])

for output in toolmaster(["Sentence 1: The man talked on and on about", "Sentence 2: This is another sentence but"]):
    print("HALLELUYAH")
    print(output["response"])

Tool token ids device: cpu
Logging to /vol/bitbucket/jg2619/augmenting_llms/model_training/models/logs/1.log
2023-08-11 17:36:26,273:  FORWARD TOOLMASTER
2023-08-11 17:36:26,274:  Received batch of 2 sentences
2023-08-11 17:36:26,278:  STARTING FREE GENERATION MODE AUGMENTED WITH TOOL SELECTION


Can be stopped: tensor([False, False])
Sampled:  how,  I
Can be stopped: tensor([False, False])
Sampled:  he,  don
Can be stopped: tensor([False, False])
Sampled:  was, 't
Can be stopped: tensor([False, False])
Sampled:  going,  intend
Can be stopped: tensor([False, False])
Sampled:  to,  to
Can be stopped: tensor([False, False])
Sampled:  be,  surprise
Can be stopped: tensor([False, False])
Sampled:  able,  you
Can be stopped: tensor([False, False])
Sampled:  to, .
Can be stopped: tensor([False, False])
Sampled:  keep,  I
Can be stopped: tensor([False, False])
Sampled:  him,  am
Can be stopped: tensor([False, False])
Sampled:  in,  not
Can be stopped: tensor([False, False])
Sampled:  his,  sure
Can be stopped: tensor([False, False])
Sampled:  situation,  what
Can be stopped: tensor([False, False])
Sampled:  and,  you
Can be stopped: tensor([False, False])
Sampled:  make, 're
Can be stopped: tensor([False, False])
Sampled:  a,  talking
Can be stopped: tensor([False, False])
Sampled:  l

In [7]:


def balance_parentheses(string):
    # While there exists "()", remove all occasions of "()" from the string
    while re.search(r'\(\)', string):
        string = re.sub(r'\(\)', '', string)

    # Find the starting substring that contains only "(" and ")":
    closing_at_start_count = length_of_match(r'^\)+', string)
    if closing_at_start_count > 0:
        string = string[closing_at_start_count:]

    opening_at_start_count = length_of_match(r'^\(+', string)

    opening_at_end_count = length_of_match(r'\(+$', string)
    if opening_at_end_count > 0:
        string = string[:-opening_at_end_count]
    closing_at_end_count = length_of_match(r'\)+$', string)

    remove = min(opening_at_start_count, closing_at_end_count)
    string = string[remove:-remove] if remove > 0 else string

    opening_count = string.count('(')
    closing_count = string.count(')')

    diff = opening_count - closing_count
    if diff > 0:
        string = string + ')' * diff
    else:
        string = '(' * (-diff) + string

    i = 0
    compensate = 0
    for char in string:
        if char == '(':
            i += 1
        elif char == ')':
            i -= 1
        if i < 0:
            compensate = min(i, compensate)

    string = '(' * abs(compensate) + string + ')' * abs(compensate)

    return string


@beartype
def calc_parse(args: str):
    multi_x_pattern = r'(?<=\d)x(?=\d)|(?<=\d)x(?!\d)|(?<!\d)x(?=\d)'
    args = re.sub(" ", "", args)
    args = re.sub(multi_x_pattern, "*", args)
    args = re.sub("÷", "/", args)
    args = re.sub("−", "-",args)
    args = re.sub("×", "*", args)
    args = re.sub(r'[^0-9+\-*/().]', '', args)

    args = balance_parentheses(args)

    return args



@beartype
def Calculator(input_query: str, extraargs=None, first=True, detail=False):
    operators = {
        '+': add,
        '-': sub,
        '*': mul,
        '/': truediv
    }
    if first:
        # calc_preprocess_args(input_query) FOR INFERENCE
        # Strip whitespace
        input_query = input_query.replace(" ", "")
        input_query = balance_parentheses(input_query)

    # Handle expressions within parentheses
    while '(' in input_query:
        start = input_query.rindex('(')
        end = start + input_query[start:].index(')')
        sub_expr = input_query[start + 1:end]
        result, op_performed = Calculator(sub_expr, first=False, detail=True)
        first = not op_performed if first else first
        input_query = input_query[:start] + str(result) + input_query[end + 1:]

    try:
        number = float(input_query)
        if first:
            raise Exception("Useless API call. 1 digit is not a calculation.")
        else:
            return (number, False) if detail else number
    except ValueError as e:
        pass

    for c in operators.keys():
        left, operator, right = input_query.partition(c)
        if len(operator) > 0:
            answer = round(operators[operator](Calculator(left, first=False), Calculator(right, first=False)), 2)
            answer = str(answer) if first else answer
            return (answer, True) if detail else answer

cache_dir = "/vol/bitbucket/jg2619/augmenting_llms/augmented_data_pipeline/toolformer/cache"

calculator_explanation = """You can use the Calculator tool to get information required to complete the text. You can call the API by writing "[Calculator(expression)]" where "expression" is the expression to be computed. Here are some examples of its usage:

Example 1: Last year we collected 237342 apples, double of what we collected this year: [Calculator(237342/2)→ 118671] 118671.

Example 2: The number in the next term is 18 + 12 x 3 = [Calculator(18+(12*3))→ 54] 54.

Example 3: A total of 252 matches were played, and 723 goals were scored (an average of [Calculator(723/252)→ 2.87] 2.87 per match). This is twenty goals more than the [Calculator(723-20)→703] 703 goals last year.

Example 4: I went to Paris in 1994 and stayed there until 2011, so in total, it was [Calculator(2011-1994)→ 17] 17 years.


With this in mind, complete the text below:\n\n"""

free_generation_prompt = """You are a question answering model that can use external tools to answer questions. You can call the API by writing "[ToolName(arguments)]" where "ToolName" is the name of the tool and "arguments" are the arguments to be passed to the tool. Here are some examples of its usage:

Example 1: Last year we collected 237342 apples, double of what we collected this year: [Calculator(237342/2)→ 118671] 118671.

Example 2: The number in the next term is 18 + 12 x 3 = [Calculator(18+(12*3))→ 54] 54.

Example 3: A total of 252 matches were played, and 723 goals were scored (an average of [Calculator(723/252)→ 2.87] 2.87 per match). This is twenty goals more than the [Calculator(723-20)→703] 703 goals last year.

Example 4: I went to Paris in 1994 and stayed there until 2011, so in total, it was [Calculator(2011-1994)→ 17] 17 years.


You can use the following tools: Calculator (can add, subtract, multiply and divide), Calendar (returns the current date), and WikiSearch (can search Wikipedia). Now, answer the following question. Lets think step by step. When you find the answer, write "Answer(<your_answer>)". For example, if the answer is 42, write "Answer(42)".

Question: """

In [9]:


tool_specs = [{
    "name": "Calculator",
    "arg_parser": lambda x: [calc_parse(x)],
    "tool": Calculator,
    "explanation_prompt": calculator_explanation,
    "short_description": "can add, subtract, multiply and divide"
}]

tool_token_ids = []
for key, value in tokenizer.get_vocab().items():
    if "[" in key:
        tool_token_ids.append(value)



# Load the Toolmaster model
toolmaster = ToolMaster(model, 
                        tokenizer = tokenizer,
                        tool_specs = tool_specs, 
                        tool_token_ids=tool_token_ids,
                        max_new_tokens=100,
                        free_generation_prompt=free_generation_prompt,
                        log_dir="/vol/bitbucket/jg2619/augmenting_llms/benchmarks/logs",
                        )

Tool token ids device: cpu
Logging to /vol/bitbucket/jg2619/augmenting_llms/benchmarks/logs/14.log


In [12]:
import json
# Function that loads and returns the GMS8K dataset
def load_gms8k_easy():
    with open("/vol/bitbucket/jg2619/augmenting_llms/benchmarks/ToolQA/data/questions/easy/gsm8k-easy.jsonl", "r") as f:
        data = [json.loads(line) for line in f.readlines()]
    return data

data = load_gms8k_easy()

In [14]:
questions = [d["question"] for d in data]

toolmaster(questions[:2])

2023-08-04 00:56:18,178:  FORWARD TOOLMASTER
2023-08-04 00:56:18,180:  Received batch of 2 sentences
2023-08-04 00:56:18,185:  STARTING FREE GENERATION MODE AUGMENTED WITH TOOL SELECTION
Pending: [(tensor([15309,    11,  7299,    11,   290,  1215,   417, 18829,    68,   389,
         2111,   284,  4929,   606,   477,    11, 14878,   326,   318,    13,
          220, 17083,   484,   423,  4978, 42489, 14878,    13,   220,  4422,
          468,  4978,   642,   517,   621,  7299,    11,   290,  7299,   468,
         4978,  1511,  1342,   621,   604,  1661,   355,   867,   355,  1215,
          417, 18829,    68,   468,  4978,    13,  1374,   867, 14878,   468,
         7299,  4978,    30]), tensor([], dtype=torch.int64), []), (tensor([ 9527,  4703,   468,   257,  9651,  8914,  2496,   286,   720,    16,
        11623,    13,   554,  3035,    11,   339,  3382,   284,  3613,  5403,
          355,   881,  4445,   287,   262,  1218,  2063,   355,   339, 16031,
          287,   262,   717,  20

[{'original_prime': tensor([15309,    11,  7299,    11,   290,  1215,   417, 18829,    68,   389,
           2111,   284,  4929,   606,   477,    11, 14878,   326,   318,    13,
            220, 17083,   484,   423,  4978, 42489, 14878,    13,   220,  4422,
            468,  4978,   642,   517,   621,  7299,    11,   290,  7299,   468,
           4978,  1511,  1342,   621,   604,  1661,   355,   867,   355,  1215,
            417, 18829,    68,   468,  4978,    13,  1374,   867, 14878,   468,
           7299,  4978,    30]),
  'response': tensor([  220, 23998,    25,   220,   807,   329,  4422,   290,   604,   329,
           7299,    13,  1867,   318,   262,  2811,  1271,   286, 14878,  7299,
            468,  4978,    30,   220, 23998,    25,   220,   513,    13,  3324,
             13,   198,   198,  7583,    11,  1309,   338,   910,  1215,   417,
          18829,    68,   468,  4978,   257,  1218, 10441,   287,   428,   983,
             13,   220,  1215,   417, 18829,    68,   468

In [23]:
a = torch.tensor([ 1639,   460,   779,   777,  4899,   284,  1037,   345,  3280,    25,           290,    11, 26384,    13,   628, 31837,   594,   352,    25,   383,           582,  6619,   319,   290,   319,   546,   261])

output = model(a.unsqueeze(0), use_cache=False).logits

print("HELLO", flush = True)
# Top 5 logits per word:
for i, word in enumerate(output[0]):
    top5 = torch.topk(word, 15)
    print(f"{i}th pos. Word is {tokenizer.decode([a[i]])}")
    print(f"Predictions: `{'`, `'.join([tokenizer.decode(idx) for idx in top5.indices])}`")




for word in a:
    print(f"'{tokenizer.decode(word)}'")

HELLO
0th pos. Word is You
Predictions: ` can`, `'re`, `'ll`, ` are`, ` know`, ` have`, ` want`, `'ve`, ` don`, ` need`, ` will`, ` do`, ` may`, ` must`, ` could`
1th pos. Word is  can
Predictions: ` also`, ` find`, ` see`, `'t`, ` use`, ` read`, ` download`, ` get`, ` check`, ` always`, ` do`, ` follow`, ` tell`, ` buy`, ` make`
2th pos. Word is  use
Predictions: ` the`, ` this`, ` it`, ` a`, ` any`, ` your`, ` these`, ` them`, ` our`, ` an`, ` either`, ` my`, ` that`, ` one`, ` all`
3th pos. Word is  these
Predictions: ` to`, ` tools`, ` methods`, ` as`, ` techniques`, ` links`, ` functions`, ` instructions`, ` resources`, ` options`, ` tips`, ` in`, ` files`, ` commands`, ` for`
4th pos. Word is  tools
Predictions: ` to`, ` in`, ` for`, ` on`, ` and`, ` with`, ` as`, ` at`, `,`, ` if`, ` when`, `:`, ` without`, ` from`, ` or`
5th pos. Word is  to
Predictions: ` help`, ` find`, ` create`, ` make`, ` get`, `:`, ` check`, ` build`, ` identify`, ` quickly`, ` learn`, ` improve`, ` deter

In [109]:
tokenizer.encode("pops")

[79, 2840]

In [26]:
Tensor((2,))

tensor([2])

In [76]:
Tensor([tokenizer.encode(sentence) for sentence in ["The cat sat on the", "The cow jumped over the"]]).long()

Tensor([[1639,  460,  779,  777, 4899,  284, 1037,  345, 3280,   25, 2891,   16,
           11, 2891,   17,   13,  628, 1212,  318,  257, 6827],
        [1639,  460,  779,  777, 4899,  284, 1037,  345, 3280,   25, 2891,   16,
           11, 2891,   17,   13,  628, 1212,  318, 1194, 6827]])


tensor([[  464,  3797,  3332,   319,   262],
        [  464,  9875, 11687,   625,   262]])

In [116]:
batch_inputs = Tensor([tokenizer.encode(sentence) for sentence in ["The cat sat on the", "The cow jumped over the"]]).long()

logits = model(batch_inputs).logits
print(logits.shape)

tokenizer.decode(gumbel_sample(model(batch_inputs).logits[:,4], temperature=0))

torch.Size([2, 5, 50257])


' floor fence'