Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
640 changes: 465 additions & 175 deletions Cargo.lock

Large diffs are not rendered by default.

15 changes: 6 additions & 9 deletions src/db/models/enrollment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ impl Token {
}

pub async fn find_by_id(pool: &DbPool, id: &str) -> Result<Self, TokenError> {
match query_as!(
if let Some(enrollment) = query_as!(
Self,
"SELECT id, user_id, admin_id, email, created_at, expires_at, used_at, token_type \
FROM token WHERE id = $1",
Expand All @@ -201,14 +201,11 @@ impl Token {
.fetch_optional(pool)
.await?
{
Some(enrollment) => {
debug!("Fetch token {enrollment:?} from database.");
Ok(enrollment)
}
None => {
debug!("Token with id {} does not exist in database.", id);
Err(TokenError::NotFound)
}
debug!("Fetch token {enrollment:?} from database.");
Ok(enrollment)
} else {
debug!("Token with id {id} does not exist in database.");
Err(TokenError::NotFound)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/enterprise/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ where

match validate_license(license.as_ref()) {
// Useless struct, but may come in handy later
Ok(_) => Ok(LicenseInfo { valid: true }),
Ok(()) => Ok(LicenseInfo { valid: true }),
Err(e) => Err(WebError::Forbidden(e.to_string())),
}
}
Expand Down
35 changes: 12 additions & 23 deletions src/enterprise/handlers/openid_login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,22 @@ async fn get_provider_metadata(url: &str) -> Result<ProvMeta, WebError> {
// Discover the provider metadata based on a known base issuer URL
// The url should be in the form of e.g. https://accounts.google.com
// The url shouldn't contain a .well-known part, it will be added automatically
let provider_metadata = match CoreProviderMetadata::discover_async(
issuer_url,
async_http_client,
)
.await
{
Ok(metadata) => metadata,
Err(_) => {
return Err(WebError::Authorization(format!(
"Failed to discover provider metadata, make sure the providers' url is correct: {url}",

)));
}
let Ok(provider_metadata) =
CoreProviderMetadata::discover_async(issuer_url, async_http_client).await
else {
return Err(WebError::Authorization(format!(
"Failed to discover provider metadata, make sure the providers' url is correct: {url}",
)));
};

Ok(provider_metadata)
}

async fn make_oidc_client(pool: &DbPool) -> Result<CoreClient, WebError> {
let provider = match OpenIdProvider::get_current(pool).await? {
Some(provider) => provider,
None => {
return Err(WebError::ObjectNotFound(
"OpenID provider not set".to_string(),
));
}
let Some(provider) = OpenIdProvider::get_current(pool).await? else {
return Err(WebError::ObjectNotFound(
"OpenID provider not set".to_string(),
));
};

let provider_metadata = get_provider_metadata(&provider.base_url).await?;
Expand All @@ -88,7 +78,7 @@ async fn make_oidc_client(pool: &DbPool) -> Result<CoreClient, WebError> {
let redirect_url = match RedirectUrl::new(url) {
Ok(url) => url,
Err(err) => {
error!("Failed to create redirect URL: {:?}", err);
error!("Failed to create redirect URL: {err:?}");
return Err(WebError::Authorization(
"Failed to create redirect URL".to_string(),
));
Expand Down Expand Up @@ -273,8 +263,7 @@ pub async fn auth_callback(
.is_some()
{
return Err(WebError::Authorization(format!(
"User with username {} already exists",
username
"User with username {username} already exists"
)));
}

Expand Down
1 change: 1 addition & 0 deletions src/enterprise/handlers/openid_providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct DeleteProviderData {
}

impl AddProviderData {
#[must_use]
pub fn new(name: &str, base_url: &str, client_id: &str, client_secret: &str) -> Self {
Self {
name: name.to_string(),
Expand Down
57 changes: 30 additions & 27 deletions src/enterprise/license.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ use anyhow::Result;
use base64::prelude::*;
use chrono::{DateTime, TimeDelta, Utc};
use humantime::format_duration;
use pgp::{
types::KeyTrait, Deserializable, SignedPublicKey, SignedPublicSubKey, StandaloneSignature,
};
use pgp::{types::KeyTrait, Deserializable, SignedPublicKey, StandaloneSignature};
use prost::Message;
use sqlx::error::Error as SqlxError;
use thiserror::Error;
Expand All @@ -20,6 +18,9 @@ use crate::{
VERSION,
};

// FIXME: this should be a hardcoded IP, make sure to add appropriate host headers
const LICENSE_SERVER_URL: &str = "https://update-service-dev.defguard.net/api/license/renew";

static LICENSE: RwLock<Option<License>> = RwLock::new(None);

pub fn set_cached_license(license: Option<License>) {
Expand Down Expand Up @@ -250,7 +251,7 @@ impl License {
let signing_key = public_key
.public_subkeys
.into_iter()
.find(|subkey| subkey.is_signing_key())
.find(KeyTrait::is_signing_key)
.ok_or(LicenseError::LicenseServerError(
"Failed to find a signing key in the provided public key".to_string(),
))?;
Expand Down Expand Up @@ -282,7 +283,7 @@ impl License {
debug!("Deserialized the license object, verifying the license signature...");

match Self::verify_signature(metadata_bytes, signature_bytes) {
Ok(_) => {
Ok(()) => {
info!("Successfully decoded the license and validated the license signature");
let metadata = LicenseMetadata::decode(metadata_bytes).map_err(|_| {
LicenseError::DecodeError("Failed to decode the license metadata".to_string())
Expand Down Expand Up @@ -332,12 +333,11 @@ impl License {
/// Create the license object based on the license key stored in the database.
/// Automatically decodes and deserializes the keys and verifies the signature.
pub async fn load(pool: &DbPool) -> Result<Option<License>, LicenseError> {
match Self::get_key(pool).await? {
Some(key) => Ok(Some(Self::from_base64(&key)?)),
None => {
debug!("No license key found in the database");
Ok(None)
}
if let Some(key) = Self::get_key(pool).await? {
Ok(Some(Self::from_base64(&key)?))
} else {
debug!("No license key found in the database");
Ok(None)
}
}

Expand All @@ -347,7 +347,9 @@ impl License {
match Self::load(pool).await? {
Some(license) => {
if license.requires_renewal() {
if !license.is_max_overdue() {
if license.is_max_overdue() {
Err(LicenseError::LicenseExpired)
} else {
info!("License requires renewal, trying to renew it...");
match renew_license(pool).await {
Ok(new_key) => {
Expand All @@ -361,8 +363,6 @@ impl License {
Ok(Some(license))
}
}
} else {
Err(LicenseError::LicenseExpired)
}
} else {
info!("Successfully loaded the license from the database.");
Expand All @@ -377,6 +377,7 @@ impl License {
///
/// NOTE: license should be considered valid for an additional period of `MAX_OVERDUE_TIME`.
/// If you want to check if the license reached this point, use `is_max_overdue` instead.
#[must_use]
pub fn is_expired(&self) -> bool {
match self.valid_until {
Some(time) => time < Utc::now(),
Expand All @@ -385,12 +386,14 @@ impl License {
}

/// Checks how much time has left until the `valid_until` time.
#[must_use]
pub fn time_left(&self) -> Option<TimeDelta> {
self.valid_until.map(|time| time - Utc::now())
}

/// Gets the time the license is past its expiry date.
/// If the license doesn't have a `valid_until` field, it will return 0.
#[must_use]
pub fn time_overdue(&self) -> TimeDelta {
match self.valid_until {
Some(time) => {
Expand All @@ -406,6 +409,7 @@ impl License {
}

/// Checks whether we should try to renew the license.
#[must_use]
pub fn requires_renewal(&self) -> bool {
if self.subscription {
if let Some(remaining) = self.time_left() {
Expand All @@ -419,6 +423,7 @@ impl License {
}

/// Checks if the license has reached its maximum overdue time.
#[must_use]
pub fn is_max_overdue(&self) -> bool {
if !self.subscription {
// Non-subscription licenses are considered expired immediately, no grace period is required
Expand All @@ -434,9 +439,8 @@ impl License {
/// Doesn't update the cached license, nor does it save the new key in the database.
async fn renew_license(db_pool: &DbPool) -> Result<String, LicenseError> {
debug!("Exchanging license for a new one...");
let old_license_key = match Settings::get_settings(db_pool).await?.license {
Some(key) => key,
None => return Err(LicenseError::LicenseNotFound),
let Some(old_license_key) = Settings::get_settings(db_pool).await?.license else {
return Err(LicenseError::LicenseNotFound);
};

let client = reqwest::Client::new();
Expand All @@ -445,9 +449,6 @@ async fn renew_license(db_pool: &DbPool) -> Result<String, LicenseError> {
key: old_license_key,
};

// FIXME: this should be a hardcoded IP, make sure to add appropriate host headers
const LICENSE_SERVER_URL: &str = "https://update-service-dev.defguard.net/api/license/renew";

let new_license_key = match client
.post(LICENSE_SERVER_URL)
.json(&request_body)
Expand Down Expand Up @@ -576,21 +577,23 @@ pub async fn run_periodic_license_check(pool: DbPool) -> Result<(), LicenseError
if license.requires_renewal() {
// check if we are pass the maximum expiration date, after which we don't
// want to try to renew the license anymore
if !license.is_max_overdue() {
debug!("License requires renewal, as it is about to expire and is not past the maximum overdue time");
true
} else {
if license.is_max_overdue() {
check_period = CHECK_PERIOD;
warn!("Your license has expired and reached its maximum overdue date, please contact sales at sales<at>defguard.net");
debug!("Changing check period to {}", format_duration(check_period));
false
} else {
debug!("License requires renewal, as it is about to expire and is not past the maximum overdue time");
true
}
} else {
// This if is only for logging purposes, to provide more detailed information
if license.subscription {
debug!("License doesn't need to be renewed yet, skipping renewal check")
debug!(
"License doesn't need to be renewed yet, skipping renewal check"
);
} else {
debug!("License is not a subscription, skipping renewal check")
debug!("License is not a subscription, skipping renewal check");
}
false
}
Expand All @@ -608,7 +611,7 @@ pub async fn run_periodic_license_check(pool: DbPool) -> Result<(), LicenseError
debug!("Changing check period to {}", format_duration(check_period));
match renew_license(&pool).await {
Ok(new_license_key) => match save_license_key(&pool, &new_license_key).await {
Ok(_) => {
Ok(()) => {
update_cached_license(Some(&new_license_key))?;
check_period = CHECK_PERIOD;
debug!("Changing check period to {}", format_duration(check_period));
Expand Down
19 changes: 8 additions & 11 deletions src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,14 @@ pub async fn user_for_admin_or_self(
) -> Result<User, WebError> {
if session.user.username == username || session.is_admin {
debug!("The user meets one or both of these conditions: 1) the user from the current session has admin privileges, 2) the user performs this operation on themself.");
match User::find_by_username(pool, username).await? {
Some(user) => {
debug!("User {} has been found in database.", user.username);
Ok(user)
}
None => {
debug!("User with {} does not exist in database.", username);
Err(WebError::ObjectNotFound(format!(
"user {username} not found"
)))
}
if let Some(user) = User::find_by_username(pool, username).await? {
debug!("User {} has been found in database.", user.username);
Ok(user)
} else {
debug!("User with {username} does not exist in database.");
Err(WebError::ObjectNotFound(format!(
"user {username} not found"
)))
}
} else {
debug!(
Expand Down
2 changes: 1 addition & 1 deletion src/handlers/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub fn prune_username(username: &str) -> String {
.to_string();

result.retain(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_');
result = result.replace(" ", "");
result = result.replace(' ', "");

result
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ Available actions:
components.add_security_scheme(
"api_key",
SecurityScheme::ApiKey(ApiKey::Header(ApiKeyValue::new("user_apikey"))),
)
);
}
}
}
Expand Down