diff --git a/cli.py b/cli.py index 958465e..a595597 100644 --- a/cli.py +++ b/cli.py @@ -1,11 +1,16 @@ from dotenv import find_dotenv, load_dotenv +from langchain.chains import LLMChain +from langchain.chat_models import ChatOpenAI +from langchain.prompts import load_prompt from pathlib import Path from rich.console import Console +from rich.style import Style from setup import __version__ -import click, os, sys, webbrowser +import click, datetime, openai, os, random, sys, webbrowser # Create a Console object console = Console() +bot_style = Style(color="#30D5C8") # Load environment variables from .env file load_dotenv(find_dotenv()) @@ -13,6 +18,7 @@ def setup_environment(): if os.getenv("OPENAI_API_KEY"): + openai.api_key = os.getenv("OPENAI_API_KEY") return True openai_api_key_url = "https://platform.openai.com/account/api-keys" @@ -63,12 +69,21 @@ def version(): @cli.command() @click.option("-v", "--verbose", count=True) -def joke(verbose): - """Tell a [probably bad] programming joke.""" +def funfact(verbose): + """Tell me something interesting about programming or AI""" setup_environment() - api_key = os.getenv("OPENAI_API_KEY") - if verbose: - console.print(f"[bold yellow]Using API key: {api_key}[/bold yellow]") + + prompt = load_prompt(Path(__file__).parent / "prompts" / "fun_fact.yaml") + + llm = ChatOpenAI(temperature=1, max_tokens=1024) + # Set up the chain + chat_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) + + with console.status("Thinking", spinner="point"): + # Select a random year from 1950 to current year so that we get a different answer each time + year = random.randint(1950, datetime.datetime.utcnow().year) + response = chat_chain.run(f"programming and artificial intelligence in the year {year}") + console.print(response, style=bot_style) if __name__ == "__main__": diff --git a/prompts/fun_fact.yaml b/prompts/fun_fact.yaml new file mode 100644 index 0000000..331f97e --- /dev/null +++ b/prompts/fun_fact.yaml @@ -0,0 +1,9 @@ +_type: prompt +template_format: f-string +input_variables: ["topic"] +template: | + You are history nerd who loves sharing information. + Your expertise is {topic}. + You love emojis. + + Tell me a fun fact. diff --git a/requirements/requirements.in b/requirements/requirements.in index 8607df7..a5c1e9f 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -1,3 +1,5 @@ click +langchain +openai python-dotenv rich diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 14511b6..92028f1 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,15 +4,100 @@ # # pip-compile # +aiohttp==3.8.4 + # via + # langchain + # openai +aiosignal==1.3.1 + # via aiohttp +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +certifi==2023.5.7 + # via requests +charset-normalizer==3.1.0 + # via + # aiohttp + # requests click==8.1.3 # via -r requirements.in +dataclasses-json==0.5.8 + # via langchain +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +idna==3.4 + # via + # requests + # yarl +langchain==0.0.207 + # via -r requirements.in +langchainplus-sdk==0.0.16 + # via langchain markdown-it-py==3.0.0 # via rich +marshmallow==3.19.0 + # via + # dataclasses-json + # marshmallow-enum +marshmallow-enum==1.5.1 + # via dataclasses-json mdurl==0.1.2 # via markdown-it-py +multidict==6.0.4 + # via + # aiohttp + # yarl +mypy-extensions==1.0.0 + # via typing-inspect +numexpr==2.8.4 + # via langchain +numpy==1.25.0 + # via + # langchain + # numexpr +openai==0.27.8 + # via -r requirements.in +openapi-schema-pydantic==1.2.4 + # via langchain +packaging==23.1 + # via marshmallow +pydantic==1.10.9 + # via + # langchain + # langchainplus-sdk + # openapi-schema-pydantic pygments==2.15.1 # via rich python-dotenv==1.0.0 # via -r requirements.in +pyyaml==6.0 + # via langchain +requests==2.31.0 + # via + # langchain + # langchainplus-sdk + # openai rich==13.4.2 # via -r requirements.in +sqlalchemy==2.0.16 + # via langchain +tenacity==8.2.2 + # via + # langchain + # langchainplus-sdk +tqdm==4.65.0 + # via openai +typing-extensions==4.6.3 + # via + # pydantic + # sqlalchemy + # typing-inspect +typing-inspect==0.9.0 + # via dataclasses-json +urllib3==2.0.3 + # via requests +yarl==1.9.2 + # via aiohttp diff --git a/test_cli.py b/test_cli.py index 0e7833e..4bf9012 100644 --- a/test_cli.py +++ b/test_cli.py @@ -1,4 +1,4 @@ -from cli import joke, version +from cli import funfact, version from click.testing import CliRunner from setup import __version__ import os, pytest @@ -12,7 +12,7 @@ def test_version(): @pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="Skipping live tests without an API key.") -def test_joke(): +def test_funfact(): runner = CliRunner() - result = runner.invoke(joke) + result = runner.invoke(funfact) assert result.exit_code == 0