In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
from concurrent.futures import ThreadPoolExecutor
from typing import (
    Dict,
    List,
    Set,
    Union,
    Any,
    Optional,
    ClassVar,
    AsyncGenerator,
    Literal,
)
from abc import ABC, abstractmethod
import regex as re
from transformers import PreTrainedModel, PreTrainedTokenizer
from tiktoken import Encoding
import numpy as np
from dataclasses import dataclass, field
import warnings
import asyncio
import aiohttp
import torch
from enum import Enum
from itertools import chain
import regex as re
from typing import List, Dict
from uuid import UUID, uuid4
import openai
import tiktoken


TokenIds = List[int]
Tokens = Dict[int, str]
TokenDistribution = Dict[str, int]
SelectedTokens = Set[int]
TokenConstraint = Union[None, SelectedTokens, str]


def transformers_tokens(tokenizer: PreTrainedTokenizer) -> Tokens:
    tokens = {
        token_id: tokenizer.decode(token_id)
        for _, token_id in tokenizer.get_vocab().items()
    }
    return tokens


def openai_tokens(tokenizer: Encoding) -> Tokens:
    vocab_len = len(tokenizer.token_byte_values())
    tokens = {i: tokenizer.decode([i]) for i in range(vocab_len - 1)}
    for i in range(vocab_len, tokenizer.max_token_value):
        tokens[i] = f"<|special_{i}|>"
    return tokens

<IPython.core.display.Javascript object>

In [32]:
!pip freeze | grep reg

regex==2023.5.5


<IPython.core.display.Javascript object>

In [None]:
tiktoken==0.4.0
transformers==4.30.2
regex==2023.5.5

In [3]:
class DecodingStrategy(str, Enum):
    GREEDY = "GREEDY"
    SAMPLE = "SAMPLE"


@dataclass
class Decoder:
    temperature: float = 0.7
    top_p: float = 0.95
    strategy: DecodingStrategy = DecodingStrategy.GREEDY

<IPython.core.display.Javascript object>

In [4]:
class Model(ABC):
    tokens: Tokens
    supported_decodings: ClassVar[Set[DecodingStrategy]]
    max_total_tokens: int = 512

    @abstractmethod
    async def generate(
        self,
        text: str,
        max_tokens: int = 1,
        selected_tokens: Optional[Set[int]] = None,
        decoder: Optional[Decoder] = None,
        timeout: float = 10.0,
    ) -> AsyncGenerator[str, None]:
        """
        Generate text using the Huggingface model.

        Args:
            text: The text to generate from.
            max_length: The maximum length of the generated text.
            selected_tokens: A set of tokens that should be excluded from the generated text.
            decoder: A parameterized description of how to select tokens from the distribution
            timeout: The timeout for the generation process.

        Returns:
            An iterator of generated text.
        """

    async def sample(
        self,
        text: str,
        selected_tokens: Optional[SelectedTokens] = None,
        decoder: Optional[Decoder] = None,
        timeout: float = 10.0,
    ) -> str:
        """Sample from the language model given the input text and the selected tokens to constrain the sampling.

        Args:
            text (str): The input text to the language model.
            selected_tokens (Optional[Set[int]]): The set of token ids to constrain the sampling. Defaults to None.

        Returns:
            str: The generated text from the language model.
        """
        return await anext(
            self.generate(
                text=text,
                max_tokens=1,
                selected_tokens=selected_tokens,
                decoder=decoder,
                timeout=timeout,
            )
        )

    @abstractmethod
    def encode(self, text: str) -> TokenIds:
        """Encode the input text as token ids.

        Args:
            text (str): The input text to encode.

        Returns:
            TokenIds: The encoded token ids.
        """

    @abstractmethod
    def decode(self, ids: TokenIds) -> str:
        """Decode the token ids into text.

        Args:
            ids (TokenIds): The token ids to decode.

        Returns:
            str: The decoded text.
        """

    @property
    def vocab_size(self) -> int:
        """Get the vocabulary size of the language model."""
        return len(self.tokens)

    @property
    @abstractmethod
    def eos_token_id(self) -> int:
        """Get the token id of the end of sequence (eos) token."""

    @property
    @abstractmethod
    def bos_token_id(self) -> int:
        """Get the token id of the beginning of sequence (bos) token."""

<IPython.core.display.Javascript object>

In [5]:
@dataclass
class Huggingface(Model):
    model_name: Optional[str] = None
    model: Optional[PreTrainedModel] = None
    tokenizer: Optional[PreTrainedTokenizer] = None
    chunk_size: int = 64
    supported_decodings: Set[DecodingStrategy] = frozenset(
        (
            DecodingStrategy.GREEDY,
            DecodingStrategy.SAMPLE,
        )
    )

    def __post_init__(self):
        if self.model_name is None and self.model is None and self.tokenizer is None:
            raise ValueError(
                "must specify either `model_name` or both `model` and `tokenizer`"
            )
        if (self.model is not None and self.tokenizer is None) or (
            self.model is None and self.tokenizer is not None
        ):
            raise ValueError(
                "must specify either `model_name` or both `model` and `tokenizer`"
            )
        self.model = self.model or AutoModelForCausalLM.from_pretrained(self.model_name)
        self.tokenizer = self.tokenizer or AutoTokenizer.from_pretrained(
            self.model_name
        )
        self.tokens = transformers_tokens(self.tokenizer)
        self._completion_buffer = {}
        if self.chunk_size < 1:
            raise ValueError(f"`chunksize` must be positive, got {self.chunksize}.")

    def encode(self, text: str) -> TokenIds:
        return self.tokenizer.encode(text)

    def decode(self, ids: TokenIds) -> str:
        return self.tokenizer.decode(ids)

    @property
    def eos_token_id(self) -> int:
        return self.tokenizer.eos_token_id

    @property
    def bos_token_id(self) -> int:
        return self.tokenizer.bos_token_id

    def _logit_processor(self, selected_tokens: Optional[SelectedTokens] = None):
        logits_processor = []
        if selected_tokens is not None:

            def _logits_processor(input_ids, scores):
                mask = np.ones_like(scores) * -1e10
                for token_id in selected_tokens:
                    mask[:, token_id] = 0
                scores = scores + mask
                return scores

            logits_processor.append(_logits_processor)
        return logits_processor

    async def generate(
        self,
        text: str,
        max_tokens: int = 1,
        selected_tokens: Optional[SelectedTokens] = None,
        decoder: Optional[Decoder] = None,
        timeout: float = 10.0,
    ) -> AsyncGenerator[str, None]:
        decoder = decoder or Decoder()
        if decoder.strategy not in self.supported_decodings:
            raise ValueError(
                f"Unsupported decoding strategy for Huggingface model `{decoder.strategy}`."
            )
        temperature = decoder.temperature
        top_p = decoder.top_p
        addtl = {}

        if decoder.strategy == DecodingStrategy.SAMPLE:
            addtl["do_sample"] = True

        gen_kwargs = dict(temperature=temperature, top_p=top_p, **addtl)

        n_gen = 0
        prompt_token_ids = self.tokenizer.encode(text)
        while n_gen < max_tokens:
            max_new_tokens = min(self.chunk_size, max_tokens - n_gen)
            output = await asyncio.to_thread(
                self.model.generate,
                input_ids=torch.tensor(prompt_token_ids)
                .unsqueeze(0)
                .to(self.model.device),
                max_new_tokens=max_new_tokens,
                logits_processor=self._logit_processor(selected_tokens),
                pad_token_id=self.tokenizer.eos_token_id,
                **gen_kwargs,
            )
            new_token_ids = output[0, len(prompt_token_ids) :].detach().cpu().tolist()
            prompt_token_ids += new_token_ids
            tok_str = self.tokenizer.decode(new_token_ids, skip_special_tokens=True)
            text += tok_str
            n_gen += max_new_tokens
            yield tok_str

<IPython.core.display.Javascript object>

In [6]:
class Constraint(ABC):
    @abstractmethod
    def constrain_tokens(
        self, base_text: str, completion_text: str, model: "Model"
    ) -> TokenConstraint:
        """Constrain the token ids that can be sampled from the model's vocabulary.

        Args:
            base_text (str): The text to which the completion_text should be appended.
            completion_text (str): The text to be completed.
            model (Model): The language model to be used.

        Returns:
            None: If no restrictions are to be applied and the full vocabulary can be used.
            set: The set of valid token ids that can be sampled.
            str: If the constraint is complete and the str is the finished value, which may not be what was passed as the completion text.
        """

    def __or__(self, other):
        return OrConstraint([self, other])

    def __and__(self, other):
        return AndConstraint([self, other])


@dataclass
class NotConstraint(Constraint):
    """Invert a token id constraint.

    Attributes:
        constraint (Constraint): The constraint to negate.
    """

    constraint: Constraint

    def constrain_tokens(
        self, base_text: str, completion_text: str, model: "Model"
    ) -> TokenConstraint:
        selected_tokens = self.constraint.constrain_tokens(
            base_text, completion_text, model
        )
        if selected_tokens is None or isinstance(selected_tokens, str):
            return selected_tokens
        return {tok for tok in model.tokens if tok not in selected_tokens}


@dataclass
class AndConstraint(Constraint):
    """Constrain token ids that can be sampled by applying multiple constraints.

    Attributes:
        constraints (List[Constraint]): The list of constraints to apply.
    """

    constraints: List[Constraint]

    def constrain_tokens(
        self, base_text: str, completion_text: str, model: "Model"
    ) -> Union[None, Set[int], str]:
        ret = None
        completions = []
        for constraint in self.constraints:
            completions = []
            selected_tokens = constraint.constrain_tokens(
                base_text, completion_text, model
            )
            if selected_tokens is None:
                # Do nothing because all tokens are valid
                pass
            if isinstance(selected_tokens, str):
                completions.append(selected_tokens)
            if isinstance(selected_tokens, set):
                ret = ret & selected_tokens if ret is not None else selected_tokens
        if len(completions) == len(self.constraints):
            if len(set(completions)) != 1:
                raise ValueError(
                    f"Got different completions for constraints `{self}`. Completions: `{set(completions)}`"
                )
            return completions[0]
        return ret


@dataclass
class OrConstraint(Constraint):
    """Constrain token ids that can be sampled by applying multiple constraints.

    Attributes:
        constraints (List[Constraint]): The list of constraints to apply.
    """

    constraints: List[Constraint]

    def constrain_tokens(
        self, base_text: str, completion_text: str, model: "Model"
    ) -> Union[None, Set[int], str]:
        ret = set()
        for constraint in self.constraints:
            selected_tokens = constraint.constrain_tokens(
                base_text, completion_text, model
            )
            if selected_tokens is None:
                # One allows everything so overall the or does
                return None
            if isinstance(selected_tokens, str):
                return selected_tokens
            if isinstance(selected_tokens, set):
                ret |= selected_tokens
        return ret


@dataclass
class RegexConstraint(Constraint):
    """Constrain token ids that can be sampled based on a regex pattern.

    Attributes:
        pattern (str): The regex pattern to match.

    Notes:
        Based on https://github.com/r2d4/rellm
    """

    pattern: str

    def __post_init__(self):
        self._pattern = re.compile(self.pattern)

    def _is_valid_token(
        self, token_id: int, partial_completion: str, model: "Model"
    ) -> bool:
        decoded_token = model.tokens[token_id]
        return self._pattern.fullmatch(partial_completion + decoded_token, partial=True)

    def constrain_tokens(
        self, base_text: str, completion_text: str, model: "Model"
    ) -> TokenConstraint:
        m = self._pattern.match(completion_text)
        if m and m.start() == 0:
            return completion_text

        with ThreadPoolExecutor():
            valid_token_ids = set(
                filter(
                    lambda token_id: self._is_valid_token(
                        token_id, completion_text, model
                    ),
                    model.tokens.keys(),
                )
            )

        return valid_token_ids


@dataclass
class StopsConstraint(Constraint):
    """Constrain token ids that can be sampled based on a regex pattern.

    Attributes:
        stop (str): The string after which to stop.
        include (bool): Whether to include the stop string in the completion or not.
    """

    stop: str
    include: bool = True

    def __post_init__(self):
        end = stop
        if not include:
            end = f"(?={stop})"
        self._re_constraint = RegexConstraint(".*?" + end)

    def constrain_tokens(
        self, base_text: str, completion_text: str, model: Model
    ) -> TokenConstraint:
        return self._re_constraint(base_text, completion_text, model)


@dataclass
class OptionsConstraint(Constraint):
    """
    Options constraint constrains output based on a list of string options
    """

    options: Set[str]
    short_circuit: bool = (
        True  # early return when available options based on completed text are <=1
    )

    def _is_valid_token(
        self, token_id: int, partial_completion: str, model: Model
    ) -> bool:
        decoded_token = model.tokens[token_id]
        return any(
            option.startswith(partial_completion + decoded_token)
            for option in self.options
        )

    def constrain_tokens(
        self, base_text: str, completion_text: str, model: Model
    ) -> TokenConstraint:
        if completion_text in self.options:
            return completion_text

        if completion_text and self.short_circuit:
            limited_options = set()
            for option in self.options:
                if option.startswith(completion_text):
                    limited_options.add(option)
                    if len(limited_options) > 1:
                        break
            if len(limited_options) == 0:
                return {}
            if len(limited_options) == 1:
                return limited_options.pop()

        with ThreadPoolExecutor():
            valid_token_ids = set(
                filter(
                    lambda token_id: self._is_valid_token(
                        token_id, completion_text, model
                    ),
                    model.tokens.keys(),
                )
            )

        return valid_token_ids

<IPython.core.display.Javascript object>

In [7]:
class Completion(str):
    """A completion string from a prompt

    Args:
        text (str): the generated string
        start (int): the start index of the completion in the prompt it came from
        stop (int): the stop index of the completion in the prompt it came from

    Returns:
        Completion (str)
    """

    def __new__(cls, text: str, start: int, stop: int):
        if isinstance(text, Completion):
            return text
        obj = str.__new__(cls, text)
        obj.start = start
        obj.stop = stop
        return obj

    def __repr__(self) -> str:
        return f"Completion(text = '{self}', start = {self.start}, stop = {self.stop})"

class Completions:
    def __init__(self):
        self._completions = []
        self._named_completions = {}

    def __repr__(self) -> str:
        return f"Completions({self._completions}, {self._named_completions})"

    def add(self, completion, name=None):
        if name is not None:
            self._named_completions[name] = completion
        else:
            self._completions.append(completion)

    def __getitem__(self, key):
        if isinstance(key, int):
            return self._completions[key]
        else:
            return self._named_completions[key]

    def __getattr__(self, name):
        if name in self._named_completions:
            return self._named_completions[name]
        else:
            raise AttributeError(f"'Completions' object has no attribute '{name}'")

    def __or__(self, other):
        if isinstance(other, Completions):
            combined = Completions()
            combined._completions = self._completions + other._completions
            combined._named_completions = {
                **self._named_completions,
                **other._named_completions,
            }
            return combined
        else:
            raise TypeError(
                f"unsupported operand type(s) for |: 'Completions' and '{type(other).__name__}'"
            )

class Prompt(str):
    """A Prompt is a piece of text a model can generate off of

    Args:
        prompt (str): the string representing the current completion of the Prompt

    Returns:
        Prompt (str)
    """

    def __new__(cls, prompt: str, completions: Optional[Completions] = None):
        if isinstance(prompt, Completion):
            return prompt
        obj = str.__new__(cls, prompt)
        obj.prompt = prompt
        obj.completions = completions or Completions()
        return obj

    def __repr__(self):
        return f"Prompt('{self.prompt}')"

    def __str__(self):
        return self.prompt

    def __add__(self, other):
        if isinstance(other, str):
            return Prompt(self.prompt + other, self.completions)
        elif isinstance(other, Prompt):
            return Prompt(
                self.prompt + other.prompt, self.completions | other.completions
            )
        else:
            raise TypeError(
                f"Cannot concatenate Prompt object with object of type {type(other)}"
            )

    def __radd__(self, other):
        if isinstance(other, str):
            return Prompt(other + self.prompt, self.completions)
        elif isinstance(other, Prompt):
            return Prompt(
                other.prompt + self.prompt, self.completions | other.completions
            )
        else:
            raise TypeError(
                f"Cannot concatenate object of type {type(other)} with Prompt object"
            )

    def token_length(self, model: Model) -> int:
        return len(model.encode(self.prompt))

    async def complete(
        self,
        model: Model,
        constraint: Optional[Constraint] = None,
        name: Optional[str] = None,
        max_tokens: Optional[int] = None,
        decoder: Optional[Decoder] = None,
        stream_queue: Optional[asyncio.Queue] = None,
        timeout: float = 10.0,
        truncate: bool = False,
    ):
        text = self.prompt
        prompt_tokens = model.encode(self.prompt)
        token_limit = min(
            max_tokens or float("inf"), model.max_total_tokens - len(prompt_tokens)
        )
        if (
            truncate
            and (len(prompt_tokens) + (max_tokens or 0)) >= model.max_total_tokens
        ):
            warnings.warn(
                f"Prompt plus `max_tokens` more than model `max_total_tokens` of {model.max_total_tokens}."
                "Truncating from right."
            )
            text = model.decode(
                prompt_tokens[-(model.max_total_tokens - token_limit) :]
            )

        if max_tokens is not None and max_tokens > token_limit:
            warnings.warn(
                f"Requested `max_tokens` of {max_tokens} "
                f"greater than remaining token limit {token_limit} "
                f"from model {str(model)[:10]}...) which has "
                f"`max_total_tokens` {model.max_total_tokens}. "
                f"will limit `max_tokens` to {token_limit}."
            )
        if constraint is None:
            generated = ""
            async for tok in model.generate(
                text, max_tokens=max_tokens, decoder=decoder, timeout=timeout
            ):
                if stream_queue:
                    await stream_queue.put(tok)
                generated += tok
            if stream_queue:
                await stream_queue.put(None)
            return self + generated

        token_count = 0
        partial_completion = ""
        prompt_plus_completion = text[:]
        while token_count < token_limit:
            selected_token_ids = constraint.constrain_tokens(
                text, partial_completion, model
            )
            selected_token_ids = (
                None
                if len(selected_token_ids) > model.vocab_size
                else selected_token_ids
            )
            if isinstance(selected_token_ids, set) and len(selected_token_ids) == 0:
                warnings.warn(
                    f"Empty token mask encountered with Constraint `{constraint}`. Ending completion."
                )
                break
            if isinstance(selected_token_ids, str):
                partial_completion = selected_token_ids
                break
            generation = await model.sample(
                prompt_plus_completion,
                selected_tokens=selected_token_ids,
                decoder=decoder,
                timeout=timeout,
            )
            if model.encode(generation)[-1] == model.eos_token_id:
                break
            if stream_queue:
                await stream_queue.put(tok)
            partial_completion += generation
            prompt_plus_completion = self.prompt + partial_completion
            token_count += 1
        if stream_queue:
            await stream_queue.put(None)

        ret = self + partial_completion
        ret.completions.add(
            Completion(
                partial_completion,
                len(self.prompt),
                len(self.prompt) + len(partial_completion),
            ),
            name,
        )

        return ret

<IPython.core.display.Javascript object>

In [10]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

<IPython.core.display.Javascript object>

In [89]:
hf = Huggingface(model=model, tokenizer=tokenizer, chunk_size=3)

# async for t in hf.generate("Hello,", 10):
#     print(t)

# await hf.sample("Hello,")

<IPython.core.display.Javascript object>

In [124]:
prompt = Prompt("There are (3-2) = ")

<IPython.core.display.Javascript object>

In [125]:
constraint = NotConstraint(OptionsConstraint(set("789")))

<IPython.core.display.Javascript object>

In [126]:
constraint

NotConstraint(constraint=OptionsConstraint(options={'7', '9', '8'}, short_circuit=True))

<IPython.core.display.Javascript object>

In [127]:
prompt = await prompt.complete(
    model=hf,
    constraint=constraint,
    name="num_piggies",
    max_tokens=10,
    decoder=Decoder(temperature=0.5, top_p=0.1, strategy=DecodingStrategy.SAMPLE),
)

<IPython.core.display.Javascript object>

In [128]:
prompt

Prompt('There are (3-2) =   1-1 =   1-1 =')

<IPython.core.display.Javascript object>

In [17]:
prompt = await Prompt("There are (3-2) = ").complete(
    model=hf,
    constraint=RegexConstraint(pattern=r"one|15|three"),
    name="num_piggies",
)

<IPython.core.display.Javascript object>

In [18]:
prompt + " piggys"

Prompt('There are (3-2) = 15 piggys')

<IPython.core.display.Javascript object>

In [21]:
int(prompt.completions.num_piggies)

15

<IPython.core.display.Javascript object>

In [10]:



def split_tags(
    text: str,
    tag_start: str = "%",
    tag_end: str = "%",
    default_role: str = "assistant",
    roles: List[str] = ["system", "user", "assistant"],
) -> List[Dict[str, str]]:
    """
    Splits a text string into a list of messages based on tags.

    Args:
        text (str): The input text to split into messages.
        tag_start (str, optional): The start delimiter for tags. Defaults to '%'.
        tag_end (str, optional): The end delimiter for tags. Defaults to '%'.
        default_role (str, optional): The default role to use for untagged messages. Defaults to 'assistant'.
        roles (List[str], optional): The list of valid roles for tagged messages. Defaults to ['system', 'user', 'assistant'].

    Returns:
        List[Dict[str, str]]: A list of messages, where each message is a dictionary with keys 'role' and 'content'.

    Raises:
        Exception: If an end tag is found with no start tag, or if an unknown or mismatched tag is found.

    Examples:
        >>> split_tags('\nYou are a friendly bot\n%/system%\n%user%Can you help me calculate stuff?%/user%\nYes, how may I help you?\n%user%\nI want to know the square root of 10%/user%\n%assistant%\nSure the square root of 10 is ...\n%/assistant%')
        [{'role': 'assistant', 'content': 'You are a friendly bot\n'}, {'role': 'system', 'content': ''}, {'role': 'user', 'content': 'Can you help me calculate stuff?'}, {'role': 'assistant', 'content': 'Yes, how may I help you?\n'}, {'role': 'user', 'co...
    """
    text = str(text)
    current = None
    messages = []

    while text:
        # first we check to see if the text is untagged
        match_role = None
        match = re.search(
            rf"(?P<content>\s*.*?)\s*(?P<tag>{tag_start}/?(?P<role>.*?){tag_end}\s*|$)",
            text,
        )
        if match:
            content = match.group("content")
            if match.group("tag").startswith(f"{tag_start}/"):
                raise Exception(f"Found end tag with no start `{match.group('tag')}`.")
            if match.group("role") is not None and match.group("role") not in roles:
                raise Exception(f"Unknown role `{match.group('role')}`.")
            if content.strip():
                messages.append({"role": default_role, "content": content.strip()})
            text = text[match.span("tag")[1] :]
            match_role = match.group("role")
        if not text:
            break
        # now that we have defaulted any untagged text, we can handle the next tagged portion
        match = re.search(
            rf"\s*.*?(?P<tag>{tag_start}/?(?P<role>.*?){tag_end}\s*)", text
        )
        content = text[: match.span("tag")[0]]
        if (match_role is not None and match.group("role") != match_role) or (
            not match.group("tag").startswith(f"{tag_start}/")
        ):
            raise Exception(
                f"Unclosed tag `{match_role}`. Found `{match.group('role')}`."
            )
        messages.append({"role": match_role, "content": content.strip()})
        text = text[match.span("tag")[1] :]
    return messages


def strip_tags(
    prompt: Prompt,
    tag_start: str = "%",
    tag_end: str = "%",
    roles_seps: Dict[str, str] = {
        "system": "",
        "user": "User: ",
        "assistant": "Assistant: ",
    },
    sep: str = "\n",
) -> Prompt:
    messages = split_tags(
        text=prompt, tag_start=tag_start, tag_end=tag_end, roles=list(roles_seps.keys())
    )
    return Prompt(
        sep.join(
            (roles_seps[message["role"]] + message["content"] for message in messages)
        ),
        prompt.completions,
    )

<IPython.core.display.Javascript object>

In [17]:


@dataclass
class OpenAIChat(Model):
    model_name: str = "gpt-3.5-turbo"
    supported_decodings: Set[DecodingStrategy] = frozenset(
        (
            DecodingStrategy.GREEDY,
            DecodingStrategy.SAMPLE,
        )
    )
    role_tag_start: str = "%"
    role_tag_end: str = "%"
    default_role: str = "assistant"
    allowed_roles: Set[str] = field(
        default_factory=lambda: {"system", "user", "assistant"}
    )
    max_retries: int = 10
    retry_sleep_time: float = 1.0
    max_token_selection: int = 300

    def __post_init__(self):
        self._tokenizer = tiktoken.encoding_for_model(self.model_name)
        self.tokens = openai_tokens(self._tokenizer)

    def encode(self, text: str) -> TokenIds:
        return self._tokenizer.encode(text)

    def decode(self, ids: TokenIds) -> str:
        return self._tokenizer.decode(ids)

    @property
    def eos_token_id(self) -> int:
        return self._tokenizer.eot_token

    @property
    def bos_token_id(self) -> int:
        return self._tokenizer.encode_single_token("<|endofprompt|>")

    async def _generate(
        self,
        text: str,
        max_tokens: int = 1,
        selected_tokens: Optional[SelectedTokens] = None,
        decoder: Optional[Decoder] = None,
        timeout: float = 10.0,
    ) -> AsyncGenerator[str, None]:
        decoder = decoder or Decoder()
        if decoder.strategy not in self.supported_decodings:
            raise ValueError(
                f"Unsupported decoding strategy for {self.__class__} model `{decoder.strategy}`."
            )
        messages = split_tags(
            text,
            self.role_tag_start,
            self.role_tag_end,
            self.default_role,
            self.allowed_roles,
        )

        temperature = decoder.temperature
        top_p = decoder.top_p
        if decoder.strategy == DecodingStrategy.GREEDY:
            # try to make the sampling as deterministic as possible
            # to select only the one top token
            top_p = 0.01  # select only n tokens to get over .01, should virually always be a single token
            temperature = 0.0

        selected_tokens = selected_tokens or []
        payload = {
            "messages": messages,
            "logit_bias": {str(token): 100 for token in selected_tokens},
            "model": self.model_name,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
        }

        async with aiohttp.ClientSession() as session:
            openai.aiosession.set(session)
            completion_stream = await openai.ChatCompletion.acreate(
                **payload, stream=True
            )
            async for chat_completion in completion_stream:
                yield chat_completion

    async def generate(
        self,
        text: str,
        max_tokens: int = 1,
        selected_tokens: Optional[SelectedTokens] = None,
        decoder: Optional[Decoder] = None,
        timeout: float = 10.0,
    ) -> AsyncGenerator[str, None]:
        if len(selected_tokens) > self.max_token_selection:
            warnings.warn(
                f"Trying to mask {len(selected_tokens)} tokens which "
                f"is more than {self.max_token_selection} mask limit "
                f"of {self}. Consider stricter constraints. Will select"
                "lowest token ids up to this limit."
            )
            selected_tokens = list(selected_tokens)[: self.max_token_selection]

        def result_handler(response):
            delta = response["choices"][0]["delta"]
            return (
                "" if not "content" in delta else delta["content"],  # content
                "finish_reason" in delta
                and delta["finish_reason"] is not None,  # complete generation
            )

        error = False
        retries = 0
        for retries in range(self.max_retries):
            async for chat_completion in self._generate(
                text=text,
                max_tokens=max_tokens,
                selected_tokens=selected_tokens,
                decoder=decoder,
                timeout=timeout,
            ):
                if "error" in chat_completion.keys():
                    message = chat_completion["error"]["message"]
                    retry = retries < self.max_retries
                    retries += 1
                    warnings.warn(
                        "OpenAI Chat Completion API raised an error: \n"
                        f"MESSAGE: {message}\n"
                        f"RETRYING {retries}"
                        if retry
                        else ""
                    )
                    error = True
                    break
                else:
                    error = False
                    content, done = result_handler(chat_completion)
                    text += content
                    if content:
                        yield content
                    if done:
                        break
            if not error:
                break

@dataclass
class OpenAICompletion(Model):
    model_name: str = "text-ada-001"
    supported_decodings: Set[DecodingStrategy] = frozenset(
        (
            DecodingStrategy.GREEDY,
            DecodingStrategy.SAMPLE,
        )
    )
    role_tag_start: str = "%"
    role_tag_end: str = "%"
    default_role: str = "assistant"
    allowed_roles: Set[str] = field(
        default_factory=lambda: {"system", "user", "assistant"}
    )
    max_retries: int = 10
    retry_sleep_time: float = 1.0
    max_token_selection: int = 300

    def __post_init__(self):
        self._tokenizer = tiktoken.encoding_for_model(self.model_name)
        self.tokens = openai_tokens(self._tokenizer)

    def encode(self, text: str) -> TokenIds:
        return self._tokenizer.encode(text)

    def decode(self, ids: TokenIds) -> str:
        return self._tokenizer.decode(ids)

    @property
    def eos_token_id(self) -> int:
        return self._tokenizer.eot_token

    @property
    def bos_token_id(self) -> int:
        return self._tokenizer.encode_single_token("<|endofprompt|>")

    async def _generate(
        self,
        text: str,
        max_tokens: int = 1,
        selected_tokens: Optional[SelectedTokens] = None,
        decoder: Optional[Decoder] = None,
        timeout: float = 10.0,
    ) -> AsyncGenerator[str, None]:
        decoder = decoder or Decoder()
        if decoder.strategy not in self.supported_decodings:
            raise ValueError(
                f"Unsupported decoding strategy for {self.__class__} model `{decoder.strategy}`."
            )

        temperature = decoder.temperature
        top_p = decoder.top_p
        if decoder.strategy == DecodingStrategy.GREEDY:
            # try to make the sampling as deterministic as possible
            # to select only the one top token
            top_p = 0.01  # select only n tokens to get over .01, should virually always be a single token
            temperature = 0.0

        selected_tokens = selected_tokens or []
        payload = {
            "prompt": text,
            "logit_bias": {str(token): 100 for token in selected_tokens},
            "model": self.model_name,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
        }

        async with aiohttp.ClientSession() as session:
            openai.aiosession.set(session)
            completion_stream = await openai.Completion.acreate(**payload, stream=True)

            async for completion in completion_stream:
                yield completion

    async def generate(
        self,
        text: str,
        max_tokens: int = 1,
        selected_tokens: Optional[SelectedTokens] = None,
        decoder: Optional[Decoder] = None,
        timeout: float = 10.0,
    ) -> AsyncGenerator[str, None]:
        if len(selected_tokens) > self.max_token_selection:
            warnings.warn(
                f"Trying to mask {len(selected_tokens)} tokens which "
                f"is more than {self.max_token_selection} mask limit "
                f"of {self}. Consider stricter constraints. Will select"
                "lowest token ids up to this limit."
            )
            selected_tokens = list(selected_tokens)[: self.max_token_selection]

        def result_handler(response):
            delta = response.choices[0]
            return (
                "" if not "text" in delta else delta["text"],  # content
                "finish_reason" in delta
                and delta["finish_reason"] is not None,  # complete generation
            )

        error = False
        retries = 0
        for retries in range(self.max_retries):
            async for completion in self._generate(
                text=text,
                max_tokens=max_tokens,
                selected_tokens=selected_tokens,
                decoder=decoder,
                timeout=timeout,
            ):
                if "error" in completion.keys():
                    message = completion["error"]["message"]
                    retry = retries < self.max_retries
                    retries += 1
                    warnings.warn(
                        "OpenAI Completion API raised an error: \n"
                        f"MESSAGE: {message}\n"
                        f"RETRYING {retries}"
                        if retry
                        else ""
                    )
                    error = True
                    break
                else:
                    error = False
                    content, done = result_handler(completion)
                    text += content
                    if content:
                        yield content
                    if done:
                        break
            if not error:
                break

<IPython.core.display.Javascript object>

In [25]:
openai.api_key = "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"

<IPython.core.display.Javascript object>

In [26]:
# oai = OpenAIChat()
oai = OpenAICompletion()

<IPython.core.display.Javascript object>

In [28]:
prompt = await Prompt(
    "We need to count some pigs. I have 48 and slaughtered 23, so now I have "
).complete(
    model=oai,
    constraint=RegexConstraint(pattern=r"-?[0-9]+"),
    name="num_piggies",
)



> [0;32m/var/folders/vl/0mv20zzj0ld26z8h0ngg3wn00000gn/T/ipykernel_95361/911037778.py[0m(76)[0;36m_generate[0;34m()[0m
[0;32m     74 [0;31m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 76 [0;31m            [0;32masync[0m [0;32mfor[0m [0mcompletion[0m [0;32min[0m [0mcompletion_stream[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m                [0;32myield[0m [0mcompletion[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m[0;34m[0m[0m
[0m
ipdb> n
> [0;32m/var/folders/vl/0mv20zzj0ld26z8h0ngg3wn00000gn/T/ipykernel_95361/911037778.py[0m(77)[0;36m_generate[0;34m()[0m
[0;32m     75 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m            [0;32masync[0m [0;32mfor[0m [0mcompletion[0m [0;32min[0m [0mcompletion_stream[0m[0;34



> [0;32m/var/folders/vl/0mv20zzj0ld26z8h0ngg3wn00000gn/T/ipykernel_95361/911037778.py[0m(76)[0;36m_generate[0;34m()[0m
[0;32m     74 [0;31m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 76 [0;31m            [0;32masync[0m [0;32mfor[0m [0mcompletion[0m [0;32min[0m [0mcompletion_stream[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m                [0;32myield[0m [0mcompletion[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m[0;34m[0m[0m
[0m
ipdb> n
> [0;32m/var/folders/vl/0mv20zzj0ld26z8h0ngg3wn00000gn/T/ipykernel_95361/911037778.py[0m(77)[0;36m_generate[0;34m()[0m
[0;32m     75 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m            [0;32masync[0m [0;32mfor[0m [0mcompletion[0m [0;32min[0m [0mcompletion_stream[0m[0;34

<IPython.core.display.Javascript object>

In [138]:
isinstance(prompt, str)

True

<IPython.core.display.Javascript object>

In [144]:
prompt += "%user%yoyoyoy%/user%"

<IPython.core.display.Javascript object>

In [121]:
%pdb

Automatic pdb calling has been turned OFF


<IPython.core.display.Javascript object>

In [29]:
prompt

Prompt('We need to count some pigs. I have 48 and slaughtered 23, so now I have -36')

<IPython.core.display.Javascript object>

In [154]:
%pdb

Automatic pdb calling has been turned ON


<IPython.core.display.Javascript object>

In [162]:
strip_tags(prompt)

Prompt('Assistant: There are (3-2) =   1-1 =   1-1 =
User: yoyoyoy')

<IPython.core.display.Javascript object>

In [132]:
type(strip_tags(prompt))

<IPython.core.display.Javascript object>

In [77]:
prompt.completions

Completions([], {'num_piggies': Completion(text = '25', start = 72, stop = 74)})

<IPython.core.display.Javascript object>

In [78]:
int(prompt.completions.num_piggies)

25

<IPython.core.display.Javascript object>

In [62]:
async for tok in oai.generate("Hello", 15):
    print(tok)

Hi
 there
!
 How
 can
 I
 assist
 you
 today
?


<IPython.core.display.Javascript object>

In [None]:
payload = {
    "messages": messages,
    "logit_bias": {str(token): 100 for token in selected_tokens},
    "model": self.model_name,
    "max_tokens": self.chunksize,
    "temperature": self.temperature,
}
chat_completion = await openai.ChatCompletion.acreate(**payload)

In [35]:
oai = OpenAIChat(chunksize=1, validate_completion_buffer=True)

<IPython.core.display.Javascript object>

In [36]:
oai

OpenAIChat(model_name='gpt-3.5-turbo', chunksize=1, temperature=0.0, role_tag_start='%', role_tag_end='%', default_role='assistant', allowed_roles={'user', 'system', 'assistant'}, use_completion_buffer=True, validate_completion_buffer=True, clear_used_validated_buffer=True)

<IPython.core.display.Javascript object>

In [37]:
prompt = """
%system%
You are a friendly bot
%/system%
%user%
what day did napolean die?
%/user%
"""

<IPython.core.display.Javascript object>

In [38]:
completion = await oai.sample(prompt)
prompt += completion

prompt

'\n%system%\nYou are a friendly bot\n%/system%\n%user%\nwhat day did napolean die?\n%/user%\nN'

<IPython.core.display.Javascript object>

In [39]:
prompt = Prompt("There are (3-2) = ")

<IPython.core.display.Javascript object>

In [40]:
completion = await prompt.complete(oai, RegexConstraint("[0-9]"), "npigs")

<IPython.core.display.Javascript object>

In [54]:
completion

Prompt('There are (3-2) = 1')

<IPython.core.display.Javascript object>

In [41]:
completion.completions

{'npigs': Completion(text='1', start=18, stop=19)}

<IPython.core.display.Javascript object>

In [42]:
grammar = """
?start: value

    ?value: object

          | array

          | string

          | "true"             -> true

          | "false"            -> false

          | "null"             -> null

    array  : "[" [value ("," value)*] "]"

    object : "{" [pair ("," pair)*] "}"

    pair   : string ":" value

    string : ESCAPED_STRING

    %import common.ESCAPED_STRING

    %import common.SIGNED_NUMBER

    %import common.WS

    %ignore WS
""".strip()

<IPython.core.display.Javascript object>

In [43]:
from lark import Lark, UnexpectedInput

<IPython.core.display.Javascript object>

In [44]:
parser = Lark(grammar)

<IPython.core.display.Javascript object>

In [None]:
def _extract_terminal_regex(parser, model: "Model"):
    regex_map = {}
    for term in parser.terminals:#type: ignore
        if term.pattern:
            regex_map[term.name] = re.compile(term.pattern.to_regexp())
    return regex_map

In [61]:
def next_lex(input_str: str, parser: Lark):
    try:
        parser.parse(input_str)  # type: ignore
    except UnexpectedInput as e:
        expected_tokens = e.expected
        parser.last_expected = expected_tokens
        return expected_tokens
    return []

<IPython.core.display.Javascript object>

In [62]:
next_lex("{", parser)

['RBRACE', 'ESCAPED_STRING']

<IPython.core.display.Javascript object>

In [40]:
parser.terminals

[TerminalDef('ESCAPED_STRING', '".*?(?<!\\\\)(\\\\\\\\)*?"'),
 TerminalDef('WS', '(?:[ \t\x0c\r\n])+'),
 TerminalDef('TRUE', 'true'),
 TerminalDef('FALSE', 'false'),
 TerminalDef('NULL', 'null'),
 TerminalDef('COMMA', ','),
 TerminalDef('LSQB', '\\['),
 TerminalDef('RSQB', '\\]'),
 TerminalDef('LBRACE', '\\{'),
 TerminalDef('RBRACE', '\\}'),
 TerminalDef('COLON', ':')]

<IPython.core.display.Javascript object>

In [48]:
!pip install "lark==1.1.5"



<IPython.core.display.Javascript object>

In [64]:
["a"] == ["a"]

True

<IPython.core.display.Javascript object>