-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathcreate_chat.R
More file actions
112 lines (101 loc) · 3.15 KB
/
create_chat.R
File metadata and controls
112 lines (101 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#' Create a chat object.
#'
#' @param vendor A character vector with one element. Currently, only 'openai', 'mistral', 'anthropic' and 'ollama' are supported.
#' @param api_key The API key for the vendor's chat engine. If the vendor is 'ollama', this parameter is not required.
#' @param port The port number for the ollama chat engine. Default to ollama's standard port. If the vendor is not 'ollama', this parameter is not required.
#' @param api_version Api version that is required for Anthropic
#'
#' @return A chat object
#' @export
#'
#' @examples
#' \dontrun{
#' dotenv::load_dot_env()
#' chat_openai <- create_chat('openai', Sys.getenv('OAI_DEV_KEY'))
#' chat_mistral <- create_chat('mistral', Sys.getenv('MISTRAL_DEV_KEY'))
#' }
create_chat <- function(vendor, api_key = '', port = if (vendor == 'ollama') 11434 else NULL, api_version = '') {
if (vendor != 'openai' & vendor != 'mistral' & vendor != 'ollama' & vendor != 'anthropic') stop('Unsupported vendor')
if (vendor == 'openai') {
# https://platform.openai.com/docs/api-reference/making-requests
engine <- httr2::request(
base_url ='https://api.openai.com/v1/chat/completions'
) |>
httr2::req_headers(
'Authorization' = paste('Bearer', api_key),
'Content-Type' = 'application/json'
)
}
if (vendor == 'mistral') {
# https://docs.mistral.ai/
engine <- httr2::request(
base_url ='https://api.mistral.ai/v1/chat/completions'
) |>
httr2::req_headers(
'Authorization' = paste('Bearer', api_key),
'Content-Type' = 'application/json',
'Accept' = 'application/json'
)
}
if (vendor == 'ollama') {
# https://docs.mistral.ai/
engine <- httr2::request(
base_url = glue::glue(
'http://localhost:{port}/api/chat'
)
)
}
if (vendor == 'anthropic') {
# https://platform.openai.com/docs/api-reference/making-requests
if (api_version == '') stop('Anthropic requires API version')
engine <- httr2::request(
base_url ='https://api.anthropic.com/v1/messages'
) |>
httr2::req_headers(
'x-api-key' = api_key,
'Content-Type' = 'application/json',
'anthropic-version' = api_version
)
}
if (vendor == 'ollama') {
chat <- list(
vendor_name = vendor,
engine = engine,
messages = list(),
params = list(stream = FALSE)
)
}
if (vendor != 'ollama') {
chat <- list(
vendor_name = vendor,
engine = engine,
messages = list()
)
}
class(chat) <- 'chat'
return(chat)
}
#' @export
print.chat <- function(x, ...) {
cli::cli_div(
theme = list(
span.param = list(color = "blue")
)
)
cli::cli_text("{.field Chat Engine}: {x$vendor_name}")
cli::cli_text("{.field Messages}: {length(x$messages)}")
if (length(x$model) > 0) cli::cli_text("{.field Model}: {x$model}")
if (length(x$params) > 0) {
cli::cli_text("{.field Parameters}:")
ul <- cli::cli_ul()
for (param in names(x$params)) {
cli::cli_li("{.param {param}}: {x$params[[param]]}")
}
cli::cli_end(ul)
}
}
#' @keywords internal
#' @noRd
knit_print.chat <- function(x, ...) {
knitr::knit_print(x, ...)
}