Skip to content

Commit

Permalink
simplifying llm with langchain for xlam inference (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
JimSalesforce committed Mar 13, 2024
2 parents 0ca9f06 + 35eddda commit 48893ca
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 82 deletions.
5 changes: 3 additions & 2 deletions agentlite/llm/LLMConfig.py
Expand Up @@ -12,5 +12,6 @@ def __init__(self, config_dict: dict) -> None:
self.stop = ["\n"]
self.max_tokens = 256
self.end_of_prompt = ""
self.openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
self.__dict__.update(config_dict)
self.api_key: str = os.environ.get("OPENAI_API_KEY", "")
self.base_url = None
self.__dict__.update(config_dict)
85 changes: 8 additions & 77 deletions agentlite/llm/agent_llms.py
Expand Up @@ -19,12 +19,6 @@
]
OPENAI_LLM_MODELS = ["text-davinci-003", "text-ada-001"]

FASTCHAT_MODELS = ["vicuna-7b", "zephyr-7b-beta", "lam-7b-v1", "lam-7b-v2"]

XGEN_NAMES = ["Salesforce/xgen-7b-4k-base", "Salesforce/xgen-7b-8k-base"]
INS_XGEN_NAMES = ["Salesforce/xgen-7b-8k-inst"]
VLLM_NAMES = ["sfr"]


class BaseLLM:
def __init__(self, llm_config: LLMConfig) -> None:
Expand All @@ -46,7 +40,7 @@ def run(self, prompt: str):
class OpenAIChatLLM(BaseLLM):
def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config=llm_config)
self.client = OpenAI(api_key=llm_config.openai_api_key)
self.client = OpenAI(api_key=llm_config.api_key)

def run(self, prompt: str):
response = self.client.chat.completions.create(
Expand All @@ -66,9 +60,10 @@ def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config)
llm = OpenAI(
model_name=self.llm_name,
openai_api_key=llm_config.openai_api_key,
temperature=self.temperature,
max_tokens=self.max_tokens,
base_url=llm_config.base_url,
api_key=llm_config.api_key,
)
human_template = "{prompt}"
prompt = PromptTemplate(template=human_template, input_variables=["prompt"])
Expand All @@ -85,9 +80,10 @@ def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config)
llm = ChatOpenAI(
model_name=self.llm_name,
openai_api_key=llm_config.openai_api_key,
temperature=self.temperature,
max_tokens=self.max_tokens,
base_url=llm_config.base_url,
api_key=llm_config.api_key
)
human_template = "{prompt}"
prompt = PromptTemplate(template=human_template, input_variables=["prompt"])
Expand All @@ -96,77 +92,12 @@ def __init__(self, llm_config: LLMConfig):
def run(self, prompt: str):
return self.llm_chain.run(prompt)


class langchain_local_llm(LangchainLLM):
def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config=llm_config)
openai.api_key = "EMPTY" # Not support yet
openai.base_url = "http://localhost:8000/v1"


class fast_llm(BaseLLM):
# using fastchat llm server
def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config=llm_config)

def format_prompt(self, prompt: str, end_of_prompt: str) -> str:
return prompt.strip() + " " + end_of_prompt

def run(self, prompt: str) -> str:
openai.api_key = "EMPTY" # Not support yet
openai.base_url = "http://localhost:8000/v1/"
prompt = self.format_prompt(prompt, self.end_of_prompt)
completion = openai.completions.create(
model=self.llm_name,
temperature=self.temperature,
stop=self.stop,
prompt=prompt,
max_tokens=self.max_tokens,
)
output = completion.choices[0].text
return output


class vllm_api_llm(BaseLLM):
def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config=llm_config)
self.api_url = ""
self.n = 1
self.stream = False
self.trial = 0

def run(self, prompt: str):
# completion = openai.Completion.create(model=self.llm_name, temperature=temperature, stop=stop,
# prompt=prompt, max_tokens=128)
done = False
while not done:
try:
response = post_http_request(
prompt=prompt,
api_url=self.api_url,
n=self.n,
use_beam_search=False,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
output = get_response(response)
done = True
return output[0]
except BaseException as ex:
print(str(ex))
self.trial += 1
if self.trial > 5:
done = True
return "No response"


def get_llm_backend(llm_config: LLMConfig):
llm_name = llm_config.llm_name

if llm_name in OPENAI_CHAT_MODELS:
return LangchainChatModel(llm_config)
elif llm_name in OPENAI_LLM_MODELS:
return LangchainLLM(llm_config)
elif llm_name in VLLM_NAMES:
return vllm_api_llm(llm_config)
else:
return fast_llm(llm_config)
else:
return LangchainLLM(llm_config)
16 changes: 13 additions & 3 deletions benchmark/hotpotqa/evaluate_hotpot_qa.py
Expand Up @@ -97,6 +97,16 @@ def run_hotpot_qa_agent(level="easy", llm_name="gpt-3.5-turbo-16k-0613", agent_a

# build the search agent
llm_config = LLMConfig({"llm_name": llm_name, "temperature": 0.0})

if llm_name == "xlam_v2":
llm_config = LLMConfig(
{
"llm_name": llm_name,
"temperature": 0.0,
"base_url": "http://localhost:8000/v1",
"api_key": "EMPTY",
}
)
llm = get_llm_backend(llm_config)
agent = WikiSearchAgent(llm=llm, agent_arch=agent_arch, PROMPT_DEBUG_FLAG=PROMPT_DEBUG_FLAG)
# add several demo trajectories to the search agent for the HotPotQA benchmark
Expand All @@ -123,10 +133,10 @@ def run_hotpot_qa_agent(level="easy", llm_name="gpt-3.5-turbo-16k-0613", agent_a
avg_f1 = np.mean(f1_list)
acc = correct / len(task_instructions)

dump_str = f"{test_task}\t{answer}\t{response}\t{f1:.4f}\t{acc:.4f}"
dump_str = f"{test_task}\t{answer}\t{response}\t{f1:.4f}\t{acc:.4f}\n"
with open(f"data/{agent_arch}_{llm_name}_results_{level}.csv", "a") as f:
f.write(dump_str, f, indent=4)
f.write(dump_str)

return avg_f1, acc


Expand Down

0 comments on commit 48893ca

Please sign in to comment.