Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: 添加文心一言 #227

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>com.unfbx</groupId>
<artifactId>chatgpt-java</artifactId>
<version>1.1.5</version>
<version>1.1.5-custom</version>
<name>chatgpt-java</name>
<description>OpenAI Java SDK, OpenAI Api for Java. ChatGPT Java SDK .</description>
<url>https://chatgpt-java.unfbx.com</url>
Expand Down Expand Up @@ -116,7 +116,7 @@
<dependency>
<groupId>com.knuddels</groupId>
<artifactId>jtokkit</artifactId>
<version>0.6.1</version>
<version>0.5.0</version>
</dependency>
</dependencies>

Expand Down
18 changes: 18 additions & 0 deletions src/main/java/com/unfbx/chatgpt/WenXinApi.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.unfbx.chatgpt;

import com.unfbx.chatgpt.entity.chat.wenxin.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.wenxin.ChatCompletionResponse;
import io.reactivex.Single;
import retrofit2.http.Body;
import retrofit2.http.POST;

public interface WenXinApi {
/**
* 最新版的GPT-3.5 chat completion 更加贴近官方网站的问答模型
*
* @param chatCompletion chat completion
* @return 返回答案
*/
@POST("rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions")
Single<ChatCompletionResponse> chatCompletion(@Body ChatCompletion chatCompletion);
}
250 changes: 250 additions & 0 deletions src/main/java/com/unfbx/chatgpt/WenXinStreamClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
package com.unfbx.chatgpt;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.constant.OpenAIConst;
import com.unfbx.chatgpt.constant.WenXinConst;
import com.unfbx.chatgpt.entity.chat.wenxin.ChatCompletion;
import com.unfbx.chatgpt.exception.BaseException;
import com.unfbx.chatgpt.exception.CommonError;
import com.unfbx.chatgpt.function.KeyRandomStrategy;
import com.unfbx.chatgpt.function.KeyStrategyFunction;
import com.unfbx.chatgpt.interceptor.*;
import com.unfbx.chatgpt.sse.ConsoleEventSourceListener;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.jetbrains.annotations.NotNull;
import retrofit2.Retrofit;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;

import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;


/**
* 描述: open ai 客户端
*
* @author https:www.unfbx.com
* 2023-02-28
*/

@Slf4j
public class WenXinStreamClient {
@Getter
@NotNull
private List<String> apiKey;
/**
* 自定义api host使用builder的方式构造client
*/
@Getter
private String apiHost;
/**
* 自定义的okHttpClient
* 如果不自定义 ,就是用sdk默认的OkHttpClient实例
*/
@Getter
private OkHttpClient okHttpClient;

/**
* api key的获取策略
*/
@Getter
private KeyStrategyFunction<List<String>, String> keyStrategy;

@Getter
private WenXinApi wenXinApi;

/**
* 自定义鉴权处理拦截器<br/>
* 可以不设置,默认实现:DefaultOpenAiAuthInterceptor <br/>
* 如需自定义实现参考:DealKeyWithOpenAiAuthInterceptor
*
* @see DynamicKeyOpenAiAuthInterceptor
* @see DefaultOpenAiAuthInterceptor
*/
@Getter
private WenXinAuthInterceptor authInterceptor;

/**
* 构造实例对象
*
* @param builder
*/
private WenXinStreamClient(Builder builder) {
if (CollectionUtil.isEmpty(builder.apiKey)) {
throw new BaseException(CommonError.API_KEYS_NOT_NUL);
}
apiKey = builder.apiKey;

if (StrUtil.isBlank(builder.apiHost)) {
builder.apiHost = WenXinConst.WenXin_HOST;
}
apiHost = builder.apiHost;

if (Objects.isNull(builder.keyStrategy)) {
builder.keyStrategy = new KeyRandomStrategy();
}
keyStrategy = builder.keyStrategy;

if (Objects.isNull(builder.authInterceptor)) {
builder.authInterceptor = new DefaultWenXinAuthInterceptor();
}
authInterceptor = builder.authInterceptor;
//设置apiKeys和key的获取策略
authInterceptor.setApiKey(this.apiKey);
authInterceptor.setKeyStrategy(this.keyStrategy);

if (Objects.isNull(builder.okHttpClient)) {
builder.okHttpClient = this.okHttpClient();
} else {
//自定义的okhttpClient 需要增加api keys
builder.okHttpClient = builder.okHttpClient
.newBuilder()
.addInterceptor(authInterceptor)
.build();
}
okHttpClient = builder.okHttpClient;

this.wenXinApi = new Retrofit.Builder()
.baseUrl(apiHost)
.client(okHttpClient)
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.addConverterFactory(JacksonConverterFactory.create())
.build().create(WenXinApi.class);
}

/**
* 创建默认的OkHttpClient
*/
private OkHttpClient okHttpClient() {
if (Objects.isNull(this.authInterceptor)) {
this.authInterceptor = new DefaultWenXinAuthInterceptor();
}
this.authInterceptor.setApiKey(this.apiKey);
this.authInterceptor.setKeyStrategy(this.keyStrategy);
return new OkHttpClient
.Builder()
.addInterceptor(this.authInterceptor)
.connectTimeout(10, TimeUnit.SECONDS)
.writeTimeout(50, TimeUnit.SECONDS)
.readTimeout(50, TimeUnit.SECONDS)
.build();
}


/**
* 流式输出,最新版的GPT-3.5 chat completion 更加贴近官方网站的问答模型
*
* @param chatCompletion 问答参数
* @param eventSourceListener sse监听器
* @see ConsoleEventSourceListener
*/
public void streamChatCompletion(ChatCompletion chatCompletion, EventSourceListener eventSourceListener) {
if (Objects.isNull(eventSourceListener)) {
log.error("参数异常:EventSourceListener不能为空,可以参考:com.unfbx.chatgpt.sse.ConsoleEventSourceListener");
throw new BaseException(CommonError.PARAM_ERROR);
}
if (!chatCompletion.isStream()) {
chatCompletion.setStream(true);
}
try {
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
ObjectMapper mapper = new ObjectMapper();
String requestBody = mapper.writeValueAsString(chatCompletion);
Request request = new Request.Builder()
.url(this.apiHost + "rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions")
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
.build();
//创建事件
factory.newEventSource(request, eventSourceListener);
} catch (JsonProcessingException e) {
log.error("请求参数解析异常:{}", e);
e.printStackTrace();
} catch (Exception e) {
log.error("请求参数解析异常:{}", e);
e.printStackTrace();
}
}


/**
* 构造
*
* @return Builder
*/
public static Builder builder() {
return new Builder();
}

public static final class Builder {
private @NotNull List<String> apiKey;
/**
* api请求地址,结尾处有斜杠
*
* @see OpenAIConst
*/
private String apiHost;

/**
* 自定义OkhttpClient
*/
private OkHttpClient okHttpClient;


/**
* api key的获取策略
*/
private KeyStrategyFunction keyStrategy;

/**
* 自定义鉴权拦截器
*/
private WenXinAuthInterceptor authInterceptor;

public Builder() {
}

public Builder apiKey(@NotNull List<String> val) {
apiKey = val;
return this;
}

/**
* @param val api请求地址,结尾处有斜杠
* @return Builder
* @see OpenAIConst
*/
public Builder apiHost(String val) {
apiHost = val;
return this;
}

public Builder keyStrategy(KeyStrategyFunction val) {
keyStrategy = val;
return this;
}

public Builder okHttpClient(OkHttpClient val) {
okHttpClient = val;
return this;
}

public Builder authInterceptor(WenXinAuthInterceptor val) {
authInterceptor = val;
return this;
}

public WenXinStreamClient build() {
return new WenXinStreamClient(this);
}
}
}
12 changes: 12 additions & 0 deletions src/main/java/com/unfbx/chatgpt/constant/WenXinConst.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.unfbx.chatgpt.constant;

/**
* 描述:
*
* @author https:www.unfbx.com
* @since 2023-03-06
*/
public class WenXinConst {

public final static String WenXin_HOST = "https://aip.baidubce.com/";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package com.unfbx.chatgpt.entity.chat.wenxin;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.unfbx.chatgpt.entity.chat.Functions;
import lombok.*;
import lombok.extern.slf4j.Slf4j;

import java.io.Serializable;
import java.util.List;
import java.util.Map;

/**
* 文心一言模型参数
*/
@Data
@Builder
@Slf4j
@JsonInclude(JsonInclude.Include.NON_NULL)
@AllArgsConstructor
public class ChatCompletion implements Serializable {
/**
* 使用什么取样温度,0到2之间。较高的值(如0.8)将使输出更加随机,而较低的值(如0.2)将使输出更加集中和确定。
* <p>
* We generally recommend altering this or but not both.top_p
*/
@Builder.Default
private double temperature = 0.2;

/**
* 使用温度采样的替代方法称为核心采样,其中模型考虑具有top_p概率质量的令牌的结果。因此,0.1 意味着只考虑包含前 10% 概率质量的代币。
* <p>
* 我们通常建议更改此设置,但不要同时更改两者。temperature
*/
@JsonProperty("top_p")
@Builder.Default
private Double topP = 0.8d;

/**
* 通过对已生成的token增加惩罚,减少重复生成的现象。说明:
* (1)值越大表示惩罚越大
* (2)默认1.0,取值范围:[1.0, 2.0]
*/
@JsonProperty("penalty_score")
@Builder.Default
private Double penalty_score = 1d;

/**
* 是否流式输出.
* default:false
*
* @see com.unfbx.chatgpt.OpenAiStreamClient
*/
@Builder.Default
private boolean stream = false;

/**
* 用户唯一值,确保接口不被重复调用
*/
private String user_id;

/**
* 问题描述
*/
@NonNull
private List<Message> messages;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.unfbx.chatgpt.entity.chat.wenxin;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.unfbx.chatgpt.entity.chat.ChatChoice;
import com.unfbx.chatgpt.entity.common.Usage;
import lombok.Data;

import java.io.Serializable;
import java.util.List;

/**
* 描述: chat答案类
*
* @author https:www.unfbx.com
* 2023-03-02
*/
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class ChatCompletionResponse implements Serializable {
private String id;
private String object;
private long created;
private long sentence_id;
private Boolean is_end;
private Boolean is_truncated;
private String result;
private Boolean need_clear_history;
private int ban_round;
}
Loading