-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #44 from SamPink/llama
Added support for anthropic and others
- Loading branch information
Showing
6 changed files
with
86 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ | |
|
||
from .robot import Robot | ||
from .observer import KEN_GREEN, KEN_RED | ||
from .llm import get_client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
def get_client(model_str): | ||
split_result = model_str.split(":") | ||
if len(split_result) == 1: | ||
# Assume default provider to be openai | ||
provider = "openai" | ||
model_name = split_result[0] | ||
elif len(split_result) > 2: | ||
# Some model names have :, so we need to join the rest of the string | ||
provider = split_result[0] | ||
model_name = ":".join(split_result[1:]) | ||
else: | ||
provider = split_result[0] | ||
model_name = split_result[1] | ||
if provider == "openai": | ||
from llama_index.llms.openai import OpenAI | ||
|
||
return OpenAI(model=model_name) | ||
elif provider == "anthropic": | ||
from llama_index.llms.anthropic import Anthropic | ||
|
||
return Anthropic(model=model_name) | ||
elif provider == "mixtral" or provider == "groq": | ||
from llama_index.llms.groq import Groq | ||
|
||
return Groq(model=model_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters