Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions biome.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
"bracketSpacing": true,
"expand": "auto",
"useEditorconfig": true,
"includes": [
"./src/**"
]
"includes": ["./src/**"]
},
"linter": {
"enabled": true,
Expand Down Expand Up @@ -70,9 +68,7 @@
"noArrayIndexKey": "off"
}
},
"includes": [
"src/**"
]
"includes": ["src/**"]
},
"javascript": {
"formatter": {
Expand All @@ -94,9 +90,7 @@
},
"overrides": [
{
"includes": [
"**/*.js"
]
"includes": ["**/*.js"]
}
],
"assist": {
Expand Down
5 changes: 4 additions & 1 deletion src-tauri/src/appstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
models::{connection::ActiveConnection, Id},
DB_POOL,
},
enterprise::provisioning::ProvisioningConfig,
utils::stats_handler,
ConnectionType,
};
Expand All @@ -18,15 +19,17 @@ pub struct AppState {
pub log_watchers: Mutex<HashMap<String, CancellationToken>>,
pub app_config: Mutex<AppConfig>,
stat_threads: Mutex<HashMap<Id, JoinHandle<()>>>, // location ID is the key
pub provisioning_config: Mutex<Option<ProvisioningConfig>>,
}

impl AppState {
#[must_use]
pub fn new(config: AppConfig) -> Self {
pub fn new(config: AppConfig, provisioning_config: Option<ProvisioningConfig>) -> Self {
AppState {
log_watchers: Mutex::new(HashMap::new()),
app_config: Mutex::new(config),
stat_threads: Mutex::new(HashMap::new()),
provisioning_config: Mutex::new(provisioning_config),
}
}

Expand Down
28 changes: 23 additions & 5 deletions src-tauri/src/bin/defguard-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use defguard_client::{
models::{location_stats::LocationStats, tunnel::TunnelStats},
DB_POOL,
},
enterprise::provisioning::handle_client_initialization,
periodic::run_periodic_tasks,
service,
tray::{configure_tray_icon, setup_tray, show_main_window},
Expand Down Expand Up @@ -137,7 +138,8 @@ fn main() {
start_global_logwatcher,
stop_global_logwatcher,
command_get_app_config,
command_set_app_config
command_set_app_config,
get_provisioning_config
])
.on_window_event(|window, event| {
if let WindowEvent::CloseRequested { api, .. } = event {
Expand Down Expand Up @@ -244,7 +246,12 @@ fn main() {
.build(),
)?;

let state = AppState::new(config);
// Check if client needs to be initialized
// and try to load provisioning config if necessary
let provisioning_config =
tauri::async_runtime::block_on(handle_client_initialization(app_handle));

let state = AppState::new(config, provisioning_config);
app.manage(state);

info!("App setup completed, log level: {log_level}");
Expand All @@ -271,9 +278,11 @@ fn main() {

// Ensure directories have appropriate permissions (dg25-28).
#[cfg(unix)]
set_perms(&data_dir);
#[cfg(unix)]
set_perms(&log_dir);
{
set_perms(&data_dir);
set_perms(&log_dir);
}

info!(
"Application data (database file) will be stored in: {data_dir:?} and application \
logs in: {log_dir:?}. Logs of the background Defguard service responsible for \
Expand All @@ -293,6 +302,15 @@ fn main() {
app_handle_clone.exit(0);
});
debug!("Ctrl-C handler has been set up successfully");

let app_handle_clone = app_handle.clone();
tauri::async_runtime::spawn(async move {
// Wait for frontend to be ready
tokio::time::sleep(std::time::Duration::from_secs(15)).await;

// Handle client initialization if necessary
handle_client_initialization(&app_handle_clone).await;
});
}
RunEvent::ExitRequested { code, api, .. } => {
debug!("Received exit request");
Expand Down
19 changes: 18 additions & 1 deletion src-tauri/src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::{
},
DB_POOL,
},
enterprise::periodic::config::poll_instance,
enterprise::{periodic::config::poll_instance, provisioning::ProvisioningConfig},
error::Error,
events::EventKey,
log_watcher::{
Expand Down Expand Up @@ -1120,3 +1120,20 @@ pub async fn command_set_app_config(
}
Ok(res)
}

#[tauri::command]
pub fn get_provisioning_config(
app_state: State<'_, AppState>,
) -> Result<Option<ProvisioningConfig>, Error> {
debug!("Running command get_provisioning_config.");
let res = app_state
.provisioning_config
.lock()
.map_err(|_err| {
error!("Failed to acquire lock on client provisioning config");
Error::StateLockFail
})?
.clone();
trace!("Returning config: {res:?}");
Ok(res)
}
2 changes: 1 addition & 1 deletion src-tauri/src/database/models/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl Instance<Id> {
Ok(())
}

pub(crate) async fn all<'e, E>(executor: E) -> Result<Vec<Self>, sqlx::Error>
pub async fn all<'e, E>(executor: E) -> Result<Vec<Self>, sqlx::Error>
where
E: SqliteExecutor<'e>,
{
Expand Down
1 change: 1 addition & 0 deletions src-tauri/src/enterprise/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod models;
pub mod periodic;
pub mod provisioning;
76 changes: 76 additions & 0 deletions src-tauri/src/enterprise/provisioning/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use std::{fs::OpenOptions, path::Path};

use serde::{Deserialize, Serialize};
use tauri::{AppHandle, Manager};

use crate::database::{models::instance::Instance, DB_POOL};

const CONFIG_FILE_NAME: &str = "enrollment.json";

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ProvisioningConfig {
pub enrollment_url: String,
pub enrollment_token: String,
}

impl ProvisioningConfig {
/// Load configuration from a file at `path`.
fn load(path: &Path) -> Option<Self> {
let file = match OpenOptions::new().read(true).open(path) {
Ok(file) => file,
Err(err) => {
warn!("Failed to open provisioning configuration file at {path:?}. Error details: {err}");
return None;
}
};
match serde_json::from_reader::<_, Self>(file) {
Ok(config) => Some(config),
Err(err) => {
warn!("Failed to parse provisioning configuration file at {path:?}. Error details: {err}");
None
}
}
}
}

pub fn try_get_provisioning_config(app_data_dir: &Path) -> Option<ProvisioningConfig> {
debug!("Trying to find provisioning config in {app_data_dir:?}");

let config_file_path = app_data_dir.join(CONFIG_FILE_NAME);
ProvisioningConfig::load(&config_file_path)
}

/// Checks if the client has already been initialized
/// and tries to load provisioning config from file if necessary
pub async fn handle_client_initialization(app_handle: &AppHandle) -> Option<ProvisioningConfig> {
// check if client has already been initialized
// we assume that if any instances exist the client has been initialized
match Instance::all(&*DB_POOL).await {
Ok(instances) => {
if instances.is_empty() {
debug!(
"Client has not been initialized yet. Checking if provisioning config exists"
);
let data_dir = app_handle
.path()
.app_data_dir()
.unwrap_or_else(|_| "UNDEFINED DATA DIRECTORY".into());
match try_get_provisioning_config(&data_dir) {
Some(config) => {
info!("Provisioning config found in {data_dir:?}.");
debug!("Provisioning config: {config:?}");
return Some(config);
}
None => {
debug!("Provisioning config not found in {data_dir:?}. Proceeding with normal startup.")
}
}
}
}
Err(err) => {
error!("Failed to verify if the client has already been initialized: {err}")
}
}

None
}
8 changes: 4 additions & 4 deletions src-tauri/src/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use tauri_plugin_notification::NotificationExt;

use crate::{tray::show_main_window, ConnectionType};

// Match src/page/client/types.ts.
// Match src/pages/client/types.ts.
#[non_exhaustive]
pub enum EventKey {
ConnectionChanged,
Expand Down Expand Up @@ -95,9 +95,9 @@ impl DeadConnReconnected {
}

#[derive(Clone, Serialize)]
struct AddInstancePayload<'a> {
token: &'a str,
url: &'a str,
pub struct AddInstancePayload<'a> {
pub token: &'a str,
pub url: &'a str,
}

/// Handle deep-link URLs.
Expand Down
46 changes: 46 additions & 0 deletions src/components/AutoProvisioningManager.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { useQuery } from '@tanstack/react-query';
import { error } from '@tauri-apps/plugin-log';
import { type PropsWithChildren, useEffect } from 'react';
import { clientApi } from '../pages/client/clientAPI/clientApi';
import type { ProvisioningConfig } from '../pages/client/clientAPI/types';
import { clientQueryKeys } from '../pages/client/query';
import { useToaster } from '../shared/defguard-ui/hooks/toasts/useToaster';
import useAddInstance from '../shared/hooks/useAddInstance';

const { getProvisioningConfig } = clientApi;

export default function AutoProvisioningManager({ children }: PropsWithChildren) {
const toaster = useToaster();
const { handleAddInstance } = useAddInstance();
const { data: provisioningConfig } = useQuery({
queryFn: getProvisioningConfig,
queryKey: [clientQueryKeys.getProvisioningConfig],
refetchOnMount: false,
refetchOnWindowFocus: false,
});

const handleProvisioning = async (config: ProvisioningConfig) => {
try {
await handleAddInstance({
url: config.enrollment_url,
token: config.enrollment_token,
});
} catch (e) {
error(
`Failed to handle automatic client provisioning with ${JSON.stringify(config)}.\n Error: ${JSON.stringify(e)}`,
);
toaster.error(
'Automatic client provisioning failed, please contact your administrator.',
);
}
};

// biome-ignore lint/correctness/useExhaustiveDependencies: migration, checkMeLater
useEffect(() => {
if (provisioningConfig) {
handleProvisioning(provisioningConfig);
}
}, [provisioningConfig]);

return <>{children}</>;
}
21 changes: 12 additions & 9 deletions src/pages/client/ClientPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { listen } from '@tauri-apps/api/event';
import { useEffect } from 'react';
import { Outlet, useLocation, useNavigate } from 'react-router-dom';
import { shallow } from 'zustand/shallow';

import AutoProvisioningManager from '../../components/AutoProvisioningManager';
import { useI18nContext } from '../../i18n/i18n-react';
import { DeepLinkProvider } from '../../shared/components/providers/DeepLinkProvider';
import { useToaster } from '../../shared/defguard-ui/hooks/toasts/useToaster';
Expand All @@ -20,10 +20,10 @@ import { useClientStore } from './hooks/useClientStore';
import { useMFAModal } from './pages/ClientInstancePage/components/LocationsList/modals/MFAModal/useMFAModal';
import { clientQueryKeys } from './query';
import {
ClientConnectionType,
type CommonWireguardFields,
type DeadConDroppedPayload,
TauriEventKey,
ClientConnectionType,
} from './types';

const { getInstances, getTunnels, getAppConfig } = clientApi;
Expand Down Expand Up @@ -235,12 +235,15 @@ export const ClientPage = () => {
}, [navigate, listChecked, instances, tunnels]);

return (
<DeepLinkProvider>
<MfaModalProvider>
<Outlet />
</MfaModalProvider>
<DeadConDroppedModal />
<ClientSideBar />
</DeepLinkProvider>
<AutoProvisioningManager>
<DeepLinkProvider>
<MfaModalProvider>
<Outlet />
</MfaModalProvider>
<DeadConDroppedModal />
<ClientSideBar />
<AutoProvisioningManager />
</DeepLinkProvider>
</AutoProvisioningManager>
);
};
5 changes: 5 additions & 0 deletions src/pages/client/clientAPI/clientApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import type {
GetLocationsRequest,
LocationDetails,
LocationDetailsRequest,
ProvisioningConfig,
RoutingRequest,
SaveConfigRequest,
SaveDeviceConfigResponse,
Expand Down Expand Up @@ -129,6 +130,9 @@ const stopGlobalLogWatcher = async (): Promise<void> =>
const getAppConfig = async (): Promise<AppConfig> =>
invokeWrapper('command_get_app_config');

const getProvisioningConfig = async (): Promise<ProvisioningConfig | null> =>
invokeWrapper('get_provisioning_config');

const setAppConfig = async (
appConfig: Partial<AppConfig>,
emitEvent: boolean,
Expand Down Expand Up @@ -164,4 +168,5 @@ export const clientApi = {
getLatestAppVersion,
startGlobalLogWatcher,
stopGlobalLogWatcher,
getProvisioningConfig,
};
Loading