Skip to content

Commit

Permalink
Refactor AI chat model for live streaming and precise responses
Browse files Browse the repository at this point in the history
This commit introduces several changes to the AI chat model in the `aicodebot/cli.py` file. The main changes include:

- The introduction of two new temperature settings: `PRECISE_TEMPERATURE` and `CREATIVE_TEMPERATURE`. These settings allow for more control over the randomness of the AI's responses.
- The use of the `rich.live.Live` class for live streaming of the AI's responses. This is implemented in the `alignment`, `debug`, `fun_fact`, `review`, and `sidekick` commands.
- The creation of a `RichLiveCallbackHandler` class that updates the live stream with each new token generated by the AI.
- Adjustments to the `max_tokens` parameter in the `fun_fact` and `sidekick` commands.

These changes aim to improve the user experience by providing real-time feedback from the AI and allowing for more precise responses. 🤖💬
  • Loading branch information
TechNickAI committed Jul 3, 2023
1 parent 172bd5c commit 0ddc6d2
Showing 1 changed file with 69 additions and 31 deletions.
100 changes: 69 additions & 31 deletions aicodebot/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,32 @@
from aicodebot.agents import get_agent
from aicodebot.helpers import exec_and_get_output, get_token_length, git_diff_context
from dotenv import load_dotenv
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts import load_prompt
from openai.api_resources import engine
from pathlib import Path
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from rich.style import Style
import click, datetime, openai, os, random, subprocess, sys, tempfile, webbrowser

# ----------------------------- Default settings ----------------------------- #

DEFAULT_MAX_TOKENS = 512
DEFAULT_TEMPERATURE = 0.1
PRECISE_TEMPERATURE = 0
CREATIVE_TEMPERATURE = 0.7
DEFAULT_SPINNER = "point"

# ----------------------- Setup for rich console output ---------------------- #

console = Console()
bot_style = Style(color="#30D5C8")
error_style = Style(color="#FF0000")
warning_style = Style(color="#FFA500")


# -------------------------- Top level command group ------------------------- #


Expand Down Expand Up @@ -55,14 +58,21 @@ def alignment(verbose):

# Set up the language model
model = get_llm_model(get_token_length(prompt.template))
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)

# Set up the chain
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
with Live(Markdown(""), auto_refresh=True) as live:
llm = ChatOpenAI(
model=model,
temperature=CREATIVE_TEMPERATURE,
max_tokens=DEFAULT_MAX_TOKENS,
verbose=verbose,
streaming=True,
callbacks=[RichLiveCallbackHandler(live)],
)

with console.status("Generating an inspirational message", spinner=DEFAULT_SPINNER):
response = chain.run({})
console.print(Markdown(response), style=bot_style)
# Set up the chain
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)

chain.run({})


@cli.command()
Expand Down Expand Up @@ -109,13 +119,13 @@ def commit(verbose, response_token_size, yes, skip_pre_commit):
console.print(f"Diff context token size: {request_token_size}, using model: {model}")

# Set up the language model
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)
llm = ChatOpenAI(model=model, temperature=PRECISE_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)

# Set up the chain
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)

console.print("The following files will be committed:\n" + files)
with console.status("Generating the commit message", spinner=DEFAULT_SPINNER):
with console.status("Examining the diff and generating the commit message", spinner=DEFAULT_SPINNER):
response = chain.run(diff_context)

# Write the commit message to a temporary file
Expand Down Expand Up @@ -168,14 +178,20 @@ def debug(command, verbose):

# Set up the language model
model = get_llm_model(get_token_length(error_output) + get_token_length(prompt.template))
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)

# Set up the chain
chat_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
with Live(Markdown(""), auto_refresh=True) as live:
llm = ChatOpenAI(
model=model,
temperature=PRECISE_TEMPERATURE,
max_tokens=DEFAULT_MAX_TOKENS,
verbose=verbose,
streaming=True,
callbacks=[RichLiveCallbackHandler(live)],
)

with console.status("Debugging", spinner=DEFAULT_SPINNER):
response = chat_chain.run(error_output)
console.print(Markdown(response), style=bot_style)
# Set up the chain
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
chain.run(error_output)

sys.exit(process.returncode)

Expand All @@ -191,16 +207,22 @@ def fun_fact(verbose):

# Set up the language model
model = get_llm_model(get_token_length(prompt.template))
llm = ChatOpenAI(model=model, temperature=0.9, max_tokens=250, verbose=verbose)

# Set up the chain
chat_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
with Live(Markdown(""), auto_refresh=True) as live:
llm = ChatOpenAI(
model=model,
temperature=PRECISE_TEMPERATURE,
max_tokens=DEFAULT_MAX_TOKENS / 2,
verbose=verbose,
streaming=True,
callbacks=[RichLiveCallbackHandler(live)],
)

# Set up the chain
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)

with console.status("Fetching a fun fact", spinner=DEFAULT_SPINNER):
# Select a random year so that we get a different answer each time
year = random.randint(1942, datetime.datetime.utcnow().year)
response = chat_chain.run(f"programming and artificial intelligence in the year {year}")
console.print(Markdown(response), style=bot_style)
chain.run(f"programming and artificial intelligence in the year {year}")


@cli.command
Expand All @@ -225,15 +247,20 @@ def review(commit, verbose):
if verbose:
console.print(f"Diff context token size: {request_token_size}, using model: {model}")

# Set up the language model
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=response_token_size, verbose=verbose)
with Live(Markdown(""), auto_refresh=True) as live:
llm = ChatOpenAI(
model=model,
temperature=PRECISE_TEMPERATURE,
max_tokens=response_token_size,
verbose=verbose,
streaming=True,
callbacks=[RichLiveCallbackHandler(live)],
)

# Set up the chain
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
# Set up the chain
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)

with console.status("Reviewing code", spinner=DEFAULT_SPINNER):
response = chain.run(diff_context)
console.print(Markdown(response), style=bot_style)
chain.run(diff_context)


@cli.command
Expand All @@ -254,7 +281,7 @@ def sidekick(task, verbose):
setup_environment()

model = get_llm_model()
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=3500, verbose=verbose)
llm = ChatOpenAI(model=model, temperature=PRECISE_TEMPERATURE, max_tokens=2000, verbose=verbose)

agent = get_agent("sidekick", llm, verbose)

Expand Down Expand Up @@ -358,5 +385,16 @@ def get_llm_model(token_size=0):
raise click.ClickException("🛑 The context is too large to for the Model. 😞")


class RichLiveCallbackHandler(BaseCallbackHandler):
buffer = []

def __init__(self, live):
self.live = live

def on_llm_new_token(self, token, **kwargs):
self.buffer.append(token)
self.live.update(Markdown("".join(self.buffer)))


if __name__ == "__main__":
cli()

0 comments on commit 0ddc6d2

Please sign in to comment.