In [2]:
import json
import random
from transformers import (
    PreTrainedTokenizer,
    PreTrainedModel,
    AutoModelForCausalLM,
    AutoTokenizer,
)
from typing import Any, Dict



model_name = "databricks/dolly-v2-3b"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [3]:
from jsonllm.logits_processors import NumberStoppingCriteria, OutputNumbersTokens

NumberStoppingCriteria(tokenizer, 2)

<jsonllm.logits_processors.NumberStoppingCriteria at 0x1777c4d60>

In [10]:

class JSONLLM:
    value: Dict[str, Any] = {}

    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        json_schema: Dict[str, Any],
        prompt: str,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.json_schema = json_schema
        self.prompt = prompt

        self.number_logit_processor = OutputNumbersTokens(self.tokenizer, self.prompt)
        self.number_stop_criteria = NumberStoppingCriteria(self.tokenizer, 3)



    def generate_number(self, suffix="") -> float:
        print("generate_number", suffix)
        prompt = self.get_prompt() + suffix

        print("\033[91m {}\033[00m".format(prompt))
        response = self.model.generate(
            self.tokenizer.encode(prompt, return_tensors="pt"),
            max_new_tokens=6,
            num_return_sequences=1,
            logits_processor=[self.number_logit_processor],
            stopping_criteria=[self.number_stop_criteria],
            temperature=1.5,
            pad_token_id=tokenizer.eos_token_id
        )

        response = self.tokenizer.decode(response[0], skip_special_tokens=True)
        print("response is")
        print("\033[94m {}\033[00m".format(response))
        response = response.strip().rstrip(".").lstrip("0")
        try:
            return float(response)
        except ValueError:
            print("ValueError")
            return 

    def generate_boolean(self, suffix="") -> bool:
        prompt = self.get_prompt()
        true_token_id = self.tokenizer.encode("true", add_special_tokens=False)[0]
        false_token_id = self.tokenizer.encode("false", add_special_tokens=False)[0]

        response = self.generate(
            prompt, forced_bos_token_id=[true_token_id, false_token_id]
        ).lower()

        if response == "true":
            return True
        else:
            return False

    def generate_array(
        self, item_schema: Dict[str, Any], obj: Dict[str, Any], suffix=""
    ) -> list:
        array_length = random.randint(0, 5)
        return [self.generate_value(item_schema, obj) for _ in range(array_length)]

    # add stopping criteria with "
    def generate_string(self) -> str:
        prompt = self.get_prompt()
        response = self.generate(prompt)
        return response

    def generate_object(
        self, properties: Dict[str, Any], obj: Dict[str, Any], suffix=""
    ) -> Dict[str, Any]:
        print("generate_object", properties)

        for key, schema in properties.items():
            value = self.generate_value(schema, obj, suffix=f'"{key}": ')

            obj[key] = value
        return obj

    def generate_value(self, schema: Dict[str, Any], obj: Dict[str, Any], suffix=""):
        schema_type = schema["type"]
        if schema_type == "number":
            return self.generate_number(suffix=suffix)
        elif schema_type == "boolean":
            return self.generate_boolean(suffix=suffix)
        elif schema_type == "array":
            return self.generate_array(schema["items"], obj, suffix=suffix)
        elif schema_type == "object":
            return self.generate_object(schema["properties"], obj, suffix=suffix)
        else:
            raise ValueError(f"Unsupported schema type: {schema_type}")

    def get_prompt(self):
        template = """{prompt}\nMake sure to output in the following format:\n{schema}\n Result: {progress}"""
        progress = json.dumps(self.value)

        progress = progress.rstrip("}").rstrip("]").rstrip(",")

        prompt = template.format(
            prompt=self.prompt,
            schema=json.dumps(self.json_schema),
            progress=progress,
        )

        return prompt

    def __call__(self) -> Dict[str, Any]:
        self.value = {}
        generated_data = self.generate_object(
            self.json_schema["properties"], self.value
        )
        return generated_data


In [11]:
weather_schema = {
    "type": "object",
    "properties": {
        "temperature": {"type": "number"},
        # "humidity": {"type": "number"},
    },
}


jsonllm = JSONLLM(
    model=model,
    tokenizer=tokenizer,
    json_schema=weather_schema,
    prompt="Generate a weather object",
)

output = jsonllm()
print(output)


generate_object {'temperature': {'type': 'number'}}
generate_number "temperature": 
[91m Generate a weather object
Make sure to output in the following format:
{"type": "object", "properties": {"temperature": {"type": "number"}}}
 Result: {"temperature": [00m
Stopping because of multiple .
response is
[94m Generate a weather object
Make sure to output in the following format:
{"type": "object", "properties": {"temperature": {"type": "number"}}}
 Result: {"temperature": 000000000.0.[00m
ValueError
{'temperature': None}
