Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.cjcrafter.openai.chat
package com.cjcrafter.openai

/**
* [FinishReason] wraps the possible reasons that a generation model may stop
Expand All @@ -23,9 +23,9 @@ enum class FinishReason {
LENGTH,

/**
* [TEMPERATURE] is a rare occurrence, and only happens when the
* [ChatRequest.temperature] is low enough that it is impossible for the
* model to continue generating text.
* [CONTENT_FILTER] occurs due to a flag from OpenAI's content filters.
* This occurrence is rare, and usually only happens when you blatantly
* misuse/violate OpenAI's terms.
*/
TEMPERATURE
CONTENT_FILTER
}
106 changes: 100 additions & 6 deletions src/main/kotlin/com/cjcrafter/openai/chat/ChatBot.kt
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package com.cjcrafter.openai.chat

import com.google.gson.*
import okhttp3.MediaType
import okhttp3.*
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.OkHttpClient
import okhttp3.OkHttpClient.Builder
import okhttp3.Request
import okhttp3.RequestBody
import okhttp3.RequestBody.Companion.toRequestBody
import java.io.IOException
import java.lang.IllegalArgumentException
import java.util.concurrent.TimeUnit
import java.util.function.Consumer

/**
* The ChatBot class wraps the OpenAI API and lets you send messages and
Expand Down Expand Up @@ -41,7 +40,9 @@ class ChatBot(private val apiKey: String) {
.readTimeout(0, TimeUnit.SECONDS).build()
private val mediaType: MediaType = "application/json; charset=utf-8".toMediaType()
private val gson: Gson = GsonBuilder()
.registerTypeAdapter(ChatUser::class.java, JsonSerializer<ChatUser> { src, _, context -> context!!.serialize(src!!.name.lowercase())!! })
.registerTypeAdapter(
ChatUser::class.java,
JsonSerializer<ChatUser> { src, _, context -> context!!.serialize(src!!.name.lowercase())!! })
.create()

/**
Expand All @@ -56,7 +57,9 @@ class ChatBot(private val apiKey: String) {
* @throws IllegalArgumentException If the input arguments are invalid.
*/
@Throws(IOException::class)
fun generateResponse(request: ChatRequest?): ChatResponse {
fun generateResponse(request: ChatRequest): ChatResponse {
request.stream = false // use streamResponse for stream=true

val json = gson.toJson(request)
val body: RequestBody = json.toRequestBody(mediaType)
val httpRequest: Request = Request.Builder()
Expand All @@ -83,4 +86,95 @@ class ChatBot(private val apiKey: String) {
throw ex
}
}

/**
* This is a helper method that calls [streamResponse], which lets you use
* the generated tokens in real time (As ChatGPT generates them).
*
* This method does not block the thread. Method calls to [onResponse] are
* not handled by the main thread. It is crucial to consider thread safety
* within the context of your program.
*
* @param request The input information for ChatGPT.
* @param onResponse The method to call for each chunk.
* @since 1.2.0
*/
fun streamResponseKotlin(request: ChatRequest, onResponse: ChatResponseChunk.() -> Unit) {
streamResponse(request, { it.onResponse() })
}

/**
* Uses ChatGPT to generate tokens in real time. As ChatGPT generates
* content, those tokens are sent in a stream in real time. This allows you
* to update the user without long delays between their input and OpenAI's
* response.
*
* For *"simpler"* calls, you can use [generateResponse] which will block
* the thread until the entire response is generated.
*
* Instead of using the [ChatResponse], this method uses [ChatResponseChunk].
* This means that it is not possible to retrieve the number of tokens from
* this method,
*
* This method does not block the thread. Method calls to [onResponse] are
* not handled by the main thread. It is crucial to consider thread safety
* within the context of your program.
*
* @param request The input information for ChatGPT.
* @param onResponse The method to call for each chunk.
* @param onFailure The method to call if the HTTP fails. This method will
* not be called if OpenAI returns an error.
* @see generateResponse
* @see streamResponseKotlin
* @since 1.2.0
*/
@JvmOverloads
fun streamResponse(
request: ChatRequest,
onResponse: Consumer<ChatResponseChunk>, // use Consumer instead of Kotlin for better Java syntax
onFailure: Consumer<IOException> = Consumer { it.printStackTrace() }
) {
request.stream = true // use requestResponse for stream=false

val json = gson.toJson(request)
val body: RequestBody = json.toRequestBody(mediaType)
val httpRequest: Request = Request.Builder()
.url("https://api.openai.com/v1/chat/completions")
.addHeader("Content-Type", "application/json")
.addHeader("Authorization", "Bearer $apiKey")
.post(body)
.build()

client.newCall(httpRequest).enqueue(object : Callback {
var cache: ChatResponseChunk? = null

override fun onFailure(call: Call, e: IOException) {
onFailure.accept(e)
}

override fun onResponse(call: Call, response: Response) {
response.body?.source()?.use { source ->
while (!source.exhausted()) {

// Parse the JSON string as a map. Every string starts
// with "data: ", so we need to remove that.
var jsonResponse = source.readUtf8Line() ?: continue
if (jsonResponse.isEmpty())
continue
jsonResponse = jsonResponse.substring("data: ".length)
if (jsonResponse == "[DONE]")
continue

val rootObject = JsonParser.parseString(jsonResponse).asJsonObject
if (cache == null)
cache = ChatResponseChunk(rootObject)
else
cache!!.update(rootObject)

onResponse.accept(cache!!)
}
}
}
})
}
}
7 changes: 4 additions & 3 deletions src/main/kotlin/com/cjcrafter/openai/chat/ChatChoice.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.cjcrafter.openai.chat

import com.cjcrafter.openai.FinishReason
import com.google.gson.JsonObject

/**
Expand All @@ -15,18 +16,18 @@ import com.google.gson.JsonObject
*
* @property index The index in the array... 0 if [ChatRequest.n]=1.
* @property message The generated text.
* @property finishReason Why did the bot stop generating tokens?
* @property finishReason The reason the bot stopped generating tokens.
* @constructor Create a new chat choice, for internal usage.
* @see FinishReason
*/
data class ChatChoice(val index: Int, val message: ChatMessage, val finishReason: FinishReason?) {
data class ChatChoice(val index: Int, val message: ChatMessage, val finishReason: FinishReason) {

/**
* JSON constructor for internal usage.
*/
constructor(json: JsonObject) : this(
json["index"].asInt,
ChatMessage(json["message"].asJsonObject),
if (json["finish_reason"].isJsonNull) null else FinishReason.valueOf(json["finish_reason"].asString.uppercase())
FinishReason.valueOf(json["finish_reason"].asString.uppercase())
)
}
65 changes: 65 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/chat/ChatChoiceChunk.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package com.cjcrafter.openai.chat

import com.cjcrafter.openai.FinishReason
import com.google.gson.JsonObject

/**
*
* The OpenAI API returns a list of [ChatChoiceChunk]. The "new content" is
* saved to the [delta] property. To access everything that is currently
* generated, use [message].
*
* By default, only 1 [ChatChoiceChunk] is generated (since [ChatRequest.n] == 1).
* When you increase `n`, more options are generated. The more options you
* generate, the more tokens you use. In general, it is best to **ONLY**
* generate 1 response, and to let the user regenerate the response.
*
* @property index The index in the array... 0 if [ChatRequest.n]=1.
* @property message All tokens that are currently generated.
* @property delta The newly generated tokens (*can be empty!*)
* @property finishReason The reason the bot stopped generating tokens.
* @constructor Create a new chat choice, for internal usage.
* @see FinishReason
* @see ChatChoice
*/
data class ChatChoiceChunk(val index: Int, val message: ChatMessage, var delta: String, var finishReason: FinishReason?) {

/**
* JSON constructor for internal usage.
*/
constructor(json: JsonObject) : this(

// The first message from ChatGPT looks like this:
// data: {"id":"chatcmpl-6xUB4Vi8jEG8u4hMBTMeO8KXgA87z","object":"chat.completion.chunk","created":1679635374,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
// So the only data we have so far is that ChatGPT will be responding.
json["index"].asInt,
ChatMessage(ChatUser.ASSISTANT, ""),
"",
null
)

internal fun update(json: JsonObject) {
val deltaJson = json["delta"].asJsonObject
delta = if (deltaJson.has("content")) deltaJson["content"].asString else ""
message.content += delta
finishReason = if (json["finish_reason"].isJsonNull) null else FinishReason.valueOf(json["finish_reason"].asString.uppercase())
}
}

/*
Below is a potential Steam response from OpenAI. You can see that the first
message contains 0 generated content, and the last message (before "[DONE]")
adds the finish_reason.

data: {"id":"chatcmpl-6xUB4Vi8jEG8u4hMBTMeO8KXgA87z","object":"chat.completion.chunk","created":1679635374,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}

data: {"id":"chatcmpl-6xUB4Vi8jEG8u4hMBTMeO8KXgA87z","object":"chat.completion.chunk","created":1679635374,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Hello"},"index":0,"finish_reason":null}]}

data: {"id":"chatcmpl-6xUB4Vi8jEG8u4hMBTMeO8KXgA87z","object":"chat.completion.chunk","created":1679635374,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":" World"},"index":0,"finish_reason":null}]}

data: {"id":"chatcmpl-6xUB4Vi8jEG8u4hMBTMeO8KXgA87z","object":"chat.completion.chunk","created":1679635374,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"."},"index":0,"finish_reason":null}]}

data: {"id":"chatcmpl-6xUB4Vi8jEG8u4hMBTMeO8KXgA87z","object":"chat.completion.chunk","created":1679635374,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{},"index":0,"finish_reason":"stop"}]}

data: [DONE]
*/
4 changes: 3 additions & 1 deletion src/main/kotlin/com/cjcrafter/openai/chat/ChatMessage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import com.google.gson.JsonObject
* conversation, we need to map each message to who sent it. This data class
* wraps a message with the user who sent the message.
*
* Note that
*
* @property role The user who sent this message.
* @property content The string content of the message.
* @see ChatUser
*/
data class ChatMessage(val role: ChatUser, val content: String) {
data class ChatMessage(var role: ChatUser, var content: String) {

/**
* JSON constructor for internal usage.
Expand Down
2 changes: 0 additions & 2 deletions src/main/kotlin/com/cjcrafter/openai/chat/ChatResponse.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import java.util.*
* @property choices The list of generated messages.
* @property usage The number of tokens used in this request/response.
* @constructor Create Chat response (for internal usage).
* @see ChatChoice
* @see ChatUsage
*/
data class ChatResponse(
val id: String,
Expand Down
84 changes: 84 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/chat/ChatResponseChunk.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package com.cjcrafter.openai.chat

import com.google.gson.JsonObject
import java.time.Instant
import java.time.ZoneId
import java.time.ZonedDateTime
import java.util.*

/**
* The [ChatResponseChunk] contains all the data returned by the OpenAI Chat API.
* For most use cases, [ChatResponseChunk.get] (passing 0 to the index argument)
* is all you need.
*
* This class is similar to [ChatResponse], except with [ChatResponseChunk] you
* determine the number of generated tokens.
*
* @property id The unique id for your request.
* @property created The Unix timestamp (measured in seconds since 00:00:00 UTC on January 1, 1970) when the API response was created.
* @property choices The list of generated messages.
* @constructor Create Chat response (for internal usage).
* @see ChatResponse
*/
data class ChatResponseChunk(
val id: String,
val created: Long,
val choices: List<ChatChoiceChunk>,
) {

/**
* JSON constructor for internal usage.
*/
constructor(json: JsonObject) : this(
json["id"].asString,
json["created"].asLong,
json["choices"].asJsonArray.map { ChatChoiceChunk(it.asJsonObject) },
)

internal fun update(json: JsonObject) {
json["choices"].asJsonArray.forEachIndexed { index, jsonElement ->
choices[index].update(jsonElement.asJsonObject)
}
}

/**
* Returns the [Instant] time that the OpenAI Chat API sent this response.
* The time is measured as a unix timestamp (measured in seconds since
* 00:00:00 UTC on January 1, 1970).
*
* Note that users expect time to be measured in their timezone, so
* [getZonedTime] is preferred.
*
* @return The instant the api created this response.
* @see getZonedTime
*/
fun getTime(): Instant {
return Instant.ofEpochSecond(created)
}

/**
* Returns the time-zoned instant that the OpenAI Chat API sent this
* response. By default, this method uses the system's timezone.
*
* @param timezone The user's timezone.
* @return The timezone adjusted date time.
* @see TimeZone.getDefault
*/
@JvmOverloads
fun getZonedTime(timezone: ZoneId = TimeZone.getDefault().toZoneId()): ZonedDateTime {
return ZonedDateTime.ofInstant(getTime(), timezone)
}

// TODO add tokenizier so we can determine token count

/**
* Shorthand for accessing the generated messages (shorthand for
* [ChatResponseChunk.choices]).
*
* @param index The index of the message (`0` for most use cases).
* @return The generated [ChatChoiceChunk] at the index.
*/
operator fun get(index: Int): ChatChoiceChunk {
return choices[index]
}
}
Loading