In [None]:
!pip install git+https://github.com/tdoehmen/jsonformer.git

In [None]:
import argparse
import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          GenerationConfig, set_seed)
from jsonformer import Jsonformer

model_id = "tdoehmen/schemapile-fk-starcoder"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
tokenizer.eos_token_id = tokenizer.eos_token_id
generation_config = GenerationConfig(
    temperature=float(0.01),
    top_k=50,
    top_p=0.95,
    repetition_penalty=1.2,
    do_sample=True,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.convert_tokens_to_ids(dialogue_template.end_token),
    min_new_tokens=32,
    max_new_tokens=512,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16
)

In [15]:
import os
import json

# patch jsonformer
def patch_jsonformer(jsonformer):
    def generate_array(item_schema, obj) -> list:
        for _ in range(jsonformer.max_array_length):
            # forces array to have at least one element
            element = jsonformer.generate_value(item_schema, obj)
            obj[-1] = element

            obj.append(jsonformer.generation_marker)
            input_prompt = jsonformer.get_prompt()
            obj.pop()
            input_tensor = jsonformer.tokenizer.encode(input_prompt, return_tensors="pt")
            output = jsonformer.model.forward(input_tensor.to(jsonformer.model.device))
            logits = output.logits[0, -1]


            top_indices = logits.topk(30).indices
            sorted_token_ids = top_indices[logits[top_indices].argsort(descending=True)]

            found_comma = False
            found_close_bracket = False
            for token_id in sorted_token_ids:
                decoded_token = jsonformer.tokenizer.decode(token_id)
                if '{' in decoded_token:
                    found_comma = True
                    break
                if ']' in decoded_token:
                    found_close_bracket = True
                    break

            if found_close_bracket or not found_comma:
                break

        return obj
    
    def get_prompt():
        template = """{prompt}{progress}"""
        progress = json.dumps(jsonformer.value)
        gen_marker_index = progress.find(f'"{jsonformer.generation_marker}"')
        if gen_marker_index != -1:
            progress = progress[:gen_marker_index]
        else:
            raise ValueError("Failed to find generation marker")

        prompt = template.format(
            prompt=jsonformer.prompt,
            #schema=json.dumps(jsonformer.json_schema),
            progress=progress,
        )

        return prompt
    jsonformer.get_prompt = get_prompt 
    jsonformer.generate_array = generate_array
    return jsonformer

def get_enums_from_prompt_text(prompt_text):
    src = prompt_text.split("\n")[1]
    trg = prompt_text.split("\n")[2]
    src_table = src[:src.find("(")]
    src_cols = src[src.find("(")+1:src.find(")")].split(", ")
    trg_table = trg[:trg.find("(")]
    trg_cols = trg[trg.find("(")+1:trg.find(")")].split(", ")
    return src_table, src_cols, trg_table, trg_cols

def generate_response_as_json_starcoder(prompt):
    src_table, src_cols, trg_table, trg_cols = get_enums_from_prompt_text(prompt)
        
    json_schema_fk = {
            "type": "object",
            "properties": {
                      "table": { 
                          "type": "enum",
                          "values": [src_table]
                      },
                      "column": { 
                        "type": "enum",
                        "values": src_cols
                      },
                      "referencedTable": { 
                          "type": "enum",
                          "values": [trg_table]
                      },
                      "referencedColumn": { 
                        "type": "enum",
                        "values": trg_cols
                      },
                    },
            "required": ["table", "column", "referencedTable", "referencedColumn"],
    }
    
    formatted_prompt = f"<|system|>\n<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>"
    

    jsonformer = Jsonformer(model, tokenizer, json_schema_fk, formatted_prompt, debug = False)
    patch_jsonformer(jsonformer)
    result_json = jsonformer()
    
    return result_json

In [14]:
text = 'You are given the following SQL database tables: \nstaff(staff_id, staff_address_id, nickname, first_name, middle_name, last_name, date_of_birth, date_joined_staff, date_left_staff)\naddresses(address_id, line_1_number_building, city, zip_postcode, state_province_county, country)\nOutput a json string with the following schema {table, column, referencedTable, referencedColumn} that contains the foreign key relationship between the two tables.'
print(text)

You are given the following SQL database tables: 
staff(staff_id, staff_address_id, nickname, first_name, middle_name, last_name, date_of_birth, date_joined_staff, date_left_staff)
addresses(address_id, line_1_number_building, city, zip_postcode, state_province_county, country)
Output a json string with the following schema {table, column, referencedTable, referencedColumn} that contains the foreign key relationship between the two tables.


In [16]:
generate_response_as_json_starcoder(text)

{'table': 'staff',
 'column': 'staff_address_id',
 'referencedTable': 'addresses',
 'referencedColumn': 'address_id'}