Skip to content

Commit

Permalink
Add tokenize & detokenize to client, fix typos
Browse files Browse the repository at this point in the history
- implemented client code for the `/tokenize` & `/detokenize` endpoints
- added docstring examples
  • Loading branch information
andreaskoepf authored and pacman82 committed Nov 30, 2023
1 parent ececcef commit 65d764b
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pub struct Stopping<'a> {
/// List of strings which will stop generation if they are generated. Stop sequences are
/// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
/// lines starting with either "Question: " or "Answer: " (alternating). After producing an
/// answer, the model will be likely to generate "Question: ". "Question: " may therfore be used
/// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
/// as stop sequence in order not to have the model generate more questions but rather restrict
/// text generation to the answers.
pub stop_sequences: &'a [&'a str],
Expand All @@ -95,7 +95,7 @@ impl<'a> Stopping<'a> {
/// Body send to the Aleph Alpha API on the POST `/completion` Route
#[derive(Serialize, Debug)]
struct BodyCompletion<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminus-base`.
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// Prompt to complete. The modalities supported depend on `model`.
pub prompt: Prompt<'a>,
Expand All @@ -104,7 +104,7 @@ struct BodyCompletion<'a> {
/// List of strings which will stop generation if they are generated. Stop sequences are
/// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
/// lines starting with either "Question: " or "Answer: " (alternating). After producing an
/// answer, the model will be likely to generate "Question: ". "Question: " may therfore be used
/// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
/// as stop sequence in order not to have the model generate more questions but rather restrict
/// text generation to the answers.
#[serde(skip_serializing_if = "<[_]>::is_empty")]
Expand Down
57 changes: 57 additions & 0 deletions src/detokenization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use crate::Task;
use serde::{Deserialize, Serialize};

/// Input for a [crate::Client::detokenize] request.
pub struct TaskDetokenization<'a> {
/// List of token ids which should be detokenized into text.
pub token_ids: &'a [u32],
}

/// Body send to the Aleph Alpha API on the POST `/detokenize` route
#[derive(Serialize, Debug)]
struct BodyDetokenization<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// List of ids to detokenize.
pub token_ids: &'a [u32],
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ResponseDetokenization {
pub result: String,
}

#[derive(Debug, PartialEq, Eq)]
pub struct DetokenizationOutput {
pub result: String,
}

impl From<ResponseDetokenization> for DetokenizationOutput {
fn from(response: ResponseDetokenization) -> Self {
Self {
result: response.result,
}
}
}

impl<'a> Task for TaskDetokenization<'a> {
type Output = DetokenizationOutput;
type ResponseBody = ResponseDetokenization;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = BodyDetokenization {
model,
token_ids: &self.token_ids,
};
client.post(format!("{base}/detokenize")).json(&body)
}

fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
DetokenizationOutput::from(response)
}
}
2 changes: 1 addition & 1 deletion src/explanation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub struct TaskExplanation<'a> {
/// The target string that should be explained. The influence of individual parts
/// of the prompt for generating this target string will be indicated in the response.
pub target: &'a str,
/// Granularity paramaters for the explanation
/// Granularity parameters for the explanation
pub granularity: Granularity,
}

Expand Down
16 changes: 8 additions & 8 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use crate::How;
/// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is
/// executed on. This allows this trait to hold in the presence of services, which use more than one
/// model and task type to achieve their result. On the other hand a bare [`crate::TaskCompletion`]
/// can not implement this trait directly, since its result would depend on what model is choosen to
/// execute it. You can remidy this by turning completion task into a job, calling
/// can not implement this trait directly, since its result would depend on what model is chosen to
/// execute it. You can remedy this by turning completion task into a job, calling
/// [`Task::with_model`].
pub trait Job {
/// Output returned by [`crate::Client::output_of`]
Expand Down Expand Up @@ -130,7 +130,7 @@ impl HttpClient {
let query = if how.be_nice {
[("nice", "true")].as_slice()
} else {
// nice=false is default, so we just ommit it.
// nice=false is default, so we just omit it.
[].as_slice()
};
let response = task
Expand All @@ -156,7 +156,7 @@ impl HttpClient {
async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Response, Error> {
let status = response.status();
if !status.is_success() {
// Store body in a variable, so we can use it, even if it is not an Error emmitted by
// Store body in a variable, so we can use it, even if it is not an Error emitted by
// the API, but an intermediate Proxy like NGinx, so we can still forward the error
// message.
let body = response.text().await?;
Expand All @@ -174,14 +174,14 @@ async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Re
}
}

/// We are only interessted in the status codes of the API.
/// We are only interested in the status codes of the API.
#[derive(Deserialize)]
struct ApiError<'a> {
/// Unique string in capital letters emitted by the API to signal different kinds of errors in a
/// finer granualrity then the HTTP status codes alone would allow for.
/// finer granularity then the HTTP status codes alone would allow for.
///
/// E.g. Differentiating between request rate limiting and parallel tasks limiting which both
/// are 429 (the former is emmited by NGinx though).
/// are 429 (the former is emitted by NGinx though).
_code: Cow<'a, str>,
}

Expand All @@ -204,7 +204,7 @@ pub enum Error {
Busy,
#[error("No response received within given timeout: {0:?}")]
ClientTimeout(Duration),
/// An error on the Http Protocl level.
/// An error on the Http Protocol level.
#[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
Http { status: u16, body: String },
/// Most likely either TLS errors creating the Client, or IO errors.
Expand Down
76 changes: 75 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
//! ```

mod completion;
mod detokenization;
mod explanation;
mod http;
mod image_preprocessing;
mod prompt;
mod semantic_embedding;
mod tokenization;

use std::time::Duration;

Expand All @@ -37,6 +39,7 @@ use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};

pub use self::{
completion::{CompletionOutput, Sampling, Stopping, TaskCompletion},
detokenization::{DetokenizationOutput, TaskDetokenization},
explanation::{
Explanation, ExplanationOutput, Granularity, ImageScore, ItemExplanation,
PromptGranularity, TaskExplanation, TextScore,
Expand All @@ -46,6 +49,7 @@ pub use self::{
semantic_embedding::{
SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding,
},
tokenization::{TaskTokenization, TokenizationOutput},
};

/// Execute Jobs against the Aleph Alpha API
Expand Down Expand Up @@ -215,6 +219,76 @@ impl Client {
.output_of(&task.with_model(model), how)
.await
}

/// Tokenize a prompt for a specific model.
///
/// ```no_run
/// use aleph_alpha_client::{Client, Error, How, TaskTokenization};
///
/// async fn tokenize() -> Result<(), Error> {
/// let client = Client::new(AA_API_TOKEN)?;
///
/// // Name of the model for which we want to tokenize text.
/// let model = "luminous-base";
///
/// // Text prompt to be tokenized.
/// let prompt = "An apple a day";
///
/// let task = TaskTokenization {
/// prompt,
/// tokens: true, // return text-tokens
/// token_ids: true, // return numeric token-ids
/// };
/// let respones = client.tokenize(&task, model, &How::default()).await?;
///
/// dbg!(&respones);
/// Ok(())
/// }
/// ```
pub async fn tokenize(
&self,
task: &TaskTokenization<'_>,
model: &str,
how: &How,
) -> Result<TokenizationOutput, Error> {
self.http_client
.output_of(&task.with_model(model), how)
.await
}

/// Detokenize a list of token ids into a string.
///
/// ```no_run
/// use aleph_alpha_client::{Client, Error, How, TaskDetokenization};
///
/// async fn detokenize() -> Result<(), Error> {
/// let client = Client::new(AA_API_TOKEN)?;
///
/// // Specify the name of the model whose tokenizer was used to generate the input token ids.
/// let model = "luminous-base";
///
/// // Token ids to convert into text.
/// let token_ids: Vec<u32> = vec![556, 48741, 247, 2983];
///
/// let task = TaskDetokenization {
/// token_ids: &token_ids,
/// };
/// let respones = client.detokenize(&task, model, &How::default()).await?;
///
/// dbg!(&respones);
/// Ok(())
/// }
/// ```
pub async fn detokenize(
&self,
task: &TaskDetokenization<'_>,
model: &str,
how: &How,
) -> Result<DetokenizationOutput, Error> {
self.http_client
.output_of(&task.with_model(model), how)
.await
}
}

/// Controls of how to execute a task
Expand Down Expand Up @@ -254,7 +328,7 @@ impl Default for How {
/// Client, Prompt, TaskSemanticEmbedding, cosine_similarity, SemanticRepresentation, How
/// };
///
/// async fn semanitc_search_with_luminous_base(client: &Client) {
/// async fn semantic_search_with_luminous_base(client: &Client) {
/// // Given
/// let robot_fact = Prompt::from_text(
/// "A robot is a machine—especially one programmable by a computer—capable of carrying out a \
Expand Down
91 changes: 91 additions & 0 deletions src/tokenization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use crate::Task;
use serde::{Deserialize, Serialize};

/// Input for a [crate::Client::tokenize] request.
pub struct TaskTokenization<'a> {
/// The text prompt which should be converted into tokens
pub prompt: &'a str,

/// Specify `true` to return text-tokens.
pub tokens: bool,

/// Specify `true` to return numeric token-ids.
pub token_ids: bool,
}

impl<'a> From<&'a str> for TaskTokenization<'a> {
fn from(prompt: &'a str) -> TaskTokenization {
TaskTokenization {
prompt,
tokens: true,
token_ids: true,
}
}
}

impl TaskTokenization<'_> {
pub fn new(prompt: &str, tokens: bool, token_ids: bool) -> TaskTokenization {
TaskTokenization {
prompt,
tokens,
token_ids,
}
}
}

#[derive(Serialize, Debug)]
struct BodyTokenization<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base`.
pub model: &'a str,
/// String to tokenize.
pub prompt: &'a str,
/// Set this value to `true` to return text-tokens.
pub tokens: bool,
/// Set this value to `true` to return numeric token-ids.
pub token_ids: bool,
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ResponseTokenization {
pub tokens: Option<Vec<String>>,
pub token_ids: Option<Vec<u32>>,
}

#[derive(Debug, PartialEq)]
pub struct TokenizationOutput {
pub tokens: Option<Vec<String>>,
pub token_ids: Option<Vec<u32>>,
}

impl From<ResponseTokenization> for TokenizationOutput {
fn from(response: ResponseTokenization) -> Self {
Self {
tokens: response.tokens,
token_ids: response.token_ids,
}
}
}

impl Task for TaskTokenization<'_> {
type Output = TokenizationOutput;
type ResponseBody = ResponseTokenization;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = BodyTokenization {
model,
prompt: &self.prompt,
tokens: self.tokens,
token_ids: self.token_ids,
};
client.post(format!("{base}/tokenize")).json(&body)
}

fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
TokenizationOutput::from(response)
}
}
Loading

0 comments on commit 65d764b

Please sign in to comment.