In [1]:
from typing import Dict, List, Type, Union
from typing import List
import json
import os
from openai import OpenAI
from dotenv import load_dotenv

import sglang as sgl
from pydantic import BaseModel, conlist
from typing import List
from sglang.srt.constrained import build_regex_from_object

load_dotenv("env_variable.env")
client = OpenAI()

In [2]:
class ConceptsList(BaseModel):
    #the list name has an important effect on the response! choose it wisely!
    Concepts_List: conlist(str, max_length=10)


@sgl.function
def pydantic_gen_ex(s, list_element):
    s += list_element
    s += sgl.gen(
        "",
        max_tokens=1024,
        temperature=0,
        regex=build_regex_from_object(ConceptsList),  # Requires pydantic >= 2.0
    )

sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))

In [3]:
def create_mistral_total_prompt(system_prompt, input):
    final_input = f"""text:{input}
Concepts_List:"""    
    final_prompt = system_prompt + "\n" + final_input

    return final_prompt

In [4]:
def local_llm_call(input):
    state = pydantic_gen_ex.run(input)
    return str(json.loads(state.text()[len(input):])["Concepts_List"])

In [5]:
from dataclasses import dataclass

@dataclass
class PromptsClass:
    """Class for keeping track of an item in inventory."""
    mistral_system_prompts: List[str]
    mistral_responses: List[str]
    target_input: str
    target_list: List[str]

    def clear_state(self):
        self.mistral_system_prompts = []
        self.mistral_responses = []

In [6]:
input_text_for_evaluation = "She said: 'today was supposed to be a day of celebration and joy in Kansas, instead it is another day where America has experience senselense gun violence' in response to what happened in Kansas near coca-cola branch"

output_list_for_evaluation = ["Gun violence", "Coca-cola", "Kansas city"]

In [7]:
promptTracker = PromptsClass([],[],"",[])
promptTracker.target_input = input_text_for_evaluation
promptTracker.target_list = output_list_for_evaluation

### Creating initial mistral input output

In [8]:
init_sys_prompt = """You are an AI designed to find a LIMITED list of GENERAL concepts associated with a given piece of text. The list size should NOT exceed 10. You Must use standardized words.

###
Here are some examples:


Text: "israel supporters attacks female palestine activist"
Concepts_List: ["Hate speech", "Palestine"]
###
"""

In [9]:
promptTracker.mistral_system_prompts.append(init_sys_prompt)

In [10]:
initial_mistral_inputs = create_mistral_total_prompt(promptTracker.mistral_system_prompts[-1], promptTracker.target_input)

In [11]:
output_from_mistral = local_llm_call(initial_mistral_inputs)
promptTracker.mistral_responses.append(output_from_mistral)

### Using OpenAI GPT4 for refinement of system prompt 

In [12]:
openai_sys_prompt = """You are an AI assistant who is expert in creating promtps for LLMs. you job is to modify and enhance a prompt for a 7b mistral instruct model. The mistral is supposed to receive an input text, and return a list of strings, entities, brand names, etc in that input text. This LLM is going to be used for .... The prompt to the mistral model can include some examples that lead the model's behavior. Mistral model performs constrained decoding, meaning that it only generated a list of strings.

A number of experiments have been done on different system prompts for mistral and the output. Those experiments which include tested system prompt, tested INPUT TO MISTRAL, and the resulting output from Mistral are provided to you. Your job is to observe the experiments, and come up with a better system prompt for Mistral to achieve the expected output. you can provide some examples, or remove some examples in your suggested system prompt. Remember that total number of examples should be limited, because it adds extra computation and we can't afford it. Note that the examples given in the system prompt of mistral should be enclosed by ### ###. Pay attention to the fact that, you are not allowed to use INPUT TO MISTRAL text in your examples for your suggested mistral system prompt.
"""

In [13]:
def create_openai_user_prompt(prompttracker):
    total_prompt = ""

    for i in range(len(prompttracker.mistral_responses)):
        total_prompt += f"""\n\n
Experiment {i}
Mistral System Prompt:
{prompttracker.mistral_system_prompts[i]}


INPUT TO MISTRAL:
{prompttracker.target_input}


output from Mistral:
{prompttracker.mistral_responses[i]}


what was expected to be output from Mistral:
{str(prompttracker.target_list)} \n\n
"""
    return total_prompt

In [22]:
class EnhancedSystemPrompt(BaseModel):
    #the list name has an important effect on the response! choose it wisely!
    Enhanced_System_Prompt: str


def request_to_openai(prompttracker):
    openai_user_prompt = create_openai_user_prompt(prompttracker=prompttracker)
    response = client.chat.completions.create(
        temperature = 0.1,
        model="gpt-4-0125-preview",
        messages=[
            {"role": "system", "content": openai_sys_prompt},
            {"role": "user", "content": openai_user_prompt},
            ],
        functions=[
            {
            "name": "Enhanced_System_Prompt",
            "description": "Enhanced System Prompt for Mistral LLM",
            "parameters": EnhancedSystemPrompt.model_json_schema()
            }
        ],
        function_call={"name": "Enhanced_System_Prompt"}
    )
    return json.loads(response.choices[0].message.function_call.arguments)['Enhanced_System_Prompt']

In [23]:
def refine_system_prompt_with_gpt4(number_of_iterations):
    for i in range(number_of_iterations):
        openai_suggestion = request_to_openai(prompttracker=promptTracker)
        promptTracker.mistral_system_prompts.append(openai_suggestion)
        mistral_inputs = create_mistral_total_prompt(promptTracker.mistral_system_prompts[-1], promptTracker.target_input)
        output_from_mistral = local_llm_call(mistral_inputs)
        promptTracker.mistral_responses.append(output_from_mistral)

In [24]:
refine_system_prompt_with_gpt4(2)

In [25]:
print(promptTracker.mistral_system_prompts[0])

You are an AI designed to find a LIMITED list of GENERAL concepts associated with a given piece of text. The list size should NOT exceed 10. You Must use standardized words.

###
Here are some examples:


Text: "israel supporters attacks female palestine activist"
Concepts_List: ["Hate speech", "Palestine"]
###



In [26]:
print(promptTracker.mistral_system_prompts[1])

You are an AI designed to extract a concise list of specific entities, brand names, and key concepts from a given piece of text. Your output should be a list of no more than 10 items, focusing on the most relevant and specific details mentioned in the text. Use standardized, recognizable terms for entities and concepts.

###
Here are some examples:

Text: "The new iPhone 12 was released yesterday, sparking excitement among Apple enthusiasts."
Concepts_List: ["iPhone 12", "Apple"]

Text: "A devastating earthquake hit Tokyo last night, causing widespread damage."
Concepts_List: ["Earthquake", "Tokyo"]

Text: "israel supporters attacks female palestine activist"
Concepts_List: ["Hate speech", "Palestine"]
###


In [27]:
print(promptTracker.mistral_system_prompts[2])

You are an AI designed to extract a precise list of specific entities, brand names, and key concepts from a given piece of text. Your output should consist of a concise list of no more than 10 items, prioritizing the most relevant and specific details mentioned in the text. Use standardized, recognizable terms for entities and concepts. Ensure that brand names are clearly identified and included when mentioned.

###
Here are some examples:

Text: "The new iPhone 12 was released yesterday, sparking excitement among Apple enthusiasts."
Concepts_List: ["iPhone 12", "Apple"]

Text: "A devastating earthquake hit Tokyo last night, causing widespread damage."
Concepts_List: ["Earthquake", "Tokyo"]

Text: "Protesters in Paris demand action on climate change."
Concepts_List: ["Climate change", "Paris"]
###


In [29]:
promptTracker.target_list

['Gun violence', 'Coca-cola', 'Kansas city']

In [30]:
for i in promptTracker.mistral_responses:
    print(i)

['Gun violence', 'Celebration', 'Joy', 'America', 'Kansas']
['Kansas', 'America', 'Gun violence']
['Kansas', 'America', 'gun violence', 'Coca-Cola']
