Skip to content

Commit

Permalink
refactor clients
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Apr 4, 2024
1 parent 2bd816c commit b55d73a
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 232 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,8 @@
# Changelog

## [0.1.16] - 2024-04-04
Bug fixes and code refactoring in clients

## [0.1.15] - 2024-04-02
Bug fixes

Expand Down
17 changes: 2 additions & 15 deletions Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "rxqlite"
version = "0.1.15"
version = "0.1.16"
readme = "README.md"
edition = "2021"
authors = [
Expand Down Expand Up @@ -94,14 +94,6 @@ features = ["runtime-tokio-rustls" , "chrono" ]
#path = "../sqlx-sqlite-cipher"
version = "0.7"


[dependencies.sqlx-core]
version = "0.7"

[dependencies.rsa]
version = "0.9"
optional = true

[dependencies.aes-gcm-siv]
version = "0.11.1"
optional = true
Expand All @@ -114,9 +106,6 @@ optional = true
version = "0.22"
optional = true

[dependencies.rand]
version = "0.8"
optional = true

[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0" , features = [ "fileapi" , "winbase" ] }
Expand All @@ -131,11 +120,9 @@ version = "0.1"
[features]
default = [ "bundled-sqlcipher-vendored-openssl" ]
test-dependency = [ "rxqlite-tests-common" ]
sqlcipher = [ "sqlx-sqlite-cipher/sqlcipher" , "rsa-crate" , "ring" , "base64" , "aes-gcm-siv" ]
sqlcipher = [ "sqlx-sqlite-cipher/sqlcipher" , "ring" , "base64" , "aes-gcm-siv" ]
bundled-sqlcipher = [ "sqlx-sqlite-cipher/bundled-sqlcipher" , "sqlcipher" ]
bundled-sqlcipher-vendored-openssl = [ "sqlx-sqlite-cipher/bundled-sqlcipher-vendored-openssl" , "sqlcipher" ]

rsa-crate = [ "rsa" , "rand" ]

[package.metadata.docs.rs]
all-features = true
2 changes: 1 addition & 1 deletion ROADMAP.md
@@ -1,6 +1,6 @@
# Roadmap

## [0.1.16]
## [0.1.17]
Client authorization

## [0.1.12]
Expand Down
11 changes: 8 additions & 3 deletions crates/rxqlite-lite-common/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "rxqlite-lite-common"
version = "0.1.1"
version = "0.1.2"


edition = "2021"
Expand All @@ -17,10 +17,15 @@ repository = "https://github.com/HaHa421/rxqlite"

[dependencies]
anyhow = "1.0"
thiserror = "1.0.58"
derive_more = "0.99.9"
futures-util= "0.3"
rustls = {version = "0.22" }
serde = { version = "1" , features = [ "derive" ] }
chrono = { version = "0.4" , features = [ "serde" ] }
serde_json = "1.0.57"
thiserror = "1.0.58"
tokio = { version = "1.35.1", features = ["full"] }
tokio-rustls = "0.26"
tokio-util = { version = "0.7" , features = [ "codec" ] }

[dependencies.rxqlite-common]
version = "0.1"
Expand Down
11 changes: 10 additions & 1 deletion crates/rxqlite-lite-common/src/lib.rs
@@ -1,3 +1,9 @@
#![deny(warnings)]
#![deny(unused_crate_dependencies)]
#![deny(unused_extern_crates)]



use std::fmt::Display;
use std::time::Duration;
use std::error::Error;
Expand Down Expand Up @@ -383,4 +389,7 @@ impl ConnectOptions {
Ok(client::RXQLiteClient::with_options(self))
}
}
*/
*/

mod net_stream;
pub use net_stream::NetStream;
223 changes: 223 additions & 0 deletions crates/rxqlite-lite-common/src/net_stream.rs
@@ -0,0 +1,223 @@
use super::*;
use tokio::time::{timeout, Duration};
//use tokio::io::{AsyncReadExt,AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use tokio_util::bytes::BytesMut;
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};

use futures_util::stream::StreamExt;
use futures_util::SinkExt;
use tokio::io::split;
use tokio::io::{ReadHalf, WriteHalf};
use tokio_rustls::rustls::ClientConfig;
use tokio_rustls::rustls::RootCertStore;
use serde_json::{from_slice, to_vec};
use std::sync::Arc;

#[derive(Debug)]
struct AllowAnyCertVerifier;

impl tokio_rustls::rustls::client::danger::ServerCertVerifier for AllowAnyCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
_intermediates: &[tokio_rustls::rustls::pki_types::CertificateDer<'_>],
_server_name: &tokio_rustls::rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: tokio_rustls::rustls::pki_types::UnixTime,
) -> Result<tokio_rustls::rustls::client::danger::ServerCertVerified, tokio_rustls::rustls::Error>
{
Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
_dss: &tokio_rustls::rustls::DigitallySignedStruct,
) -> Result<
tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
tokio_rustls::rustls::Error,
> {
Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
_dss: &tokio_rustls::rustls::DigitallySignedStruct,
) -> Result<
tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
tokio_rustls::rustls::Error,
> {
Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
vec![
tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA1,
tokio_rustls::rustls::SignatureScheme::ECDSA_SHA1_Legacy,
tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA256,
tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA384,
tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA512,
tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA256,
tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA384,
tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA512,
tokio_rustls::rustls::SignatureScheme::ED25519,
tokio_rustls::rustls::SignatureScheme::ED448,
]
}
}

pub enum NetStream {
Tls(
FramedWrite<WriteHalf<tokio_rustls::client::TlsStream<TcpStream>>, LengthDelimitedCodec>,
FramedRead<ReadHalf<tokio_rustls::client::TlsStream<TcpStream>>, LengthDelimitedCodec>,
),
Tcp(
FramedWrite<WriteHalf<TcpStream>, LengthDelimitedCodec>,
FramedRead<ReadHalf<TcpStream>, LengthDelimitedCodec>,
),
}

impl From<tokio_rustls::client::TlsStream<TcpStream>> for NetStream {
fn from(stream: tokio_rustls::client::TlsStream<TcpStream>) -> Self {
let (reader, writer) = split(stream);
Self::Tls(
FramedWrite::new(writer, LengthDelimitedCodec::new()),
FramedRead::new(reader, LengthDelimitedCodec::new()),
)
}
}

impl From<TcpStream> for NetStream {
fn from(stream: TcpStream) -> Self {
let (reader, writer) = split(stream);
Self::Tcp(
FramedWrite::new(writer, LengthDelimitedCodec::new()),
FramedRead::new(reader, LengthDelimitedCodec::new()),
)
}
}

impl NetStream {
pub async fn new(notifications_addr:&str, accept_invalid_certificates: Option<bool>)->anyhow::Result<Self> {
if let Some(accept_invalid_certificates) =accept_invalid_certificates {
let root_certs = RootCertStore::empty();
let mut config/*: rustls::ConfigBuilder<ClientConfig,rustls::WantsVersions>*/= ClientConfig::builder()
.with_root_certificates(root_certs)
.with_no_client_auth();
if accept_invalid_certificates {
config
.dangerous()
.set_certificate_verifier(Arc::new(AllowAnyCertVerifier));
}

let connector = TlsConnector::from(Arc::new(config));
let server_name = rustls::pki_types::ServerName::try_from(
notifications_addr.split(":").next().unwrap(),
)?;
let stream = TcpStream::connect(notifications_addr).await?;
let tls_stream = connector.connect(server_name.to_owned(), stream).await?;
let notification_stream = NetStream::from(tls_stream);
Ok(notification_stream)
} else {
let stream = TcpStream::connect(notifications_addr).await?;
let notification_stream = NetStream::from(stream);
Ok(notification_stream)
}
}


pub async fn write(&mut self, notification_request: NotificationRequest) -> anyhow::Result<()> {
let message = to_vec(&notification_request)?;
match self {
Self::Tls(framed_write, _) => {
framed_write
.send(BytesMut::from(message.as_slice()).freeze())
.await?;
}
Self::Tcp(framed_write, _) => {
framed_write
.send(BytesMut::from(message.as_slice()).freeze())
.await?;
}
}
Ok(())
}
pub async fn read(&mut self) -> anyhow::Result<NotificationEvent> {
match self {
Self::Tls(_, length_delimited_stream) => {
let message = length_delimited_stream.next().await;
if let Some(message) = message {
let message: BytesMut = message?;
let message: NotificationEvent = from_slice(&message)?;
Ok(message)
} else {
Err(anyhow::anyhow!("stream closed"))
}
}
Self::Tcp(_, length_delimited_stream) => {
let message = length_delimited_stream.next().await;
if let Some(message) = message {
let message: BytesMut = message?;
let message: NotificationEvent = from_slice(&message)?;
Ok(message)
} else {
Err(anyhow::anyhow!("stream closed"))
}
}
}
}
pub async fn read_timeout(
&mut self,
timeout_duration: Duration,
) -> anyhow::Result<Option<NotificationEvent>> {
match timeout(timeout_duration,self.read()).await {
Ok(notification)=>Ok(Some(notification?)),
Err(_)=>Ok(None),
}
}
/*
#[allow(dead_code)]
pub async fn read_timeout(
&mut self,
timeout_duration: Duration,
) -> anyhow::Result<Option<NotificationEvent>> {
match self {
Self::Tls(_, length_delimited_stream) => {
let res = timeout(timeout_duration, length_delimited_stream.next()).await;
match res {
Ok(message) => {
if let Some(message) = message {
let message: BytesMut = message?;
let message: NotificationEvent = from_slice(&message)?;
Ok(Some(message))
} else {
Err(anyhow::anyhow!("stream closed"))
}
}
Err(_) => Ok(None),
}
}
Self::Tcp(_, length_delimited_stream) => {
let res = timeout(timeout_duration, length_delimited_stream.next()).await;
match res {
Ok(message) => {
if let Some(message) = message {
let message: BytesMut = message?;
let message: NotificationEvent = from_slice(&message)?;
Ok(Some(message))
} else {
Err(anyhow::anyhow!("stream closed"))
}
}
Err(_) => Ok(None),
}
}
}
}
*/
}

0 comments on commit b55d73a

Please sign in to comment.