In [3]:
import langchain
import os

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

In [5]:
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI

In [6]:
llm = OpenAI(openai_api_key=OPENAI_API_KEY)
chat_model = ChatOpenAI(openai_api_key=OPENAI_API_KEY)

In [7]:
text = "What would be a good company name for a company that makes colorful socks?"

print(llm.predict(text))

print(chat_model.predict(text))



Rainbow Socks Co.
Rainbow Threads


In [8]:
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.chains import LLMChain
from langchain.schema import BaseOutputParser

In [9]:
class CommaParser(BaseOutputParser):
    def parse(self, text: str):
        return text.strip().split(",")

In [10]:
template = """
You are a helpful assistant who generate comma separated lists. A user will pass in a category, and you should
generate 5 objects in that category in a comma separated lists. ONLY return a comma separated lists, and nothing more.
DO a fact check and return empty list if you are certain you do not know the answer.
"""

In [11]:
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
human_template = "{country} {topic}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)

In [12]:
chat_prompt = ChatPromptTemplate.from_messages([
    system_message_prompt,
    human_message_prompt
])

In [13]:
chain = LLMChain(
    llm=ChatOpenAI(openai_api_key=OPENAI_API_KEY),
    prompt=chat_prompt,
    output_parser=CommaParser()
)

In [14]:
chain.run({"country": "generate", "topic": "random sentence"})

Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised ServiceUnavailableError: The server is overloaded or not ready yet..


["I'm sorry",
 ' but I cannot generate random sentences. My purpose is to assist with specific tasks or provide information. Is there anything else I can help you with?']

In [15]:
chain.run({"country": "generate", "topic": ""})

['United States', ' China', ' Japan', ' Germany', ' United Kingdom']