diff --git a/biome.json b/biome.json index c2ca6e40..a0a83836 100644 --- a/biome.json +++ b/biome.json @@ -27,9 +27,7 @@ "bracketSpacing": true, "expand": "auto", "useEditorconfig": true, - "includes": [ - "./src/**" - ] + "includes": ["./src/**"] }, "linter": { "enabled": true, @@ -70,9 +68,7 @@ "noArrayIndexKey": "off" } }, - "includes": [ - "src/**" - ] + "includes": ["src/**"] }, "javascript": { "formatter": { @@ -94,9 +90,7 @@ }, "overrides": [ { - "includes": [ - "**/*.js" - ] + "includes": ["**/*.js"] } ], "assist": { diff --git a/src-tauri/src/appstate.rs b/src-tauri/src/appstate.rs index f8fa0d96..578ef4f0 100644 --- a/src-tauri/src/appstate.rs +++ b/src-tauri/src/appstate.rs @@ -10,6 +10,7 @@ use crate::{ models::{connection::ActiveConnection, Id}, DB_POOL, }, + enterprise::provisioning::ProvisioningConfig, utils::stats_handler, ConnectionType, }; @@ -18,15 +19,17 @@ pub struct AppState { pub log_watchers: Mutex>, pub app_config: Mutex, stat_threads: Mutex>>, // location ID is the key + pub provisioning_config: Mutex>, } impl AppState { #[must_use] - pub fn new(config: AppConfig) -> Self { + pub fn new(config: AppConfig, provisioning_config: Option) -> Self { AppState { log_watchers: Mutex::new(HashMap::new()), app_config: Mutex::new(config), stat_threads: Mutex::new(HashMap::new()), + provisioning_config: Mutex::new(provisioning_config), } } diff --git a/src-tauri/src/bin/defguard-client.rs b/src-tauri/src/bin/defguard-client.rs index 9d1fa467..609cc4dc 100644 --- a/src-tauri/src/bin/defguard-client.rs +++ b/src-tauri/src/bin/defguard-client.rs @@ -18,6 +18,7 @@ use defguard_client::{ models::{location_stats::LocationStats, tunnel::TunnelStats}, DB_POOL, }, + enterprise::provisioning::handle_client_initialization, periodic::run_periodic_tasks, service, tray::{configure_tray_icon, setup_tray, show_main_window}, @@ -137,7 +138,8 @@ fn main() { start_global_logwatcher, stop_global_logwatcher, command_get_app_config, - command_set_app_config + command_set_app_config, + get_provisioning_config ]) .on_window_event(|window, event| { if let WindowEvent::CloseRequested { api, .. } = event { @@ -244,7 +246,12 @@ fn main() { .build(), )?; - let state = AppState::new(config); + // Check if client needs to be initialized + // and try to load provisioning config if necessary + let provisioning_config = + tauri::async_runtime::block_on(handle_client_initialization(app_handle)); + + let state = AppState::new(config, provisioning_config); app.manage(state); info!("App setup completed, log level: {log_level}"); @@ -271,9 +278,11 @@ fn main() { // Ensure directories have appropriate permissions (dg25-28). #[cfg(unix)] - set_perms(&data_dir); - #[cfg(unix)] - set_perms(&log_dir); + { + set_perms(&data_dir); + set_perms(&log_dir); + } + info!( "Application data (database file) will be stored in: {data_dir:?} and application \ logs in: {log_dir:?}. Logs of the background Defguard service responsible for \ @@ -293,6 +302,15 @@ fn main() { app_handle_clone.exit(0); }); debug!("Ctrl-C handler has been set up successfully"); + + let app_handle_clone = app_handle.clone(); + tauri::async_runtime::spawn(async move { + // Wait for frontend to be ready + tokio::time::sleep(std::time::Duration::from_secs(15)).await; + + // Handle client initialization if necessary + handle_client_initialization(&app_handle_clone).await; + }); } RunEvent::ExitRequested { code, api, .. } => { debug!("Received exit request"); diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index c638e105..2c440913 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -29,7 +29,7 @@ use crate::{ }, DB_POOL, }, - enterprise::periodic::config::poll_instance, + enterprise::{periodic::config::poll_instance, provisioning::ProvisioningConfig}, error::Error, events::EventKey, log_watcher::{ @@ -1120,3 +1120,20 @@ pub async fn command_set_app_config( } Ok(res) } + +#[tauri::command] +pub fn get_provisioning_config( + app_state: State<'_, AppState>, +) -> Result, Error> { + debug!("Running command get_provisioning_config."); + let res = app_state + .provisioning_config + .lock() + .map_err(|_err| { + error!("Failed to acquire lock on client provisioning config"); + Error::StateLockFail + })? + .clone(); + trace!("Returning config: {res:?}"); + Ok(res) +} diff --git a/src-tauri/src/database/models/instance.rs b/src-tauri/src/database/models/instance.rs index a9aac5b6..bcd74499 100644 --- a/src-tauri/src/database/models/instance.rs +++ b/src-tauri/src/database/models/instance.rs @@ -67,7 +67,7 @@ impl Instance { Ok(()) } - pub(crate) async fn all<'e, E>(executor: E) -> Result, sqlx::Error> + pub async fn all<'e, E>(executor: E) -> Result, sqlx::Error> where E: SqliteExecutor<'e>, { diff --git a/src-tauri/src/enterprise/mod.rs b/src-tauri/src/enterprise/mod.rs index f9a20825..98f9e5ef 100644 --- a/src-tauri/src/enterprise/mod.rs +++ b/src-tauri/src/enterprise/mod.rs @@ -1,2 +1,3 @@ pub mod models; pub mod periodic; +pub mod provisioning; diff --git a/src-tauri/src/enterprise/provisioning/mod.rs b/src-tauri/src/enterprise/provisioning/mod.rs new file mode 100644 index 00000000..41cc7caf --- /dev/null +++ b/src-tauri/src/enterprise/provisioning/mod.rs @@ -0,0 +1,76 @@ +use std::{fs::OpenOptions, path::Path}; + +use serde::{Deserialize, Serialize}; +use tauri::{AppHandle, Manager}; + +use crate::database::{models::instance::Instance, DB_POOL}; + +const CONFIG_FILE_NAME: &str = "enrollment.json"; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ProvisioningConfig { + pub enrollment_url: String, + pub enrollment_token: String, +} + +impl ProvisioningConfig { + /// Load configuration from a file at `path`. + fn load(path: &Path) -> Option { + let file = match OpenOptions::new().read(true).open(path) { + Ok(file) => file, + Err(err) => { + warn!("Failed to open provisioning configuration file at {path:?}. Error details: {err}"); + return None; + } + }; + match serde_json::from_reader::<_, Self>(file) { + Ok(config) => Some(config), + Err(err) => { + warn!("Failed to parse provisioning configuration file at {path:?}. Error details: {err}"); + None + } + } + } +} + +pub fn try_get_provisioning_config(app_data_dir: &Path) -> Option { + debug!("Trying to find provisioning config in {app_data_dir:?}"); + + let config_file_path = app_data_dir.join(CONFIG_FILE_NAME); + ProvisioningConfig::load(&config_file_path) +} + +/// Checks if the client has already been initialized +/// and tries to load provisioning config from file if necessary +pub async fn handle_client_initialization(app_handle: &AppHandle) -> Option { + // check if client has already been initialized + // we assume that if any instances exist the client has been initialized + match Instance::all(&*DB_POOL).await { + Ok(instances) => { + if instances.is_empty() { + debug!( + "Client has not been initialized yet. Checking if provisioning config exists" + ); + let data_dir = app_handle + .path() + .app_data_dir() + .unwrap_or_else(|_| "UNDEFINED DATA DIRECTORY".into()); + match try_get_provisioning_config(&data_dir) { + Some(config) => { + info!("Provisioning config found in {data_dir:?}."); + debug!("Provisioning config: {config:?}"); + return Some(config); + } + None => { + debug!("Provisioning config not found in {data_dir:?}. Proceeding with normal startup.") + } + } + } + } + Err(err) => { + error!("Failed to verify if the client has already been initialized: {err}") + } + } + + None +} diff --git a/src-tauri/src/events.rs b/src-tauri/src/events.rs index 68be593b..e7267d08 100644 --- a/src-tauri/src/events.rs +++ b/src-tauri/src/events.rs @@ -4,7 +4,7 @@ use tauri_plugin_notification::NotificationExt; use crate::{tray::show_main_window, ConnectionType}; -// Match src/page/client/types.ts. +// Match src/pages/client/types.ts. #[non_exhaustive] pub enum EventKey { ConnectionChanged, @@ -95,9 +95,9 @@ impl DeadConnReconnected { } #[derive(Clone, Serialize)] -struct AddInstancePayload<'a> { - token: &'a str, - url: &'a str, +pub struct AddInstancePayload<'a> { + pub token: &'a str, + pub url: &'a str, } /// Handle deep-link URLs. diff --git a/src/components/AutoProvisioningManager.tsx b/src/components/AutoProvisioningManager.tsx new file mode 100644 index 00000000..536418d3 --- /dev/null +++ b/src/components/AutoProvisioningManager.tsx @@ -0,0 +1,46 @@ +import { useQuery } from '@tanstack/react-query'; +import { error } from '@tauri-apps/plugin-log'; +import { type PropsWithChildren, useEffect } from 'react'; +import { clientApi } from '../pages/client/clientAPI/clientApi'; +import type { ProvisioningConfig } from '../pages/client/clientAPI/types'; +import { clientQueryKeys } from '../pages/client/query'; +import { useToaster } from '../shared/defguard-ui/hooks/toasts/useToaster'; +import useAddInstance from '../shared/hooks/useAddInstance'; + +const { getProvisioningConfig } = clientApi; + +export default function AutoProvisioningManager({ children }: PropsWithChildren) { + const toaster = useToaster(); + const { handleAddInstance } = useAddInstance(); + const { data: provisioningConfig } = useQuery({ + queryFn: getProvisioningConfig, + queryKey: [clientQueryKeys.getProvisioningConfig], + refetchOnMount: false, + refetchOnWindowFocus: false, + }); + + const handleProvisioning = async (config: ProvisioningConfig) => { + try { + await handleAddInstance({ + url: config.enrollment_url, + token: config.enrollment_token, + }); + } catch (e) { + error( + `Failed to handle automatic client provisioning with ${JSON.stringify(config)}.\n Error: ${JSON.stringify(e)}`, + ); + toaster.error( + 'Automatic client provisioning failed, please contact your administrator.', + ); + } + }; + + // biome-ignore lint/correctness/useExhaustiveDependencies: migration, checkMeLater + useEffect(() => { + if (provisioningConfig) { + handleProvisioning(provisioningConfig); + } + }, [provisioningConfig]); + + return <>{children}; +} diff --git a/src/pages/client/ClientPage.tsx b/src/pages/client/ClientPage.tsx index 74b04bd4..9358bf44 100644 --- a/src/pages/client/ClientPage.tsx +++ b/src/pages/client/ClientPage.tsx @@ -5,7 +5,7 @@ import { listen } from '@tauri-apps/api/event'; import { useEffect } from 'react'; import { Outlet, useLocation, useNavigate } from 'react-router-dom'; import { shallow } from 'zustand/shallow'; - +import AutoProvisioningManager from '../../components/AutoProvisioningManager'; import { useI18nContext } from '../../i18n/i18n-react'; import { DeepLinkProvider } from '../../shared/components/providers/DeepLinkProvider'; import { useToaster } from '../../shared/defguard-ui/hooks/toasts/useToaster'; @@ -20,10 +20,10 @@ import { useClientStore } from './hooks/useClientStore'; import { useMFAModal } from './pages/ClientInstancePage/components/LocationsList/modals/MFAModal/useMFAModal'; import { clientQueryKeys } from './query'; import { + ClientConnectionType, type CommonWireguardFields, type DeadConDroppedPayload, TauriEventKey, - ClientConnectionType, } from './types'; const { getInstances, getTunnels, getAppConfig } = clientApi; @@ -235,12 +235,15 @@ export const ClientPage = () => { }, [navigate, listChecked, instances, tunnels]); return ( - - - - - - - + + + + + + + + + + ); }; diff --git a/src/pages/client/clientAPI/clientApi.ts b/src/pages/client/clientAPI/clientApi.ts index b689e1aa..4ca1506a 100644 --- a/src/pages/client/clientAPI/clientApi.ts +++ b/src/pages/client/clientAPI/clientApi.ts @@ -17,6 +17,7 @@ import type { GetLocationsRequest, LocationDetails, LocationDetailsRequest, + ProvisioningConfig, RoutingRequest, SaveConfigRequest, SaveDeviceConfigResponse, @@ -129,6 +130,9 @@ const stopGlobalLogWatcher = async (): Promise => const getAppConfig = async (): Promise => invokeWrapper('command_get_app_config'); +const getProvisioningConfig = async (): Promise => + invokeWrapper('get_provisioning_config'); + const setAppConfig = async ( appConfig: Partial, emitEvent: boolean, @@ -164,4 +168,5 @@ export const clientApi = { getLatestAppVersion, startGlobalLogWatcher, stopGlobalLogWatcher, + getProvisioningConfig, }; diff --git a/src/pages/client/clientAPI/types.ts b/src/pages/client/clientAPI/types.ts index 39d3ad1b..95ed9f93 100644 --- a/src/pages/client/clientAPI/types.ts +++ b/src/pages/client/clientAPI/types.ts @@ -1,6 +1,6 @@ import type { ThemeKey } from '../../../shared/defguard-ui/hooks/theme/types'; import type { CreateDeviceResponse } from '../../../shared/hooks/api/types'; -import type { DefguardInstance, DefguardLocation, ClientConnectionType } from '../types'; +import type { ClientConnectionType, DefguardInstance, DefguardLocation } from '../types'; export type GetLocationsRequest = { instanceId: number; @@ -82,6 +82,11 @@ export type AppConfig = { peer_alive_period: number; }; +export type ProvisioningConfig = { + enrollment_token: string; + enrollment_url: string; +}; + export type LocationDetails = { location_id: number; name: string; @@ -142,4 +147,5 @@ export type TauriCommandKey = | 'start_global_logwatcher' | 'stop_global_logwatcher' | 'command_get_app_config' - | 'command_set_app_config'; + | 'command_set_app_config' + | 'get_provisioning_config'; diff --git a/src/pages/client/components/ClientSideBar/ClientSideBar.tsx b/src/pages/client/components/ClientSideBar/ClientSideBar.tsx index 36a5efb5..7e8a9c94 100644 --- a/src/pages/client/components/ClientSideBar/ClientSideBar.tsx +++ b/src/pages/client/components/ClientSideBar/ClientSideBar.tsx @@ -60,8 +60,9 @@ export const ClientSideBar = () => { {instances.map((instance) => ( { }); useEffect(() => { - const isDefguardInstance = - selectedInstanceType === ClientConnectionType.LOCATION; + const isDefguardInstance = selectedInstanceType === ClientConnectionType.LOCATION; const isTunnelInstance = selectedInstanceType === ClientConnectionType.TUNNEL; if (isDefguardInstance && !selectedInstance) { diff --git a/src/pages/client/pages/ClientInstancePage/components/LocationsList/LocationsList.tsx b/src/pages/client/pages/ClientInstancePage/components/LocationsList/LocationsList.tsx index c708b45e..3c1546ab 100644 --- a/src/pages/client/pages/ClientInstancePage/components/LocationsList/LocationsList.tsx +++ b/src/pages/client/pages/ClientInstancePage/components/LocationsList/LocationsList.tsx @@ -9,9 +9,9 @@ import { useToaster } from '../../../../../../shared/defguard-ui/hooks/toasts/us import { routes } from '../../../../../../shared/routes'; import { useClientStore } from '../../../../hooks/useClientStore'; import { + ClientConnectionType, type CommonWireguardFields, type DefguardInstance, - ClientConnectionType, } from '../../../../types'; import { LocationsDetailView } from './components/LocationsDetailView/LocationsDetailView'; import { LocationsGridView } from './components/LocationsGridView/LocationsGridView'; diff --git a/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/LocationsDetailView.tsx b/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/LocationsDetailView.tsx index 54dfc3fe..72fb46f6 100644 --- a/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/LocationsDetailView.tsx +++ b/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/LocationsDetailView.tsx @@ -13,9 +13,9 @@ import { clientApi } from '../../../../../../clientAPI/clientApi'; import { useClientStore } from '../../../../../../hooks/useClientStore'; import { clientQueryKeys } from '../../../../../../query'; import { + ClientConnectionType, type CommonWireguardFields, type DefguardInstance, - ClientConnectionType, } from '../../../../../../types'; import { LocationConnectionHistory } from './components/LocationConnectionHistory/LocationConnectionHistory'; import { LocationDetailCard } from './components/LocationDetailCard/LocationDetailCard'; diff --git a/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationConnectionHistory/LocationConnectionHistory.tsx b/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationConnectionHistory/LocationConnectionHistory.tsx index 3e547527..ca76cba4 100644 --- a/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationConnectionHistory/LocationConnectionHistory.tsx +++ b/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationConnectionHistory/LocationConnectionHistory.tsx @@ -7,8 +7,8 @@ import { Card } from '../../../../../../../../../../shared/defguard-ui/component import { clientApi } from '../../../../../../../../clientAPI/clientApi'; import { clientQueryKeys } from '../../../../../../../../query'; import type { - DefguardLocation, ClientConnectionType, + DefguardLocation, } from '../../../../../../../../types'; import { LocationCardNeverConnected } from '../../../LocationCardNeverConnected/LocationCardNeverConnected'; import { LocationHistoryTable } from './LocationHistoryTable/LocationHistoryTable'; diff --git a/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationDetails/LocationDetails.tsx b/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationDetails/LocationDetails.tsx index 94c33000..4fa78d60 100644 --- a/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationDetails/LocationDetails.tsx +++ b/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationDetails/LocationDetails.tsx @@ -11,8 +11,8 @@ import { Label } from '../../../../../../../../../../shared/defguard-ui/componen import { clientApi } from '../../../../../../../../clientAPI/clientApi'; import { clientQueryKeys } from '../../../../../../../../query'; import type { - DefguardLocation, ClientConnectionType, + DefguardLocation, } from '../../../../../../../../types'; import { LocationLogs } from '../LocationLogs/LocationLogs'; diff --git a/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationLogs/LocationLogs.tsx b/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationLogs/LocationLogs.tsx index 94fd2f88..56cdfce1 100644 --- a/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationLogs/LocationLogs.tsx +++ b/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsDetailView/components/LocationLogs/LocationLogs.tsx @@ -13,8 +13,8 @@ import { Card } from '../../../../../../../../../../shared/defguard-ui/component import type { LogItem, LogLevel } from '../../../../../../../../clientAPI/types'; import { useClientStore } from '../../../../../../../../hooks/useClientStore'; import type { - DefguardLocation, ClientConnectionType, + DefguardLocation, } from '../../../../../../../../types'; import { LocationLogsSelect } from './LocationLogsSelect'; @@ -66,7 +66,7 @@ export const LocationLogs = ({ locationId, connectionType }: Props) => { const element = createLogLineElement(messageString); const scrollAfterAppend = logsContainerElement.current.scrollHeight - - logsContainerElement.current.scrollTop === + logsContainerElement.current.scrollTop === logsContainerElement.current.clientHeight; logsContainerElement.current.appendChild(element); // auto scroll to bottom if user didn't scroll up diff --git a/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsGridView/LocationsGridView.tsx b/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsGridView/LocationsGridView.tsx index 926e94dc..2503380b 100644 --- a/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsGridView/LocationsGridView.tsx +++ b/src/pages/client/pages/ClientInstancePage/components/LocationsList/components/LocationsGridView/LocationsGridView.tsx @@ -13,9 +13,9 @@ import { clientApi } from '../../../../../../clientAPI/clientApi'; import { useClientStore } from '../../../../../../hooks/useClientStore'; import { clientQueryKeys } from '../../../../../../query'; import type { + ClientConnectionType, CommonWireguardFields, DefguardInstance, - ClientConnectionType, } from '../../../../../../types'; import { LocationUsageChart } from '../../../LocationUsageChart/LocationUsageChart'; import { LocationUsageChartType } from '../../../LocationUsageChart/types'; diff --git a/src/pages/client/query.ts b/src/pages/client/query.ts index 93d1c945..a19ec385 100644 --- a/src/pages/client/query.ts +++ b/src/pages/client/query.ts @@ -8,4 +8,5 @@ export const clientQueryKeys = { getLocationDetails: 'GET_LOCATION_DETAILS', getTunnels: 'GET_TUNNELS', getApplicationConfig: 'GET_APPLICATION_CONFIG', + getProvisioningConfig: 'GET_PROVISIONING_CONFIG', }; diff --git a/src/pages/client/types.ts b/src/pages/client/types.ts index 97136f0c..26f32900 100644 --- a/src/pages/client/types.ts +++ b/src/pages/client/types.ts @@ -92,6 +92,11 @@ export type DeadConDroppedPayload = { peer_alive_period: number; }; +export type AddInstancePayload = { + token: string; + url: string; +}; + export enum TauriEventKey { CONNECTION_CHANGED = 'connection-changed', INSTANCE_UPDATE = 'instance-update', diff --git a/src/shared/components/providers/DeepLinkProvider.tsx b/src/shared/components/providers/DeepLinkProvider.tsx index 3f8792b5..26658c13 100644 --- a/src/shared/components/providers/DeepLinkProvider.tsx +++ b/src/shared/components/providers/DeepLinkProvider.tsx @@ -1,27 +1,16 @@ -import { invoke } from '@tauri-apps/api/core'; import { getCurrent, onOpenUrl } from '@tauri-apps/plugin-deep-link'; -import { debug, error } from '@tauri-apps/plugin-log'; -import dayjs from 'dayjs'; +import { error } from '@tauri-apps/plugin-log'; import { type PropsWithChildren, useCallback, useEffect, useRef } from 'react'; -import { useNavigate } from 'react-router-dom'; import z, { string } from 'zod'; -import { clientApi } from '../../../pages/client/clientAPI/clientApi'; -import { useClientStore } from '../../../pages/client/hooks/useClientStore'; -import { AddInstanceFormStep } from '../../../pages/client/pages/ClientAddInstancePage/hooks/types'; -import { useAddInstanceStore } from '../../../pages/client/pages/ClientAddInstancePage/hooks/useAddInstanceStore'; -import { ClientConnectionType } from '../../../pages/client/types'; -import { useEnrollmentStore } from '../../../pages/enrollment/hooks/store/useEnrollmentStore'; -import { useEnrollmentApi } from '../../../pages/enrollment/hooks/useEnrollmentApi'; -import type { EnrollmentStartResponse } from '../../hooks/api/types'; -import { routes } from '../../routes'; +import useAddInstance from '../../hooks/useAddInstance'; enum DeepLink { AddInstance = 'addinstance', } -const linkStorageKey = 'lastSuccessfullyHandledDeepLink'; +export const linkStorageKey = 'lastSuccessfullyHandledDeepLink'; -const storeLink = (value: string) => { +export const storeLink = (value: string) => { sessionStorage.setItem(linkStorageKey, value); }; @@ -63,123 +52,25 @@ const linkIntoPayload = (link: URL | null): LinkPayload | null => { return null; }; -const prepareProxyUrl = (value: string) => { - let proxyUrl = value; - if (proxyUrl[proxyUrl.length - 1] === '/') { - proxyUrl = proxyUrl.slice(0, -1); - } - proxyUrl = `${proxyUrl}/api/v1`; - return proxyUrl; -}; - export const DeepLinkProvider = ({ children }: PropsWithChildren) => { const mounted = useRef(false); - const { - enrollment: { start, networkInfo }, - } = useEnrollmentApi(); - - const setEnrollmentState = useEnrollmentStore((s) => s.init); - const setAddInstanceState = useAddInstanceStore((s) => s.setState); - const setClientState = useClientStore((s) => s.setState); - - const navigate = useNavigate(); - - // biome-ignore lint/correctness/useExhaustiveDependencies: should init once - const handleValidLink = useCallback(async (payload: LinkPayload, rawLink?: string) => { - const { data, link } = payload; - switch (link) { - case DeepLink.AddInstance: - await start({ - token: data.token, - proxyUrl: prepareProxyUrl(data.url), - }).then(async (response) => { - if (response.ok) { - const authCookie = response.headers - .getSetCookie() - .find((cookie) => cookie.startsWith('defguard_proxy=')); - if (authCookie === undefined) { - error('Failed to open deep link, auth cookie missing from proxy response.'); - return; - } - const respData = (await response.json()) as EnrollmentStartResponse; - const instances = await clientApi.getInstances(); - const proxy_api_url = prepareProxyUrl( - respData.instance.proxy_url ?? respData.instance.url, - ); - const existingInstance = instances.find( - (instance) => instance.uuid === respData.instance.id, - ); - if (existingInstance) { - // update existing instance instead - const networkInfoResp = await networkInfo( - { - pubkey: existingInstance.pubkey, - }, - proxy_api_url, - authCookie, - ); - await invoke('update_instance', { - instanceId: existingInstance.id, - response: networkInfoResp, - }); - setClientState({ - selectedInstance: { - type: ClientConnectionType.LOCATION, - id: existingInstance.id, - }, - }); - if (rawLink) { - storeLink(rawLink); - } - debug(`Updated ${existingInstance.name} via deep link`); - navigate(routes.client.base, { replace: true }); - return; - } - if (!respData.user.enrolled) { - // user needs full enrollment - const sessionEnd = dayjs - .unix(respData.deadline_timestamp) - .utc() - .local() - .format(); - const sessionStart = dayjs().local().format(); - // set enrollment - setEnrollmentState({ - enrollmentSettings: respData.settings, - proxy_url: proxy_api_url, - userInfo: respData.user, - adminInfo: respData.admin, - endContent: respData.final_page_content, - cookie: authCookie, - sessionEnd, - sessionStart, - }); - navigate('/enrollment', { replace: true }); - } else { - // only needs to register this device - setAddInstanceState({ - step: AddInstanceFormStep.DEVICE, - response: { - cookie: authCookie, - device_names: respData.user.device_names, - url: proxy_api_url, - }, - }); - navigate('/client/add-instance', { replace: true }); - } - } else { - error( - `Add instance from deep link failed! Proxy enrollment start request failed! status: ${response.status}`, - ); - } - }); - break; - } - if (rawLink) { - storeLink(rawLink); - } - }, []); + const { handleAddInstance } = useAddInstance(); + + const handleValidLink = useCallback( + async (payload: LinkPayload, rawLink?: string) => { + const { data, link } = payload; + switch (link) { + case DeepLink.AddInstance: + await handleAddInstance(data, rawLink); + break; + } + if (rawLink) { + storeLink(rawLink); + } + }, + [handleAddInstance], + ); // biome-ignore lint/correctness/useExhaustiveDependencies: only on mount useEffect(() => { diff --git a/src/shared/hooks/useAddInstance.ts b/src/shared/hooks/useAddInstance.ts new file mode 100644 index 00000000..ea3fde7a --- /dev/null +++ b/src/shared/hooks/useAddInstance.ts @@ -0,0 +1,146 @@ +/** + * Hook which handles adding an instance in the background and triggering enrollment process (if necessary) + * in automated scenarios e.g. deep-link, client provisioning etc. + */ + +import { invoke } from '@tauri-apps/api/core'; +import { debug, error } from '@tauri-apps/plugin-log'; +import dayjs from 'dayjs'; +import { useCallback, useState } from 'react'; +import { useNavigate } from 'react-router-dom'; +import { clientApi } from '../../pages/client/clientAPI/clientApi'; +import { useClientStore } from '../../pages/client/hooks/useClientStore'; +import { AddInstanceFormStep } from '../../pages/client/pages/ClientAddInstancePage/hooks/types'; +import { useAddInstanceStore } from '../../pages/client/pages/ClientAddInstancePage/hooks/useAddInstanceStore'; +import { type AddInstancePayload, ClientConnectionType } from '../../pages/client/types'; +import { useEnrollmentStore } from '../../pages/enrollment/hooks/store/useEnrollmentStore'; +import { useEnrollmentApi } from '../../pages/enrollment/hooks/useEnrollmentApi'; +import { storeLink } from '../components/providers/DeepLinkProvider'; +import type { EnrollmentStartResponse } from '../hooks/api/types'; +import { routes } from '../routes'; + +const prepareProxyUrl = (value: string) => { + let proxyUrl = value; + if (proxyUrl[proxyUrl.length - 1] === '/') { + proxyUrl = proxyUrl.slice(0, -1); + } + proxyUrl = `${proxyUrl}/api/v1`; + return proxyUrl; +}; + +export default function useAddInstance() { + const [loading, setLoading] = useState(false); + + const setEnrollmentState = useEnrollmentStore((s) => s.init); + const setAddInstanceState = useAddInstanceStore((s) => s.setState); + const setClientState = useClientStore((s) => s.setState); + + const navigate = useNavigate(); + + const { + enrollment: { start, networkInfo }, + } = useEnrollmentApi(); + + const handleAddInstance = useCallback( + async (payload: AddInstancePayload, rawLink?: string) => { + setLoading(true); + + await start({ + token: payload.token, + proxyUrl: prepareProxyUrl(payload.url), + }).then(async (response) => { + if (response.ok) { + const authCookie = response.headers + .getSetCookie() + .find((cookie) => cookie.startsWith('defguard_proxy=')); + if (authCookie === undefined) { + error( + 'Failed to automatically add new instance, auth cookie missing from proxy response.', + ); + return; + } + const respData = (await response.json()) as EnrollmentStartResponse; + const instances = await clientApi.getInstances(); + const proxy_api_url = prepareProxyUrl( + respData.instance.proxy_url ?? respData.instance.url, + ); + const existingInstance = instances.find( + (instance) => instance.uuid === respData.instance.id, + ); + if (existingInstance) { + // update existing instance instead + const networkInfoResp = await networkInfo( + { + pubkey: existingInstance.pubkey, + }, + proxy_api_url, + authCookie, + ); + await invoke('update_instance', { + instanceId: existingInstance.id, + response: networkInfoResp, + }); + setClientState({ + selectedInstance: { + type: ClientConnectionType.LOCATION, + id: existingInstance.id, + }, + }); + if (rawLink) { + storeLink(rawLink); + } + debug(`Automatically updated ${existingInstance.name}`); + navigate(routes.client.base, { replace: true }); + return; + } + if (!respData.user.enrolled) { + // user needs full enrollment + const sessionEnd = dayjs + .unix(respData.deadline_timestamp) + .utc() + .local() + .format(); + const sessionStart = dayjs().local().format(); + // set enrollment + setEnrollmentState({ + enrollmentSettings: respData.settings, + proxy_url: proxy_api_url, + userInfo: respData.user, + adminInfo: respData.admin, + endContent: respData.final_page_content, + cookie: authCookie, + sessionEnd, + sessionStart, + }); + navigate('/enrollment', { replace: true }); + } else { + // only needs to register this device + setAddInstanceState({ + step: AddInstanceFormStep.DEVICE, + response: { + cookie: authCookie, + device_names: respData.user.device_names, + url: proxy_api_url, + }, + }); + navigate('/client/add-instance', { replace: true }); + } + } else { + error( + `Adding instance automatically failed. Proxy enrollment start request failed with status: ${response.status}`, + ); + } + }); + }, + [ + setClientState, + networkInfo, + start, + setEnrollmentState, + setAddInstanceState, + navigate, + ], + ); + + return { handleAddInstance, loading, error }; +}