Skip to content

Commit

Permalink
refactor: bundle the lightweight axum test client (#3669)
Browse files Browse the repository at this point in the history
* refactor: bundle the lightweight axum test client

Signed-off-by: tison <wander4096@gmail.com>

* address comments

Signed-off-by: tison <wander4096@gmail.com>

---------

Signed-off-by: tison <wander4096@gmail.com>
  • Loading branch information
tisonkun committed Apr 9, 2024
1 parent ea9367f commit 883b7fc
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 31 deletions.
25 changes: 3 additions & 22 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ reqwest = { version = "0.11", default-features = false, features = [
"json",
"rustls-tls-native-roots",
"stream",
"multipart",
] }
rskafka = "0.5"
rust_decimal = "1.33"
Expand Down
4 changes: 2 additions & 2 deletions src/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ derive_builder.workspace = true
futures = "0.3"
hashbrown = "0.14"
headers = "0.3"
hostname = "0.3.1"
hostname = "0.3"
http = "0.2"
http-body = "0.4"
humantime-serde.workspace = true
hyper = { version = "0.14", features = ["full"] }
Expand Down Expand Up @@ -109,7 +110,6 @@ tikv-jemalloc-ctl = { version = "0.5", features = ["use_std"] }

[dev-dependencies]
auth = { workspace = true, features = ["testing"] }
axum-test-helper = "0.3"
catalog = { workspace = true, features = ["testing"] }
client.workspace = true
common-base.workspace = true
Expand Down
5 changes: 4 additions & 1 deletion src/servers/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ pub mod greptime_result_v1;
pub mod influxdb_result_v1;
pub mod table_result;

#[cfg(any(test, feature = "testing"))]
pub mod test_helpers;

pub const HTTP_API_VERSION: &str = "v1";
pub const HTTP_API_PREFIX: &str = "/v1/";
/// Default http body limit (64M).
Expand Down Expand Up @@ -824,7 +827,6 @@ mod test {
use axum::handler::Handler;
use axum::http::StatusCode;
use axum::routing::get;
use axum_test_helper::TestClient;
use common_query::Output;
use common_recordbatch::RecordBatches;
use datatypes::prelude::*;
Expand All @@ -838,6 +840,7 @@ mod test {

use super::*;
use crate::error::Error;
use crate::http::test_helpers::TestClient;
use crate::query_handler::grpc::GrpcQueryHandler;
use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};

Expand Down
254 changes: 254 additions & 0 deletions src/servers/src/http/test_helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Axum Test Client
//!
//! ```rust
//! use axum::Router;
//! use axum::http::StatusCode;
//! use axum::routing::get;
//! use crate::servers::http::test_helpers::TestClient;
//!
//! let async_block = async {
//! // you can replace this Router with your own app
//! let app = Router::new().route("/", get(|| async {}));
//!
//! // initiate the TestClient with the previous declared Router
//! let client = TestClient::new(app);
//!
//! let res = client.get("/").await;
//! assert_eq!(res.status(), StatusCode::OK);
//! };
//!
//! // Create a runtime for executing the async block. This runtime is local
//! // to the main function and does not require any global setup.
//! let runtime = tokio::runtime::Builder::new_current_thread()
//! .enable_all()
//! .build()
//! .unwrap();
//!
//! // Use the local runtime to block on the async block.
//! runtime.block_on(async_block);
//! ```

use std::convert::TryFrom;
use std::net::{SocketAddr, TcpListener};

use axum::body::HttpBody;
use axum::BoxError;
use bytes::Bytes;
use common_telemetry::info;
use http::header::{HeaderName, HeaderValue};
use http::{Request, StatusCode};
use hyper::service::Service;
use hyper::{Body, Server};
use tower::make::Shared;

/// Test client to Axum servers.
pub struct TestClient {
client: reqwest::Client,
addr: SocketAddr,
}

impl TestClient {
/// Create a new test client.
pub fn new<S, ResBody>(svc: S) -> Self
where
S: Service<Request<Body>, Response = http::Response<ResBody>> + Clone + Send + 'static,
ResBody: HttpBody + Send + 'static,
ResBody::Data: Send,
ResBody::Error: Into<BoxError>,
S::Future: Send,
S::Error: Into<BoxError>,
{
let listener = TcpListener::bind("127.0.0.1:0").expect("Could not bind ephemeral socket");
let addr = listener.local_addr().unwrap();
info!("Listening on {}", addr);

tokio::spawn(async move {
let server = Server::from_tcp(listener).unwrap().serve(Shared::new(svc));
server.await.expect("server error");
});

let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();

TestClient { client, addr }
}

/// Returns the base URL (http://ip:port) for this TestClient
///
/// this is useful when trying to check if Location headers in responses
/// are generated correctly as Location contains an absolute URL
pub fn base_url(&self) -> String {
format!("http://{}", self.addr)
}

/// Create a GET request.
pub fn get(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.get(format!("http://{}{}", self.addr, url)),
}
}

/// Create a HEAD request.
pub fn head(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.head(format!("http://{}{}", self.addr, url)),
}
}

/// Create a POST request.
pub fn post(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.post(format!("http://{}{}", self.addr, url)),
}
}

/// Create a PUT request.
pub fn put(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.put(format!("http://{}{}", self.addr, url)),
}
}

/// Create a PATCH request.
pub fn patch(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.patch(format!("http://{}{}", self.addr, url)),
}
}

/// Create a DELETE request.
pub fn delete(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.delete(format!("http://{}{}", self.addr, url)),
}
}
}

/// Builder for test requests.
pub struct RequestBuilder {
builder: reqwest::RequestBuilder,
}

impl RequestBuilder {
pub async fn send(self) -> TestResponse {
TestResponse {
response: self.builder.send().await.unwrap(),
}
}

/// Set the request body.
pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
self.builder = self.builder.body(body);
self
}

/// Set the request forms.
pub fn form<T: serde::Serialize + ?Sized>(mut self, form: &T) -> Self {
self.builder = self.builder.form(&form);
self
}

/// Set the request JSON body.
pub fn json<T>(mut self, json: &T) -> Self
where
T: serde::Serialize,
{
self.builder = self.builder.json(json);
self
}

/// Set a request header.
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.builder = self.builder.header(key, value);
self
}

/// Set a request multipart form.
pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
self.builder = self.builder.multipart(form);
self
}
}

/// A wrapper around [`reqwest::Response`] that provides common methods with internal `unwrap()`s.
///
/// This is convenient for tests where panics are what you want. For access to
/// non-panicking versions or the complete `Response` API use `into_inner()` or
/// `as_ref()`.
pub struct TestResponse {
response: reqwest::Response,
}

impl TestResponse {
/// Get the response body as text.
pub async fn text(self) -> String {
self.response.text().await.unwrap()
}

/// Get the response body as bytes.
pub async fn bytes(self) -> Bytes {
self.response.bytes().await.unwrap()
}

/// Get the response body as JSON.
pub async fn json<T>(self) -> T
where
T: serde::de::DeserializeOwned,
{
self.response.json().await.unwrap()
}

/// Get the response status.
pub fn status(&self) -> StatusCode {
self.response.status()
}

/// Get the response headers.
pub fn headers(&self) -> &http::HeaderMap {
self.response.headers()
}

/// Get the response in chunks.
pub async fn chunk(&mut self) -> Option<Bytes> {
self.response.chunk().await.unwrap()
}

/// Get the response in chunks as text.
pub async fn chunk_text(&mut self) -> Option<String> {
let chunk = self.chunk().await?;
Some(String::from_utf8(chunk.to_vec()).unwrap())
}

/// Get the inner [`reqwest::Response`] for less convenient but more complete access.
pub fn into_inner(self) -> reqwest::Response {
self.response
}
}

impl AsRef<reqwest::Response> for TestResponse {
fn as_ref(&self) -> &reqwest::Response {
&self.response
}
}
2 changes: 1 addition & 1 deletion src/servers/tests/http/http_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

use axum::Router;
use axum_test_helper::TestClient;
use common_test_util::ports;
use servers::http::test_helpers::TestClient;
use servers::http::{HttpOptions, HttpServerBuilder};
use table::test_util::MemTable;

Expand Down
2 changes: 1 addition & 1 deletion src/servers/tests/http/influxdb_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ use api::v1::RowInsertRequests;
use async_trait::async_trait;
use auth::tests::{DatabaseAuthInfo, MockUserProvider};
use axum::{http, Router};
use axum_test_helper::TestClient;
use common_query::Output;
use common_test_util::ports;
use query::parser::PromQuery;
use query::plan::LogicalPlan;
use query::query_engine::DescribeResult;
use servers::error::{Error, Result};
use servers::http::header::constants::GREPTIME_DB_HEADER_NAME;
use servers::http::test_helpers::TestClient;
use servers::http::{HttpOptions, HttpServerBuilder};
use servers::influxdb::InfluxdbRequest;
use servers::query_handler::grpc::GrpcQueryHandler;
Expand Down
Loading

0 comments on commit 883b7fc

Please sign in to comment.