Skip to content

Commit

Permalink
feat!: enable request-based authentication without default token on c…
Browse files Browse the repository at this point in the history
…lient
  • Loading branch information
moldhouse committed Jul 4, 2024
1 parent ce45c6c commit 4806b4b
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 41 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,5 @@ thiserror = "1.0.58"

[dev-dependencies]
dotenv = "0.15.0"
lazy_static = "1.4.0"
tokio = { version = "1.37.0", features = ["rt", "macros"] }
wiremock = "0.6.0"
5 changes: 5 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## Unreleased

## 0.10.0

* Add the option to have authentication exclusively on a per request basis, without the need to specify a dummy token.
* Rename `Client::new` to `Client::with_authentication`.

## 0.9.0

* Add `How::api_token` to allow specifying API tokens for individual requests.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use aleph_alpha_client::{Client, TaskCompletion, How, Task};
#[tokio::main]
fn main() {
// Authenticate against API. Fetches token.
let client = Client::new("AA_API_TOKEN").unwrap();
let client = Client::with_authentication("AA_API_TOKEN").unwrap();

// Name of the model we we want to use. Large models give usually better answer, but are also
// more costly.
Expand Down
14 changes: 9 additions & 5 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,19 @@ where
pub struct HttpClient {
base: String,
http: reqwest::Client,
api_token: String,
api_token: Option<String>,
}

impl HttpClient {
/// In production you typically would want set this to <https://api.aleph-alpha.com>. Yet you
/// may want to use a different instances for testing.
pub fn with_base_url(host: String, api_token: &str) -> Result<Self, Error> {
pub fn with_base_url(host: String, api_token: Option<String>) -> Result<Self, Error> {
let http = ClientBuilder::new().build()?;

Ok(Self {
base: host,
http,
api_token: api_token.to_owned(),
api_token,
})
}

Expand All @@ -106,7 +106,7 @@ impl HttpClient {
///
/// async fn print_completion() -> Result<(), Error> {
/// // Authenticate against API. Fetches token.
/// let client = Client::new("AA_API_TOKEN")?;
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Name of the model we we want to use. Large models give usually better answer, but are
/// // also slower and more costly.
Expand All @@ -132,7 +132,11 @@ impl HttpClient {
[].as_slice()
};

let api_token = how.api_token.as_ref().unwrap_or(&self.api_token);
let api_token = how
.api_token
.as_ref()
.or(self.api_token.as_ref())
.expect("API token needs to be set on client construction or per request");
let response = task
.build_request(&self.http, &self.base)
.query(query)
Expand Down
34 changes: 24 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! #[tokio::main(flavor = "current_thread")]
//! async fn main() {
//! // Authenticate against API. Fetches token.
//! let client = Client::new("AA_API_TOKEN").unwrap();
//! let client = Client::with_authentication("AA_API_TOKEN").unwrap();
//!
//! // Name of the model we we want to use. Large models give usually better answer, but are also
//! // more costly.
Expand Down Expand Up @@ -63,15 +63,29 @@ pub struct Client {

impl Client {
/// A new instance of an Aleph Alpha client helping you interact with the Aleph Alpha API.
pub fn new(api_token: &str) -> Result<Self, Error> {
/// For "normal" client applications you may likely rather use [`Self::with_authentication`] or
/// [`Self::with_base_url`].
///
/// You may want to only use request based authentication and skip default authentication. This
/// is useful if writing an application which invokes the client on behalf of many different
/// users. Having neither request, nor default authentication is considered a bug and will cause
/// a panic.
pub fn new(host: String, api_token: Option<String>) -> Result<Self, Error> {
let http_client = HttpClient::with_base_url(host, api_token)?;
Ok(Self { http_client })
}

/// Use the Aleph Alpha SaaS offering with your API token for all requests.
pub fn with_authentication(api_token: impl Into<String>) -> Result<Self, Error> {
Self::with_base_url("https://api.aleph-alpha.com".to_owned(), api_token)
}

/// Use your on-premise inference with your API token for all requests.
///
/// In production you typically would want set this to <https://api.aleph-alpha.com>. Yet
/// you may want to use a different instances for testing.
pub fn with_base_url(host: String, api_token: &str) -> Result<Self, Error> {
let http_client = HttpClient::with_base_url(host, api_token)?;
Ok(Self { http_client })
pub fn with_base_url(host: String, api_token: impl Into<String>) -> Result<Self, Error> {
Self::new(host, Some(api_token.into()))
}

/// Execute a task with the aleph alpha API and fetch its result.
Expand All @@ -81,7 +95,7 @@ impl Client {
///
/// async fn print_completion() -> Result<(), Error> {
/// // Authenticate against API. Fetches token.
/// let client = Client::new("AA_API_TOKEN")?;
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Name of the model we we want to use. Large models give usually better answer, but are
/// // also slower and more costly.
Expand Down Expand Up @@ -145,7 +159,7 @@ impl Client {
///
/// async fn print_completion() -> Result<(), Error> {
/// // Authenticate against API. Fetches token.
/// let client = Client::new("AA_API_TOKEN")?;
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Name of the model we we want to use. Large models give usually better answer, but are
/// // also slower and more costly.
Expand Down Expand Up @@ -182,7 +196,7 @@ impl Client {
/// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error, Granularity, TaskExplanation, Stopping, Prompt, Sampling};
///
/// async fn print_explanation() -> Result<(), Error> {
/// let client = Client::new("AA_API_TOKEN")?;
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Name of the model we we want to use. Large models give usually better answer, but are
/// // also slower and more costly.
Expand Down Expand Up @@ -226,7 +240,7 @@ impl Client {
/// use aleph_alpha_client::{Client, Error, How, TaskTokenization};
///
/// async fn tokenize() -> Result<(), Error> {
/// let client = Client::new("AA_API_TOKEN")?;
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Name of the model for which we want to tokenize text.
/// let model = "luminous-base";
Expand Down Expand Up @@ -262,7 +276,7 @@ impl Client {
/// use aleph_alpha_client::{Client, Error, How, TaskDetokenization};
///
/// async fn detokenize() -> Result<(), Error> {
/// let client = Client::new("AA_API_TOKEN")?;
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Specify the name of the model whose tokenizer was used to generate the input token ids.
/// let model = "luminous-base";
Expand Down
2 changes: 1 addition & 1 deletion src/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl<'a> Modality<'a> {
/// let _ = dotenv();
/// let aa_api_token = std::env::var("AA_API_TOKEN")
/// .expect("AA_API_TOKEN environment variable must be specified to run demo.");
/// let client = Client::new(&aa_api_token).unwrap();
/// let client = Client::with_authentication(aa_api_token).unwrap();
/// // Define task
/// let task = TaskCompletion {
/// prompt: Prompt::from_vec(vec![
Expand Down
86 changes: 63 additions & 23 deletions tests/integration.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fs::File, io::BufReader};
use std::{fs::File, io::BufReader, sync::OnceLock};

use aleph_alpha_client::{
cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Modality, Prompt,
Expand All @@ -8,15 +8,14 @@ use aleph_alpha_client::{
};
use dotenv::dotenv;
use image::ImageFormat;
use lazy_static::lazy_static;

lazy_static! {
static ref AA_API_TOKEN: String = {
// Use `.env` file if it exists
let _ = dotenv();
fn api_token() -> &'static str {
static AA_API_TOKEN: OnceLock<String> = OnceLock::new();
AA_API_TOKEN.get_or_init(|| {
drop(dotenv());
std::env::var("AA_API_TOKEN")
.expect("AA_API_TOKEN environment variable must be specified to run tests.")
};
})
}

#[tokio::test]
Expand All @@ -25,7 +24,7 @@ async fn completion_with_luminous_base() {
let task = TaskCompletion::from_text("Hello", 1);

let model = "luminous-base";
let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -38,17 +37,17 @@ async fn completion_with_luminous_base() {
}

#[tokio::test]
async fn completion_with_different_aa_api_token() {
async fn request_authentication_has_priority() {
let bad_aa_api_token = "DUMMY";
let task = TaskCompletion::from_text("Hello", 1);

let model = "luminous-base";
let client = Client::new(bad_aa_api_token).unwrap();
let client = Client::with_authentication(bad_aa_api_token).unwrap();
let response = client
.output_of(
&task.with_model(model),
&How {
api_token: Some(AA_API_TOKEN.to_owned()),
api_token: Some(api_token().to_owned()),
..Default::default()
},
)
Expand All @@ -60,6 +59,47 @@ async fn completion_with_different_aa_api_token() {
// Then
assert!(!response.completion.is_empty())
}

#[tokio::test]
async fn authentication_only_per_request() {
// Given
let model = "luminous-base";
let task = TaskCompletion::from_text("Hello", 1);

// When
let client = Client::new("https://api.aleph-alpha.com".to_owned(), None).unwrap();
let response = client
.output_of(
&task.with_model(model),
&How {
api_token: Some(api_token().to_owned()),
..Default::default()
},
)
.await
.unwrap();

// Then there is some successful completion
assert!(!response.completion.is_empty())
}

#[should_panic = "API token needs to be set on client construction or per request"]
#[tokio::test]
async fn must_panic_if_authentication_is_missing() {
// Given
let model = "luminous-base";
let task = TaskCompletion::from_text("Hello", 1);

// When
let client = Client::new("https://api.aleph-alpha.com".to_owned(), None).unwrap();
client
.output_of(&task.with_model(model), &How::default())
.await
.unwrap();

// Then the client panics on invocation
}

#[tokio::test]
async fn semanitc_search_with_luminous_base() {
// Given
Expand All @@ -75,7 +115,7 @@ async fn semanitc_search_with_luminous_base() {
temperature, traditionally in a wood-fired oven.",
);
let query = Prompt::from_text("What is Pizza?");
let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();

// When
let robot_embedding_task = TaskSemanticEmbedding {
Expand Down Expand Up @@ -138,7 +178,7 @@ async fn complete_structured_prompt() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -160,7 +200,7 @@ async fn explain_request() {
target: " How is it going?",
granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Sentence),
};
let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -190,7 +230,7 @@ async fn explain_request_with_auto_granularity() {
target: " How is it going?",
granularity: Granularity::default(),
};
let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -222,7 +262,7 @@ async fn explain_request_with_image_modality() {
target: " a cat.",
granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Paragraph),
};
let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -272,7 +312,7 @@ async fn describe_image_starting_from_a_path() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand Down Expand Up @@ -301,7 +341,7 @@ async fn describe_image_starting_from_a_dyn_image() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -327,7 +367,7 @@ async fn only_answer_with_specific_animal() {
},
};
let model = "luminous-base";
let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -354,7 +394,7 @@ async fn answer_should_continue() {
},
};
let model = "luminous-base";
let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -381,7 +421,7 @@ async fn batch_semanitc_embed_with_luminous_base() {
temperature, traditionally in a wood-fired oven.",
);

let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();

// When
let embedding_task = TaskBatchSemanticEmbedding {
Expand All @@ -406,7 +446,7 @@ async fn tokenization_with_luminous_base() {
// Given
let input = "Hello, World!";

let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();

// When
let task1 = TaskTokenization::new(input, false, true);
Expand Down Expand Up @@ -443,7 +483,7 @@ async fn detokenization_with_luminous_base() {
// Given
let input = vec![49222, 15, 5390, 4];

let client = Client::new(&AA_API_TOKEN).unwrap();
let client = Client::with_authentication(api_token()).unwrap();

// When
let task = TaskDetokenization { token_ids: &input };
Expand Down

0 comments on commit 4806b4b

Please sign in to comment.